object_detection_tutorial TypeErrorの実行に関する問題:2つの必須の位置引数が欠落しているload()


11

私はtensorflowにかなり慣れていないので、object_detection_tutorialを実行しようとしています。TypeErrrorを取得していますが、修正方法がわかりません。

これは2つの引数がないload_model関数です:

タグ:必要なMetaGraphDefを識別する文字列タグのセット。これらは、SavedModel save()APIを使用して変数を保存するときに使用されるタグに対応している必要があります。

export_dir:SavedModelプロトコルバッファとロードされる変数が配置されているディレクトリ。

def load_model(model_name):
  base_url = 'http://download.tensorflow.org/models/object_detection/'
  model_file = model_name + '.tar.gz'
  model_dir = tf.keras.utils.get_file(
    fname=model_name, 
    origin=base_url + model_file,
    untar=True)

  model_dir = pathlib.Path(model_dir)/"saved_model"

  model = tf.saved_model.load(str(model_dir))
  model = model.signatures['serving_default']

  return model
WARNING:tensorflow:From <ipython-input-9-f8a3c92a04a4>:11: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-e10c73a22cc9> in <module>
      1 model_name = 'ssd_mobilenet_v1_coco_2017_11_17'
----> 2 detection_model = load_model(model_name)

<ipython-input-9-f8a3c92a04a4> in load_model(model_name)
      9   model_dir = pathlib.Path(model_dir)/"saved_model"
     10 
---> 11   model = tf.saved_model.load(str(model_dir))
     12   model = model.signatures['serving_default']
     13 

~/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',

TypeError: load() missing 2 required positional arguments: 'tags' and 'export_dir'

これを修正して私の最初のオブジェクト検出器:Dを実行するのを手伝ってくれませんか?

回答:


14

私は同じ問題を抱えていたので、1週間解決しようとしています。私はこれが解決策であるべきだと思います。

model = tf.compat.v2.saved_model.load(str(model_dir), None)

詳細は(公式ウェブサイトから)になります。

SavedModelをexport_dirからロードします。

tf.saved_model.load(
    export_dir,
    tags=None
)

エイリアス:

tf.compat.v1.saved_model.load_v2

tf.compat.v2.saved_model.load

1
私はあなたの解決策を使用し、別のエラーが発生しました。私は可能な限りすべてを更新し、それはうまくいきました!また、pathlibがインストールされていないというエラーも発生しました。
ドミニク

@Dominikより具体的にできますか?多分私はこのテンソルフローの冒険が私に多くの問題を解決するように導いたので私は助けることができます:D
Onur Baskin

4
@OnurBaskin後でエラーが発生します:TypeError:int()引数は、「Tensor」ではなく、文字列、バイトのようなオブジェクトまたは数値でなければなりません
kaitsu

@Dominik Tensorflowのバージョンだと思います。バージョン2.0(安定版)である必要があります。ここに私が尋ねた質問へのリンクがあります。多分あなたは正確なエラーを抱えています。また、「compat.v1」を必要とする古いインポートを検索します。後で多くのエラーが発生するはずですが、これが古いコードを移行する方法です。
Onur Baskin

@OnurBaskin私はかなり混乱しています。オブジェクト検出APIはTensorFlow 1バージョンとのみ互換性があると思いました。
Biiiiiird

0

私はそれがブランチの問題だったと思いtf_2_1_referenceブランチを使用することが私にとってはトリックでした:

igian@iGians-MBP models % git checkout tf_2_1_reference
M   research/object_detection/object_detection_tutorial.ipynb
Branch 'tf_2_1_reference' set up to track remote branch 'tf_2_1_reference' from 'origin'.
Switched to a new branch 'tf_2_1_reference'
igians@iGians-MBP models % jupyter notebook

次に、チュートリアルの各木星セルを初心者のように実行しました。

これは私が使用したブランチです:https : //github.com/tensorflow/models/tree/tf_2_1_reference

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.