TensorFlowでのVariableとget_variableの違い


125

私の知る限りでVariableは、は変数を作成するためのデフォルトの操作であり、get_variable主にウェイトシェアリングに使用されます。

一方で、変数が必要get_variableVariableときはいつでも、プリミティブ演算の代わりに使用することを提案する人がいます。一方、get_variableTensorFlowの公式ドキュメントとデモでの使用は見ただけです。

したがって、これら2つのメカニズムを正しく使用する方法について、いくつかの経験則を知りたいと思います。「標準」の原則はありますか?


6
ルカシュが言うようget_variableが新しい方法で、変数は(永遠にサポートされる可能性があります)古い方法である(PSは:彼は多くのTFでの変数名のスコープの書いた)
ヤロスラフBulatov

回答:


90

常に使用するtf.get_variable(...)ことをお勧めします。たとえば、マルチGPU設定などでいつでも変数を共有する必要がある場合は、コードをリファクタリングしやすくなります(マルチGPU CIFARの例を参照)。それにはマイナス面はありません。

純粋tf.Variableは下位レベルです。ある時点でtf.get_variable()は存在しなかったため、一部のコードはまだ低レベルの方法を使用しています。


5
回答ありがとうございます。しかし、私tf.Variabletf.get_variableどこにでも置き換える方法についてまだ質問があります。これは、変数をnumpy配列で初期化する場合、のようにクリーンで効率的な方法を見つけることができないためtf.Variableです。どのように解決しますか?ありがとう。
Lifu Huang 2016

68

tf.Variableはクラスであり、かつ含むtf.Variable作成するには、いくつかの方法がありますtf.Variable.__init__とはtf.get_variable

tf.Variable.__init__initial_valueで新しい変数を作成します。

W = tf.Variable(<initial-value>, name=<optional-name>)

tf.get_variable:これらのパラメーターを持つ既存の変数を取得するか、新しい変数を作成します。イニシャライザを使用することもできます。

W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
       regularizer=None, trainable=True, collections=None)

次のような初期化子を使用すると非常に便利ですxavier_initializer

W = tf.get_variable("W", shape=[784, 256],
       initializer=tf.contrib.layers.xavier_initializer())

詳細はこちら


はい、Variable実際にはそれを使用することを意味します__init__。のでget_variableとても便利ですが、私はなぜ最もTensorFlowコードIソーの使用を疑問に思うVariableの代わりにget_variable。それらの間で選択するときに考慮すべき慣習や要素はありますか?ありがとうございました!
Lifu Huang 2016

特定の値が必要な場合、Variableの使用は簡単です:x = tf.Variable(3)。
キム・ソン・キム

@SungKimを通常使用する場合tf.Variable()、切り捨てられた正規分布からランダムな値として初期化できます。これが私の例w1 = tf.Variable(tf.truncated_normal([5, 50], stddev = 0.01), name = 'w1')です。これに相当するものは何でしょうか?切り捨てられた法線が必要であることをどのように伝えるのですか?私はただやるべきw1 = tf.get_variable(name = 'w1', shape = [5,50], initializer = tf.truncated_normal, regularizer = tf.nn.l2_loss)ですか?
Euler_Salter 2017年

@Euler_Salter:を使用tf.truncated_normal_initializer()して、目的の結果を得ることができます。
ベータ版

46

1つと他の2つの主な違いを見つけることができます。

  1. 1つ目は、tf.Variable常に新しい変数を作成するのに対し、指定されたパラメーターを持つ既存の変数をグラフからtf.get_variable取得し、存在しない場合は新しい変数を作成することです。

  2. tf.Variable 初期値を指定する必要があります。

tf.get_variable再利用チェックを実行するには、関数が名前の前に現在の変数スコープを付けることを明確にすることが重要です。例えば:

with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
    c = tf.get_variable("v", [1]) #c.name == "one/v:0"

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"

assert(a is c)  #Assertion is true, they refer to the same object.
assert(a is d)  #AssertionError: they are different objects
assert(d is e)  #AssertionError: they are different objects

最後のアサーションエラーは興味深いものです。同じスコープで同じ名前の2つの変数は同じ変数であると想定されています。しかし、あなたは変数の名前をテストする場合de、あなたがTensorflowは、変数の名前を変更することを実現しますe

d.name   #d.name == "two/v:0"
e.name   #e.name == "two/v_1:0"

素晴らしい例!私はちょうど渡って来ているテンソルグラフネーミング操作でこのTensorFlowドキュメント、それを説明する:d.namee.nameIf the default graph already contained an operation named "answer", the TensorFlow would append "_1", "_2", and so on to the name, in order to make it unique.
Atlas7

2

もう1つの違いは、1つは('variable_store',)コレクションにあるが、もう1つはコレクションにないという点です。

ソースコードをご覧ください:

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store

それを説明しましょう:

import tensorflow as tf
from tensorflow.python.framework import ops

embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32) 
embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024])

graph = tf.get_default_graph()
collections = graph.collections

for c in collections:
    stores = ops.get_collection(c)
    print('collection %s: ' % str(c))
    for k, store in enumerate(stores):
        try:
            print('\t%d: %s' % (k, str(store._vars)))
        except:
            print('\t%d: %s' % (k, str(store)))
    print('')

出力:

collection ('__variable_store',): 0: {'word_embeddings_2': <tf.Variable 'word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.