複数の時系列データでLSTMモデルをトレーニングする方法は?


13

複数の時系列データでLSTMモデルをトレーニングする方法は?

使用例:過去5年間、毎週20,000人のエージェントの売上があります。各エージェントの今後の週次売上を予測する必要があります。

バッチ処理手法に従う必要がありますか?一度に1つのエージェントを取得し、LSTMモデルをトレーニングしてから予測しますか?もっと良い方法は?


あなたはこれを理解しましたか?私は同様の問題を見ています。
vishnu viswanath 2017年

@vishnuviswanathは、すべてのエージェントの1つのモデルとなるニューラルネット(RNN)の開発に取り組んでいます。
Aljo Jose

ありがとう。複数のエージェントのモデルをどのようにトレーニングしていますか?バッチごとに1つのエージェントをトレーニングしていますか。
vishnu viswanath 2017年

まだ建設段階です。kaggle.com/c/web-traffic-time-series-forecasting/discussion/…が役立ちます。
Aljo Jose

回答:


11

エージェントのIDを機能の1つにして、すべてのデータをトレーニングします。おそらく一度に128のエージェントのミニバッチでトレーニングします。これらの128のエージェントの開始から終了まで時系列を実行し、エージェントの新しいミニバッチを選択します。ミニバッチごとに、たとえば50タイムステップのスライスを実行してから、バックプロップします。そのスライスの最終状態を保持し、それらの最終状態から開始して、次の50タイムステップを実行します。〜128エージェントのミニバッチについて、タイムステップの最後に到達するまですすぎ、繰り返します。

各エージェントのIDを機能の1つとすることで、ネットワークで次のことが可能になります。

  • すべてのデータから学び、それによってデータの利用率を最大化します。
  • 各エージェントの固有の特性を学習して、すべてのエージェントを平均化するだけではないようにする
  • 特定のエージェントの将来を予測する場合は、対応するエージェントID機能を使用してください。ネットワークはそれに応じて予測を調整します。

編集:アルポホセは書きました:

わかりました、エージェントのIDを作成するために1つのホットエンコーディングを使用する必要がありますか?

おお、それは本当だ。それらの20,000があります。それはたくさんのことです。あなたがしたいのはそれらを「埋め込む」ことだと思います。エージェントID(整数、インデックスとして表される)を取り込み、長さ50〜300のベクトルのような高次元ベクトルを出力するルックアップレイヤーを用意します。あなたのLSTM。

数学的には、「埋め込み層」とも呼ばれるルックアップテーブルは、エージェントIDをワンホットベクトルにして、線形(完全に接続された)層を通過することと同じです。ただし、埋め込みレイヤーの場合、メモリの要件は大幅に削減されます。

埋め込み層が何を学習するかという点では、トレーニングすると、埋め込み層は各エージェントのある種の潜在表現を形成します。潜在的な表現は、おそらくどのような方法でも読み取り可能/解釈可能ではありませんが、モデルが「ok this agent、1524」は比較的効果的ですが、週末ではないことを学ぶことができます。1526は毎日素晴らしい場所です。etc .... '。埋め込みベクトルの潜在的な次元は実際には何かを意味するかもしれませんが、誰もがそれらの意味を理解しようとすることはありません(私はそれが難しい/不可能だと思います)。ただし、エージェントごとの高次元の埋め込みにより、モデルは各エージェントの動作について何かを学習し、これを時系列予測でモデル化できます。


わかりました、エージェントのIDを作成するために1つのホットエンコーディングを使用する必要がありますか?
Aljo Jose

@AljoJoseはこの質問を考慮に入れるために回答を更新しました
Hugh Perkins

わかりました、試してみます。ヒューありがとう
Aljo Jose

同様の問題がありますが、ここのエージェントとは異なり、同じプロセスのインスタンスとして複数の時系列がありますが、それぞれが可変長です。したがって、機能としてエージェント(私の場合はプロセス)は必要ありません。それを処理する方法の提案はありますか?
Anakin
弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.