webentwicklung-frage-antwort-db.com.de

Unterschied zwischen Variable und get_variable in TensorFlow

Soweit ich weiß, ist Variable die Standardoperation zum Erstellen einer Variablen, und get_variable wird hauptsächlich für die gemeinsame Nutzung von Gewicht verwendet.

Auf der einen Seite schlagen einige Leute vor, get_variable anstelle der primitiven Variable-Operation zu verwenden, wann immer Sie eine Variable benötigen. Andererseits sehe ich lediglich die Verwendung von get_variable in den offiziellen Dokumenten und Demos von TensorFlow.

Daher möchte ich einige Faustregeln kennenlernen, um diese beiden Mechanismen richtig anzuwenden. Gibt es "Standard" -Prinzipien?

99
Lifu Huang

Ich würde empfehlen, immer tf.get_variable(...) zu verwenden. Dadurch wird es einfacher, den Code umzuwandeln, wenn Sie Variablen jederzeit gemeinsam nutzen möchten, z. in einer Multi-GPU-Einstellung (siehe das Multi-GPU-CIFAR-Beispiel). Es gibt keinen Nachteil. 

Reiner tf.Variable ist untergeordnet; Irgendwann existierte tf.get_variable() nicht, so dass ein Teil des Codes immer noch die Low-Level-Methode verwendet.

82
Lukasz Kaiser

tf.Variable ist eine Klasse, und es gibt verschiedene Möglichkeiten, tf.Variable zu erstellen, einschließlich tf.Variable .__ init__ und tf.get_variable. 

tf.Variable .__ init__: Erstellt eine neue Variable mit initial_value.

W = tf.Variable(<initial-value>, name=<optional-name>)

tf.get_variable: Ruft eine vorhandene Variable mit diesen Parametern ab oder erstellt eine neue. Sie können auch den Initialisierer verwenden.

W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
       regularizer=None, trainable=True, collections=None)

Es ist sehr nützlich, Initialisierer wie xavier_initializer zu verwenden:

W = tf.get_variable("W", shape=[784, 256],
       initializer=tf.contrib.layers.xavier_initializer())

Weitere Informationen finden Sie unter https://www.tensorflow.org/versions/r0.8/api_docs/python/state_ops.html#Variable .

61
Sung Kim

Ich kann zwei Hauptunterschiede zwischen den beiden finden:

  1. Erstens: tf.Variable erstellt immer eine neue Variable, unabhängig davon, ob tf.get_variable aus dem Graphen eine vorhandene Variable mit diesen Parametern abruft. Wenn sie nicht vorhanden ist, wird eine neue erstellt.

  2. tf.Variable erfordert, dass ein Anfangswert angegeben wird.

Es ist wichtig zu verdeutlichen, dass die Funktion tf.get_variable dem Namen den aktuellen Variablenbereich voranstellt, um Wiederverwendungsprüfungen durchzuführen. Zum Beispiel:

with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
    c = tf.get_variable("v", [1]) #c.name == "one/v:0"

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"

assert(a is c)  #Assertion is true, they refer to the same object.
assert(a is d)  #AssertionError: they are different objects
assert(d is e)  #AssertionError: they are different objects

Der letzte Assertionsfehler ist interessant: Zwei Variablen mit demselben Namen unter demselben Gültigkeitsbereich sollen dieselbe Variable sein. Wenn Sie jedoch die Namen der Variablen d und e testen, werden Sie feststellen, dass Tensorflow den Namen der Variablen e geändert hat:

d.name   #d.name == "two/v:0"
e.name   #e.name == "two/v_1:0"
38
Jadiel de Armas

Ein weiterer Unterschied besteht darin, dass sich eine in der ('variable_store',)-Sammlung befindet, die andere jedoch nicht. 

Bitte sehen Sie den Quellcode code

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store

Lassen Sie mich das veranschaulichen: 

import tensorflow as tf
from tensorflow.python.framework import ops

embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="Word_embeddings_1", dtype=tf.float32) 
embedding_2 = tf.get_variable("Word_embeddings_2", shape=[30522, 1024])

graph = tf.get_default_graph()
collections = graph.collections

for c in collections:
    stores = ops.get_collection(c)
    print('collection %s: ' % str(c))
    for k, store in enumerate(stores):
        try:
            print('\t%d: %s' % (k, str(store._vars)))
        except:
            print('\t%d: %s' % (k, str(store)))
    print('')

Die Ausgabe: 

collection ('__variable_store',): 0: {'Word_embeddings_2': <tf.Variable 'Word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}

1
lerner