AIエンジニアの探求

計算論的神経科学で博士号取得後、AIエンジニアとして活動中。LLMの活用や脳とAIの関係などについて記事を書きます。

Deep Image Prior

衝撃的な出会い

「一切学習させていないニューラルネットワークに一枚だけ画像を渡して、その画像に対してノイズ除去したり解像度をあげたり空白を自然に埋めてくれるって手法が提案されてるよ」

そう言われてこの論文*1を渡された時、にわかには信じられなかった。というより、何を言われているのか正直良くわからなかった。それぐらいの衝撃だった。

しかも掲載されている写真を見るとかなり綺麗だし、ざっと眺めた感じ数式もあまり出てこないところからしてどうもシンプルな手法らしい。謎は深まるばかりだ。

Deep Image Priorの方法

著者の主張としては
「Deep Convolutional Networkは画像生成・復元に広く使われている。そしてその高いパフォーマンスは大量の学習データによって(画像らしさの)事前確率分布を得ることができるためとされてるけど、実際には「ネットワークの構造」そのものが事前確率分布を与えている。」
ということらしい。そしてこのことを示すために、ランダムに初期化されたネットワークを使って色々なタスクをやってみてその威力を示している。

具体的な手法としては次のような手順で行っている(大きな枠組みは共通しているけど、ここではノイズ除去を想定した)。


1. まずrandomもしくは"Meshgrid"な画像を用意する。
f:id:tripdancer0916:20180113023721p:plain
(図1:Meshgridな画像の例:Ulyanov, D. (n.d.). Deep Image Prior Supplementary Material, (9).より引用)

2. 次に画像を入力として画像を出力するようなネットワークを用意する(GANで使われるGeneratorのイメージ)。このネットワークはランダムに初期化したものを使用する。
f:id:tripdancer0916:20180113023906p:plain
(図2:Ulyanov, D. (n.d.). Deep Image Prior Supplementary Material, (9).より引用)

3.タスクに応じて不完全な画像(=target画像)を用意する
f:id:tripdancer0916:20180113024300p:plain
(図3:デモでも使われているかたつむりの絵にノイズを加えたもの)

4.ネットワークから出力される画像とtarget画像の間の誤差関数を定義して、それを誤差逆伝播法によって小さくする。このとき誤差関数はユークリッド距離のようなシンプルなものでもワークする。
{ \displaystyle
\theta ^* = argmin_{\theta}E(f_\theta (z) ; x_0)  
}
(\thetaはネットワークのパラメータ、f_\theta(\bullet)はGenerator関数、zは入力画像、x_0はtarget画像を表す。)
{ \displaystyle
x^* = f_{\theta ^*}(z)
}
(得られる画像)
{ \displaystyle
E(x;x_0) = \| x-x_0 \|^2
}
(誤差関数)

5.4を進めていく過程で、ノイズが除去された画像が生成される。
f:id:tripdancer0916:20180113030100p:plain
(図4:実際にできたもの。著者のページにはより綺麗な例がたくさん載っている)

なぜこれでうまくいくのか?かなり不思議だ。。。
まず、ネットワークは巨大なんだから誤差関数は限りなく0に近づけることができるはずで、結果としてtarget画像と同じものが生成されるだけではないのか?という疑問が浮かぶ。
それは実際正しくて、epoch数を進め過ぎるとそういう現象が起こる。
f:id:tripdancer0916:20180113030506j:plain
(図5)

だからパラメータの更新をちょうどいい感じのところで止めると、いい結果が得られる。そしてここにこのDeep Image Priorの本質部分がある(と自分では解釈した)。

復元にかかる"時間差"を利用

f:id:tripdancer0916:20180113031200p:plain
(図6:Ulyanov, D., Vedaldi, A., & Lempitsky, V. (n.d.). Deep Image Prior.より引用)
上のグラフは論文中から引用したもので、誤差関数(Mean-Squared Error)の推移を異なるtarget画像に対して表示したものだ。
下2つは普通の画像と普通の画像にノイズを加えたもの、そして上2つは画像中のピクセルをすべてシャッフルしたもの、そして完全なホワイトノイズに対する結果を表示している。

これを見ると、誤差関数が下がるタイミングに差があることがわかる。つまり、意味のある画像は意味のない画像よりも「早く」復元されるため、意味のない部分の復元が追いついていない領域においては結果的にノイズ除去ができている、ということだ。
実際図5を見るとそのような挙動になっている(900iterationsの画像では一回消えたノイズがまた復活している)ことが分かる。


これで感覚的には納得がいったけど、たぶん厳密な議論はこれからだと思う。そしてネットワークの構造やハイパーパラメーターによってタスクのパフォーマンスが大きく変わるらしくて、そこの部分も研究が進んでいくだろう。

もっと言うと「人間がどのように世界を知覚して認識しているのか」という基礎的な問いや「階層的な情報処理を行う本質的な意味とは」「畳み込み&プーリングによって何ができているのか」みたいな話にもどんどん応用が効きそうで、理論的にかなり面白いんじゃないかと期待している。

*1:Ulyanov, D., Vedaldi, A., & Lempitsky, V. (n.d.). Deep Image Prior.