Tensorflow SavedModelで使用されているすべてのオペレーションをリストする方法は?


10

tensorflow.saved_model.saveSavedModel形式の関数を使用してモデルを保存した場合、このモデルで使用されているTensorflow Opsを後で取得するにはどうすればよいですか。モデルを復元できるので、これらの操作はグラフに保存されsaved_model.pbます。おそらくファイルにあると思います。このprotobuf(モデル全体ではない)をロードすると、protobufのライブラリ部分にこれらのリストが表示されますが、これはドキュメント化されておらず、現時点では実験的な機能としてタグ付けされていません。Tensorflow 1.xで作成されたモデルには、この部分はありません。

では、SavedModel形式のモデルから使用済みの操作(のようなMatchingFilesまたはWriteFile)のリストを取得するための高速で信頼性の高い方法は何でしょうか。

今は、全部をフリーズできtensorflowjs-converterます。サポートされている操作も確認します。これは現在、LSTMがモデルにある場合は機能しません。こちらを参照してください。Opsは間違いなくそこにいるので、これを行うより良い方法はありますか?

モデルの例:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

この場合、少なくとも以下を含むすべてのOpが出力で予期されます:


1
あなたが何を望んでいるのか正確に伝えるのは難しいです、saved_model.pbそれはtf.GraphDef、それともSavedModelprotobufメッセージですか?がtf.GraphDef呼び出されている場合は、gdを使用して使用中の演算のリストを取得できますsorted(set(n.op for n in gd.node))。ロードされたモデルがある場合は、実行できますsorted(set(op.type for op in tf.get_default_graph().get_operations()))。である場合は、SavedModelそれから取得できますtf.GraphDef(例:)saved_model.meta_graphs[0].graph_def
jdehesa

保存されたSavedModelからオペレーションを取得したい。実際、あなたが説明している最後のオプションです。saved_model最後の例の変数は何ですか?tf.saved_model.load('/path/to/model')saved_model.pbファイルのprotobuf の結果またはロード。
サンパー

回答:


1

saved_model.pbSavedModelprotobufメッセージの場合、そこから直接操作を取得します。次のようにモデルを作成するとします。

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

これで、このモデルで使用されている操作を次のように見つけることができます。

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin

私はこのようなことを試しましたが、残念ながらこれは期待したことではありません:これを行うモデルがあるとします:input_scalar = tf.reshape(file_name, []) output = tf.io.read_file(input_scalar) return tf.stack([output], name='content')次に、ここにリストされている ReadFile Op がそこにありますが、印刷されません。
サンパー

1
@sampers私はあなたが提案するような例で答えを編集しました。ReadFile出力で操作を取得します。あなたの実際のケースでは、その操作が保存されたモデルの入力と出力の間ではない可能性はありますか?その場合、剪定される可能性があると思います。
jdehesa

確かに与えられたモデルで動作します。残念ながら、tf2で作成されたモジュールの場合は、そうではありません。私はとの1つの機能付きtf.Module作成する場合はfile_name、引数@tf.function私は私の以前のコメントに記載された通話を含む注釈を、それは次のリストを与える:Const, NoOp, PartitionedCall, Placeholder, StatefulPartitionedCall
sampers

私の質問にモデルを追加しました
sampers

@sampers回答を更新しました。以前はTF 1.xを使用していましたが、TF 2.xでのグラフ定義オブジェクトの変更に慣れていなかったので、保存されたモデルのすべてをカバーしていると思います。あなたが書いたPython関数に対応する操作は、その関数オブジェクト内saved_model.meta_graphs[0].graph_def.library.function[0]node_defコレクションだと思います。
jdehesa
弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.