AIエンジニアの探求

計算論的神経科学で博士号取得後、AIエンジニアとして活動中。LLMの活用や脳とAIの関係などについて記事を書きます。

Biologically plausible backpropagationまとめ②:Feedback Alignment

こんにちは、タイトルの通り「生物学的に妥当な誤差逆伝播法」の分野で、具体的に提案されている手法を紹介していきます。

ちなみに概要的な話はここに書いてあります。

tripdancer0916.hatenablog.com


今回紹介するのはFeedback Alignmentという手法で、2016年にnature communicationsから出版されています(初出はarxivで2014年)。

Random synaptic feedback weights support error backpropagation for deep learning | Nature Communications

[1411.0247] Random feedback weights support learning in deep neural networks


Feedback Alignmentは「誤差逆伝播法の実行のためには順方向のシナプスと逆方向のシナプスが対称である必要があるが、現実の神経系と照らし合わせるとそれは不自然な仮定である(※)」という問題に対応するための手法です。

(※)この問題はweight transport problemと呼ばれ、あのフランシス・クリック言及しています

f:id:tripdancer0916:20181025215749p:plain
順方向に情報を伝えるシナプスと逆方向に情報を伝えるシナプスは別物で、それが等しいと仮定するのは不自然。

Back propagationとFeedback Alignmentの計算

 
Feedback Alignmentの発想は極めてシンプルで、「対称なシナプスの代わりに、誤差を伝えるためのランダム行列をあらかじめ用意しておいてそれで勾配を計算しても案外学習はうまくいくんじゃないか?」というものです。これでほぼ尽きています。

f:id:tripdancer0916:20181025215743p:plain
feedback alignmentの概念図(画像は論文中より)


具体的な計算を以下に書いていきます。ここでは簡単のために出力層の活性化関数はsoftmaxとして、タスクはクラス分類とします。
入力をx、出力をy、重み行列をW^i、活性化関数をf(\cdot)として、順伝播計算はi \in \{1,2,...,k\}に対して以下のようになります。

a^i = W^i\cdot h^{i-1}, \\ h^i = f(a^i), \\ y = softmax(h^k)

ここで、h^0 = xとします。
通常の逆伝播計算であれば、targetをdとしたときに誤差関数の各層毎の勾配\nabla_{W^i}Lは次のように計算されます。

\delta_i = (W^i)^T \delta_{i+1} \odot f'(a^i), \\
\delta_k = \frac{1}{N}(y-d), \\
\nabla_{W^i}L = \delta_i(h^{i-1})^T
ただし、Nはここではバッチサイズ、\odotは要素毎の積(Frobenius積)を表します。

これによってW^i
 W^i \leftarrow W^i - \lambda \odot \delta_{BP} W^i \\
= W^i - \lambda \odot \nabla_{W^i}L
と更新していきます。

一方でFeedback Alignmentではdelta_iを計算するときに、(W^i)^Tの代わりにランダム行列B^i(学習中固定しておく)を使います。具体的には、
\delta^{FA}_i = B^i \delta_{i+1} \odot f'(a^i)
とします。

なぜうまく行くのか?

この方法でMNISTを計算したときの結果が以下のようになります。

f:id:tripdancer0916:20181028135552p:plain
MNISTに対しては、feedback alignmentの精度はback propagationで計算した時に匹敵する。

これを見ると分かるとおり、back propagationとほぼ同じ精度を誇ることが確認できます。

ではなぜ、全然違う、ある意味出鱈目とすら言えるやり方で計算してもうまく学習できるのか?

ここからは大雑把な議論になりますが、勾配法の直感的なイメージから考えてBPで計算したときの\delta_{BP}W^iとFAで計算したときの\delta_{FA}W^iとの角度が90度以下ならば、その更新によって誤差関数は小さくなるはずです。

で、当然学習の初期はランダム行列BW^Tの間には何の相関もないので上の条件は満たされないのですが、学習していくとW
W^T \measuredangle B < 90^{\circ}という関係を満たすようになります。(下図参照)

f:id:tripdancer0916:20181028135549p:plain
学習とともに、Wの更新に制約が陰に加わっていく。

これによって学習が進みます。


なお、Dierct Feedback Alignmentという拡張も2016年のNIPSで提案されています。

Biologically plausible backpropagationまとめ①

こんにちは、機械学習の研究分野の中でニッチながらも少しずつ知名度をあげてきている分野、"biologically plausible backpropagation"(日本語に直訳すると『生物学的に妥当な誤差逆伝播法』)についてこれから何回かに分けて記事を書いていきます。

初回は全体の概要的な話を書いて、次回から具体的に提案されているアルゴリズム、例えばfeedback alignmenttarget propagationなどについて詳しく紹介していきたいと思います。

誤差逆伝播法の生物学的妥当性

誤差逆伝播法は現在ニューラルネットワークを学習させる際に最も広く用いられている方法で、実際それによって深層学習は大成功を収めている。しかしこれを脳内の学習メカニズムの候補として考えると(元々ニューラルネットワークは脳にヒントを得て作られたにも関わらず)いくつかの問題に晒される。そこで誤差逆伝播法に代わるより生物学的妥当性が高い学習アルゴリズムを研究しよう。そこから脳の理解につながる可能性も出てくるし、逆に深層学習そのものの理解も深まるかもしれない。

この研究分野の概要を大雑把に説明すると上のようになります。神経科学の分野と機械学習の分野をリンクさせようというDeepmind的モチベーションで、実際DeepmindからもSynthetic gradientという手法が提案されています。


それでは具体的な「誤差逆伝播法が神経科学的に見てイケてない」点をTowards Biologically Plausible Deep Learningに沿って見ていきたいと思います。

誤差逆伝播法において誤差を出力層から隠れ層に伝える計算はすべて線形な演算だが、生体内では本来線形演算子非線形演算子の組み合わせでできているはずである。
ニューラルネットでの逆向き計算は隠れ層間のシナプス行列と隠れ層の出力の微分係数を次々と掛けていくことで行われているけど、実際の神経系でそんなに長く線形演算だけを繰り返すことなんてできるの?という疑問です。ただこれは私見になりますが、「活性化関数にReLUを使ったときは微分係数が0か1になるのでon/offの切り替えを行う非線形演算が各層で挟まれる」と主張すれば一応ディフェンスにはなる気もしています。

②脳内に存在するフィードバックパスが誤差を伝える役割を果たしているとしたら、このシナプスは書くニューロンの活性化関数の微分係数を正確に知っている必要がある。
→これは直感的にも分かりやすいんじゃないでしょうか。誤差逆伝播法みたいな精密な計算を脳内で本当に実現できるのか?というある意味素朴な疑問です。

③2のようなフィードバックパスは前向きのパスと正確に対称である必要がある。
誤差逆伝播法の計算にはシナプス行列の転置行列を使います。これはコンピュータで行う計算としては全く問題がありませんが、実際のシナプスで実装するためには行きと帰りが正確に同じである必要がでてくるわけで、それはどのように設計されたのかという問題が残ります。

④通常ニューラルネットでは決定論的な連続値を扱うが、実際の神経系では(おそらく確率的な)スパイクでやりとりがなされる。
神経科学と機械学習のギャップを埋めるためには避けては通れない話題ですが、誤差逆伝播の問題と言うよりはニューラルネット全体の課題という気もします。

⑤前向き計算と逆向き計算のそれぞれが正確に同期しながら順番に切り替わっていく必要がある(逆向き計算は順向き計算で得られたそれぞれの隠れ層の値を利用して行われる必要があるため)。
→計算の順序がどのように制御されているのか?という問題ですね。

⑥出力層における"target"(教師データ)はどこから得られるのか?
誤差逆伝播法に限らず、教師あり学習全般に関わってくる問題ですね。例えばクラス分類にしても、どのように正解ラベルを表すone-hotベクトルが神経系の中でエンコードされているかは全く自明でないと思います。

研究の意義

現在世界で唯一実現されている汎用人工知能が脳であることを考えると、その動作・学習原理を解明することは確実に大きなインパクトをもたらすはずです。そしてその解明のためのアプローチとしてニューロンレベルからそのダイナミクスを理解していこうというボトムアップ的手法に対して、biologically plausible backpropagationの研究はトップダウンなアプローチをもたらしてくれるんじゃないかという期待があります。



次回は上のリストのうち③に対応することがメインの目的であるfeedback alignmentの紹介を行う予定です。

kerasで知識の蒸留(Distillation)

概要

kerasで知識の蒸留(knowledge distillation)を実装する際結構ハマったので備忘録も兼ねてポイントを整理します。

ちなみにsoft targetをあらかじめ計算しておくやり方であればそんなに難しくないのですが、Imagedatageneratorとfit_generatorを使ってバッチ毎にデータ拡張しながら学習させようと思うと結構大変でした。

github.com


※効率の悪い書き方をしてるかもしれないので色々指摘してくれると嬉しいです。

全体の流れ

knowledge distillation(以下KD)では次の2つの誤差関数を用意します。

Hard loss: 生徒モデルの出力と真のラベルの間に定義される誤差関数。いわゆる「普通のloss」
Soft loss: 生徒モデルの出力と教師モデルの出力の間に定義される誤差関数。正確には出力をそのまま用いるのではなく以下のように温度項を導入したものを用いる。
zを生徒モデルのlogits(softmax層の直前の出力値)、z'を教師モデルのlogitsとして、

 p_i = \frac{\exp \bigl({\frac{z_i}{T}}\bigr)}{\sum^n_{j=1} \exp \bigl({\frac{z_j}{T}}\bigr)}

 q_i = \frac{\exp \bigl({\frac{z^{\prime}_i}{T}} \bigr)}{\sum^n_{j=1} \exp \bigl({\frac{z^{\prime}_j}{T}} \bigr)}

上の2つの誤差関数をλで重みづけて足し合わせたものが最終的な誤差関数です。

 

全体的な流れを書くと下図のようになります。
f:id:tripdancer0916:20181014191731j:plain

誤差関数の計算

まずは誤差関数ですが、

https://github.com/TropComplique/knowledge-distillation-kerasを参考にして次のように定義しました。

#リストを引数にとる
def knowledge_distillation_loss(input_distillation):
    y_pred, y_true, y_soft, y_pred_soft = input_distillation
    return (1 - lambda_) * logloss(y_true, y_pred) + lambda_*T*T*logloss(y_soft, y_pred_soft)

 
y, \hat y, p, qの4つを引数として、求める合計の誤差関数を返り値とする関数です。ただし、copile時に指定できる誤差関数はy_predとy_trueの二値を引数とする関数だけなので、knowledge_distillation_lossもモデル内に組み込んでしまい、compile時には最終的な出力値をそのまま返すようなダミーの関数を指定します。

model.train_model.compile(
 optimizer=keras.optimizers.Adam(lr=0.003, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
    loss=lambda y_true, y_pred: y_pred,
)

ダミーの入力の作り方は後述します。

モデルの構築

全体で定義するモデルは教師モデルと生徒モデルが結合しているものになります。教師モデルは保存済みのものをload_modelしたあと全てのレイヤーについてパラメータを固定する必要があります。

self.teacher_model = keras.models.load_model(teacher_model)
for i in range(len(self.teacher_model.layers)):
    self.teacher_model.layers[i].trainable = False
#trainable=Falseを有効にするためのcompile
self.teacher_model.compile(optimizer="adam", loss="categorical_crossentropy")

 qを計算する部分は次のように書きました。温度で割る計算のためにpop()によって一度softmax層を外すことが必要になります。

self.teacher_model.layers.pop()
input_layer = self.teacher_model.input
# q(=teacher_probabilities_T)の計算
teacher_logits = self.teacher_model.layers[-1].output
teacher_logits_T = Lambda(lambda x: x / T)(teacher_logits)
teacher_probabilities_T = Activation('softmax', name='softmax1_')(teacher_logits_T)

生徒モデルは上のinput_layerからレイヤーを付け足していけば構築できます。ただしその際レイヤー名が教師モデルとコンフリクトしないように変える必要があります。

学習時は生徒モデルと教師モデルが合わさったものが必要ですが、欲しいのは生徒モデルだけなのでここで2つのModelを定義します。

# 生徒モデル
with tf.device('/cpu:0'):
    distilled_model = Model(inputs=input_layer, outputs=output_softmax)
    input_true = Input(name='input_true', shape=[None], dtype='float32')

# モデル全体
output_loss = Lambda(knowledge_distillation_loss, output_shape=(1,), name='kd_')(
    [output_softmax, input_true, teacher_probabilities_T, probabilities_T]
)
inputs = [input_layer, input_true] 
with tf.device('/cpu:0'):
    train_model = Model(inputs=inputs, outputs=output_loss)

ここでのポイントですが、train_model(学習に使う方)はcompile時にtargetを与えることができないので、inputに組み込んでしまいます。input_trueは一切中で計算されずにそのまま誤差関数に渡されます。

モデルを定義する関数の全体像は次のようになります。

def prepare(self):
    self.teacher_model.layers.pop()
    input_layer = self.teacher_model.input
    teacher_logits = self.teacher_model.layers[-1].output
    teacher_logits_T = Lambda(lambda x: x / self.temperature)(teacher_logits)
    teacher_probabilities_T = Activation('softmax', name='softmax1_')(teacher_logits_T)

    x = Convolution2D(32, (3, 3), padding='same', name='conv2d1')(input_layer)
    x = BatchNormalization(name='bn1')(x)
    x = advanced_activations.LeakyReLU(alpha=0.1, name='lrelu1')(x)

    x = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(x)
    x = Dropout(0.3, name='drop1')(x)

    x = Convolution2D(64, (3, 3), padding='same', name='conv2d3')(x)
    x = BatchNormalization(name='bn3')(x)
    x = advanced_activations.LeakyReLU(alpha=0.1, name='lrelu3')(x)

    x = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(x)
    x = Dropout(0.3, name='drop2')(x)

    x = Convolution2D(128, (3, 3), padding='same', name='conv2d5')(x)
    x = BatchNormalization(name='bn5')(x)
    x = advanced_activations.LeakyReLU(alpha=0.1, name='lrelu5')(x)
    x = Convolution2D(128, (3, 3), padding='same', name='conv2d6')(x)
    x = BatchNormalization(name='bn6')(x)
    x = advanced_activations.LeakyReLU(alpha=0.1, name='lrelu6')(x)

    x = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(x)
    x = Dropout(0.3, name='drop3')(x)

    x = Flatten(name='flatten1')(x)
    x = Dense(256, activation=None, name='dense1')(x)
    x = BatchNormalization(name='bn8')(x)
    x = advanced_activations.LeakyReLU(alpha=0.1, name='lrelu8')(x)

    logits = Dense(num_classes, activation=None, name='dense2')(x)
    output_softmax = Activation('softmax', name='output_softmax')(logits)
    logits_T = Lambda(lambda x: x / self.temperature, name='logits')(logits)
    probabilities_T = Activation('softmax', name='probabilities')(logits_T)

    with tf.device('/cpu:0'):
        distilled_model = Model(inputs=input_layer, outputs=output_softmax)
        input_true = Input(name='input_true', shape=[None], dtype='float32')
    output_loss = Lambda(knowledge_distillation_loss, output_shape=(1,), name='kd_')(
        [output_softmax, input_true, teacher_probabilities_T, probabilities_T]
    )
    inputs = [input_layer, input_true]

    with tf.device('/cpu:0'):
        train_model = Model(inputs=inputs, outputs=output_loss)

    return train_model, distilled_model

イテレータ

ここまでくればあとは学習するだけなのですが、ここでもハマりました。上のような変則的な入力をする以上、fit_generatorには(train, target)のタプルではなく、([train, target], target)の形で渡す必要があります(2つめのtargetに意味はないのでここはなんでもいいのですが・・・)。

datagenator.flowは(train, target)の形で値を返すので、これを欲しい形に整形するためのイテレータを作ることにしました(このやり方が良いのかは分かりませんが、、、)。

class MyIterator(object):
    def __init__(self, iterator_org):
        self.iterator = iterator_org

    def __iter__(self):
        return self

    def __next__(self):
        tmp = next(self.iterator)
        return [tmp[0], tmp[1]], tmp[1]

tmp_iterator = datagen.flow(x_train, y_train, batch_size=batch_size)
iterator = MyIterator(tmp_iterator)

model.train_model.fit_generator(iterator,
                                steps_per_epoch=x_train.shape[0] // batch_size,
                                epochs=epochs,
                                workers=4, callbacks=[training_callback])

これで無事、学習させることができました。実際に蒸留の効果も確認することができたので、一応成功と言えそうです。


ここで使ったテクニックですが、知識の蒸留以外にも変則的な誤差関数が必要なモデルの学習に色々応用が効きそうです。