In [1]:
!pip install tokenizers -q
In [2]:
!pip install datasets nltk transformers -q
In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
Using device: cuda
In [ ]:
dataset = load_dataset("SetFit/20_newsgroups")
train_data = dataset['train']
test_data = dataset['test']

print(f"Train data size: {len(train_data)}")
print(f"Train data example: {train_data[0]}")
print(f"Test data size: {len(test_data)}")
print(f"Test data example: {test_data[0]}")
In [ ]:
train_data = train_data.to_pandas()

train_data['text_length'] = train_data['text'].str.len()
train_data['word_count'] = train_data['text'].str.split().str.len()

print("\n[Text Length Statistics]")
print(train_data[['text_length', 'word_count']].describe())

def has_email_headers(text):
  headers = ['From:', 'Subject:', 'Organization:', 'Lines:', 'NNTP-Posting-Host:']
  return any(header in text for header in headers)

train_data['has_headers'] = train_data['text'].apply(has_email_headers)
print(f"\n[Email Header Detection]")
print(f"Documents with email headers: {train_data['has_headers'].sum()} / {len(train_data)}")
print(f"Percentage: {train_data['has_headers'].mean():.2%}")

import re
def count_urls(text):
  return len(re.findall(r'http[s]?://\S+', text))

def count_emails(text):
  return len(re.findall(r'\S+@\S+', text))

def count_quote_lines(text):
  return len(re.findall(r'^{{>}}+', text, re.MULTILINE))

train_data['url_count'] = train_data['text'].apply(count_urls)
train_data['email_count'] = train_data['text'].apply(count_emails)
train_data['quote_count'] = train_data['text'].apply(count_quote_lines)

print(f"\n[Noise Statistics]")
print(f"Documents containing URLs: {(train_data['url_count'] > 0).sum()}")
print(f"Documents containing emails: {(train_data['email_count'] > 0).sum()}")
print(f"Documents containing quotes: {(train_data['quote_count'] > 0).sum()}")

# 5. Class distribution
print(f"\n[Class Distribution]")
label_counts = train_data['label'].value_counts().sort_index()
print(label_counts)
print(f"\nMax class samples / Min class samples = {label_counts.max() / label_counts.min():.2f}")

print("\n[Abnormal Sample Detection]")
print(f"Empty documents: {(train_data['text'].str.strip() == '').sum()}")
print(f"Very short documents (<50 chars): {(train_data['text_length'] < 50).sum()}")
print(f"Very long documents (>10000 chars): {(train_data['text_length'] > 10000).sum()}")
In [6]:
import re

class SimplePreprocessor:
  def __init__(self, min_length=20, max_length=50000):
    self.min_length = min_length
    self.max_length = max_length

  def clean_text(self, text):
    text = re.sub(r'(\S)\1{10,}', r'\1\1\1', text)

    text = re.sub(r'\s+', ' ', text)
    text = text.strip()

    if len(text) > self.max_length:
      text = text[:self.max_length]

    return text

  def is_valid(self, text):
    if not text or len(text.strip()) < self.min_length:
      return False

    return True

  def process_dataset(self, texts, labels):
    cleaned_texts = []
    valid_labels = []

    removed_empty = 0
    removed_short = 0
    kept = 0

    for text, label in zip(texts, labels):
      cleaned = self.clean_text(text)

      if not cleaned:
        removed_empty += 1
        continue

      if not self.is_valid(cleaned):
        removed_short += 1
        continue

      cleaned_texts.append(cleaned)
      valid_labels.append(label)
      kept += 1

    return cleaned_texts, valid_labels

preprocessor = SimplePreprocessor(min_length=20, max_length=50000)
train_texts_cleaned, train_labels_cleaned = preprocessor.process_dataset(
    train_data['text'],
    train_data['label']
)

test_texts_cleaned, test_labels_cleaned = preprocessor.process_dataset(
    test_data['text'],
    test_data['label']
)
In [ ]:
import nltk

nltk.download('punkt_tab')
nltk.download('punkt', quiet=True)

from nltk.tokenize import sent_tokenize

def quick_analyze_structure(texts, n_samples=500):
  """Quickly analyze sentence structure"""
  import random

  # Random sampling
  sample_texts = random.sample(texts, min(n_samples, len(texts)))

  sent_counts = []
  sent_lengths = []

  for text in sample_texts:
    sentences = sent_tokenize(text)
    sent_counts.append(len(sentences))

    for sent in sentences:
      sent_lengths.append(len(sent.split()))

  sent_counts = np.array(sent_counts)
  sent_lengths = np.array(sent_lengths)

  print("Sentence Structure Analysis:")
  print(f"\nSentences per document:")
  print(f"  Average: {sent_counts.mean():.1f}")
  print(f"  Median: {np.median(sent_counts):.0f}")
  print(f"  75th percentile: {np.percentile(sent_counts, 75):.0f}")
  print(f"  95th percentile: {np.percentile(sent_counts, 95):.0f}")

  print(f"\nWords per sentence:")
  print(f"  Average: {sent_lengths.mean():.1f}")
  print(f"  Median: {np.median(sent_lengths):.0f}")
  print(f"  75th percentile: {np.percentile(sent_lengths, 75):.0f}")
  print(f"  95th percentile: {np.percentile(sent_lengths, 95):.0f}")

  # Recommended hyperparameters (using 95th percentile)
  recommended_max_sents = int(np.percentile(sent_counts, 95))
  recommended_max_words = int(np.percentile(sent_lengths, 95))

  print(f"\n[Recommended Hyperparameters]")
  print(f"  max_sentences: {recommended_max_sents}")
  print(f"  max_words: {recommended_max_words}")

  return recommended_max_sents, recommended_max_words

max_sents, max_words = quick_analyze_structure(train_texts_cleaned)
In [ ]:
from collections import Counter
from nltk.tokenize import word_tokenize

# Build word-level vocabulary
word_counter = Counter()
for text in tqdm(train_texts_cleaned, desc="Building vocabulary"):
  for sent in sent_tokenize(text):
    tokens = [w.lower() for w in word_tokenize(sent)]
    word_counter.update(tokens)

# Create vocab with frequency threshold
vocab = {'[PAD]': 0, '[UNK]': 1}
for word, count in word_counter.items():
  if count >= 2:
    vocab[word] = len(vocab)

print(f"Vocabulary size: {len(vocab)}")
print(f"Most common words: {word_counter.most_common(20)}")

class WordLevelTokenizer:
  def __init__(self, vocab):
    self.vocab = vocab
    self.pad_id = vocab['[PAD]']
    self.unk_id = vocab['[UNK]']

  def encode(self, text):
    tokens = [w.lower() for w in word_tokenize(text)]
    ids = [self.vocab.get(token, self.unk_id) for token in tokens]
    return ids

  def get_vocab_size(self):
    return len(self.vocab)

  def get_vocab(self):
    return self.vocab

tokenizer = WordLevelTokenizer(vocab)
Building vocabulary: 100%|██████████| 10949/10949 [00:13<00:00, 797.00it/s]
Vocabulary size: 50758
Most common words: [(',', 108751), ('the', 105427), ('.', 101720), ('>', 71587), ('to', 52446), ("'ax", 52276), ('of', 46542), ('a', 43694), ('and', 42108), ('(', 38359), (')', 38147), ('i', 34404), ('is', 30765), ('in', 30619), ('that', 27825), (':', 26124), ('it', 23441), ('?', 21481), ("''", 20117), ('*', 19847)]
In [9]:
from nltk.tokenize import sent_tokenize
import torch.nn.utils.rnn as rnn_utils

class HierarchicalTextProcessor:
  def __init__(self, tokenizer, max_sentences, max_words):
    self.tokenizer = tokenizer
    self.max_sentences = max_sentences
    self.max_words = max_words
    self.pad_id = tokenizer.pad_id

  def encode_document(self, text):
    """Encode document into (list of lists of token ids, list of sentence lengths, num sentences)"""
    sentences = sent_tokenize(text)

    # Limit number of sentences
    if len(sentences) > self.max_sentences:
      sentences = sentences[:self.max_sentences]

    encoded_sentences = []
    sentence_lengths = []

    for sent in sentences:
      ids = self.tokenizer.encode(sent)
      ids = ids[:self.max_words]  # Truncate words

      # Store actual length of sentence after truncation
      sentence_lengths.append(len(ids))
      encoded_sentences.append(ids)

    return encoded_sentences, len(encoded_sentences), sentence_lengths

# Using recommended hyperparameters
text_processor_han = HierarchicalTextProcessor(
  tokenizer,
  max_sentences=max_sents,
  max_words=max_words
)
In [10]:
def process_data_hierarchical(texts, labels, processor):
  all_input_ids = [] # List of lists of lists (document -> sentence -> token_ids)
  all_sentence_lengths = [] # List of lists (document -> sentence_lengths)
  all_num_sentences = [] # List of ints (document -> num_sentences)

  for text in tqdm(texts, desc="Processing hierarchical texts"):
    # doc_ids: list of lists (sentences -> token_ids)
    # num_sents: int
    # sent_lens: list of ints
    doc_ids, num_sents, sent_lens = processor.encode_document(text)
    all_input_ids.append(doc_ids)
    all_num_sentences.append(num_sents)
    all_sentence_lengths.append(sent_lens)

  return {
    'input_ids': all_input_ids, # No tensor conversion yet
    'num_sentences': torch.tensor(all_num_sentences, dtype=torch.long),
    'sentence_lengths': all_sentence_lengths, # No tensor conversion yet
    'labels': torch.tensor(labels, dtype=torch.long)
  }

train_processed = process_data_hierarchical(
  train_texts_cleaned,
  train_labels_cleaned,
  text_processor_han
)
test_processed = process_data_hierarchical(
  test_texts_cleaned,
  test_labels_cleaned,
  text_processor_han
)
Processing hierarchical texts: 100%|██████████| 10949/10949 [00:10<00:00, 1027.36it/s]
Processing hierarchical texts: 100%|██████████| 7267/7267 [00:06<00:00, 1058.10it/s]
In [ ]:
import torch.nn.utils.rnn as rnn_utils

class HierarchicalTextDataset(Dataset):
  def __init__(self, input_ids, num_sentences, sentence_lengths, labels):
    self.input_ids = input_ids # This is now a list of lists of lists
    self.num_sentences = num_sentences # This is a tensor (batch_size,)
    self.sentence_lengths = sentence_lengths # This is a list of lists (doc -> sentence lengths)
    self.labels = labels # This is a tensor

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    # Returns raw data from processing, to be padded by collate_fn
    return {
      'input_ids': self.input_ids[idx], # list of lists (sentences -> token_ids)
      'num_sentences': self.num_sentences[idx], # int (actual num sentences in this document)
      'sentence_lengths': self.sentence_lengths[idx], # list of ints (actual lengths of sentences in this document)
      'labels': self.labels[idx]
    }

def collate_fn_hierarchical(batch):
  # batch is a list of dictionaries, where each dict is an item from __getitem__

  # Separate components from the batch
  input_ids_docs_list = [item['input_ids'] for item in batch] # List of (list of lists of token_ids)
  num_sentences_list = [item['num_sentences'] for item in batch] # List of ints
  sentence_lengths_docs_list = [item['sentence_lengths'] for item in batch] # List of (list of ints)
  labels = torch.stack([item['labels'] for item in batch]) # Tensor (batch_size,)

  batch_size = len(batch)
  max_doc_num_sentences_in_batch = max(num_sentences_list) if num_sentences_list else 0

  # Step 1: Prepare word-level data for packing
  # We need to flatten all sentences from all documents in the batch
  # and collect their actual lengths.
  all_sentences_in_batch = []
  all_sentence_lengths_in_batch = []
  doc_sentence_pointers = [] # Helps in reconstructing document structure after word GRU

  current_total_sentences = 0
  for i, doc_sentences in enumerate(input_ids_docs_list):
    # doc_sentences is a list of lists of token_ids
    # doc_sentence_lengths is a list of ints
    doc_sentence_lengths = sentence_lengths_docs_list[i]

    # Convert each sentence (list of token_ids) to a tensor
    sentences_as_tensors = [torch.tensor(s, dtype=torch.long) for s in doc_sentences]

    all_sentences_in_batch.extend(sentences_as_tensors)
    all_sentence_lengths_in_batch.extend(doc_sentence_lengths)

    doc_sentence_pointers.append((current_total_sentences, current_total_sentences + len(doc_sentences)))
    current_total_sentences += len(doc_sentences)

  padded_sentences_for_attention = rnn_utils.pad_sequence(all_sentences_in_batch, batch_first=True, padding_value=text_processor_han.pad_id)
  
  # Create a tensor for the sentence lengths per document, padded to max_doc_num_sentences_in_batch
  padded_sentence_lengths_per_doc = []
  for doc_len_list in sentence_lengths_docs_list:
      current_doc_lengths = torch.tensor(doc_len_list, dtype=torch.long)
      if len(current_doc_lengths) < max_doc_num_sentences_in_batch:
          pad_len = max_doc_num_sentences_in_batch - len(current_doc_lengths)
          current_doc_lengths = torch.cat([current_doc_lengths, torch.zeros(pad_len, dtype=torch.long)], dim=0)
      padded_sentence_lengths_per_doc.append(current_doc_lengths)
  padded_sentence_lengths_per_doc_tensor = torch.stack(padded_sentence_lengths_per_doc) # (batch_size, max_sentences_in_batch)

  # Create final input_ids tensor for the model. This will be (batch_size, max_sentences_in_batch, max_words_in_batch)
  # It needs to be constructed from padded_sentences_for_attention
  # Reconstruct documents from padded_sentences_for_attention
  final_input_ids = []
  for i, (start_idx, end_idx) in enumerate(doc_sentence_pointers):
    current_doc_sentences = padded_sentences_for_attention[start_idx:end_idx]
    if current_doc_sentences.size(0) < max_doc_num_sentences_in_batch:
      sentence_pad_shape = (max_doc_num_sentences_in_batch - current_doc_sentences.size(0), current_doc_sentences.size(1))
      sentence_pad = torch.full(sentence_pad_shape, text_processor_han.pad_id, dtype=torch.long)
      current_doc_sentences = torch.cat([current_doc_sentences, sentence_pad], dim=0)
    final_input_ids.append(current_doc_sentences)
  final_input_ids_tensor = torch.stack(final_input_ids) # (batch_size, max_sentences_in_batch, max_words_in_batch)

  return {
    'input_ids': final_input_ids_tensor, # (batch_size, max_sentences_in_batch, max_words_in_batch)
    'num_sentences': torch.tensor(num_sentences_list, dtype=torch.long), # (batch_size,) - actual counts
    'sentence_lengths': padded_sentence_lengths_per_doc_tensor, # (batch_size, max_sentences_in_batch) - actual lengths of each sentence, padded for consistent shape
    'labels': labels
  }

# Recreate datasets
train_dataset = HierarchicalTextDataset(
  train_processed['input_ids'],
  train_processed['num_sentences'],
  train_processed['sentence_lengths'],
  train_processed['labels']
)

test_dataset = HierarchicalTextDataset(
  test_processed['input_ids'],
  test_processed['num_sentences'],
  test_processed['sentence_lengths'],
  test_processed['labels']
)

print(f"Train Dataset size: {len(train_dataset)}")
print(f"Test Dataset size: {len(test_dataset)}")

batch_size = 64

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              num_workers=4, pin_memory=True, collate_fn=collate_fn_hierarchical)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                             pin_memory=True, collate_fn=collate_fn_hierarchical)
Train Dataset size: 10949
Test Dataset size: 7267
In [12]:
!wget http://nlp.stanford.edu/data/glove.6B.zip -P /content
!unzip -q glove.6B.zip
--2025-11-14 08:09:33--  http://nlp.stanford.edu/data/glove.6B.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://nlp.stanford.edu/data/glove.6B.zip [following]
--2025-11-14 08:09:34--  https://nlp.stanford.edu/data/glove.6B.zip
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip [following]
--2025-11-14 08:09:34--  https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 862182613 (822M) [application/zip]
Saving to: ‘/content/glove.6B.zip’

glove.6B.zip        100%[===================>] 822.24M  4.95MB/s    in 2m 55s  

2025-11-14 08:12:30 (4.70 MB/s) - ‘/content/glove.6B.zip’ saved [862182613/862182613]

In [13]:
import numpy as np
from tqdm import tqdm

def load_glove_embeddings(glove_file, embedding_dim):
  embeddings_index = {}

  with open(glove_file, 'r', encoding='utf-8') as f:
    for line in tqdm(f, desc="Loading GloVe"):
      values = line.split()
      word = values[0]
      vector = np.asarray(values[1:], dtype='float32')
      embeddings_index[word] = vector

  print(f'Loaded {len(embeddings_index)} word vectors.')
  return embeddings_index

glove_embeddings = load_glove_embeddings('/content/glove.6B.100d.txt', 100)
Loading GloVe: 400000it [00:08, 44774.83it/s]
Loaded 400000 word vectors.

In [14]:
embedding_size = 100

def create_embedding_matrix(vocab, glove_embeddings, embedding_size):
  vocab_size = len(vocab)
  embedding_matrix = np.zeros((vocab_size, embedding_size))

  matched = 0
  for word, idx in tqdm(vocab.items(), desc="Creating embedding matrix"):
    if word in glove_embeddings:
      embedding_matrix[idx] = glove_embeddings[word]
      matched += 1
    else:
      embedding_matrix[idx] = np.random.normal(0, 0.1, embedding_size)

  # Set PAD token to zeros
  pad_id = vocab['[PAD]']
  embedding_matrix[pad_id] = np.zeros(embedding_size)

  print(f"Matched {matched}/{vocab_size} words with GloVe ({matched/vocab_size*100:.1f}%)")

  return embedding_matrix

embedding_matrix = create_embedding_matrix(tokenizer.get_vocab(), glove_embeddings, embedding_size)
Creating embedding matrix: 100%|██████████| 50758/50758 [00:00<00:00, 300721.36it/s]
Matched 34519/50758 words with GloVe (68.0%)

In [ ]:
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils

class Attention(nn.Module):
  def __init__(self, hidden_size):
    super(Attention, self).__init__()
    self.attention = nn.Linear(hidden_size, hidden_size)
    self.context_vector = nn.Linear(hidden_size, 1, bias=False)

  def forward(self, hidden_state, mask):
    u = torch.tanh(self.attention(hidden_state))
    score = self.context_vector(u).squeeze(-1)

    score.masked_fill_(mask == 0, -1e9) # Fill with a very small number for softmax

    attention_weights = F.softmax(score, dim=-1)
    context_vector = torch.bmm(
        attention_weights.unsqueeze(1),
        hidden_state
    ).squeeze(1)

    return context_vector, attention_weights

class HAN(nn.Module):
  def __init__(self, vocab_size, embedding_size, hidden_size, num_classes, dropout=0.5, pad_idx=0):
    super(HAN, self).__init__()
    self.pad_idx = pad_idx
    self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx)
    self.word_gru = nn.GRU(
        embedding_size,
        hidden_size,
        bidirectional=True,
        batch_first=True
    )
    self.word_attention = Attention(hidden_size * 2)

    self.sentence_gru = nn.GRU(
        hidden_size * 2, # Output of word attention is hidden_size * 2
        hidden_size,
        bidirectional=True,
        batch_first=True
    )
    self.sentence_attention = Attention(hidden_size * 2)
    self.fc = nn.Linear(hidden_size * 2, num_classes)
    self.dropout = nn.Dropout(dropout)

  def forward(self, documents, num_sentences, sentence_lengths):
    batch_size, max_sentences, max_words = documents.size()
    word_input = documents.view(batch_size * max_sentences, max_words) 
    flat_sentence_lengths = sentence_lengths.view(-1) # (batch_size * max_sentences)

    # Filter out sentences that have length 0 (these are the padded sentences at the document level)
    non_zero_lengths_mask = (flat_sentence_lengths > 0)
    actual_sentences_input = word_input[non_zero_lengths_mask]
    actual_sentence_lengths = flat_sentence_lengths[non_zero_lengths_mask]

    embedded_words = self.embedding(actual_sentences_input) # (num_actual_sentences, max_words_in_batch, embedding_size)

    # Pack padded sequence for word GRU
    packed_embedded_words = rnn_utils.pack_padded_sequence(
        embedded_words,
        actual_sentence_lengths.cpu(),
        batch_first=True,
        enforce_sorted=False # Let pack_padded_sequence handle sorting
    )
    packed_word_gru_out, _ = self.word_gru(packed_embedded_words)
    word_gru_out, _ = rnn_utils.pad_packed_sequence(packed_word_gru_out, batch_first=True, total_length=max_words)
    # word_gru_out shape: (num_actual_sentences, max_words_in_batch, 2*hidden_size)

    # Expand word_gru_out back to (batch_size * max_sentences, max_words, 2*hidden_size)
    # for attention. Need to put zeros where the padded sentences were.
    full_word_gru_out = torch.zeros(
        batch_size * max_sentences, max_words, 2 * self.word_gru.hidden_size,
        device=documents.device, dtype=word_gru_out.dtype
    )
    full_word_gru_out[non_zero_lengths_mask] = word_gru_out

    # Create mask for word attention based on padding
    word_attention_mask = (word_input != self.pad_idx) # (batch_size * max_sentences, max_words)

    # Get sentence vectors using word attention
    sentence_vectors, _ = self.word_attention(full_word_gru_out.contiguous(), word_attention_mask) # (batch_size * max_sentences, 2*hidden_size)


    # --- Sentence-level GRU and Attention ---
    # Reshape sentence_vectors to (batch_size, max_sentences, 2*hidden_size)
    sentence_vectors_reshaped = sentence_vectors.view(batch_size, max_sentences, -1)

    # num_sentences contains the actual number of sentences per document
    non_zero_num_sentences_mask = (num_sentences > 0)
    actual_num_sentences = num_sentences[non_zero_num_sentences_mask]
    actual_sentence_vectors = sentence_vectors_reshaped[non_zero_num_sentences_mask]

    if actual_sentence_vectors.size(0) == 0:
        document_vectors = torch.zeros(batch_size, 2 * self.sentence_gru.hidden_size, device=documents.device)
        output = self.fc(self.dropout(document_vectors))
        return output, None, None

    # Pack padded sequence for sentence GRU
    packed_embedded_sentences = rnn_utils.pack_padded_sequence(
        actual_sentence_vectors,
        actual_num_sentences.cpu(),
        batch_first=True,
        enforce_sorted=False
    )
    packed_sentence_gru_out, _ = self.sentence_gru(packed_embedded_sentences)
    sentence_gru_out, _ = rnn_utils.pad_packed_sequence(packed_sentence_gru_out, batch_first=True, total_length=max_sentences)
    # sentence_gru_out shape: (num_actual_documents, max_sentences_in_batch, 2*hidden_size)

    # Expand sentence_gru_out back to (batch_size, max_sentences, 2*hidden_size)
    full_sentence_gru_out = torch.zeros(
        batch_size, max_sentences, 2 * self.sentence_gru.hidden_size,
        device=documents.device, dtype=sentence_gru_out.dtype
    )
    full_sentence_gru_out[non_zero_num_sentences_mask] = sentence_gru_out

    # Create mask for sentence attention based on actual number of sentences per document
    sentence_attention_mask = (torch.arange(max_sentences, device=documents.device).expand(batch_size, max_sentences) < num_sentences.unsqueeze(1))

    # Get document vectors using sentence attention
    document_vectors, sentence_att_weights = self.sentence_attention(full_sentence_gru_out, sentence_attention_mask) # (batch_size, 2*hidden_size)

    output = self.fc(self.dropout(document_vectors))
    return output, None, sentence_att_weights # Word attention weights are tricky to return in this setup.

# Instantiate and move model to device
vocab_size = tokenizer.get_vocab_size()
embedding_size = 100
hidden_size = 64
num_classes = len(set(train_labels_cleaned))

model = HAN(
    vocab_size=vocab_size,
    embedding_size=embedding_size,
    hidden_size=hidden_size,
    num_classes=num_classes,
    dropout=0.5,
    pad_idx=vocab['[PAD]']
)

model.embedding.weight.data.copy_(torch.FloatTensor(embedding_matrix))
model.to(device)
Out[ ]:
HAN(
  (embedding): Embedding(50758, 100, padding_idx=0)
  (word_gru): GRU(100, 64, batch_first=True, bidirectional=True)
  (word_attention): Attention(
    (attention): Linear(in_features=128, out_features=128, bias=True)
    (context_vector): Linear(in_features=128, out_features=1, bias=False)
  )
  (sentence_gru): GRU(128, 64, batch_first=True, bidirectional=True)
  (sentence_attention): Attention(
    (attention): Linear(in_features=128, out_features=128, bias=True)
    (context_vector): Linear(in_features=128, out_features=1, bias=False)
  )
  (fc): Linear(in_features=128, out_features=20, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)
In [16]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import torch
import torch.nn as nn

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=2
)

def train_epoch(model, dataloader, criterion, optimizer, device):
  model.train()
  total_loss = 0
  predictions = []
  true_labels = []

  for batch in tqdm(dataloader, desc="Training"):
    input_ids = batch['input_ids'].to(device)
    num_sentences = batch['num_sentences'].to(device)
    sentence_lengths = batch['sentence_lengths'].to(device)
    labels = batch['labels'].to(device)

    optimizer.zero_grad()
    outputs, _, _ = model(input_ids, num_sentences, sentence_lengths)
    loss = criterion(outputs, labels)

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

    total_loss += loss.item()
    preds = outputs.argmax(dim=1).cpu().numpy()
    predictions.extend(preds)
    true_labels.extend(labels.cpu().numpy())
    optimizer.step()

  avg_loss = total_loss / len(dataloader)
  accuracy = accuracy_score(true_labels, predictions)
  return avg_loss, accuracy

def evaluate(model, dataloader, criterion, device):
  model.eval()
  total_loss = 0
  predictions = []
  true_labels = []

  with torch.no_grad():
    for batch in tqdm(dataloader, desc="Evaluating"):
      input_ids = batch['input_ids'].to(device)
      num_sentences = batch['num_sentences'].to(device)
      sentence_lengths = batch['sentence_lengths'].to(device)
      labels = batch['labels'].to(device)

      outputs, _, _ = model(input_ids, num_sentences, sentence_lengths)
      loss = criterion(outputs, labels)

      total_loss += loss.item()
      preds = outputs.argmax(dim=1).cpu().numpy()
      predictions.extend(preds)
      true_labels.extend(labels.cpu().numpy())

  avg_loss = total_loss / len(dataloader)
  accuracy = accuracy_score(true_labels, predictions)
  return avg_loss, accuracy, predictions, true_labels

# Training loop with early stopping
num_epochs = 50
best_accuracy = 0
patience = 3
patience_counter = 0

for epoch in range(num_epochs):
  train_loss, train_acc = train_epoch(model, train_dataloader, criterion, optimizer, device)
  val_loss, val_acc, _, _ = evaluate(model, test_dataloader, criterion, device)

  scheduler.step(val_acc) # Moved scheduler.step() here with validation accuracy

  print(f"Epoch {epoch+1}/{num_epochs}")
  print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
  print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

  if val_acc > best_accuracy:
    best_accuracy = val_acc
    torch.save(model.state_dict(), 'best_han_bpe_model.pth')
    patience_counter = 0
    print(f"  ✓ Saved best model (accuracy: {val_acc:.4f})")
  else:
    patience_counter += 1
    print(f"  No improvement for {patience_counter}/{patience} epochs")

  if patience_counter >= patience:
    print(f"\nEarly stopping triggered after {epoch+1} epochs!")
    break

  print()

print(f"✓ Training completed! Best validation accuracy: {best_accuracy:.4f}")
Training: 100%|██████████| 172/172 [00:07<00:00, 22.35it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 49.94it/s]
Epoch 1/50
  Train Loss: 2.9844, Train Acc: 0.0695
  Val Loss: 2.9632, Val Acc: 0.0936
  ✓ Saved best model (accuracy: 0.0936)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.42it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.97it/s]
Epoch 2/50
  Train Loss: 2.8932, Train Acc: 0.1460
  Val Loss: 2.7265, Val Acc: 0.2071
  ✓ Saved best model (accuracy: 0.2071)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.75it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.60it/s]
Epoch 3/50
  Train Loss: 2.5067, Train Acc: 0.2148
  Val Loss: 2.2809, Val Acc: 0.2952
  ✓ Saved best model (accuracy: 0.2952)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.36it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.46it/s]
Epoch 4/50
  Train Loss: 2.1404, Train Acc: 0.3105
  Val Loss: 2.0013, Val Acc: 0.3916
  ✓ Saved best model (accuracy: 0.3916)

Training: 100%|██████████| 172/172 [00:05<00:00, 28.71it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.89it/s]
Epoch 5/50
  Train Loss: 1.9193, Train Acc: 0.3879
  Val Loss: 1.8398, Val Acc: 0.4597
  ✓ Saved best model (accuracy: 0.4597)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.84it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.78it/s]
Epoch 6/50
  Train Loss: 1.7616, Train Acc: 0.4479
  Val Loss: 1.7218, Val Acc: 0.4965
  ✓ Saved best model (accuracy: 0.4965)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.36it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.49it/s]
Epoch 7/50
  Train Loss: 1.6300, Train Acc: 0.4893
  Val Loss: 1.6252, Val Acc: 0.5155
  ✓ Saved best model (accuracy: 0.5155)

Training: 100%|██████████| 172/172 [00:05<00:00, 30.02it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 53.31it/s]
Epoch 8/50
  Train Loss: 1.5276, Train Acc: 0.5230
  Val Loss: 1.5531, Val Acc: 0.5320
  ✓ Saved best model (accuracy: 0.5320)

Training: 100%|██████████| 172/172 [00:05<00:00, 30.33it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.89it/s]
Epoch 9/50
  Train Loss: 1.4372, Train Acc: 0.5566
  Val Loss: 1.4931, Val Acc: 0.5425
  ✓ Saved best model (accuracy: 0.5425)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.34it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 53.23it/s]
Epoch 10/50
  Train Loss: 1.3593, Train Acc: 0.5723
  Val Loss: 1.4533, Val Acc: 0.5496
  ✓ Saved best model (accuracy: 0.5496)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.17it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 53.87it/s]
Epoch 11/50
  Train Loss: 1.2891, Train Acc: 0.5997
  Val Loss: 1.4162, Val Acc: 0.5588
  ✓ Saved best model (accuracy: 0.5588)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.85it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.87it/s]
Epoch 12/50
  Train Loss: 1.2429, Train Acc: 0.6148
  Val Loss: 1.3898, Val Acc: 0.5656
  ✓ Saved best model (accuracy: 0.5656)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.02it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.43it/s]
Epoch 13/50
  Train Loss: 1.1878, Train Acc: 0.6261
  Val Loss: 1.3713, Val Acc: 0.5722
  ✓ Saved best model (accuracy: 0.5722)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.05it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.93it/s]
Epoch 14/50
  Train Loss: 1.1397, Train Acc: 0.6439
  Val Loss: 1.3511, Val Acc: 0.5797
  ✓ Saved best model (accuracy: 0.5797)

Training: 100%|██████████| 172/172 [00:05<00:00, 30.01it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.53it/s]
Epoch 15/50
  Train Loss: 1.0991, Train Acc: 0.6582
  Val Loss: 1.3357, Val Acc: 0.5835
  ✓ Saved best model (accuracy: 0.5835)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.33it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 53.00it/s]
Epoch 16/50
  Train Loss: 1.0537, Train Acc: 0.6652
  Val Loss: 1.3275, Val Acc: 0.5916
  ✓ Saved best model (accuracy: 0.5916)

Training: 100%|██████████| 172/172 [00:06<00:00, 28.42it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 46.00it/s]
Epoch 17/50
  Train Loss: 1.0160, Train Acc: 0.6830
  Val Loss: 1.3234, Val Acc: 0.5923
  ✓ Saved best model (accuracy: 0.5923)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.45it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.11it/s]
Epoch 18/50
  Train Loss: 0.9807, Train Acc: 0.6928
  Val Loss: 1.3156, Val Acc: 0.5945
  ✓ Saved best model (accuracy: 0.5945)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.65it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.56it/s]
Epoch 19/50
  Train Loss: 0.9505, Train Acc: 0.7032
  Val Loss: 1.3260, Val Acc: 0.5979
  ✓ Saved best model (accuracy: 0.5979)

Training: 100%|██████████| 172/172 [00:05<00:00, 28.87it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.18it/s]
Epoch 20/50
  Train Loss: 0.9164, Train Acc: 0.7181
  Val Loss: 1.3103, Val Acc: 0.6024
  ✓ Saved best model (accuracy: 0.6024)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.11it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.74it/s]
Epoch 21/50
  Train Loss: 0.8766, Train Acc: 0.7297
  Val Loss: 1.3111, Val Acc: 0.6066
  ✓ Saved best model (accuracy: 0.6066)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.27it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 53.43it/s]
Epoch 22/50
  Train Loss: 0.8431, Train Acc: 0.7419
  Val Loss: 1.3191, Val Acc: 0.6037
  No improvement for 1/3 epochs

Training: 100%|██████████| 172/172 [00:05<00:00, 28.97it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.71it/s]
Epoch 23/50
  Train Loss: 0.8196, Train Acc: 0.7502
  Val Loss: 1.3183, Val Acc: 0.6077
  ✓ Saved best model (accuracy: 0.6077)

Training: 100%|██████████| 172/172 [00:05<00:00, 28.75it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.61it/s]
Epoch 24/50
  Train Loss: 0.7874, Train Acc: 0.7628
  Val Loss: 1.3199, Val Acc: 0.6070
  No improvement for 1/3 epochs

Training: 100%|██████████| 172/172 [00:05<00:00, 29.67it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.84it/s]
Epoch 25/50
  Train Loss: 0.7555, Train Acc: 0.7735
  Val Loss: 1.3216, Val Acc: 0.6113
  ✓ Saved best model (accuracy: 0.6113)

Training: 100%|██████████| 172/172 [00:05<00:00, 29.86it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.14it/s]
Epoch 26/50
  Train Loss: 0.7309, Train Acc: 0.7824
  Val Loss: 1.3284, Val Acc: 0.6095
  No improvement for 1/3 epochs

Training: 100%|██████████| 172/172 [00:05<00:00, 29.73it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.87it/s]
Epoch 27/50
  Train Loss: 0.7022, Train Acc: 0.7902
  Val Loss: 1.3334, Val Acc: 0.6108
  No improvement for 2/3 epochs

Training: 100%|██████████| 172/172 [00:05<00:00, 29.25it/s]
Evaluating: 100%|██████████| 114/114 [00:02<00:00, 50.73it/s]
Epoch 28/50
  Train Loss: 0.6787, Train Acc: 0.7996
  Val Loss: 1.3473, Val Acc: 0.6108
  No improvement for 3/3 epochs

Early stopping triggered after 28 epochs!
✓ Training completed! Best validation accuracy: 0.6113