Как добавить, если условие в графе TensorFlow?
Скажем, у меня есть следующий код:
x = tf.placeholder("float32", shape=[None, ins_size**2*3], name = "x_input")
condition = tf.placeholder("int32", shape=[1, 1], name = "condition")
W = tf.Variable(tf.zeros([ins_size**2*3,label_option]), name = "weights")
b = tf.Variable(tf.zeros([label_option]), name = "bias")
if condition > 0:
y = tf.nn.softmax(tf.matmul(x, W) + b)
else:
y = tf.nn.softmax(tf.matmul(x, W) - b)
Будет ли оператор if
работать в вычислении (я так не думаю)? Если нет, как я могу добавить оператор if
в график расчета TensorFlow?
Ответы
Ответ 1
Вы правы, что инструкция if
здесь не работает, потому что условие оценивается во время построения графика, тогда как, предположительно, вы хотите, чтобы условие зависело от значения, введенного в заполнитель во время выполнения. (Фактически, он всегда будет принимать первую ветвь, потому что condition > 0
оценивается как Tensor
, который "правдивый" в Python. )
Чтобы поддерживать поток условного управления, TensorFlow предоставляет оператор tf.cond()
, который оценивает одну из двух ветвей, в зависимости от логического условия. Чтобы показать вам, как его использовать, я переписал вашу программу, чтобы condition
было скалярным значением tf.int32
для простоты:
x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input")
condition = tf.placeholder(tf.int32, shape=[], name="condition")
W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights")
b = tf.Variable(tf.zeros([label_option]), name="bias")
y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b)
Ответ 2
TensorFlow 2.0
В TF 2.0 появилась функция AutoGraph, которая позволяет JIT компилировать код Python в графические исполнения. Это означает, что вы можете использовать операторы потока управления python (да, это включает операторы if
). Из документов,
Автограф поддерживает общие заявления Python, как while
, for
, if
, break
, continue
и return
с поддержкой вложенности. Это означает, что вы можете использовать выражения Tensor в условии операторов while
и if
или выполнять итерацию по Tensor в цикле for
.
Вам нужно будет определить функцию, реализующую вашу логику, и аннотировать ее с помощью функции tf.function
. Вот модифицированный пример из документации:
import tensorflow as tf
@tf.function
def sum_even(items):
s = 0
for c in items:
if tf.equal(c % 2, 0):
s += c
return s
sum_even(tf.constant([10, 12, 15, 20]))
# <tf.Tensor: id=1146, shape=(), dtype=int32, numpy=42>