Получить последний выход dynamic_rnn в тензорном потоке?
Я использую dynamic_rnn для обработки данных MNIST:
# LSTM Cell
lstm = rnn_cell.LSTMCell(num_units=200,
forget_bias=1.0,
initializer=tf.random_normal)
# Initial state
istate = lstm.zero_state(batch_size, "float")
# Get lstm cell output
output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate)
# Output at last time point T
output_at_T = output[:, 27, :]
Полный код: http://pastebin.com/bhf9MgMe
Вход в lstm равен (batch_size, sequence_length, input_size)
В результате размеры output_at_T
равны (batch_size, sequence_length, num_units)
, где num_units=200
.
Мне нужно получить последний результат по размеру sequence_length
. В приведенном выше коде это жестко закодировано как 27
. Тем не менее, я не знаю sequence_length
заранее, так как он может меняться от партии к партии в моем приложении.
Я пробовал:
output_at_T = output[:, -1, :]
но он говорит, что отрицательная индексация еще не реализована, и я попытался использовать переменную-заполнителя, а также константу (в которую я мог бы идеально комбинировать sequence_length
для определенной партии); ни работали.
Любой способ реализовать что-то подобное в тензорном потоке atm?
Ответы
Ответ 1
Вы заметили, что есть два выхода из dynamic_rnn?
- Выход 1, позвоните ему h, имеет все выходы на каждом шаге (т.е. h_1, h_2 и т.д.),
- Выход 2, final_state, имеет два элемента: cell_state и последний вывод для каждого элемента пакета (до тех пор, пока вы вводите длину последовательности в dynamic_rnn).
Итак, из:
h, final_state= tf.dynamic_rnn( ..., sequence_length=[batch_size_vector], ... )
последнее состояние для каждого элемента в партии:
final_state.h
Обратите внимание, что это включает случай, когда длина последовательности отличается для каждого элемента пакета, поскольку мы используем аргумент sequence_length.
Ответ 2
Я новичок в Stackoverflow и не могу комментировать, поэтому я пишу этот новый ответ. @VM_AI, последний индекс tf.shape(output)[1] - 1
.
Итак, повторное использование вашего ответа:
# Let first fetch the last index of seq length
# last_index would have a scalar value
last_index = tf.shape(output)[1] - 1
# Then let reshape the output to [sequence_length,batch_size,num_units]
# for convenience
output_rs = tf.transpose(output,[1,0,2])
# Last state of all batches
last_state = tf.nn.embedding_lookup(output_rs,last_index)
Это работает для меня.
Ответ 3
output[:, -1, :]
работает с Tensorflow 1.x сейчас!
Ответ 4
Это то, что gather_nd для!
def extract_axis_1(data, ind):
"""
Get specified elements along the first axis of tensor.
:param data: Tensorflow tensor that will be subsetted.
:param ind: Indices to take (one for each element along axis 0 of data).
:return: Subsetted tensor.
"""
batch_range = tf.range(tf.shape(data)[0])
indices = tf.stack([batch_range, ind], axis=1)
res = tf.gather_nd(data, indices)
return res
В вашем случае (предполагая, что sequence_length
является одномерным тензором с длиной каждого элемента оси 0):
output = extract_axis_1(output, sequence_length - 1)
Теперь вывод представляет собой тензор размера [batch_size, num_cells]
.
Ответ 5
Вы должны иметь доступ к форме тензора output
, используя tf.shape(output)
. Функция tf.shape()
вернет 1-й тензор, содержащий размеры тензора output
. В вашем примере это будет (batch_size, sequence_length, num_units)
Затем вы можете извлечь значение output_at_T
как output[:, tf.shape(output)[1], :]
Ответ 6
В TensorFlow tf.shape
есть функция, позволяющая получить символическую интерпретацию формы, а не None
, возвращаемую output._shape[1]
. И после получения последнего индекса вы можете искать с помощью tf.nn.embedding_lookup
, который рекомендуется особенно когда данные, которые будут выбраны, высоки, так как это делает параллельный поиск 32 by default
.
# Let first fetch the last index of seq length
# last_index would have a scalar value
last_index = tf.shape(output)[1]
# Then let reshape the output to [sequence_length,batch_size,num_units]
# for convenience
output_rs = tf.transpose(output,[1,0,2])
# Last state of all batches
last_state = tf.nn.embedding_lookup(output_rs,last_index)
Это должно работать.
Просто чтобы уточнить, что сказал @Benoit Steiner. Его решение не будет работать, так как tf.shape вернет символическую интерпретацию значения формы, и это не может быть использовано для нарезки тензоров, т.е. Прямой индексации