AIエンジニアの探求

計算論的神経科学で博士号取得後、AIエンジニアとして活動中。LLMの活用や脳とAIの関係などについて記事を書きます。

kerasで知識の蒸留(Distillation)

概要

kerasで知識の蒸留(knowledge distillation)を実装する際結構ハマったので備忘録も兼ねてポイントを整理します。

ちなみにsoft targetをあらかじめ計算しておくやり方であればそんなに難しくないのですが、Imagedatageneratorとfit_generatorを使ってバッチ毎にデータ拡張しながら学習させようと思うと結構大変でした。

github.com


※効率の悪い書き方をしてるかもしれないので色々指摘してくれると嬉しいです。

全体の流れ

knowledge distillation(以下KD)では次の2つの誤差関数を用意します。

Hard loss: 生徒モデルの出力と真のラベルの間に定義される誤差関数。いわゆる「普通のloss」
Soft loss: 生徒モデルの出力と教師モデルの出力の間に定義される誤差関数。正確には出力をそのまま用いるのではなく以下のように温度項を導入したものを用いる。
zを生徒モデルのlogits(softmax層の直前の出力値)、z'を教師モデルのlogitsとして、

 p_i = \frac{\exp \bigl({\frac{z_i}{T}}\bigr)}{\sum^n_{j=1} \exp \bigl({\frac{z_j}{T}}\bigr)}

 q_i = \frac{\exp \bigl({\frac{z^{\prime}_i}{T}} \bigr)}{\sum^n_{j=1} \exp \bigl({\frac{z^{\prime}_j}{T}} \bigr)}

上の2つの誤差関数をλで重みづけて足し合わせたものが最終的な誤差関数です。

 

全体的な流れを書くと下図のようになります。
f:id:tripdancer0916:20181014191731j:plain

誤差関数の計算

まずは誤差関数ですが、

https://github.com/TropComplique/knowledge-distillation-kerasを参考にして次のように定義しました。

#リストを引数にとる
def knowledge_distillation_loss(input_distillation):
    y_pred, y_true, y_soft, y_pred_soft = input_distillation
    return (1 - lambda_) * logloss(y_true, y_pred) + lambda_*T*T*logloss(y_soft, y_pred_soft)

 
y, \hat y, p, qの4つを引数として、求める合計の誤差関数を返り値とする関数です。ただし、copile時に指定できる誤差関数はy_predとy_trueの二値を引数とする関数だけなので、knowledge_distillation_lossもモデル内に組み込んでしまい、compile時には最終的な出力値をそのまま返すようなダミーの関数を指定します。

model.train_model.compile(
 optimizer=keras.optimizers.Adam(lr=0.003, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
    loss=lambda y_true, y_pred: y_pred,
)

ダミーの入力の作り方は後述します。

モデルの構築

全体で定義するモデルは教師モデルと生徒モデルが結合しているものになります。教師モデルは保存済みのものをload_modelしたあと全てのレイヤーについてパラメータを固定する必要があります。

self.teacher_model = keras.models.load_model(teacher_model)
for i in range(len(self.teacher_model.layers)):
    self.teacher_model.layers[i].trainable = False
#trainable=Falseを有効にするためのcompile
self.teacher_model.compile(optimizer="adam", loss="categorical_crossentropy")

 qを計算する部分は次のように書きました。温度で割る計算のためにpop()によって一度softmax層を外すことが必要になります。

self.teacher_model.layers.pop()
input_layer = self.teacher_model.input
# q(=teacher_probabilities_T)の計算
teacher_logits = self.teacher_model.layers[-1].output
teacher_logits_T = Lambda(lambda x: x / T)(teacher_logits)
teacher_probabilities_T = Activation('softmax', name='softmax1_')(teacher_logits_T)

生徒モデルは上のinput_layerからレイヤーを付け足していけば構築できます。ただしその際レイヤー名が教師モデルとコンフリクトしないように変える必要があります。

学習時は生徒モデルと教師モデルが合わさったものが必要ですが、欲しいのは生徒モデルだけなのでここで2つのModelを定義します。

# 生徒モデル
with tf.device('/cpu:0'):
    distilled_model = Model(inputs=input_layer, outputs=output_softmax)
    input_true = Input(name='input_true', shape=[None], dtype='float32')

# モデル全体
output_loss = Lambda(knowledge_distillation_loss, output_shape=(1,), name='kd_')(
    [output_softmax, input_true, teacher_probabilities_T, probabilities_T]
)
inputs = [input_layer, input_true] 
with tf.device('/cpu:0'):
    train_model = Model(inputs=inputs, outputs=output_loss)

ここでのポイントですが、train_model(学習に使う方)はcompile時にtargetを与えることができないので、inputに組み込んでしまいます。input_trueは一切中で計算されずにそのまま誤差関数に渡されます。

モデルを定義する関数の全体像は次のようになります。

def prepare(self):
    self.teacher_model.layers.pop()
    input_layer = self.teacher_model.input
    teacher_logits = self.teacher_model.layers[-1].output
    teacher_logits_T = Lambda(lambda x: x / self.temperature)(teacher_logits)
    teacher_probabilities_T = Activation('softmax', name='softmax1_')(teacher_logits_T)

    x = Convolution2D(32, (3, 3), padding='same', name='conv2d1')(input_layer)
    x = BatchNormalization(name='bn1')(x)
    x = advanced_activations.LeakyReLU(alpha=0.1, name='lrelu1')(x)

    x = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(x)
    x = Dropout(0.3, name='drop1')(x)

    x = Convolution2D(64, (3, 3), padding='same', name='conv2d3')(x)
    x = BatchNormalization(name='bn3')(x)
    x = advanced_activations.LeakyReLU(alpha=0.1, name='lrelu3')(x)

    x = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(x)
    x = Dropout(0.3, name='drop2')(x)

    x = Convolution2D(128, (3, 3), padding='same', name='conv2d5')(x)
    x = BatchNormalization(name='bn5')(x)
    x = advanced_activations.LeakyReLU(alpha=0.1, name='lrelu5')(x)
    x = Convolution2D(128, (3, 3), padding='same', name='conv2d6')(x)
    x = BatchNormalization(name='bn6')(x)
    x = advanced_activations.LeakyReLU(alpha=0.1, name='lrelu6')(x)

    x = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(x)
    x = Dropout(0.3, name='drop3')(x)

    x = Flatten(name='flatten1')(x)
    x = Dense(256, activation=None, name='dense1')(x)
    x = BatchNormalization(name='bn8')(x)
    x = advanced_activations.LeakyReLU(alpha=0.1, name='lrelu8')(x)

    logits = Dense(num_classes, activation=None, name='dense2')(x)
    output_softmax = Activation('softmax', name='output_softmax')(logits)
    logits_T = Lambda(lambda x: x / self.temperature, name='logits')(logits)
    probabilities_T = Activation('softmax', name='probabilities')(logits_T)

    with tf.device('/cpu:0'):
        distilled_model = Model(inputs=input_layer, outputs=output_softmax)
        input_true = Input(name='input_true', shape=[None], dtype='float32')
    output_loss = Lambda(knowledge_distillation_loss, output_shape=(1,), name='kd_')(
        [output_softmax, input_true, teacher_probabilities_T, probabilities_T]
    )
    inputs = [input_layer, input_true]

    with tf.device('/cpu:0'):
        train_model = Model(inputs=inputs, outputs=output_loss)

    return train_model, distilled_model

イテレータ

ここまでくればあとは学習するだけなのですが、ここでもハマりました。上のような変則的な入力をする以上、fit_generatorには(train, target)のタプルではなく、([train, target], target)の形で渡す必要があります(2つめのtargetに意味はないのでここはなんでもいいのですが・・・)。

datagenator.flowは(train, target)の形で値を返すので、これを欲しい形に整形するためのイテレータを作ることにしました(このやり方が良いのかは分かりませんが、、、)。

class MyIterator(object):
    def __init__(self, iterator_org):
        self.iterator = iterator_org

    def __iter__(self):
        return self

    def __next__(self):
        tmp = next(self.iterator)
        return [tmp[0], tmp[1]], tmp[1]

tmp_iterator = datagen.flow(x_train, y_train, batch_size=batch_size)
iterator = MyIterator(tmp_iterator)

model.train_model.fit_generator(iterator,
                                steps_per_epoch=x_train.shape[0] // batch_size,
                                epochs=epochs,
                                workers=4, callbacks=[training_callback])

これで無事、学習させることができました。実際に蒸留の効果も確認することができたので、一応成功と言えそうです。


ここで使ったテクニックですが、知識の蒸留以外にも変則的な誤差関数が必要なモデルの学習に色々応用が効きそうです。