Какова цель коллекций графов в TensorFlow?

API обсуждает Коллекции графов, которые, судя по коду, являются универсальным ключом/хранилищем данных. Какова цель этих коллекций?

Ответы

Ответ 1

Помните, что под капотом Tensorflow - это система для указания и последующего выполнения графиков потока вычислительных данных. Коллекции графов используются как часть отслеживания построенных графиков и того, как они должны выполняться. Например, когда вы создаете определенные виды op, например tf.train.batch_join, код, который добавляет оп, также добавляет некоторые бегуны очереди в коллекцию графа QUEUE_RUNNERS. Позже, когда вы вызываете start_queue_runners(), по умолчанию он будет смотреть коллекцию QUEUE_RUNNERS, чтобы узнать, какие бегуны запускаться.

Ответ 2

Я думаю, что для меня есть как минимум две выгоды:

  1. когда вы распространяете свою программу на нескольких графических процессорах или машинах, удобно собирать потери с разных устройств, находящихся в одной коллекции. Используйте tf.add_n, чтобы добавить их для накопления потерь.
  2. Чтобы обновить определенный набор переменных, таких как вес и предубеждения, по-своему.

Например:

import tensorflow as tf    
w = tf.Variable([1,2,3], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)    
w2 = tf.Variable([11,22,32], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
weight_init_op = tf.variables_initializer(tf.get_collection_ref(tf.GraphKeys.WEIGHTS))
sess = tf.InteractiveSession()
weight_init_op.run()
for vari in tf.get_collection_ref(tf.GraphKeys.WEIGHTS): 
    tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, vari.assign(0.2 * vari))
weight_update_ops = tf.get_collection_ref(tf.GraphKeys.UPDATE_OPS)
for op in weight_update_ops:
    print(op.eval())

Выход:

[0.2 0.4 0.6]
[2.2 4.4 6.4]