Как мне получить текущее значение переменной?
Предположим, у нас есть переменная:
x = tf.Variable(...)
Эта переменная может быть обновлена в процессе обучения с помощью метода assign()
.
Каков наилучший способ получить текущее значение переменной?
Я знаю, что мы могли бы использовать это:
session.run(x)
Но я боюсь, что это вызовет целую цепочку операций.
В Теано, вы могли бы просто сделать
y = theano.shared(...)
y_vals = y.get_value()
Я ищу эквивалентную вещь в TensorFlow.
Ответы
Ответ 1
В общем, session.run(x)
будет оценивать только те узлы, которые необходимы для вычисления x
и ничего больше, поэтому он должен быть относительно дешевым, если вы хотите проверить значение переменной.
Посмотрите на этот отличный ответ fooobar.com/questions/56115/... для получения дополнительной информации.
Ответ 2
Единственный способ получить значение переменной - запустить ее в session
. В FAQ написано что:
Объект Tensor - это символический дескриптор результата операции, но на самом деле он не содержит значений выходных данных операции.
Таким образом, эквивалент TF будет:
import tensorflow as tf
x = tf.Variable([1.0, 2.0])
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
v = sess.run(x)
print(v) # will show you your variable.
Часть с init = global_variables_initializer()
важна и должна быть сделана для инициализации переменных.
Кроме того, взгляните на InteractiveSession, если вы работаете в IPython.
Ответ 3
tf.Print
может упростить вашу жизнь!
tf.Print
будет печатать значение тензора (ов), которое вы укажете ему для печати, в тот момент, когда строка кода tf.Print
вызывается в вашем коде при оценке вашего кода.
Итак, например:
import tensorflow as tf
x = tf.Variable([1.0, 2.0])
x = tf.Print(x,[x])
x = 2* x
tf.initialize_all_variables()
sess = tf.Session()
sess.run()
[1.0 2.0]
поскольку он печатает значение x
в тот момент, когда строка tf.Print
. Если вместо этого вы делаете
v = x.eval()
print(v)
вы получите:
[2.0 4.0]
потому что он даст вам окончательное значение x.
Ответ 4
Как они отменили tf.Variable()
в tensorflow 2.0.0
,
Если вы хотите извлечь значения из tensor(ie "net")
, вы можете использовать это,
net.[tf.newaxis,:,:].numpy().