文档的分层注意力网络
1. 引入
Hierarchical Attention Networks for Document Classification 这篇论文引入了一个分层的注意力网络:它使用结构化的注意力,先理解单词,再理解句子,最后理解整篇文档。
2. 分层注意力网络
基本组件
HAN 由下面的组件组成:
- 单词序列 Encoder:使用 Encoder 对句子中的单词进行处理。这个和 Seq2Seq 中的 Encoder 类似。
- 单词级注意力层:在 Encoder 理解了单词后,我们使用一个单词级注意力层来判断哪些单词是重要的。我们将这些重要的单词信息汇总,得到一个代表整个句子的向量。
- 句子 Encoder:得到代表整个句子的向量后,句子 Encoder 就会读取这些句子向量来理解句子内容。
- 句子级注意力层:在 Encoder 理解了句子后,同样使用一个句子级注意力层来收集对理解文档重要的句子信息,得到一个代表整个文档的向量。
最终,我们将得到的文档向量送入分类器,即可判断文档的类别。
下面我们介绍一下这些组件的大致实现。
Encoder
Encoder 和 Seq2Seq 中的一样,HAN 原论文中使用的是 GRU,它是另一种有记忆化功能的 RNN。同时,为了知道单词的上下文,我们使用双向的 GRU。
单词注意力
不管是单词注意力还是句子注意力,它们的实现方法都和 Luong Attention 以及其他的注意力差不多:先计算一个分数,然后对这个分数用 Softmax,使用 Softmax 得到的最终权重得到最终的上下文向量。
HAN 中的单词注意力计算如下:
-
计算隐藏表示
-
计算注意力权重
-
计算对应的上下文向量
句子注意力和单词注意力几乎是一样的,只是隐藏表示变成了整个句子的向量。
3. 简单实现
我在 IMDB 和 News20 数据集上实现了 HAN 网络。这里主要讲 News20 中实现的(IMDB的就是基于 News20 的改了一下)。
Tokenizer 处理
我们使用 nltk 中的 word_tokenize 和sent_tokenize 来分别从词层面和句子层面处理原句子,这符合论文中从词和句子两个不同角度审视文档的做法:
from nltk.tokenize import sent_tokenize
import torch.nn.utils.rnn as rnn_utils
class HierarchicalTextProcessor:
def __init__(self, tokenizer, max_sentences, max_words):
self.tokenizer = tokenizer
self.max_sentences = max_sentences
self.max_words = max_words
self.pad_id = tokenizer.pad_id
def encode_document(self, text):
"""Encode document into (list of lists of token ids, list of sentence lengths, num sentences)"""
sentences = sent_tokenize(text)
# Limit number of sentences
if len(sentences) > self.max_sentences:
sentences = sentences[:self.max_sentences]
encoded_sentences = []
sentence_lengths = []
for sent in sentences:
ids = self.tokenizer.encode(sent)
ids = ids[:self.max_words] # Truncate words
# Store actual length of sentence after truncation
sentence_lengths.append(len(ids))
encoded_sentences.append(ids)
return encoded_sentences, len(encoded_sentences), sentence_lengths
Padding 相关逻辑
然后是 Padding。由于我们需要在词与句子这两个不同层面处理文档,它的结构如下:
Document (文档)
├── Sentence 1 (句子1)
│ ├── word1, word2, word3, ...
├── Sentence 2 (句子2)
│ ├── word1, word2, ...
└── ...
因此,我们分别对词和句子做 padding。词的 Padding 在 Tokenizer 阶段,句子的 Padding 在 collate_fn 这个 dataload 阶段
def encode_document(self, text):
sentences = sent_tokenize(text)
for sent in sentences:
ids = self.tokenizer.encode(sent)
ids = ids[:self.max_words] # 截断到最大单词数
sentence_lengths.append(len(ids)) # 记录实际长度
encoded_sentences.append(ids)
# 找出 batch 中最多的句子数
max_doc_num_sentences_in_batch = max(num_sentences_list)
# 对每个文档的句子进行 padding
for i, doc_sentences in enumerate(input_ids_docs_list):
current_doc_sentences = padded_sentences_for_attention[start_idx:end_idx]
# 如果句子数不足,用全 PAD 的句子填充
if current_doc_sentences.size(0) < max_doc_num_sentences_in_batch:
sentence_pad_shape = (max_doc_num_sentences_in_batch - current_doc_sentences.size(0),
current_doc_sentences.size(1))
sentence_pad = torch.full(sentence_pad_shape, text_processor_han.pad_id, dtype=torch.long)
current_doc_sentences = torch.cat([current_doc_sentences, sentence_pad], dim=0)
最终生成的 batch 张量的形状为 (batch_size, max_sentences_in_batch, max_words_in_batch)。
由于 Padding 会引入噪音,我们需要 Mask 来告诉我们哪些是真实数据,在 forward 中使用 torch 的布尔索引即可:
batch_size, max_sentences, max_words = documents.size()
word_input = documents.view(batch_size * max_sentences, max_words)
flat_sentence_lengths = sentence_lengths.view(-1) # (batch_size * max_sentences)
# Filter out sentences that have length 0 (these are the padded sentences at the document level)
non_zero_lengths_mask = (flat_sentence_lengths > 0)
actual_sentences_input = word_input[non_zero_lengths_mask]
actual_sentence_lengths = flat_sentence_lengths[non_zero_lengths_mask]
然后在发送数据时使用 pack_padded_sequence 和 pad_packed_sequence 进行压缩与解压缩即可:
packed_embedded_words = rnn_utils.pack_padded_sequence(
embedded_words,
actual_sentence_lengths.cpu(),
batch_first=True,
enforce_sorted=False
)
packed_word_gru_out, _ = self.word_gru(packed_embedded_words)
word_gru_out, _ = rnn_utils.pad_packed_sequence(packed_word_gru_out, batch_first=True, total_length=max_words)
一定要注意 Padding 逻辑和 Tokenizer 逻辑!!!不然 acc rate 会非常低!
注意力层
HAN 的注意力层很简单、并且词注意力和句子注意力是一样的,不需要像 Luong 那样有很多 shape 操作:
class Attention(nn.Module):
def __init__(self, hidden_size):
super(Attention, self).__init__()
self.attention = nn.Linear(hidden_size, hidden_size)
self.context_vector = nn.Linear(hidden_size, 1, bias=False)
def forward(self, hidden_state, mask):
u = torch.tanh(self.attention(hidden_state))
score = self.context_vector(u).squeeze(-1)
score.masked_fill_(mask == 0, -1e9)
attention_weights = F.softmax(score, dim=-1)
context_vector = torch.bmm(
attention_weights.unsqueeze(1),
hidden_state
).squeeze(1)
return context_vector, attention_weights