Tensorflow: как получить все переменные из rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell
У меня есть настройка, где мне нужно инициализировать LSTM после основной инициализации, которая использует tf.initialize_all_variables()
. То есть Я хочу позвонить tf.initialize_variables([var_list])
Есть ли способ собрать все внутренние обучаемые переменные для обоих:
- rnn_cell.BasicLSTM
- rnn_cell.MultiRNNCell
чтобы я мог инициализировать JUST эти параметры?
Основная причина, по которой я хочу это, состоит в том, что я не хочу повторно инициализировать некоторые обучаемые значения из более ранних версий.
Ответы
Ответ 1
Самый простой способ решить вашу проблему - использовать область переменной. Имена переменных в пределах области будут иметь префикс с именем. Вот короткий фрагмент:
cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
# Execute the LSTM cell here in any way, for example:
for i in range(num_steps):
output[i], state = cell(input_data[i], state)
# Retrieve just the LSTM variables.
lstm_variables = [v for v in tf.all_variables()
if v.name.startswith(vs.name)]
# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)
Он будет работать аналогично с MultiRNNCell
.
EDIT: изменено tf.trainable_variables
на tf.all_variables()
Ответ 2
Вы также можете использовать tf.get_collection()
:
cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
# Execute the LSTM cell here in any way, for example:
for i in range(num_steps):
output[i], state = cell(input_data[i], state)
lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)
(частично скопирован из ответа Рафаля)
Обратите внимание, что последняя строка эквивалентна пониманию списка в коде Rafal.
В принципе, tensorflow хранит глобальный набор переменных, который может быть выбран с помощью tf.all_variables()
или tf.get_collection(tf.GraphKeys.VARIABLES)
. Если вы укажете scope
(имя области) в tf.get_collection()
, то вы получите только тензоры (переменные в этом случае) в коллекции чьи области находятся под указанной областью.
EDIT:
Вы можете также использовать tf.GraphKeys.TRAINABLE_VARIABLES
для получения только обучаемых переменных. Но так как vanilla BasicLSTMCell не инициализирует какую-либо не обучаемую переменную, обе будут функционально эквивалентными. Для получения полного списка коллекций графов по умолчанию проверьте этот вне.