Как установить состояние RNN TensorFlow, когда state_is_tuple = True?
Я написал модель языка RNN, используя TensorFlow. Модель реализована как класс RNN
. Структура графа встроена в конструктор, а методы RNN.train
и RNN.test
запускают его.
Я хочу иметь возможность reset состояния RNN при переходе к новому документу в наборе обучения или когда я хочу запустить проверку, установленную во время обучения. Я делаю это, управляя состоянием внутри цикла обучения, передавая его в график через словарь фида.
В конструкторе я определяю RNN следующим образом
cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
initial_state=self.state)
Цикл обучения выглядит следующим образом:
for document in document:
state = session.run(self.reset_state)
for x, y in document:
_, state = session.run([self.train_step, self.next_state],
feed_dict={self.x:x, self.y:y, self.state:state})
x
и y
- это пакеты данных обучения в документе. Идея состоит в том, что я передаю последнее состояние после каждой партии, за исключением случаев, когда я запускаю новый документ, когда я обнуляю состояние, запустив self.reset_state
.
Все это работает. Теперь я хочу изменить свой RNN, чтобы использовать рекомендуемый state_is_tuple=True
. Однако я не знаю, как передать более сложный объект состояния LSTM через словарь фида. Также я не знаю, какие аргументы передаются в строку self.state = tf.placeholder(...)
в моем конструкторе.
Какая здесь правильная стратегия? Для dynamic_rnn
доступно еще немного кода или документации для примера.
Проблемы с TensorFlow 2695 и 2838 отображаются соответствующие.
A сообщение в блоге на WILDML решает эти проблемы, но прямо не объясняет ответ.
См. также TensorFlow: запомните состояние LSTM для следующей партии (с сохранением состояния LSTM).
Ответы
Ответ 1
Одна проблема с заполнитель Tensorflow заключается в том, что вы можете подавать его только с помощью списка Python или массива Numpy (я думаю). Таким образом, вы не можете сохранить состояние между запусками в кортежах LSTMStateTuple.
Я решил это, сохранив состояние в тензоре, подобном этому
initial_state = np.zeros((num_layers, 2, batch_size, state_size))
У вас есть два компонента в слое LSTM, состояние ячейки и скрытое состояние, вот что происходит от "2". (эта статья замечательная: https://arxiv.org/pdf/1506.00019.pdf)
При создании графика вы распаковываете и создаете состояние кортежа следующим образом:
state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
[tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
for idx in range(num_layers)]
)
Затем вы получаете новое состояние обычным способом
cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)
Это не должно быть так... возможно, они работают над решением.
Ответ 2
Простой способ подачи в состоянии RNN состоит в том, чтобы просто загружать оба компонента кортежа состояний индивидуально.
# Constructing the graph
self.state = rnn_cell.zero_state(...)
self.output, self.next_state = tf.nn.dynamic_rnn(
rnn_cell,
self.input,
initial_state=self.state)
# Running with initial state
output, state = sess.run([self.output, self.next_state], feed_dict={
self.input: input
})
# Running with subsequent state:
output, state = sess.run([self.output, self.next_state], feed_dict={
self.input: input,
self.state[0]: state[0],
self.state[1]: state[1]
})