変分オートエンコーダーでのKLD損失と再構成損失の重み付け方法


26

VAEで見たほぼすべてのコード例で、損失関数は次のように定義されています(これはtensorflowコードですが、theano、torchなどでも同様です。これはconvnetでも見られますが、あまり関係ありません) 、単に合計が引き継がれる軸に影響します):

# latent space loss. KL divergence between latent space distribution and unit gaussian, for each batch.
# first half of eq 10. in https://arxiv.org/abs/1312.6114
kl_loss = -0.5 * tf.reduce_sum(1 + log_sigma_sq - tf.square(mu) - tf.exp(log_sigma_sq), axis=1)

# reconstruction error, using pixel-wise L2 loss, for each batch
rec_loss = tf.reduce_sum(tf.squared_difference(y, x), axis=[1,2,3])

# or binary cross entropy (assuming 0...1 values)
y = tf.clip_by_value(y, 1e-8, 1-1e-8) # prevent nan on log(0)
rec_loss = -tf.reduce_sum(x * tf.log(y) + (1-x) * tf.log(1-y), axis=[1,2,3])

# sum the two and average over batches
loss = tf.reduce_mean(kl_loss + rec_loss)

ただし、kl_lossとrec_lossの数値範囲は、それぞれ潜在空間の暗さと入力フィーチャサイズ(ピクセル解像度など)に大きく依存します。reduce_sumをreduce_meanに置き換えて、z-dim KLDおよびピクセル(または機能)LSEまたはBCEごとに取得するのが賢明でしょうか?さらに重要なことは、最終的な損失を合計するときに、潜在的な損失と再構成損失をどのように重み付けするかということです。試行錯誤だけですか?それとも何らかの理論(または少なくとも経験則)がありますか?これに関する情報はどこにも見つかりませんでした(元の論文を含む)。


私が抱えている問題は、入力フィーチャ(x)次元と潜在空間(z)次元のバランスが「最適」でない場合、再構築は非常に良好ですが、学習した潜在空間は構造化されていない(x次元の場合は非常に高く、再構築エラーがKLDを支配します)、またはその逆(再構築は良くありませんが、KLDが支配すれば学習した潜在空間はうまく構造化されます)。

再構成損失(入力フィーチャサイズで除算)とKLD(z次元で除算)を正規化し、KLD項を任意の重み係数で手動で重み付けする必要があることに気づきました(正規化は、同じまたはxまたはzの次元に依存しない同様の重み)。経験的には、再構築と、私にとって「スイートスポット」のように感じられる構造化された潜在的な空間とのバランスが良いことが0.1前後でわかりました。この分野での以前の仕事を探しています。


要求に応じて、上記の数学表記(再構成エラーのためのL2損失に焦点を当てる)

Llatent(i)=12j=1J(1+log(σj(i))2(μj(i))2(σj(i))2)

Lrecon(i)=k=1K(yk(i)xk(i))2

L(m)=1Mi=1M(Llatent(i)+Lrecon(i))

Jzμσ2KM(i)iL(m)m

回答:


17

この投稿につまずいて答えを探している人のために、このツイッタースレッドは多くの非常に有用な洞察を追加しました。

すなわち:

beta-VAE:制約付き変分フレームワークを使用した基本的な視覚概念の学習

βnorm

および関連する読書(同様の問題が議論されている場所)


7

この問題に関する論文もう1つ追加したいと思います(現時点では評判が悪いためコメントできません)。

論文のサブセクション3.1で、著者は、尤度とKLの発散を等しく重み付けするVAEの直線的な実装を訓練することに失敗したと指定しました。彼らの場合、KL損失は小さい値になると予想されていたが、望ましくないことにゼロに減少した。これを克服するために、彼らはKL発散項(青い曲線)の重み係数を0から1にゆっくりと増加させる「KLコストアニーリング」の使用を提案しました。

図2. Penn TreeBankのVAEのKL発散項の(重み付けされていない)値と一緒にプロットされた、典型的なシグモイドアニーリングスケジュールによる変分下限のKL発散項の重み。

この回避策は、ラダーVAEにも適用されます。

紙:

Bowman、SR、Vilnis、L.、Vinyals、O.、Dai、AM、Jozefowicz、R. and Bengio、S.、2015。連続空間から文章を生成する。arXivプレプリントarXiv:1511.06349。

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.