AIエンジニアの探求

計算論的神経科学で博士号取得後、AIエンジニアとして活動中。LLMの活用や脳とAIの関係などについて記事を書きます。

Biologically plausible backpropagationまとめ②:Feedback Alignment

こんにちは、タイトルの通り「生物学的に妥当な誤差逆伝播法」の分野で、具体的に提案されている手法を紹介していきます。

ちなみに概要的な話はここに書いてあります。

tripdancer0916.hatenablog.com


今回紹介するのはFeedback Alignmentという手法で、2016年にnature communicationsから出版されています(初出はarxivで2014年)。

Random synaptic feedback weights support error backpropagation for deep learning | Nature Communications

[1411.0247] Random feedback weights support learning in deep neural networks


Feedback Alignmentは「誤差逆伝播法の実行のためには順方向のシナプスと逆方向のシナプスが対称である必要があるが、現実の神経系と照らし合わせるとそれは不自然な仮定である(※)」という問題に対応するための手法です。

(※)この問題はweight transport problemと呼ばれ、あのフランシス・クリック言及しています

f:id:tripdancer0916:20181025215749p:plain
順方向に情報を伝えるシナプスと逆方向に情報を伝えるシナプスは別物で、それが等しいと仮定するのは不自然。

Back propagationとFeedback Alignmentの計算

 
Feedback Alignmentの発想は極めてシンプルで、「対称なシナプスの代わりに、誤差を伝えるためのランダム行列をあらかじめ用意しておいてそれで勾配を計算しても案外学習はうまくいくんじゃないか?」というものです。これでほぼ尽きています。

f:id:tripdancer0916:20181025215743p:plain
feedback alignmentの概念図(画像は論文中より)


具体的な計算を以下に書いていきます。ここでは簡単のために出力層の活性化関数はsoftmaxとして、タスクはクラス分類とします。
入力をx、出力をy、重み行列をW^i、活性化関数をf(\cdot)として、順伝播計算はi \in \{1,2,...,k\}に対して以下のようになります。

a^i = W^i\cdot h^{i-1}, \\ h^i = f(a^i), \\ y = softmax(h^k)

ここで、h^0 = xとします。
通常の逆伝播計算であれば、targetをdとしたときに誤差関数の各層毎の勾配\nabla_{W^i}Lは次のように計算されます。

\delta_i = (W^i)^T \delta_{i+1} \odot f'(a^i), \\
\delta_k = \frac{1}{N}(y-d), \\
\nabla_{W^i}L = \delta_i(h^{i-1})^T
ただし、Nはここではバッチサイズ、\odotは要素毎の積(Frobenius積)を表します。

これによってW^i
 W^i \leftarrow W^i - \lambda \odot \delta_{BP} W^i \\
= W^i - \lambda \odot \nabla_{W^i}L
と更新していきます。

一方でFeedback Alignmentではdelta_iを計算するときに、(W^i)^Tの代わりにランダム行列B^i(学習中固定しておく)を使います。具体的には、
\delta^{FA}_i = B^i \delta_{i+1} \odot f'(a^i)
とします。

なぜうまく行くのか?

この方法でMNISTを計算したときの結果が以下のようになります。

f:id:tripdancer0916:20181028135552p:plain
MNISTに対しては、feedback alignmentの精度はback propagationで計算した時に匹敵する。

これを見ると分かるとおり、back propagationとほぼ同じ精度を誇ることが確認できます。

ではなぜ、全然違う、ある意味出鱈目とすら言えるやり方で計算してもうまく学習できるのか?

ここからは大雑把な議論になりますが、勾配法の直感的なイメージから考えてBPで計算したときの\delta_{BP}W^iとFAで計算したときの\delta_{FA}W^iとの角度が90度以下ならば、その更新によって誤差関数は小さくなるはずです。

で、当然学習の初期はランダム行列BW^Tの間には何の相関もないので上の条件は満たされないのですが、学習していくとW
W^T \measuredangle B < 90^{\circ}という関係を満たすようになります。(下図参照)

f:id:tripdancer0916:20181028135549p:plain
学習とともに、Wの更新に制約が陰に加わっていく。

これによって学習が進みます。


なお、Dierct Feedback Alignmentという拡張も2016年のNIPSで提案されています。