TensorFlow保存/グラフのファイルからの読み込み


98

これまでに収集したものから、TensorFlowグラフをファイルにダンプして別のプログラムにロードする方法はいくつかありますが、それらがどのように機能するかについての明確な例や情報を見つけることができませんでした。私がすでに知っているのはこれです:

  1. a tf.train.Saver()を使用してモデルの変数をチェックポイントファイル(.ckpt)に保存し、後で復元する(ソース
  2. モデルを.pbファイルに保存し、tf.train.write_graph()and tf.import_graph_def()source
  3. .pbファイルからモデルを読み込み、再トレーニングして、Bazel(ソースを使用して新しい.pbファイルにダンプします
  4. グラフをフリーズして、グラフと重みを一緒に保存します(ソース
  5. as_graph_def()モデルを保存するために使用し、重み/変数については、それらを定数にマッピングします(ソース

ただし、これらのさまざまな方法に関するいくつかの質問を解決することはできませんでした。

  1. チェックポイントファイルについては、モデルのトレーニング済みの重みのみを保存しますか?チェックポイントファイルを新しいプログラムにロードして、モデルを実行するために使用できますか、それとも、特定の時間/段階でモデルの重みを保存する方法として機能するだけですか?
  2. について tf.train.write_graph()、重み/変数も保存されますか?
  3. Bazelに関しては、再トレーニングのために.pbファイルにのみ保存/ロードできますか?グラフを.pbにダンプするだけの簡単なBazelコマンドはありますか?
  4. フリーズに関して、フリーズされたグラフを使用してロードできますか tf.import_graph_def()ますか?
  5. TensorFlowのAndroidデモは、.pbファイルからGoogleのInceptionモデルに読み込まれます。自分の.pbファイルを置き換えたい場合は、どうすればよいですか?ネイティブコード/メソッドを変更する必要がありますか?
  6. 一般的に、これらすべての方法の違いは何ですか?より広義には、as_graph_def()/。ckpt / .pb の違いは何ですか?

要するに、私が探しているのは、グラフ(さまざまな操作など)とその重み/変数の両方をファイルに保存する方法です。これを使用して、グラフと重みを別のプログラムにロードできます。 、使用するため(必ずしも継続/再トレーニングではない)。

このトピックに関するドキュメントは簡単ではないので、回答や情報があれば大歓迎です。


2
最新/最も完全なAPIはメタグラフ
Yaroslav Bulatov

回答:


80

TensorFlowでモデルを保存するという問題に取り組むには多くの方法があり、少し混乱する可能性があります。順番にあなたのサブ質問のそれぞれを取る:

  1. チェックポイントファイル(たとえばsaver.save()tf.train.Saverオブジェクトを呼び出すことによって生成される)には、重みと、同じプログラムで定義されている他の変数のみが含まれます。これらを別のプログラムで使用するには、関連するグラフ構造を再作成する必要があります(たとえば、コードを実行して再構築するかtf.import_graph_def()、を呼び出す)。これにより、TensorFlowにそれらの重みをどうするかを指示します。を呼び出すと、を含むsaver.save()ファイルも生成されます。このファイルMetaGraphDefには、グラフと、チェックポイントからの重みをそのグラフに関連付ける方法の詳細が含まれています。詳細については、チュートリアルを参照しください。

  2. tf.train.write_graph()グラフ構造のみを書き込みます。重みではありません。

  3. BazelはTensorFlowグラフの読み取りまたは書き込みとは無関係です。(おそらく私はあなたの質問を誤解しています:コメントでそれを明確にしてください。)

  4. フリーズしたグラフは、を使用してロードできますtf.import_graph_def()。この場合、重みは(通常)グラフに埋め込まれているため、個別のチェックポイントをロードする必要はありません。

  5. 主な変更は、モデルに供給されるテンソルの名前と、モデルからフェッチされるテンソルの名前を更新することです。TensorFlow Androidデモでは、これはに渡されるinputNameおよびoutputName文字列に対応しTensorFlowClassifier.initializeTensorFlow()ます。

  6. これGraphDefはプログラム構造であり、通常、トレーニングプロセスを通じて変更されません。チェックポイントは、トレーニングプロセスの状態のスナップショットであり、通常、トレーニングプロセスの各ステップで変化します。その結果、TensorFlowはこれらのタイプのデータにさまざまなストレージ形式を使用し、低レベルAPIはそれらを保存およびロードするさまざまな方法を提供します。など、より高いレベルのライブラリ、MetaGraphDefライブラリ、Keras、およびskflowモデル全体を保存して復元するより便利な方法を提供するために、これらのメカニズムの構築。


そのこれはどういう意味C ++ APIドキュメント、それはあなたが保存したグラフをロードできることを言うとき、嘘tf.train.write_graph()、それを実行しますか?
mnicky 2016年

2
C ++ APIドキュメントには嘘はありませんが、いくつかの詳細が欠けています。最も重要な詳細は、GraphDefによって保存されたものに加えてtf.train.write_graph()、グラフを実行するときにフィードおよびフェッチするテンソルの名前も覚えておく必要があることです(上記の項目5)。
mrry 2016年

@mrry:私はテンソルフローのDeepDreamの例を使用しようとしました。しかし、それはpb形式の事前訓練されたモデルを必要とするようです!Cifar10の例を実行しましたが、チェックポイントのみが作成されます!私はどんなpbファイルも何も見つけることができませんでした!チェックポイントを、deepdreamの例が使用するpb形式に変換するにはどうすればよいですか?
Rika

2
@ Coderx7チェックポイントには重みと変数しか含まれておらず、グラフの構造について何も知らないため、.ckptを.pbに変換できないと思います
davidivad

1
.pbファイルをロードして実行する簡単なコードはありますか?
コング

1

次のコードを試すことができます:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.