He escrito una capa de keras personalizada para AttentiveLSTMCell
y AttentiveLSTM(RNN)
en línea con el nuevo enfoque de keras para RNN. Bahdanau describe este mecanismo de atención donde, en un modelo de codificador/decodificador, se crea un vector de "contexto" a partir de todas las salidas del codificador y el estado oculto actual del decodificador. Luego agrego el vector de contexto, en cada paso de tiempo, a la entrada.
El modelo se está utilizando para crear un agente de diálogo, pero es muy similar a los modelos NMT en arquitectura (tareas similares).
Sin embargo, al agregar este mecanismo de atención, he ralentizado el entrenamiento de mi red 5 veces, y realmente me gustaría saber cómo podría escribir la parte del código que lo está ralentizando tanto de una manera más eficiente.
La mayor parte del cálculo se hace aquí:
h_tm1 = states[0] # previous memory state c_tm1 = states[1] # previous carry state # attention mechanism # repeat the hidden state to the length of the sequence _stm = K.repeat(h_tm1, self.annotation_timesteps) # multiplty the weight matrix with the repeated (current) hidden state _Wxstm = K.dot(_stm, self.kernel_w) # calculate the attention probabilities # self._uh is of shape (batch, timestep, self.units) et = K.dot(activations.tanh(_Wxstm + self._uh), K.expand_dims(self.kernel_v)) at = K.exp(et) at_sum = K.sum(at, axis=1) at_sum_repeated = K.repeat(at_sum, self.annotation_timesteps) at /= at_sum_repeated # vector of size (batchsize, timesteps, 1) # calculate the context vector context = K.squeeze(K.batch_dot(at, self.annotations, axes=1), axis=1) # append the context vector to the inputs inputs = K.concatenate([inputs, context])
en el método de call
de AttentiveLSTMCell
(un paso de tiempo).
El código completo se puede encontrar aquí . Si es necesario que brinde algunos datos y formas de interactuar con el modelo, entonces puedo hacerlo.
¿Algunas ideas? Por supuesto, estoy entrenando en una GPU si hay algo inteligente aquí.
Recomendaría entrenar su modelo usando relu en lugar de tanh, ya que esta operación es significativamente más rápida de calcular. Esto le ahorrará tiempo de cálculo en el orden de sus ejemplos de entrenamiento * longitud de secuencia promedio por ejemplo * número de épocas.
Además, evaluaría la mejora del rendimiento al agregar el vector de contexto, teniendo en cuenta que esto ralentizará su ciclo de iteración en otros parámetros. Si no le está dando mucha mejora, podría valer la pena probar otros enfoques.
Modificaste la clase LSTM, que es buena para el cálculo de la CPU, pero mencionaste que estás entrenando en GPU.
Recomiendo buscar en la implementación cudnn-recurrent o más en la parte tf que se usa. Tal vez puedas extender el código allí.