TensorFlow: получение всех состояний из RNN
Как вы получаете все скрытые состояния от tf.nn.rnn()
или tf.nn.dynamic_rnn()
в TensorFlow? API дает мне только конечное состояние.
Первым вариантом было бы написать цикл при построении модели, которая работает непосредственно на RNNCell. Однако количество временных меток не фиксировано для меня и зависит от входящей партии.
Некоторые параметры - использовать GRU или написать собственный RNNCell, который объединяет состояние с выходом. Первый выбор не является достаточно общим, и последний кажется слишком хриплым.
Другой вариант - сделать что-то вроде ответов в этом вопросе, получив все переменные из RNN. Однако я не уверен, как здесь отделять скрытые состояния от других переменных стандартным образом.
Есть ли хороший способ получить все скрытые состояния из RNN, все еще используя API RNN, предоставляемые библиотекой?
Ответы
Ответ 1
tf.nn.dynamic_rnn (также tf.nn.static_rnn) имеет два возвращаемых значения; "выходы" , "состояние" (https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
Как вы сказали, "состояние" является конечным состоянием RNN, но "выходы" - это все скрытые состояния RNN (какая форма [batch_size, max_time, cell.output_size])
Вы можете использовать "выходы" в качестве скрытых состояний RNN, потому что в большинстве предоставляемых библиотекой RNNCell "вывод" и "состояние" одинаковы. (кроме LSTMCell)
Ответ 2
Я уже создал PR здесь, и это может помочь вам справиться с простыми случаями
Позвольте мне кратко объяснить мою реализацию, поэтому вы можете написать свою версию, если вам нужно. Основная часть - это модификация функции _time_step
:
def _time_step(time, output_ta_t, state, *args):
Параметры остаются неизменными, за исключением того, что передается дополнительная *args
. Но почему args
? Это потому, что я хочу поддерживать привычное поведение тензорного потока. Вы можете вернуть конечное состояние, просто проигнорировав параметр args
:
if states_ta is not None:
# If you want to return all states, set `args` to be `states_ta`
loop_vars = (time, output_ta, state, states_ta)
else:
# If you want the final state only, ignore `args`
loop_vars = (time, output_ta, state)
Как его использовать?
if args:
args = tuple(
ta.write(time, out) for ta, out in zip(args[0], [new_state])
)
На самом деле это всего лишь модификация следующих (оригинальных) кодов:
output_ta_t = tuple(
ta.write(time, out) for ta, out in zip(output_ta_t, output)
)
Теперь args
должен содержать все состояния, которые вы хотите.
После всех выполненных выше работ вы можете выбрать состояния (или конечное состояние) со следующими кодами:
_, output_final_ta, *state_info = control_flow_ops.while_loop( ...
и
if states_ta is not None:
final_state, states_final_ta = state_info
else:
final_state, states_final_ta = state_info[0], None
Хотя я не тестировал его в сложных случаях, он должен работать в "простом" состоянии (вот мои тестовые примеры)