ValueError: попытка обмена переменной rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel
Это код:
X = tf.placeholder(tf.float32, [batch_size, seq_len_1, 1], name='X')
labels = tf.placeholder(tf.float32, [None, alpha_size], name='labels')
rnn_cell = tf.contrib.rnn.BasicLSTMCell(512)
m_rnn_cell = tf.contrib.rnn.MultiRNNCell([rnn_cell] * 3, state_is_tuple=True)
pre_prediction, state = tf.nn.dynamic_rnn(m_rnn_cell, X, dtype=tf.float32)
Это полная ошибка:
ValueError: Попытка обмена переменной rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel, но заданная форма (1024, 2048) и найденная форма (513, 2048).
Я использую GPU-версию тензорного потока.
Ответы
Ответ 1
Я столкнулся с аналогичной проблемой, когда обновился до v1.2 (tensorflow-gpu).
Вместо использования [rnn_cell]*3
я создал 3 rnn_cells
(stacked_rnn) циклом (так, чтобы они не делили переменные) и передал MultiRNNCell
с помощью stacked_rnn
, и проблема исчезла. Я не уверен, что это правильный способ сделать это.
stacked_rnn = []
for iiLyr in range(3):
stacked_rnn.append(tf.nn.rnn_cell.LSTMCell(num_units=512, state_is_tuple=True))
MultiLyr_cell = tf.nn.rnn_cell.MultiRNNCell(cells=stacked_rnn, state_is_tuple=True)
Ответ 2
Официальный учебник TensorFlow рекомендует этот способ определения сети LSTM:
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(lstm_size)
stacked_lstm = tf.contrib.rnn.MultiRNNCell(
[lstm_cell() for _ in range(number_of_layers)])
Вы можете найти его здесь: https://www.tensorflow.org/tutorials/recurrent
На самом деле это почти тот же подход, который предложил Васи Ахмад и Маоси Чен, но, возможно, в немного более элегантной форме.
Ответ 3
Я предполагаю, что ваши ячейки RNN на каждом из ваших трех слоев имеют одну и ту же форму ввода и вывода.
В слое 1 размер входного файла равен 513 = 1 (ваш размер x) + 512 (размер скрытого слоя) для каждой отметки времени для каждой партии.
В слое 2 и 3 входное измерение равно 1024 = 512 (выход с предыдущего уровня) + 512 (выход из предыдущей метки времени).
То, как вы складываете свой MultiRNNCell, вероятно, подразумевает, что 3 ячейки используют одну и ту же форму ввода и вывода.
Я складываю MultiRNNCell, объявляя два отдельных типа ячеек, чтобы предотвратить их совместное использование формы ввода
rnn_cell1 = tf.contrib.rnn.BasicLSTMCell(512)
run_cell2 = tf.contrib.rnn.BasicLSTMCell(512)
stack_rnn = [rnn_cell1]
for i in range(1, 3):
stack_rnn.append(rnn_cell2)
m_rnn_cell = tf.contrib.rnn.MultiRNNCell(stack_rnn, state_is_tuple = True)
Тогда я смогу обучать свои данные без этой ошибки.
Я не уверен, правильна ли моя догадка, но она работает для меня. Надеюсь, это сработает для вас.