TensorFlow: получение переменной по имени
При использовании TensorFlow Python API я создал переменную (без указания ее name
в конструкторе), а ее свойство name
имеет значение "Variable_23:0"
. Когда я пытаюсь выбрать эту переменную с помощью tf.get_variable("Variable23")
, вместо нее создается новая переменная с именем "Variable_23_1:0"
. Как правильно выбрать "Variable_23"
вместо создания нового?
Что я хочу сделать, это выбрать переменную по имени и повторно инициализировать ее, чтобы я мог точно определить вес.
Ответы
Ответ 1
Самый простой способ получить переменную по имени - найти ее в коллекции tf.global_variables()
:
var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0]
Это хорошо подходит для повторного использования существующих переменных. Более структурированный подход: когда вы хотите обмениваться переменными между несколькими частями модели, вы можете ознакомиться в Разделение переменных переменных.
Ответ 2
Функция get_variable()
создает новую переменную или возвращает ранее созданную get_variable()
. Он не будет возвращать переменную, созданную с помощью tf.Variable()
. Вот краткий пример:
>>> with tf.variable_scope("foo"):
... bar1 = tf.get_variable("bar", (2,3)) # create
...
>>> with tf.variable_scope("foo", reuse=True):
... bar2 = tf.get_variable("bar") # reuse
...
>>> with tf.variable_scope("", reuse=True): # root variable scope
... bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
...
>>> (bar1 is bar2) and (bar2 is bar3)
True
Если вы не создали переменную с помощью tf.get_variable()
, у вас есть пара вариантов. Во-первых, вы можете использовать tf.global_variable()
(как предлагает @mrry):
>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
>>> bar1 is bar2
True
Или вы можете использовать tf.get_collection()
так:
>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
>>> bar1 is bar2
True
Edit
Вы также можете использовать get_tensor_by_name()
:
>>> bar1 = tf.Variable(0.0, name="bar")
>>> graph = tf.get_default_graph()
>>> bar2 = tf.get_tensor_by_name("bar:0")
>>> bar1 is bar2
True
Напомним, что тензор является результатом операции. Он имеет то же имя, что и операция, плюс :0
. Если операция имеет несколько выходов, они имеют то же имя, что и операция плюс :0
, :1
, :2
и т.д.