Dark Dwarf Blog background

Bahdanau Attention

Bahdanau Attention

1. 引入

在 Encoder-Decoder 架构中,我们知道 Encoder 负责把输入压缩成一个隐藏状态向量,但是压缩成一个固定的向量必然会导致关键信息的损失,一个简单的想法是:不把所有内容压缩到一个向量中,而是把序列中的每个词都生成一个对应的向量,然后在 Decoder 输出时去“寻找”它对应的向量,利用自己找到的内容生成结果。这便是 Seq2Seq 中简单的注意力思想。

2. 具体流程

上面的思想在 Seq2Seq 中的具体流程如下:

  1. 在 Encode 过程中,对序列中的每个词或词组生成对应的向量。
  2. 在 Decode 过程中,每当生成一个词,Decoder 会查看 Encoder 生成的所有向量序列,然后就开始“寻找”过程:
  3. 它会用一个小型神经网络计算当前处理的词与哪个向量最相关,并计算对应的注意力分数、根据注意力分数分配权重。
  4. 然后根据这些权重,对源句子向量序列进行加权平均,形成一个用在这次生成的上下文向量。
  5. 然后,Decoder 利用这个上下文向量进行预测、输出结果。

这个设计被称作“注意力机制”也是因为 Decoder 会在这个过程中自己注意自己需要的结果。

具体而言,我们首先如下计算注意力分数:

eij=a(si1,hj)e_{ij} = a(s_{i-1}, h_j)

其中:

a(si1,hj)=vaTtanh(Wssi1+Whhj)a(s_{i-1}, h_j) = v_a^T \tanh(W_s s_{i-1} + W_h h_j)

这里的 vaTv_a^T 就是模型注意力的体现,它决定了模型会注意哪些自己需要的信息。然后我们将分数归一化得到权重:

αij=exp(eij)kexp(eik)\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k} \exp(e_{ik})}

最后根据权重,就能得到上下文向量了:

ci=jαijhjc_i = \sum_j \alpha_{ij} h_j

3. 简单实现

在实践中,我们用单独的全连接层来处理 hidden 和 Encoder output,将它们加和并输入到 tanh。然后我们用一个全连接层来表示 vaTv_a^T、让模型自己学习该注意哪些内容。最后将结果送入 Softmax、做点积即可:

class BahdanauAttention(nn.Module):
  def __init__(self, hidden_dim, encoder_dim):
    super(BahdanauAttention, self).__init__()
    self.hidden_dim = hidden_dim
    self.encoder_dim = encoder_dim

    # Attention layers
    self.attn_hidden = nn.Linear(hidden_dim, hidden_dim)
    self.attn_encoder = nn.Linear(encoder_dim, hidden_dim)
    self.attn_combine = nn.Linear(hidden_dim, 1, bias=False)

  def forward(self, hidden, encoder_outputs, mask=None):
    hidden_proj = self.attn_hidden(hidden).unsqueeze(1)
    encoder_proj = self.attn_encoder(encoder_outputs)
    energy = torch.tanh(hidden_proj + encoder_proj)
    attention_scores = self.attn_combine(energy).squeeze(2)
    attention_scores = attention_scores.masked_fill(mask, -1e10)

    attention_weights = F.softmax(attention_scores, dim=1)
    context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)

    return context, attention_weights

之后这个上下文向量就可以和原始的词嵌入向量拼接在一起、作为 Decoder 的参考上下文:

class Decoder(nn.Module):
  def __init__(self, ...):
    self.lstm = nn.LSTM(...)
  def forward(self, ...):
    context, attention_weights = self.attention(
      hidden[-1], encoder_outputs, src_mask
    )
    lstm_input = torch.cat([embedded, context.unsqueeze(1)], dim=2)
    output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
    ...