Какова цель коллекций графов в TensorFlow?
API обсуждает Коллекции графов, которые, судя по коду, являются универсальным ключом/хранилищем данных. Какова цель этих коллекций?
Ответы
Ответ 1
Помните, что под капотом Tensorflow - это система для указания и последующего выполнения графиков потока вычислительных данных. Коллекции графов используются как часть отслеживания построенных графиков и того, как они должны выполняться. Например, когда вы создаете определенные виды op, например tf.train.batch_join
, код, который добавляет оп, также добавляет некоторые бегуны очереди в коллекцию графа QUEUE_RUNNERS
. Позже, когда вы вызываете start_queue_runners()
, по умолчанию он будет смотреть коллекцию QUEUE_RUNNERS
, чтобы узнать, какие бегуны запускаться.
Ответ 2
Я думаю, что для меня есть как минимум две выгоды:
- когда вы распространяете свою программу на нескольких графических процессорах или машинах, удобно собирать потери с разных устройств, находящихся в одной коллекции. Используйте tf.add_n, чтобы добавить их для накопления потерь.
- Чтобы обновить определенный набор переменных, таких как вес и предубеждения, по-своему.
Например:
import tensorflow as tf
w = tf.Variable([1,2,3], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
w2 = tf.Variable([11,22,32], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
weight_init_op = tf.variables_initializer(tf.get_collection_ref(tf.GraphKeys.WEIGHTS))
sess = tf.InteractiveSession()
weight_init_op.run()
for vari in tf.get_collection_ref(tf.GraphKeys.WEIGHTS):
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, vari.assign(0.2 * vari))
weight_update_ops = tf.get_collection_ref(tf.GraphKeys.UPDATE_OPS)
for op in weight_update_ops:
print(op.eval())
Выход:
[0.2 0.4 0.6]
[2.2 4.4 6.4]