Deep Q-Learning損失関数を正確に計算するにはどうすればよいですか?


10

Deep Q-Learning Networkの損失関数がどの程度正確にトレーニングされているのか疑問です。私は、線形出力層とRelu非表示層のある2層フィードフォワードネットワークを使用しています。

  1. 4つのアクションがあるとします。したがって、現在の状態に対する私のネットワークの出力はです。より具体的にするために、と仮定しましょうstQ(st)R4Q(st)=[1.3,0.4,4.3,1.5]
  2. 次に、値対応するアクション、つまり3番目のアクションを実行し、新しい状態到達します。at=24.3st+1
  3. 次に、状態フォワードパスを計算し、出力レイヤー次の値を取得するとします。また、報酬ととしましょう。st+1Q(st+1)=[9.1,2.4,0.1,0.3]rt=2γ=1.0
  4. 損失は​​以下によって与えられます:

    L=(11.14.3)2

    または

    L=14i=03([11.1,11.1,11.1,11.1][1.3,0.4,4.3,1.5])2

    または

    L=14i=03([11.1,4.4,2.1,2.3][1.3,0.4,4.3,1.5])2

ありがとう、申し訳ありませんが、これを非常に基本的な方法で書き出さなければなりませんでした。(正解は2番目だと思います...)


1
この明確な例のある質問は、私が過去1週間に読んだ他のメディアの記事よりもディープqラーニングを理解するのに役立ちました。
dhruvm

回答:


5

方程式をもう数回確認した後。正しい損失は次のとおりだと思います。

L=(11.14.3)2

私の推論では、一般的な場合のqラーニング更新ルールは、特定ペアのq値のみを更新するものです。state,action

Q(s,a)=r+γmaxaQ(s,a)

この方程式は、更新が1つの特定のペア、およびニューラルqネットワークでのみ発生することを意味します。つまり、損失は、特定の対応する1つの特定の出力ユニットに対してのみ計算されます。a c t i o nstate,actionaction

例では、で、はです。t a r g e t rQ(s,a)=4.3targetr+γmaxaQ(s,a)=11.1

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.