Dark Dwarf Blog background

文档的分层注意力网络

文档的分层注意力网络

1. 引入

Hierarchical Attention Networks for Document Classification 这篇论文引入了一个分层的注意力网络:它使用结构化的注意力,先理解单词,再理解句子,最后理解整篇文档。

2. 分层注意力网络

a.a. 基本组件

HAN 由下面的组件组成:

  1. 单词序列 Encoder:使用 Encoder 对句子中的单词进行处理。这个和 Seq2Seq 中的 Encoder 类似。
  2. 单词级注意力层:在 Encoder 理解了单词后,我们使用一个单词级注意力层来判断哪些单词是重要的。我们将这些重要的单词信息汇总,得到一个代表整个句子的向量
  3. 句子 Encoder:得到代表整个句子的向量后,句子 Encoder 就会读取这些句子向量来理解句子内容。
  4. 句子级注意力层:在 Encoder 理解了句子后,同样使用一个句子级注意力层来收集对理解文档重要的句子信息,得到一个代表整个文档的向量

最终,我们将得到的文档向量送入分类器,即可判断文档的类别。

下面我们介绍一下这些组件的大致实现。

b.b. Encoder

Encoder 和 Seq2Seq 中的一样,HAN 原论文中使用的是 GRU,它是另一种有记忆化功能的 RNN。同时,为了知道单词的上下文,我们使用双向的 GRU。

c.c. 单词注意力

不管是单词注意力还是句子注意力,它们的实现方法都和 Luong Attention 以及其他的注意力差不多:先计算一个分数,然后对这个分数用 Softmax,使用 Softmax 得到的最终权重得到最终的上下文向量。

HAN 中的单词注意力计算如下:

  1. 计算隐藏表示 ui=tanh ⁣(Wshi+bs)u_i = \tanh\!\left(W_s h_i + b_s\right)

  2. 计算注意力权重 αi=softmax ⁣(uiuw)\alpha_i = \operatorname{softmax}\!\left(u_i \cdot u_w\right)

  3. 计算对应的上下文向量 s=iαihis = \sum_i \alpha_i h_i

句子注意力和单词注意力几乎是一样的,只是隐藏表示变成了整个句子的向量。

3. 简单实现

我在 IMDB 和 News20 数据集上实现了 HAN 网络。这里主要讲 News20 中实现的(IMDB的就是基于 News20 的改了一下)。

a.a. Tokenizer 处理

我们使用 nltk 中的 word_tokenize 和sent_tokenize 来分别从词层面和句子层面处理原句子,这符合论文中从词和句子两个不同角度审视文档的做法:

from nltk.tokenize import sent_tokenize
import torch.nn.utils.rnn as rnn_utils

class HierarchicalTextProcessor:
  def __init__(self, tokenizer, max_sentences, max_words):
    self.tokenizer = tokenizer
    self.max_sentences = max_sentences
    self.max_words = max_words
    self.pad_id = tokenizer.pad_id

  def encode_document(self, text):
    """Encode document into (list of lists of token ids, list of sentence lengths, num sentences)"""
    sentences = sent_tokenize(text)

    # Limit number of sentences
    if len(sentences) > self.max_sentences:
      sentences = sentences[:self.max_sentences]

    encoded_sentences = []
    sentence_lengths = []

    for sent in sentences:
      ids = self.tokenizer.encode(sent)
      ids = ids[:self.max_words]  # Truncate words

      # Store actual length of sentence after truncation
      sentence_lengths.append(len(ids))
      encoded_sentences.append(ids)

    return encoded_sentences, len(encoded_sentences), sentence_lengths

b.b. Padding 相关逻辑

然后是 Padding。由于我们需要在词与句子这两个不同层面处理文档,它的结构如下:

Document (文档)
├── Sentence 1 (句子1)
│   ├── word1, word2, word3, ...
├── Sentence 2 (句子2)
│   ├── word1, word2, ...
└── ...

因此,我们分别对词和句子做 padding。词的 Padding 在 Tokenizer 阶段,句子的 Padding 在 collate_fn 这个 dataload 阶段

def encode_document(self, text):
  sentences = sent_tokenize(text)

  for sent in sentences:
    ids = self.tokenizer.encode(sent)
    ids = ids[:self.max_words]  # 截断到最大单词数
    sentence_lengths.append(len(ids))  # 记录实际长度
    encoded_sentences.append(ids)
# 找出 batch 中最多的句子数
max_doc_num_sentences_in_batch = max(num_sentences_list)

# 对每个文档的句子进行 padding
for i, doc_sentences in enumerate(input_ids_docs_list):
  current_doc_sentences = padded_sentences_for_attention[start_idx:end_idx]

  # 如果句子数不足,用全 PAD 的句子填充
  if current_doc_sentences.size(0) < max_doc_num_sentences_in_batch:
    sentence_pad_shape = (max_doc_num_sentences_in_batch - current_doc_sentences.size(0),
                             current_doc_sentences.size(1))
    sentence_pad = torch.full(sentence_pad_shape, text_processor_han.pad_id, dtype=torch.long)
    current_doc_sentences = torch.cat([current_doc_sentences, sentence_pad], dim=0)

最终生成的 batch 张量的形状为 (batch_size, max_sentences_in_batch, max_words_in_batch)

由于 Padding 会引入噪音,我们需要 Mask 来告诉我们哪些是真实数据,在 forward 中使用 torch 的布尔索引即可:

batch_size, max_sentences, max_words = documents.size()
word_input = documents.view(batch_size * max_sentences, max_words)
flat_sentence_lengths = sentence_lengths.view(-1) # (batch_size * max_sentences)

# Filter out sentences that have length 0 (these are the padded sentences at the document level)
non_zero_lengths_mask = (flat_sentence_lengths > 0)
actual_sentences_input = word_input[non_zero_lengths_mask]
actual_sentence_lengths = flat_sentence_lengths[non_zero_lengths_mask]

然后在发送数据时使用 pack_padded_sequencepad_packed_sequence 进行压缩与解压缩即可:

packed_embedded_words = rnn_utils.pack_padded_sequence(
  embedded_words,
  actual_sentence_lengths.cpu(),
  batch_first=True,
  enforce_sorted=False
)

packed_word_gru_out, _ = self.word_gru(packed_embedded_words)
word_gru_out, _ = rnn_utils.pad_packed_sequence(packed_word_gru_out, batch_first=True, total_length=max_words)

p.s.p.s.一定要注意 Padding 逻辑和 Tokenizer 逻辑!!!不然 acc rate 会非常低!

c.c. 注意力层

HAN 的注意力层很简单、并且词注意力和句子注意力是一样的,不需要像 Luong 那样有很多 shape 操作:

class Attention(nn.Module):
  def __init__(self, hidden_size):
    super(Attention, self).__init__()
    self.attention = nn.Linear(hidden_size, hidden_size)
    self.context_vector = nn.Linear(hidden_size, 1, bias=False)

  def forward(self, hidden_state, mask):
    u = torch.tanh(self.attention(hidden_state))
    score = self.context_vector(u).squeeze(-1)

    score.masked_fill_(mask == 0, -1e9)

    attention_weights = F.softmax(score, dim=-1)
    context_vector = torch.bmm(
        attention_weights.unsqueeze(1),
        hidden_state
    ).squeeze(1)

    return context_vector, attention_weights

d.d. Notebook

i.i. News20

ii.ii. IMDB