AIエンジニアの探求

計算論的神経科学で博士号取得後、AIエンジニアとして活動中。LLMの活用や脳とAIの関係などについて記事を書きます。

[論文紹介]Born-Again Neural Networks

こんにちは、Born Again Neural Networksというknowledge distillation系の論文で面白いものがあったのでそれを紹介していきます。

一部再現したコードをgithubにあげています。

github.com

ちなみにknowledge distillation(KD)については分かりやすい解説記事がたくさんあるのでここでは詳しい説明は省略します、、、

例えば

がめちゃくちゃオススメです。

生まれ変わって強くなる

KDでは通常の正解ラベル(=hard target)に加えて学習済み教師モデルの出力(=soft target)を正解ラベルとして使って生徒モデルを学習することで、普通に学習させるよりも精度を向上させることを目的としています。

ここで、通常教師モデルはパラメータ数の大きいネットワーク、生徒モデルはパラメータ数の小さいネットワークを使用します。つまり、KDの枠組みを利用して精度とメモリのトレードオフを改善しようという意図があるわけです。

 

一方Born Again Networksでは生徒モデルと教師モデルのネットワークを完全に同じにします。自分自身が教師になるというわけです。

あるいは一度学習させたパラメータを全て捨て去って、まっさらな状態から心機一転、学習し直す事で以前よりも高い精度を実現してしまおうという試みとも言えるかもしれません。

 

タイトルのつけ方とか、なんとなく厨二病っぽい感じが好きです。

 

さて、そんな虫のいい話なんてあるのか?と疑問に思う人も多いと思いますが、実は上手くいきます。

 

f:id:tripdancer0916:20181011215617p:plain

上の表はCIFAR-10に対する結果で、Teacherが元のネットワークの精度(test loss)でBANが生まれ変わったネットワークの精度です。ネットワークサイズが大きくなると生まれ変わりの効果は効かなくなってしまいます*が。。。

 

f:id:tripdancer0916:20181011215639p:plain

より難しいタスクであるCIFAR-100であればどのネットワークを使用してもBANの方が精度が高くなっています。

*おそらくネットワークが大きいと最初の学習の段階で十分「良く」学習できてしまうためそこでサチってしまうものと考えられます。

 

 

Dark Knowledge Under the Light*

*論文中の章のタイトル。これもかっこいい。

Soft targetのうち、本来の正解ラベルでない部分をDark knowledgeと言います。「教師モデルがどのように画像を認識しているか」を表すとされ、これがKDにとって本質的に大事な働きをしていると(Hinton et al., 2015)は言っています。

 

一方で「Dark knowledgeはそこまで重要ではなく、KDは『教師データがこのサンプルをどれぐらい自信を持って分類しているか』という指標を元にサンプル間の重み付けを変えることによってワークしているのではないか」、という仮説を立てることも可能です。

これにメスを入れているのがこのパートで、Dark knowledgeの効果を検証するために2つの実験を行なっています。

 

Confidence Weighted by Teacher Max(CWTM)

一つ目の実験では、soft targetのうち最大値以外をすべて0にしています。つまり、Dark knowledgeの寄与を消してサンプル間の重み付けの調整のためだけにsoft lossを使うとどうなるのか、を検証しています。

f:id:tripdancer0916:20181012201540p:plainf:id:tripdancer0916:20181012201554p:plain

zを生徒モデルのlogits(softmax層の直前の出力値)、z'を教師モデルのlogitsとして、

 p_i = \frac{e^{\frac{z_i}{T}}}{\sum^n_{j=1} e^{\frac{z_j}{T}}}

 q_i = \frac{e^{\frac{z^{\prime}_i}{T}}}{\sum^n_{j=1} e^{\frac{z^{\prime}_j}{T}}}

sをbatchの大きさとすると、通常のKDではsoft lossの誤差逆伝播を計算するときに

 \delta _j = \frac{1}{s} (p_j - q_j)

となるところが

 \delta _j = \frac{\max_i{q_{ij}}}{\sum^s_{u=1} \max_i q_{iu}} (p_j - \hat{y}_j)

と変更されます(ここでyは本物のラベルで、Tは温度項を表しています)。

 

Dark Knowledge with Permuted Predictions(DKPP)

次はDark Knowledgeを完全に0にするのではなく、でたらめに入れ替えてしまうというセットアップを考えます。ただしsoft targetで最大値をとる部分については変えません。

f:id:tripdancer0916:20181012201540p:plainf:id:tripdancer0916:20181012205045p:plain

上記の条件を満たしてランダムに入れ替える関数をφとおくと、計算式は以下のようになります。

 \delta _j = \frac{1}{s} (p_j - \phi(q)_j)

 

以上2つの設定を入れた結果が以下のようになっています。 学習対象はCIFAR-100です。ちなみに右半分はBorn-Againを何回か繰り返した時の結果とそれらのアンサンブルの結果です。

f:id:tripdancer0916:20181012205517p:plain

通常のBorn-Againが一番精度が高くなるのは自然ですが、CWTMとDKPPも元のモデルより精度が向上していることが確認できます。なお図中のBAN+Lはsoft lossとhard lossを組み合わせた時の結果です。普通のKDではsoft lossのみより採用されやすいですが、この実験では純粋なBANより悪くなっています。

 

結果の再現

CIFAR-10を対象に、論文中のネットワークよりも軽いものを使ってBorn-Againの効果を確かめてみました。畳み込み層7層、プーリング層3層のVGGを軽量化したようなシンプルなCNNです。data augmentationはなしです。

Teacherのaccuracyが88.18%で、同じネットワークを学習させたところ89.10%でした。

f:id:tripdancer0916:20181012212245p:plain

またepoch毎に精度を記録すると、上図のように常にteacher modelを上回っていたため(たまたまいい結果が得られたわけではなく)born againの効果が確認できたと言えそうです。