Ответ 1
TL; DR: Если вы хотите tf.cond()
выполнить побочный эффект (например, присвоение) в одной из ветвей, вы должны создать op, который выполняет побочный эффект внутри, который вы передаете в tf.cond()
.
Поведение tf.cond()
немного неинтуитивно. Поскольку выполнение в графе TensorFlow перемещается вперед по графику, все операции, которые вы указываете в ветке или, должны выполняться до вычисления условия. Это означает, что как истинная, так и ложная ветки получают зависимую зависимость от tf.assign()
op, поэтому y
всегда получает значение 2
, даже если pred is
False`.
Решение состоит в том, чтобы создать tf.assign()
op внутри функции, которая определяет истинную ветвь. Например, вы можете структурировать свой код следующим образом:
pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
with tf.control_dependencies([tf.assign(x, [2])]):
return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
session.run(tf.initialize_all_variables())
print(y.eval(feed_dict={pred: False})) # ==> [1]
print(y.eval(feed_dict={pred: True})) # ==> [2]