Luong Attention
1. 全局注意力
概述
全局注意力的概念在 Bahdanau Attention 笔记中已经介绍过了:在 Decode 过程中,每当生成一个词,Decoder 会查看 Encoder 输出的所有隐藏状态并关注自己需要的那个。Luong Attention 的这一流程如下:
- 获取 RNN 的隐藏状态 和 Encoder 所有隐藏状态 。
- 计算对齐分数 。
- 使用 Softmax 计算权重,最终使用权重得到上下文向量 。
注意力分数计算
与 Bahdanau Attention 单一的对齐分数计算不同,Luong Attention 提出了三种分数计算方法:
- 点积分数:
- general 分数:
- 拼接分数:
这里的拼接分数就是我们前面讲过的 Bahdanau Attention。虽然看起来形式不太一样,不过做一些变换即可。我们前面的计算方法为:
而对于 Luong 中提到的拼接分数,设 的维度为 dec_dim, 的维度为 enc_dim,则拼接后向量 维度为 dec_dim + enc_dim, 的形状为 (attention_dim, dec_dim + enc_dim)。 可以被拆分成下面两部分:
- : 中与 相乘的部分,形状为
(attention_dim, dec_dim)。 - : 中与 相乘的部分,形状为
(attention_dim, enc_dim)。
因此,矩阵乘法 就可以被展开为:
也就是我们前面提到的计算方法。
2. 局部注意力
概述
全局注意力在每次解码时都会查询所有的隐藏状态,当序列长度较大时,这个过程非常耗时。因此我们可以不查看所有的隐藏状态、只关注一个小的窗口。这就是局部注意力。
局部注意力的整体流程如下:
- 对于每个目标词 ,首先生成一个“对齐位置” 。
- 以 为中心,定义一个大小为 的窗口 ( 是一个凭经验选择的超参数,比如5)。
- 在这个窗口内,对 RNN 隐藏状态进行加权平均,得到 。
局部注意力的对齐向量 的维度是固定的 ,这使得计算更加快速稳定。
对齐方式
局部注意力的一个关键是如何选取对齐位置,Luong 的原论文提出了下面两种方法:
单调对齐
单调对齐 (Monotonic alignment, local-m) 是最为简单的对齐方法。我们简单地假设 。这个假设基于一个直觉:源句子和目标句子在很大程度上是单调对齐的。例如这个例子:
- “I love this beautiful cat” “我 爱 这只 美丽的 猫”
但是这个假设过强了,句子稍微复杂一些就用不了了,因此 Luong 提出了另一种方法:
预测对齐
预测对齐和我们之前在实现 Simple NMT 的双向 RNN 输出转换的想法类似:与其自己确定对齐位置,不如让模型自己去学习怎么对齐。类似 Bahdanau Attention,我们引入 和 这两个需要学习的参数,让模型通过训练更新这两个参数、做出合适的判断:
然后,为了让靠近窗口中心的位置的注意力权重增大,我们使用高斯权重分布:
3. 输入-反馈方法
我们前面讨论的注意力决策都是相对独立的:我们每次都利用当前 RNN 的状态来计算注意力。这并不是最优的,除了当前 RNN 的状态,我们还应该参考过去的对齐信息,这可以防止 RNN 重复处理或忘记处理某些序列。
让 Attention 考虑之前的对齐信息非常简单:我们只需要将上一步的注意力隐藏状态 和上一个生成词的 embedding 共同作为 Decoder 的输入,即 [Embedding(yt); h~_t]。
这种做法不仅让模型有了之前对齐信息的信息,还创建了更深的网络: 信息从 Decoder 的 output 流出,经过注意力计算再被注入回解码器的 input 。这相当于在网络的垂直结构上增加了一条“捷径连接”,类似于深度残差网络的思想,有助于梯度流动和训练更深的网络。
我们前面的 Bahdanau 也有使用这个方法,只是那篇论文没有集中论述这个东西:
context, attention_weights = self.attention(
hidden[-1], encoder_outputs, src_mask
)
lstm_input = torch.cat([embedded, context.unsqueeze(1)], dim=2)
4. 注意力计算路径
Luong Attention 使用了与 Bahdanau Attention 不同的计算路径。在 Bahdanau Attention 中,注意力计算是生成隐藏状态向量的过程:
而 Luong Attention 则使用 RNN 的顶层状态来决定注意力计算:
5. 简单实现
Luong Attention 和 Bahdanau Attention 的实现非常类似,只需要修改 Attention Layer 和 Decoder 来适配不同的注意力计算路径即可:
全局注意力
我们以 general 注意力分数为例,general 注意力有一个可学习参数 :
我们用一个全连接层表示它:
class LuongAttention(nn.Module):
def __init__(self, hidden_dim, encoder_dim):
super(LuongAttention, self).__init__()
self.attn = nn.Linear(hidden_dim, encoder_dim)
def forward(self, decoder_hidden, encoder_outputs, mask=None):
# decoder_hidden: (batch_size, hidden_dim) - This is ht
# encoder_outputs: (batch_size, seq_len, encoder_dim) - These are h_s
# Transform decoder hidden state: (batch_size, encoder_dim)
transformed_hidden = self.attn(decoder_hidden)
# Compute attention scores: (batch_size, seq_len)
attention_scores = torch.bmm(
encoder_outputs,
transformed_hidden.unsqueeze(2)
).squeeze(2)
attention_scores.masked_fill_(mask, -1e10)
attention_weights = F.softmax(attention_scores, dim=1)
context = torch.bmm(
attention_weights.unsqueeze(1),
encoder_outputs
).squeeze(1)
return context, attention_weights
然后是 Decoder 的实现,和 Bahdanau Attention 的数据流动不同,我们是使用当前 RNN 的顶层状态计算 Attention 的:
class Decoder(nn.Module):
def __init__(self, ...):
super(Decoder, self).__init__()
self.attention = LuongAttention(hidden_dim, encoder_dim)
self.lstm = nn.LSTM(
embed_dim + encoding_dim,
hidden_dim,
num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0
)
self.attention_combine = nn.Linear(hidden_dim + encoding_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, tgt, hidden, cell, encoder_outputs, src_mask=None):
...
output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
context, attention_weights = self.attention(
output.squeeze(1), encoder_outputs, src_mask
)
concat_output = torch.tanh(self.attention_combine(torch.cat([output.squeeze(1), context], dim=1)))
prediction = self.fc(concat_output)
...
其他类型的注意力分数类似,只需要修改 Attention 类的注意力计算方法、引入或移除全连接层即可。
局部注意力
我们以 local-p 为例。实现局部注意力时,我们需要引入 、 这两个可学习参数:
import torch
import torch.nn as nn
import torch.nn.functional as F
class LuongLocalAttention(nn.Module):
def __init__(self, hidden_dim, encoder_dim, window_size=10):
super(LuongLocalAttention, self).__init__()
self.hidden_dim = hidden_dim
self.encoder_dim = encoder_dim
self.D = window_size // 2 # 窗口半径 D
# 对应论文中的 Wp 和 vp,用于预测对齐位置 pt
self.Wp = nn.Linear(hidden_dim, hidden_dim)
self.vp = nn.Parameter(torch.rand(hidden_dim))
# 评分函数,这里我们继续使用 "general" 方法
self.attn = nn.Linear(hidden_dim, encoder_dim)
# 高斯分布的标准差,根据论文设为 D/2
self.sigma = self.D / 2.0
def forward(self, decoder_hidden, encoder_outputs, mask=None):
"""
decoder_hidden: (batch_size, hidden_dim) - ht
encoder_outputs: (batch_size, seq_len, encoder_dim) - h_s
mask: (batch_size, seq_len)
"""
batch_size, seq_len, _ = encoder_outputs.size()
device = encoder_outputs.device
# Predict pt
pt_logits = self.vp * torch.tanh(self.Wp(decoder_hidden)) # (B, H)
pt_logits = pt_logits.sum(dim=1) # (B)
pt = seq_len * torch.sigmoid(pt_logits) # (B)
# Define window size
pt_int = pt.long()
start = torch.clamp(pt_int - self.D, min=0)
end = torch.clamp(pt_int + self.D + 1, max=seq_len)
# Calculate window attention
idx = torch.arange(seq_len, device=device).unsqueeze(0).repeat(batch_size, 1) # (B, S)
window_mask = (idx < start.unsqueeze(1)) | (idx >= end.unsqueeze(1)) # (B, S)
transformed_hidden = self.attn(decoder_hidden)
attention_scores = torch.bmm(encoder_outputs, transformed_hidden.unsqueeze(2)).squeeze(2) # (B, S)
gaussian_penalty = -((idx.float() - pt.unsqueeze(1))**2) / (2 * self.sigma**2)
attention_scores = attention_scores + gaussian_penalty
attention_scores.masked_fill_(mask, -1e10)
attention_scores.masked_fill_(window_mask, -1e10)
# Softmax
attention_weights = F.softmax(attention_scores, dim=1) # (B, S)
context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1) # (B, E)
return context, attention_weights
而 Decoder 的实现和全局注意力中的一致:因为 Attention 接口并没有变。