Soweit ich weiß, ist Variable
die Standardoperation zum Erstellen einer Variablen, und get_variable
wird hauptsächlich für die gemeinsame Nutzung von Gewicht verwendet.
Auf der einen Seite schlagen einige Leute vor, get_variable
anstelle der primitiven Variable
-Operation zu verwenden, wann immer Sie eine Variable benötigen. Andererseits sehe ich lediglich die Verwendung von get_variable
in den offiziellen Dokumenten und Demos von TensorFlow.
Daher möchte ich einige Faustregeln kennenlernen, um diese beiden Mechanismen richtig anzuwenden. Gibt es "Standard" -Prinzipien?
Ich würde empfehlen, immer tf.get_variable(...)
zu verwenden. Dadurch wird es einfacher, den Code umzuwandeln, wenn Sie Variablen jederzeit gemeinsam nutzen möchten, z. in einer Multi-GPU-Einstellung (siehe das Multi-GPU-CIFAR-Beispiel). Es gibt keinen Nachteil.
Reiner tf.Variable
ist untergeordnet; Irgendwann existierte tf.get_variable()
nicht, so dass ein Teil des Codes immer noch die Low-Level-Methode verwendet.
tf.Variable ist eine Klasse, und es gibt verschiedene Möglichkeiten, tf.Variable zu erstellen, einschließlich tf.Variable .__ init__ und tf.get_variable.
tf.Variable .__ init__: Erstellt eine neue Variable mit initial_value.
W = tf.Variable(<initial-value>, name=<optional-name>)
tf.get_variable: Ruft eine vorhandene Variable mit diesen Parametern ab oder erstellt eine neue. Sie können auch den Initialisierer verwenden.
W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
regularizer=None, trainable=True, collections=None)
Es ist sehr nützlich, Initialisierer wie xavier_initializer zu verwenden:
W = tf.get_variable("W", shape=[784, 256],
initializer=tf.contrib.layers.xavier_initializer())
Weitere Informationen finden Sie unter https://www.tensorflow.org/versions/r0.8/api_docs/python/state_ops.html#Variable .
Ich kann zwei Hauptunterschiede zwischen den beiden finden:
Erstens: tf.Variable
erstellt immer eine neue Variable, unabhängig davon, ob tf.get_variable
aus dem Graphen eine vorhandene Variable mit diesen Parametern abruft. Wenn sie nicht vorhanden ist, wird eine neue erstellt.
tf.Variable
erfordert, dass ein Anfangswert angegeben wird.
Es ist wichtig zu verdeutlichen, dass die Funktion tf.get_variable
dem Namen den aktuellen Variablenbereich voranstellt, um Wiederverwendungsprüfungen durchzuführen. Zum Beispiel:
with tf.variable_scope("one"):
a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
c = tf.get_variable("v", [1]) #c.name == "one/v:0"
with tf.variable_scope("two"):
d = tf.get_variable("v", [1]) #d.name == "two/v:0"
e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"
assert(a is c) #Assertion is true, they refer to the same object.
assert(a is d) #AssertionError: they are different objects
assert(d is e) #AssertionError: they are different objects
Der letzte Assertionsfehler ist interessant: Zwei Variablen mit demselben Namen unter demselben Gültigkeitsbereich sollen dieselbe Variable sein. Wenn Sie jedoch die Namen der Variablen d
und e
testen, werden Sie feststellen, dass Tensorflow den Namen der Variablen e
geändert hat:
d.name #d.name == "two/v:0"
e.name #e.name == "two/v_1:0"
Ein weiterer Unterschied besteht darin, dass sich eine in der ('variable_store',)
-Sammlung befindet, die andere jedoch nicht.
Bitte sehen Sie den Quellcode code :
def _get_default_variable_store():
store = ops.get_collection(_VARSTORE_KEY)
if store:
return store[0]
store = _VariableStore()
ops.add_to_collection(_VARSTORE_KEY, store)
return store
Lassen Sie mich das veranschaulichen:
import tensorflow as tf
from tensorflow.python.framework import ops
embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="Word_embeddings_1", dtype=tf.float32)
embedding_2 = tf.get_variable("Word_embeddings_2", shape=[30522, 1024])
graph = tf.get_default_graph()
collections = graph.collections
for c in collections:
stores = ops.get_collection(c)
print('collection %s: ' % str(c))
for k, store in enumerate(stores):
try:
print('\t%d: %s' % (k, str(store._vars)))
except:
print('\t%d: %s' % (k, str(store)))
print('')
Die Ausgabe:
collection ('__variable_store',): 0: {'Word_embeddings_2': <tf.Variable 'Word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}