Dark Dwarf Blog background

Luong Attention

Luong Attention

1. 全局注意力

a.a. 概述

全局注意力的概念在 Bahdanau Attention 笔记中已经介绍过了:在 Decode 过程中,每当生成一个词,Decoder 会查看 Encoder 输出的所有隐藏状态并关注自己需要的那个。Luong Attention 的这一流程如下:

  1. 获取 RNN 的隐藏状态 hth_t 和 Encoder 所有隐藏状态 hsh_s
  2. 计算对齐分数 a(ht,hs)a(h_t, h_s)
  3. 使用 Softmax 计算权重,最终使用权重得到上下文向量 ct=sat(s)hsc_t = \sum_{s} a_t(s)\,h_s

b.b. 注意力分数计算

与 Bahdanau Attention 单一的对齐分数计算不同,Luong Attention 提出了三种分数计算方法:

  1. 点积分数:score(ht,hs)=htThs\text{score}(h_t, h_s) = h_t^ T h_s
  2. general 分数:score(ht,hs)=htTWahs\text{score}(h_t, h_s) = h_t^T W_a h_s
  3. 拼接分数:score(ht,hs)=vaTtanh(Wa[ht;hs])\text{score}(h_t, h_s) = va^T \, \text{tanh}(W_a [h_t; h_s])

这里的拼接分数就是我们前面讲过的 Bahdanau Attention。虽然看起来形式不太一样,不过做一些变换即可。我们前面的计算方法为:

score(ht,hs)=vaTtanh(Whht+Wshs)\text{score}(h_t, h_s) = v_a^T \tanh(W_h h_t + W_s h_s)

而对于 Luong 中提到的拼接分数,设 hth_t 的维度为 dec_dimhsh_s 的维度为 enc_dim,则拼接后向量 [ht,hs][h_t, h_s] 维度为 dec_dim + enc_dimWaW_a 的形状为 (attention_dim, dec_dim + enc_dim)WaW_a 可以被拆分成下面两部分:

  • WhW_hWaW_a 中与 hth_t 相乘的部分,形状为 (attention_dim, dec_dim)
  • WsW_sWaW_a 中与 hsh_s 相乘的部分,形状为 (attention_dim, enc_dim)

因此,矩阵乘法 Wa[ht;hs]W_a [h_t; h_s] 就可以被展开为:

Wa[ht;hs]=Whht+WshsW_a [h_t; h_s] = W_h h_t + W_s h_s

也就是我们前面提到的计算方法。

2. 局部注意力

a.a. 概述

全局注意力在每次解码时都会查询所有的隐藏状态,当序列长度较大时,这个过程非常耗时。因此我们可以不查看所有的隐藏状态、只关注一个小的窗口。这就是局部注意力。

局部注意力的整体流程如下:

  1. 对于每个目标词 yty_t,首先生成一个“对齐位置” ptp_t
  2. ptp_t 为中心,定义一个大小为 2D+12D+1 的窗口 [ptD,pt+D][p_t - D, p_t + D]DD 是一个凭经验选择的超参数,比如5)。
  3. 在这个窗口内,对 RNN 隐藏状态进行加权平均,得到 ctc_t

局部注意力的对齐向量 ata_t 的维度是固定的 2D+12D+1,这使得计算更加快速稳定。

b.b. 对齐方式

局部注意力的一个关键是如何选取对齐位置,Luong 的原论文提出了下面两种方法:

i.i. 单调对齐

单调对齐 (Monotonic alignment, local-m) 是最为简单的对齐方法。我们简单地假设 pt=tp_t = t。这个假设基于一个直觉:源句子和目标句子在很大程度上是单调对齐的。例如这个例子:

  • “I love this beautiful cat” \rightarrow “我 爱 这只 美丽的 猫”

但是这个假设过强了,句子稍微复杂一些就用不了了,因此 Luong 提出了另一种方法:

ii.ii. 预测对齐

预测对齐和我们之前在实现 Simple NMT 的双向 RNN 输出转换的想法类似:与其自己确定对齐位置,不如让模型自己去学习怎么对齐。类似 Bahdanau Attention,我们引入 WpW_pvpv_p 这两个需要学习的参数,让模型通过训练更新这两个参数、做出合适的判断:

pt=Ssigmoid ⁣(vpTtanh(Wpht))p_t = S \cdot \operatorname{sigmoid}\!\left(v_p^{T}\tanh\left(W_p h_t\right)\right)

然后,为了让靠近窗口中心的位置的注意力权重增大,我们使用高斯权重分布:

at(s)=align ⁣(ht,hs)exp ⁣((spt)22σ2)a_t(s) = \operatorname{align}\!\left(h_t, h_s\right)\,\exp\!\left(-\frac{(s - p_t)^2}{2\sigma^2}\right)

3. 输入-反馈方法

我们前面讨论的注意力决策都是相对独立的:我们每次都利用当前 RNN 的状态来计算注意力。这并不是最优的,除了当前 RNN 的状态,我们还应该参考过去的对齐信息,这可以防止 RNN 重复处理或忘记处理某些序列。

让 Attention 考虑之前的对齐信息非常简单:我们只需要将上一步的注意力隐藏状态 h~t\tilde{h}_t 和上一个生成词的 embedding 共同作为 Decoder 的输入,即 [Embedding(yt); h~_t]

这种做法不仅让模型有了之前对齐信息的信息,还创建了更深的网络: 信息从 Decoder 的 output 流出,经过注意力计算再被注入回解码器的 input 。这相当于在网络的垂直结构上增加了一条“捷径连接”,类似于深度残差网络的思想,有助于梯度流动和训练更深的网络。

我们前面的 Bahdanau 也有使用这个方法,只是那篇论文没有集中论述这个东西:

context, attention_weights = self.attention(
    hidden[-1], encoder_outputs, src_mask
)

lstm_input = torch.cat([embedded, context.unsqueeze(1)], dim=2)

4. 注意力计算路径

Luong Attention 使用了与 Bahdanau Attention 不同的计算路径。在 Bahdanau Attention 中,注意力计算是生成隐藏状态向量的过程:

ht1atcthth_{t-1} \rightarrow a_t \rightarrow c_t \rightarrow h_t

而 Luong Attention 则使用 RNN 的顶层状态来决定注意力计算:

htatcth~tpredictionh_t \rightarrow a_t \rightarrow c_t \rightarrow \tilde{h}_t \rightarrow \text{prediction}

5. 简单实现

Luong Attention 和 Bahdanau Attention 的实现非常类似,只需要修改 Attention Layer 和 Decoder 来适配不同的注意力计算路径即可:

a.a. 全局注意力

我们以 general 注意力分数为例,general 注意力有一个可学习参数 WTW^T

score(ht,hs)=htTWahs\text{score}(h_t, h_s) = h_t^T W_a h_s

我们用一个全连接层表示它:

class LuongAttention(nn.Module):
  def __init__(self, hidden_dim, encoder_dim):
    super(LuongAttention, self).__init__()
    self.attn = nn.Linear(hidden_dim, encoder_dim)

  def forward(self, decoder_hidden, encoder_outputs, mask=None):
    # decoder_hidden: (batch_size, hidden_dim) - This is ht
    # encoder_outputs: (batch_size, seq_len, encoder_dim) - These are h_s
    # Transform decoder hidden state: (batch_size, encoder_dim)
    transformed_hidden = self.attn(decoder_hidden)
    
    # Compute attention scores: (batch_size, seq_len)
    attention_scores = torch.bmm(
      encoder_outputs,
      transformed_hidden.unsqueeze(2)
    ).squeeze(2)
    
    attention_scores.masked_fill_(mask, -1e10)
    
    attention_weights = F.softmax(attention_scores, dim=1)
    
    context = torch.bmm(
      attention_weights.unsqueeze(1),
      encoder_outputs
    ).squeeze(1)
    
    return context, attention_weights

然后是 Decoder 的实现,和 Bahdanau Attention 的数据流动不同,我们是使用当前 RNN 的顶层状态计算 Attention 的:

class Decoder(nn.Module):
  def __init__(self, ...):
    super(Decoder, self).__init__()
    self.attention = LuongAttention(hidden_dim, encoder_dim)

    self.lstm = nn.LSTM(
      embed_dim + encoding_dim,  
      hidden_dim,
      num_layers,
      batch_first=True,
      dropout=dropout if num_layers > 1 else 0
    )

    self.attention_combine = nn.Linear(hidden_dim + encoding_dim, hidden_dim)
    self.fc = nn.Linear(hidden_dim, vocab_size)
  
  def forward(self, tgt, hidden, cell, encoder_outputs, src_mask=None):
    ...
    output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
    context, attention_weights = self.attention(
      output.squeeze(1), encoder_outputs, src_mask
    )
    concat_output = torch.tanh(self.attention_combine(torch.cat([output.squeeze(1), context], dim=1)))
    prediction = self.fc(concat_output)
    ...

其他类型的注意力分数类似,只需要修改 Attention 类的注意力计算方法、引入或移除全连接层即可。

b.b. 局部注意力

我们以 local-p 为例。实现局部注意力时,我们需要引入 WpW_pvpv_p 这两个可学习参数:

import torch
import torch.nn as nn
import torch.nn.functional as F

class LuongLocalAttention(nn.Module):
  def __init__(self, hidden_dim, encoder_dim, window_size=10):
    super(LuongLocalAttention, self).__init__()
    self.hidden_dim = hidden_dim
    self.encoder_dim = encoder_dim
    self.D = window_size // 2  # 窗口半径 D

    # 对应论文中的 Wp 和 vp,用于预测对齐位置 pt
    self.Wp = nn.Linear(hidden_dim, hidden_dim)
    self.vp = nn.Parameter(torch.rand(hidden_dim))
    
    # 评分函数,这里我们继续使用 "general" 方法
    self.attn = nn.Linear(hidden_dim, encoder_dim)

    # 高斯分布的标准差,根据论文设为 D/2
    self.sigma = self.D / 2.0

  def forward(self, decoder_hidden, encoder_outputs, mask=None):
    """
    decoder_hidden: (batch_size, hidden_dim) - ht
    encoder_outputs: (batch_size, seq_len, encoder_dim) - h_s
    mask: (batch_size, seq_len)
    """
    batch_size, seq_len, _ = encoder_outputs.size()
    device = encoder_outputs.device
    
    # Predict pt 
    pt_logits = self.vp * torch.tanh(self.Wp(decoder_hidden)) # (B, H)
    pt_logits = pt_logits.sum(dim=1) # (B)
    
    pt = seq_len * torch.sigmoid(pt_logits) # (B)
    
    # Define window size 
    pt_int = pt.long()
    start = torch.clamp(pt_int - self.D, min=0)
    end = torch.clamp(pt_int + self.D + 1, max=seq_len)

    # Calculate window attention
    idx = torch.arange(seq_len, device=device).unsqueeze(0).repeat(batch_size, 1) # (B, S)
    
    window_mask = (idx < start.unsqueeze(1)) | (idx >= end.unsqueeze(1)) # (B, S)
    
    transformed_hidden = self.attn(decoder_hidden)
    attention_scores = torch.bmm(encoder_outputs, transformed_hidden.unsqueeze(2)).squeeze(2) # (B, S)

    gaussian_penalty = -((idx.float() - pt.unsqueeze(1))**2) / (2 * self.sigma**2)
    
    attention_scores = attention_scores + gaussian_penalty
    
    attention_scores.masked_fill_(mask, -1e10)
    attention_scores.masked_fill_(window_mask, -1e10)
    
    # Softmax
    attention_weights = F.softmax(attention_scores, dim=1) # (B, S)
    context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1) # (B, E)
    
    return context, attention_weights

而 Decoder 的实现和全局注意力中的一致:因为 Attention 接口并没有变。