Как я могу ускорить это вычисление Keras Attention?
Я написал собственный слой keras для AttentiveLSTMCell
и AttentiveLSTM(RNN)
в соответствии с новым подходом keras к AttentiveLSTM(RNN)
. Этот механизм внимания описан Bahdanau, где в модели кодировщика/декодера создается "контекстный" вектор из всех выходов кодировщика и скрытого состояния декодера. Затем я добавляю вектор контекста на каждый временной интервал к входу.
Модель используется для создания агента Dialog, но очень похожа на модели NMT в архитектуре (аналогичные задачи).
Однако, добавив этот механизм внимания, я несколько раз сократил обучение своей сети, и мне очень хотелось бы знать, как я мог бы написать часть кода, которая замедляет ее настолько эффективно.
Основная задача вычисления выполняется здесь:
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
# attention mechanism
# repeat the hidden state to the length of the sequence
_stm = K.repeat(h_tm1, self.annotation_timesteps)
# multiplty the weight matrix with the repeated (current) hidden state
_Wxstm = K.dot(_stm, self.kernel_w)
# calculate the attention probabilities
# self._uh is of shape (batch, timestep, self.units)
et = K.dot(activations.tanh(_Wxstm + self._uh), K.expand_dims(self.kernel_v))
at = K.exp(et)
at_sum = K.sum(at, axis=1)
at_sum_repeated = K.repeat(at_sum, self.annotation_timesteps)
at /= at_sum_repeated # vector of size (batchsize, timesteps, 1)
# calculate the context vector
context = K.squeeze(K.batch_dot(at, self.annotations, axes=1), axis=1)
# append the context vector to the inputs
inputs = K.concatenate([inputs, context])
в call
методе в AttentiveLSTMCell
(один временный шаг).
Полный код можно найти здесь. Если необходимо предоставить некоторые данные и способы взаимодействия с моделью, то я могу это сделать.
Есть идеи? Я, конечно же, тренируюсь на GPU, если здесь есть что-то умное.
Ответы
Ответ 1
Я бы рекомендовал тренировать вашу модель, используя relu, а не tanh, так как эта операция значительно быстрее вычисляется. Это позволит сэкономить время вычислений по порядку ваших учебных примеров. * Средняя длина последовательности на пример * количество эпох.
Кроме того, я бы оценил улучшение производительности при добавлении вектора контекста, имея в виду, что это замедлит ваш цикл итерации по другим параметрам. Если это не даст вам большого улучшения, возможно, стоит попробовать другие подходы.
Ответ 2
Вы изменили класс LSTM, который хорош для вычислений ЦП, но вы упомянули, что вы тренируетесь на GPU.
Я рекомендую заглянуть в рекуррентную реализацию cudnn или далее в используемую часть tf. Может быть, вы можете расширить код там.