FactorVAEを実装した話
Factor VAEをTensorflow2.0で実装してみた話です。
github.com
Tensorflow 1.XやChainerの実装はありましたが、Tensorflow 2.0で書かれた物はなかったので、実装してみました。
トップの画像は実際にLatent traversalを行った結果です。
若干チューニングの甘さがありますが、まあある程度できていると言って差し支えないでしょう。
What's Factor VAE
β-VAEの改良版で、disentangled representationを抽出するネットワークを学習します。
β-VAEでは、disentangledrepresentationを得るための誤差関数がVAEの再構成誤差とトレードオフの関係にありましたが、FactorVAEはそこをAdversarial Trainingによって解消しています。
ところで、disentangled representationとはなにか?については下記が詳しい。
要は説明可能な特徴表現を指していて、上記のP.4をひっぱてきた。
http://www.cv.info.gifu-u.ac.jp/contents/workshop/contents/nips2018/ppt/NIPS_yamada.pdf
Factor VAEの仕組み
概要
Autoencoderについて以下の誤差を最小化して、誤差を計算します。
- Autoencoderの再構成誤差(入出力のL2距離を計算)を最小化
- Autoencoderの潜在変数のGaussian KL divergenceを最小化(ここまで普通のVAEの誤差関数)
- Autoencoderの潜在変数をDiscriminatorに入力してその出力についてTotal Correlationを最小化。(これがDisentangled representationを得るための仕組み)
また、Discriminatorの学習には、ある画像から得られる潜在変数を入力したときと、特定の次元についてSwapしたときの誤差がより大きくなるように学習を行います。
Total Correlation
en.wikipedia.org
Total Correlationって何?ってなりましたが、つまるところ潜在変数の各次元の分布と全潜在変数のJoint distributionが最小になる=になるので、すべての次元が独立になるということで良いでしょう。
以下の図において円同士が重なるところがTotal Correlationで計算しているところであり、ここを最小化しています。
Wikipediaから引用
このTotal Correlationを計算するためには当然、とを得る必要があります。
したがって、一般的なDiscriminatorは入力が1次元(入力が本物か偽物かを判定)ですが、このDiscriminatorは2次元の出力を持つことになります。
ちなみに、は通常のDiscriminatorの出力でを推定することがFactorVAE固有のものになります。
実装のためのテクニック
当たり前ですが論文の数式を愚直に実装するよりAdversarial Training特有の式変形を行って誤差関数を計算しやすくしたほうが当然学習は安定します
(というかそれなしで実装している実装をみかけたのですが、どうなってるんですかね?普通にNaN吐いて死ぬんですけど…)
実際に各誤差は以下の実装および計算を行います。
各記号はそれぞれ以下の図とのようになっています。
計算式は最終的な計算結果を最初に示して詳しい話は後に書くようにしています。
(1) 再構成誤差
実にシンプルな二乗誤差。
説明は不要でしょう。
(2) KL-divergence
[KL=tex: 0.5 \times (\textit{E}\[\theta\]+ Var\[\theta\] - lnVar\[\theta\] - 1)]
これは、以下の式にとに標準正規分布のを代入した式です。
def gaussian_kl_divergence(mean, ln_var, raxis=1): var = tf.exp(ln_var) mean_square = mean * mean return tf.reduce_sum((mean_square + var - ln_var - 1) * 0.5, axis=raxis)
(3) Total correlation
Discriminatorの出力にsigmoid関数を掛けずにlogitを計算するようにすればそのままエントロピーになりますので、差を計算するだけで大丈夫。
ただし、はDiscriminatorの2番目の出力です。
実際には以下の式変形になります。
実際のコードでは以下が該当の部分です。
logits_orig_z = disc(z) logits_shuffle_z = disc(z_shuffled) L_{TC} = tf.gather(logits_orig_z, 0, axis=1) - tf.gather(logits_orig_z, 1, axis=1)
(4) Discriminator loss
通常のBinary crossentropyの差を計算する代わりにsoftplusを使って計算することで、NaNが計算されないようにしています。
Binary CrossentropyとSoftplusの関係性は以下のブログが式変形まで詳しいです。
tatsukawa.hatenablog.com
ここで、の部分は特定のEncoderが抽出する潜在変数の特定の次元を入れ替える操作です。
実装上は以下のようになります。
あまりいい実装出ないように思いますが、まあ良しとしましょう。
indices = list(range(z_shape[1])) swap_index_pair = np.random.choice(indices, size=2, replace=False) tmp = indices[swap_index_pair[0]] indices[swap_index_pair[0]] = indices[swap_index_pair[1]] indices[swap_index_pair[1]] = tmp nd_indices = [ [i, j] for i, j in product(range(z_shape[0]), indices) ] z_shuffled = tf.reshape(tf.gather_nd(z, tf.convert_to_tensor(nd_indices)), z_shape) loss_disc = (0.5 * tf.keras.activations.softplus(-tf.gather(logits_orig_z, 0, axis=1))\ + 0.5 * tf.keras.activations.softplus(-tf.gather(logits_shuffle_z, 0, axis=1)))
チューニング
ここからは実際にDisentangled representationを得るために行ったチューニングの諸々。
結構完成まで紆余曲折あったので、まとめてみた。
VAEのConvolutionにResidual Connectionを加える
そもそもVAEの学習が遅かったのでResidual Connectionを加えた。
ついでに画像のボケ具合も解消された感じがあったので、そこそこ効果があった気がする。
この辺りはかなり初期に追加していたので、実際問題として実装が悪かったおかげで改善したのか、本当に意味があったのかは不明です。
一般論で言えば、効果あるでしょうけど。
reduce_meanを使わない
Tensorflow固有の話かは分からないけれど、再構成誤差の部分にreduce_meanの代わりにreduce_sumを必ず使おう。
サンプル数のNで割られた結果誤差が消失するのかまた、他の影響なのか画像がグレーになりました。
Total correlationについて
Total Correlationは意外と厄介で、これは数式上負の値を取りうる。
これのせいで、最適化序盤で破綻することがあった。
特にDiscriminatorが貧弱なときに起きる印象があるので、ある程度の層数のDiscriminatorを使うと全く起きなくなった。
もうちょっと頑張った方が良いと思うところ
以下の3点をどうにかしたいと思うけど、そろそろImage Transformerが気になって仕方がないので、頑張る気はないです。
1. Latent representationが正規分布じゃない
→Batch sizeを上げれば解消するけど、これすると次はTotal Correlationの最適化がうまく行かなくなってる。Leraning rate落として緩やかに学習すべきかね。
2. VAEの出力にSigmoidかけるの忘れてた。
→普通に忘れてた。もうちょっとキレイな結果が得られると思う。
3. 1次元分値を変えても出力が変わらない潜在変数が存在する。
→Latent representationの次数が多すぎるのでしょうかね。