Разница между переменной и get_variable в TensorFlow
Насколько мне известно, Variable
- это операция по умолчанию для создания переменной, а get_variable
в основном используется для распределения веса.
С одной стороны, есть люди, предлагающие использовать get_variable
вместо примитивной операции Variable
всякий раз, когда вам нужна переменная. С другой стороны, я просто вижу использование get_variable
в официальных документах и демонстрациях get_variable
.
Таким образом, я хочу знать некоторые эмпирические правила о том, как правильно использовать эти два механизма. Существуют ли какие-либо "стандартные" принципы?
Ответы
Ответ 1
Я бы рекомендовал всегда использовать tf.get_variable(...)
- это упростит рефакторинг вашего кода, если вам нужно совместно использовать переменные в любое время, например. в настройке с несколькими gpu (см. пример CIFAR с несколькими gpu). Нет недостатка в этом.
Чистый tf.Variable
является более низким уровнем; в какой-то момент tf.get_variable()
не существовало, поэтому какой-то код по-прежнему использует метод низкого уровня.
Ответ 2
tf.Variable является классом, и существует несколько способов создания tf.Variable, включая tf.Variable.__init__
и tf.get_variable
.
tf.Variable.__init__
: создает новую переменную с initial_value.
W = tf.Variable(<initial-value>, name=<optional-name>)
tf.get_variable
: получает существующую переменную с этими параметрами или создает новую. Вы также можете использовать инициализатор.
W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
regularizer=None, trainable=True, collections=None)
Очень полезно использовать инициализаторы, такие как xavier_initializer
:
W = tf.get_variable("W", shape=[784, 256],
initializer=tf.contrib.layers.xavier_initializer())
Больше информации здесь.
Ответ 3
Я могу найти два основных различия между одним и другим:
Во-первых, tf.Variable
всегда создает новую переменную, тогда как tf.get_variable
получает существующую переменную с указанными параметрами из графика, а если она не существует, создает новую.
tf.Variable
требует указания начального значения.
Важно уточнить, что функция tf.get_variable
добавляет префикс имени к текущей области действия переменной для выполнения повторного использования. Например:
with tf.variable_scope("one"):
a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
c = tf.get_variable("v", [1]) #c.name == "one/v:0"
with tf.variable_scope("two"):
d = tf.get_variable("v", [1]) #d.name == "two/v:0"
e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"
assert(a is c) #Assertion is true, they refer to the same object.
assert(a is d) #AssertionError: they are different objects
assert(d is e) #AssertionError: they are different objects
Последняя ошибка утверждения интересна: две переменные с одинаковыми именами в одной области видимости должны быть одной и той же переменной. Но если вы проверите имена переменных d
и e
, вы поймете, что Tensorflow изменил имя переменной e
:
d.name #d.name == "two/v:0"
e.name #e.name == "two/v_1:0"
Ответ 4
Другое отличие состоит в том, что один находится в коллекции ('variable_store',)
, а другой нет.
Пожалуйста, смотрите исходный код:
def _get_default_variable_store():
store = ops.get_collection(_VARSTORE_KEY)
if store:
return store[0]
store = _VariableStore()
ops.add_to_collection(_VARSTORE_KEY, store)
return store
Позвольте мне проиллюстрировать это:
import tensorflow as tf
from tensorflow.python.framework import ops
embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32)
embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024])
graph = tf.get_default_graph()
collections = graph.collections
for c in collections:
stores = ops.get_collection(c)
print('collection %s: ' % str(c))
for k, store in enumerate(stores):
try:
print('\t%d: %s' % (k, str(store._vars)))
except:
print('\t%d: %s' % (k, str(store)))
print('')
Вывод:
collection ('__variable_store',): 0: {'word_embeddings_2':
<tf.Variable 'word_embeddings_2:0' shape=(30522, 1024)
dtype=float32_ref>}