Haskellでニューラルネットワークアーキテクチャを実装し、MNISTで使用しようとしています。
hmatrix
線形代数のパッケージを使用しています。私のトレーニングフレームワークは、pipes
パッケージを使用して構築されています。
私のコードはコンパイルされ、クラッシュしません。ただし、問題は、レイヤーサイズ(たとえば、1000)、ミニバッチサイズ、および学習率の特定の組み合わせによってNaN
、計算に値が生じることです。いくつかの検査の後、非常に小さな値(次数1e-100
)が最終的にアクティベーションに表示されることがわかります。しかし、それが起こらなくても、トレーニングはまだ機能しません。その損失や精度に改善はありません。
私は自分のコードをチェックして再チェックしましたが、問題の根本が何であるかについて途方に暮れています。
各レイヤーのデルタを計算するバックプロパゲーショントレーニングは次のとおりです。
backward lf n (out,tar) das = do
let δout = tr (derivate lf (tar, out)) -- dE/dy
deltas = scanr (\(l, a') δ ->
let w = weights l
in (tr a') * (w <> δ)) δout (zip (tail $ toList n) das)
return (deltas)
lf
損失関数は、あるn
ネットワーク(れるweight
マトリックス及びbias
各層のベクトル)、out
及びtar
ネットワークの実際の出力としているtarget
(所望の)出力、及びdas
各層の活性化誘導体です。
バッチモードではout
、tar
行列(行は出力ベクトルである)であり、das
行列のリストです。
実際の勾配計算は次のとおりです。
grad lf (n, (i,t)) = do
-- Forward propagation: compute layers outputs and activation derivatives
let (as, as') = unzip $ runLayers n i
(out) = last as
(ds) <- backward lf n (out, t) (init as') -- Compute deltas with backpropagation
let r = fromIntegral $ rows i -- Size of minibatch
let gs = zipWith (\δ a -> tr (δ <> a)) ds (i:init as) -- Gradients for weights
return $ GradBatch ((recip r .*) <$> gs, (recip r .*) <$> squeeze <$> ds)
ここで、lf
およびn
は上記と同じであり、i
は入力であり、t
はターゲット出力です(両方ともバッチ形式で、行列として)。
squeeze
各行を合計することにより、行列をベクトルに変換します。つまり、ds
はデルタの行列のリストです。各列は、ミニバッチの行のデルタに対応します。したがって、バイアスの勾配は、すべてのミニバッチにわたるデルタの平均です。の場合も同じですgs
。これは、重みの勾配に対応します。
実際の更新コードは次のとおりです。
move lr (n, (i,t)) (GradBatch (gs, ds)) = do
-- Update function
let update = (\(FC w b af) g δ -> FC (w + (lr).*g) (b + (lr).*δ) af)
n' = Network.fromList $ zipWith3 update (Network.toList n) gs ds
return (n', (i,t))
lr
学習率です。FC
はレイヤーコンストラクターであり、af
はそのレイヤーの活性化関数です。
最急降下アルゴリズムは、学習率に負の値を渡すようにします。勾配降下のための実際のコードは、単にの組成物の周りにループであるgrad
とmove
、パラメータ化された停止条件付き。
最後に、平均二乗誤差損失関数のコードは次のとおりです。
mse :: (Floating a) => LossFunction a a
mse = let f (y,y') = let gamma = y'-y in gamma**2 / 2
f' (y,y') = (y'-y)
in Evaluator f f'
Evaluator
損失関数とその導関数(出力層のデルタを計算するため)をバンドルするだけです。
残りのコードはGitHub:NeuralNetworkにあります。
したがって、誰かが問題についての洞察を持っている場合、または私がアルゴリズムを正しく実装していることをサニティチェックするだけでさえあれば、私は感謝するでしょう。