In [1]:
!pip install nltk -q
In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
Using device: cuda

1. Load IMDB Dataset¶

In [3]:
!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz -P /content
!tar -xzf /content/aclImdb_v1.tar.gz -C /content
--2025-11-14 10:00:52--  http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10
Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 84125825 (80M) [application/x-gzip]
Saving to: ‘/content/aclImdb_v1.tar.gz’

aclImdb_v1.tar.gz   100%[===================>]  80.23M  24.1MB/s    in 3.7s    

2025-11-14 10:00:56 (21.8 MB/s) - ‘/content/aclImdb_v1.tar.gz’ saved [84125825/84125825]

In [4]:
import os

def load_imdb_data(data_dir):
    texts = []
    labels = []
    for label_type in ['pos', 'neg']:
        file_dir = os.path.join(data_dir, label_type)
        for fname in os.listdir(file_dir):
            if fname.endswith('.txt'):
                with open(os.path.join(file_dir, fname), 'r', encoding='utf-8') as f:
                    texts.append(f.read())
                labels.append(1 if label_type == 'pos' else 0)

    return texts, labels

train_data_texts, train_data_labels = load_imdb_data('/content/aclImdb/train')
test_data_texts, test_data_labels = load_imdb_data('/content/aclImdb/test')

print(f"Train data size: {len(train_data_texts)}")
print(f"Test data size: {len(test_data_texts)}")
Train data size: 25000
Test data size: 25000

First review example:
Label: Positive
Text (first 200 chars): "The Odd Couple" is one of those movies that far surpasses its reputation. People all know it, they hum the theme song, they complain of living with a sloppy "Oscar" or a fastidious "Felix"...but they...

2. Data Analysis¶

In [6]:
import re

class IMDBPreprocessor:
    """Preprocessor specifically designed for IMDB reviews"""

    def __init__(self, min_length=20, max_length=50000):
        self.min_length = min_length
        self.max_length = max_length

    def clean_text(self, text):
        """Clean IMDB review text"""
        # Remove HTML tags
        text = re.sub(r'<[^>]+>', ' ', text)

        # Remove URLs
        text = re.sub(r'http[s]?://\S+', ' ', text)

        # Replace multiple spaces with single space
        text = re.sub(r'\s+', ' ', text)

        # Remove repeated characters (more than 3 times)
        text = re.sub(r'(\w)\1{3,}', r'\1\1\1', text)

        text = text.strip()

        # Truncate if too long
        if len(text) > self.max_length:
            text = text[:self.max_length]

        return text

    def is_valid(self, text):
        """Check if text is valid"""
        if not text or len(text.strip()) < self.min_length:
            return False
        return True

    def process_dataset(self, texts, labels):
        """Process entire dataset"""
        cleaned_texts = []
        valid_labels = []

        removed_count = 0

        for text, label in tqdm(zip(texts, labels), total=len(texts), desc="Preprocessing"):
            cleaned = self.clean_text(text)

            if not self.is_valid(cleaned):
                removed_count += 1
                continue

            cleaned_texts.append(cleaned)
            valid_labels.append(label)

        return cleaned_texts, valid_labels

# Preprocess data
preprocessor = IMDBPreprocessor(min_length=20, max_length=50000)
train_texts_cleaned, train_labels_cleaned = preprocessor.process_dataset(
    train_data_texts,
    train_data_labels
)
test_texts_cleaned, test_labels_cleaned = preprocessor.process_dataset(
    test_data_texts,
    test_data_labels
)
Preprocessing: 100%|██████████| 25000/25000 [00:03<00:00, 6651.10it/s]
Preprocessing: 100%|██████████| 25000/25000 [00:03<00:00, 7033.82it/s]

4. Analyze Hierarchical Structure¶

In [7]:
import nltk

nltk.download('punkt_tab', quiet=True)
nltk.download('punkt', quiet=True)

from nltk.tokenize import sent_tokenize

def analyze_sentence_structure(texts, n_samples=1000):
    """Analyze sentence structure of IMDB reviews"""
    import random

    # Random sampling
    sample_texts = random.sample(texts, min(n_samples, len(texts)))

    sent_counts = []
    sent_lengths = []

    for text in sample_texts:
        sentences = sent_tokenize(text)
        sent_counts.append(len(sentences))

        for sent in sentences:
            sent_lengths.append(len(sent.split()))

    sent_counts = np.array(sent_counts)
    sent_lengths = np.array(sent_lengths)

    print("=" * 60)
    print("Sentence Structure Analysis")
    print("=" * 60)

    print(f"\n[Sentences per Review]")
    print(f"  Average: {sent_counts.mean():.1f}")
    print(f"  Median: {np.median(sent_counts):.0f}")
    print(f"  75th percentile: {np.percentile(sent_counts, 75):.0f}")
    print(f"  95th percentile: {np.percentile(sent_counts, 95):.0f}")
    print(f"  Max: {sent_counts.max():.0f}")

    print(f"\n[Words per Sentence]")
    print(f"  Average: {sent_lengths.mean():.1f}")
    print(f"  Median: {np.median(sent_lengths):.0f}")
    print(f"  75th percentile: {np.percentile(sent_lengths, 75):.0f}")
    print(f"  95th percentile: {np.percentile(sent_lengths, 95):.0f}")

    # For IMDB, use 95th percentile to handle long reviews
    recommended_max_sents = int(np.percentile(sent_counts, 95))
    recommended_max_words = int(np.percentile(sent_lengths, 95))

    print(f"\n[Recommended Hyperparameters]")
    print(f"  max_sentences: {recommended_max_sents}")
    print(f"  max_words: {recommended_max_words}")

    return recommended_max_sents, recommended_max_words

max_sents, max_words = analyze_sentence_structure(train_texts_cleaned)
============================================================
Sentence Structure Analysis
============================================================

[Sentences per Review]
  Average: 12.5
  Median: 10
  75th percentile: 15
  95th percentile: 29
  Max: 143

[Words per Sentence]
  Average: 18.8
  Median: 16
  75th percentile: 24
  95th percentile: 42

[Recommended Hyperparameters]
  max_sentences: 29
  max_words: 42

5. Build Word-Level Vocabulary¶

In [8]:
from collections import Counter
from nltk.tokenize import word_tokenize

# Build word-level vocabulary
word_counter = Counter()
for text in tqdm(train_texts_cleaned, desc="Building vocabulary"):
    for sent in sent_tokenize(text):
        tokens = [w.lower() for w in word_tokenize(sent)]
        word_counter.update(tokens)

# Create vocab with frequency threshold (min_freq=2 for IMDB)
vocab = {'[PAD]': 0, '[UNK]': 1}
for word, count in word_counter.items():
    if count >= 2:
        vocab[word] = len(vocab)

print(f"\nVocabulary Statistics:")
print(f"  Total vocabulary size: {len(vocab)}")
print(f"  Unique tokens in corpus: {len(word_counter)}")
print(f"  Coverage: {len(vocab)/len(word_counter)*100:.1f}%")
print(f"\nMost common words: {word_counter.most_common(20)}")

# Create simple tokenizer class
class WordLevelTokenizer:
    def __init__(self, vocab):
        self.vocab = vocab
        self.pad_id = vocab['[PAD]']
        self.unk_id = vocab['[UNK]']

    def encode(self, text):
        tokens = [w.lower() for w in word_tokenize(text)]
        ids = [self.vocab.get(token, self.unk_id) for token in tokens]
        return ids

    def get_vocab_size(self):
        return len(self.vocab)

    def get_vocab(self):
        return self.vocab

tokenizer = WordLevelTokenizer(vocab)
print(f"\nTokenizer created with vocab size: {tokenizer.get_vocab_size()}")
Building vocabulary: 100%|██████████| 25000/25000 [00:32<00:00, 761.22it/s]
Vocabulary Statistics:
  Total vocabulary size: 52730
  Unique tokens in corpus: 103310
  Coverage: 51.0%

Most common words: [('the', 334839), (',', 275878), ('.', 273305), ('and', 163470), ('a', 162302), ('of', 145433), ('to', 135202), ('is', 110432), ('it', 95869), ('in', 93270), ('i', 86739), ('this', 75736), ('that', 73167), ("'s", 62227), ('was', 50468), ('as', 46835), ('for', 44101), ('with', 44073), ('movie', 43341), ('but', 42432)]

Tokenizer created with vocab size: 52730

6. Hierarchical Text Processing¶

In [9]:
from nltk.tokenize import sent_tokenize
import torch.nn.utils.rnn as rnn_utils

class HierarchicalTextProcessor:
    def __init__(self, tokenizer, max_sentences, max_words):
        self.tokenizer = tokenizer
        self.max_sentences = max_sentences
        self.max_words = max_words
        self.pad_id = tokenizer.pad_id

    def encode_document(self, text):
        """Encode document into (list of lists of token ids, list of sentence lengths, num sentences)"""
        sentences = sent_tokenize(text)

        # Limit number of sentences
        if len(sentences) > self.max_sentences:
            sentences = sentences[:self.max_sentences]

        encoded_sentences = []
        sentence_lengths = []

        for sent in sentences:
            ids = self.tokenizer.encode(sent)
            ids = ids[:self.max_words]  # Truncate words

            # Store actual length of sentence after truncation
            sentence_lengths.append(len(ids))
            encoded_sentences.append(ids)

        return encoded_sentences, len(encoded_sentences), sentence_lengths

# Create processor
text_processor = HierarchicalTextProcessor(
    tokenizer,
    max_sentences=max_sents,
    max_words=max_words
)

print(f"Hierarchical processor created:")
print(f"  Max sentences per document: {max_sents}")
print(f"  Max words per sentence: {max_words}")
Hierarchical processor created:
  Max sentences per document: 29
  Max words per sentence: 42
In [10]:
def process_data_hierarchical(texts, labels, processor):
    """Process texts into hierarchical structure"""
    all_input_ids = []  # List of lists of lists (document -> sentence -> token_ids)
    all_sentence_lengths = []  # List of lists (document -> sentence_lengths)
    all_num_sentences = []  # List of ints (document -> num_sentences)

    for text in tqdm(texts, desc="Processing hierarchical texts"):
        # doc_ids: list of lists (sentences -> token_ids)
        # num_sents: int
        # sent_lens: list of ints
        doc_ids, num_sents, sent_lens = processor.encode_document(text)
        all_input_ids.append(doc_ids)
        all_num_sentences.append(num_sents)
        all_sentence_lengths.append(sent_lens)

    return {
        'input_ids': all_input_ids,
        'num_sentences': torch.tensor(all_num_sentences, dtype=torch.long),
        'sentence_lengths': all_sentence_lengths,
        'labels': torch.tensor(labels, dtype=torch.long)
    }

# Process train and test data
train_processed = process_data_hierarchical(
    train_texts_cleaned,
    train_labels_cleaned,
    text_processor
)
test_processed = process_data_hierarchical(
    test_texts_cleaned,
    test_labels_cleaned,
    text_processor
)

print(f"\nProcessing completed:")
print(f"  Train samples: {len(train_processed['labels'])}")
print(f"  Test samples: {len(test_processed['labels'])}")
Processing hierarchical texts: 100%|██████████| 25000/25000 [00:31<00:00, 787.71it/s]
Processing hierarchical texts: 100%|██████████| 25000/25000 [00:31<00:00, 800.14it/s]
Processing completed:
  Train samples: 25000
  Test samples: 25000

7. Create Dataset and DataLoader¶

In [11]:
import torch.nn.utils.rnn as rnn_utils

class HierarchicalTextDataset(Dataset):
    def __init__(self, input_ids, num_sentences, sentence_lengths, labels):
        self.input_ids = input_ids
        self.num_sentences = num_sentences
        self.sentence_lengths = sentence_lengths
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'num_sentences': self.num_sentences[idx],
            'sentence_lengths': self.sentence_lengths[idx],
            'labels': self.labels[idx]
        }

def collate_fn_hierarchical(batch):
    """Custom collate function for hierarchical data"""

    input_ids_docs_list = [item['input_ids'] for item in batch]
    num_sentences_list = [item['num_sentences'] for item in batch]
    sentence_lengths_docs_list = [item['sentence_lengths'] for item in batch]
    labels = torch.stack([item['labels'] for item in batch])

    batch_size = len(batch)
    max_doc_num_sentences_in_batch = max(num_sentences_list) if num_sentences_list else 0

    # Flatten all sentences from all documents
    all_sentences_in_batch = []
    all_sentence_lengths_in_batch = []
    doc_sentence_pointers = []

    current_total_sentences = 0
    for i, doc_sentences in enumerate(input_ids_docs_list):
        doc_sentence_lengths = sentence_lengths_docs_list[i]

        sentences_as_tensors = [torch.tensor(s, dtype=torch.long) for s in doc_sentences]

        all_sentences_in_batch.extend(sentences_as_tensors)
        all_sentence_lengths_in_batch.extend(doc_sentence_lengths)

        doc_sentence_pointers.append((current_total_sentences, current_total_sentences + len(doc_sentences)))
        current_total_sentences += len(doc_sentences)

    # Pad all sentences
    if all_sentences_in_batch:
        padded_sentences_for_attention = rnn_utils.pad_sequence(
            all_sentences_in_batch,
            batch_first=True,
            padding_value=text_processor.pad_id
        )
    else:
        padded_sentences_for_attention = torch.empty(0, text_processor.max_words, dtype=torch.long)

    # Create padded sentence lengths per document
    padded_sentence_lengths_per_doc = []
    for doc_len_list in sentence_lengths_docs_list:
        current_doc_lengths = torch.tensor(doc_len_list, dtype=torch.long)
        if len(current_doc_lengths) < max_doc_num_sentences_in_batch:
            pad_len = max_doc_num_sentences_in_batch - len(current_doc_lengths)
            current_doc_lengths = torch.cat([current_doc_lengths, torch.zeros(pad_len, dtype=torch.long)], dim=0)
        padded_sentence_lengths_per_doc.append(current_doc_lengths)
    padded_sentence_lengths_per_doc_tensor = torch.stack(padded_sentence_lengths_per_doc)

    # Reconstruct documents from padded sentences
    final_input_ids = []
    for i, (start_idx, end_idx) in enumerate(doc_sentence_pointers):
        current_doc_sentences = padded_sentences_for_attention[start_idx:end_idx]
        if current_doc_sentences.size(0) < max_doc_num_sentences_in_batch:
            sentence_pad_shape = (max_doc_num_sentences_in_batch - current_doc_sentences.size(0), current_doc_sentences.size(1))
            sentence_pad = torch.full(sentence_pad_shape, text_processor.pad_id, dtype=torch.long)
            current_doc_sentences = torch.cat([current_doc_sentences, sentence_pad], dim=0)
        final_input_ids.append(current_doc_sentences)
    final_input_ids_tensor = torch.stack(final_input_ids)

    return {
        'input_ids': final_input_ids_tensor,
        'num_sentences': torch.tensor(num_sentences_list, dtype=torch.long),
        'sentence_lengths': padded_sentence_lengths_per_doc_tensor,
        'labels': labels
    }

# Create datasets
train_dataset = HierarchicalTextDataset(
    train_processed['input_ids'],
    train_processed['num_sentences'],
    train_processed['sentence_lengths'],
    train_processed['labels']
)

test_dataset = HierarchicalTextDataset(
    test_processed['input_ids'],
    test_processed['num_sentences'],
    test_processed['sentence_lengths'],
    test_processed['labels']
)

print(f"Train Dataset size: {len(train_dataset)}")
print(f"Test Dataset size: {len(test_dataset)}")

# Create dataloaders
batch_size = 32  # Smaller batch size for longer IMDB reviews

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    collate_fn=collate_fn_hierarchical
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True,
    collate_fn=collate_fn_hierarchical
)

print(f"Train batches: {len(train_dataloader)}")
print(f"Test batches: {len(test_dataloader)}")
Train Dataset size: 25000
Test Dataset size: 25000
Train batches: 782
Test batches: 782

8. Load GloVe Embeddings¶

In [12]:
!wget http://nlp.stanford.edu/data/glove.6B.zip -P /content
!unzip -q /content/glove.6B.zip -d /content
--2025-11-14 10:02:48--  http://nlp.stanford.edu/data/glove.6B.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://nlp.stanford.edu/data/glove.6B.zip [following]
--2025-11-14 10:02:48--  https://nlp.stanford.edu/data/glove.6B.zip
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip [following]
--2025-11-14 10:02:48--  https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 862182613 (822M) [application/zip]
Saving to: ‘/content/glove.6B.zip’

glove.6B.zip        100%[===================>] 822.24M  5.00MB/s    in 2m 39s  

2025-11-14 10:05:27 (5.18 MB/s) - ‘/content/glove.6B.zip’ saved [862182613/862182613]

In [13]:
import numpy as np
from tqdm import tqdm

def load_glove_embeddings(glove_file, embedding_dim):
    """Load GloVe embeddings from file"""
    embeddings_index = {}

    with open(glove_file, 'r', encoding='utf-8') as f:
        for line in tqdm(f, desc="Loading GloVe"):
            values = line.split()
            word = values[0]
            vector = np.asarray(values[1:], dtype='float32')
            embeddings_index[word] = vector

    print(f'Loaded {len(embeddings_index)} word vectors.')
    return embeddings_index

glove_embeddings = load_glove_embeddings('/content/glove.6B.100d.txt', 100)
Loading GloVe: 400000it [00:08, 44735.33it/s]
Loaded 400000 word vectors.

In [14]:
embedding_size = 100

def create_embedding_matrix(vocab, glove_embeddings, embedding_size):
    """Create embedding matrix from GloVe"""
    vocab_size = len(vocab)
    embedding_matrix = np.zeros((vocab_size, embedding_size))

    matched = 0
    for word, idx in tqdm(vocab.items(), desc="Creating embedding matrix"):
        if word in glove_embeddings:
            embedding_matrix[idx] = glove_embeddings[word]
            matched += 1
        else:
            embedding_matrix[idx] = np.random.normal(0, 0.1, embedding_size)

    # Set PAD token to zeros
    pad_id = vocab['[PAD]']
    embedding_matrix[pad_id] = np.zeros(embedding_size)

    print(f"Matched {matched}/{vocab_size} words with GloVe ({matched/vocab_size*100:.1f}%)")

    return embedding_matrix

embedding_matrix = create_embedding_matrix(
    tokenizer.get_vocab(),
    glove_embeddings,
    embedding_size
)
Creating embedding matrix: 100%|██████████| 52730/52730 [00:00<00:00, 390562.27it/s]
Matched 46075/52730 words with GloVe (87.4%)

9. Build HAN Model¶

In [15]:
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils

class Attention(nn.Module):
    """Attention mechanism"""

    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.attention = nn.Linear(hidden_size, hidden_size)
        self.context_vector = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, hidden_state, mask):
        u = torch.tanh(self.attention(hidden_state))
        score = self.context_vector(u).squeeze(-1)

        score.masked_fill_(mask == 0, -1e9)

        attention_weights = F.softmax(score, dim=-1)
        context_vector = torch.bmm(
            attention_weights.unsqueeze(1),
            hidden_state
        ).squeeze(1)

        return context_vector, attention_weights

class HAN(nn.Module):
    """Hierarchical Attention Network"""

    def __init__(self, vocab_size, embedding_size, hidden_size, num_classes, dropout=0.5, pad_idx=0):
        super(HAN, self).__init__()
        self.pad_idx = pad_idx
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx)

        # Word-level GRU
        self.word_gru = nn.GRU(
            embedding_size,
            hidden_size,
            bidirectional=True,
            batch_first=True
        )
        self.word_attention = Attention(hidden_size * 2)

        # Sentence-level GRU
        self.sentence_gru = nn.GRU(
            hidden_size * 2,
            hidden_size,
            bidirectional=True,
            batch_first=True
        )
        self.sentence_attention = Attention(hidden_size * 2)

        self.fc = nn.Linear(hidden_size * 2, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, documents, num_sentences, sentence_lengths):
        batch_size, max_sentences, max_words = documents.size()

        # --- Word-level processing ---
        word_input = documents.view(batch_size * max_sentences, max_words)
        flat_sentence_lengths = sentence_lengths.view(-1)

        # Filter out padded sentences
        non_zero_lengths_mask = (flat_sentence_lengths > 0)
        actual_sentences_input = word_input[non_zero_lengths_mask]
        actual_sentence_lengths = flat_sentence_lengths[non_zero_lengths_mask]

        if actual_sentences_input.size(0) == 0:
            document_vectors = torch.zeros(batch_size, 2 * self.sentence_gru.hidden_size, device=documents.device)
            output = self.fc(self.dropout(document_vectors))
            return output, None, None

        # Embed words
        embedded_words = self.embedding(actual_sentences_input)

        # Pack padded sequence for word GRU
        packed_embedded_words = rnn_utils.pack_padded_sequence(
            embedded_words,
            actual_sentence_lengths.cpu(),
            batch_first=True,
            enforce_sorted=False
        )
        packed_word_gru_out, _ = self.word_gru(packed_embedded_words)
        word_gru_out, _ = rnn_utils.pad_packed_sequence(
            packed_word_gru_out,
            batch_first=True,
            total_length=max_words
        )

        # Expand back to full batch
        full_word_gru_out = torch.zeros(
            batch_size * max_sentences, max_words, 2 * self.word_gru.hidden_size,
            device=documents.device, dtype=word_gru_out.dtype
        )
        full_word_gru_out[non_zero_lengths_mask] = word_gru_out

        # Word attention
        word_attention_mask = (word_input != self.pad_idx)
        sentence_vectors, _ = self.word_attention(full_word_gru_out.contiguous(), word_attention_mask)

        # --- Sentence-level processing ---
        sentence_vectors_reshaped = sentence_vectors.view(batch_size, max_sentences, -1)

        non_zero_num_sentences_mask = (num_sentences > 0)
        actual_num_sentences = num_sentences[non_zero_num_sentences_mask]
        actual_sentence_vectors = sentence_vectors_reshaped[non_zero_num_sentences_mask]

        if actual_sentence_vectors.size(0) == 0:
            document_vectors = torch.zeros(batch_size, 2 * self.sentence_gru.hidden_size, device=documents.device)
            output = self.fc(self.dropout(document_vectors))
            return output, None, None

        # Pack padded sequence for sentence GRU
        packed_embedded_sentences = rnn_utils.pack_padded_sequence(
            actual_sentence_vectors,
            actual_num_sentences.cpu(),
            batch_first=True,
            enforce_sorted=False
        )
        packed_sentence_gru_out, _ = self.sentence_gru(packed_embedded_sentences)
        sentence_gru_out, _ = rnn_utils.pad_packed_sequence(
            packed_sentence_gru_out,
            batch_first=True,
            total_length=max_sentences
        )

        # Expand back to full batch
        full_sentence_gru_out = torch.zeros(
            batch_size, max_sentences, 2 * self.sentence_gru.hidden_size,
            device=documents.device, dtype=sentence_gru_out.dtype
        )
        full_sentence_gru_out[non_zero_num_sentences_mask] = sentence_gru_out

        # Sentence attention
        sentence_attention_mask = (torch.arange(max_sentences, device=documents.device).expand(batch_size, max_sentences) < num_sentences.unsqueeze(1))
        document_vectors, sentence_att_weights = self.sentence_attention(full_sentence_gru_out, sentence_attention_mask)

        output = self.fc(self.dropout(document_vectors))
        return output, None, sentence_att_weights

# Instantiate model
vocab_size = tokenizer.get_vocab_size()
embedding_size = 100
hidden_size = 64  # Can be adjusted for IMDB
num_classes = 2  # Binary classification

model = HAN(
    vocab_size=vocab_size,
    embedding_size=embedding_size,
    hidden_size=hidden_size,
    num_classes=num_classes,
    dropout=0.5,
    pad_idx=vocab['[PAD]']
)

# Load pretrained GloVe embeddings
model.embedding.weight.data.copy_(torch.FloatTensor(embedding_matrix))
model.to(device)

print(f"\nModel created successfully!")
print(f"  Vocabulary size: {vocab_size}")
print(f"  Embedding size: {embedding_size}")
print(f"  Hidden size: {hidden_size}")
print(f"  Number of classes: {num_classes}")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
Model created successfully!
  Vocabulary size: 52730
  Embedding size: 100
  Hidden size: 64
  Number of classes: 2
  Total parameters: 5,444,778

10. Training¶

In [18]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=2
)

def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    predictions = []
    true_labels = []

    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        num_sentences = batch['num_sentences'].to(device)
        sentence_lengths = batch['sentence_lengths'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs, _, _ = model(input_ids, num_sentences, sentence_lengths)
        loss = criterion(outputs, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1).cpu().numpy()
        predictions.extend(preds)
        true_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='binary')

    return avg_loss, accuracy, precision, recall, f1

def evaluate(model, dataloader, criterion, device):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            num_sentences = batch['num_sentences'].to(device)
            sentence_lengths = batch['sentence_lengths'].to(device)
            labels = batch['labels'].to(device)

            outputs, _, _ = model(input_ids, num_sentences, sentence_lengths)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            preds = outputs.argmax(dim=1).cpu().numpy()
            predictions.extend(preds)
            true_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='binary')

    return avg_loss, accuracy, precision, recall, f1, predictions, true_labels
In [19]:
# Training loop
num_epochs = 20
best_accuracy = 0
patience = 3
patience_counter = 0

train_losses = []
train_accs = []
val_losses = []
val_accs = []

print("=" * 60)
print("Starting Training")
print("=" * 60)

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 60)

    # Train
    train_loss, train_acc, train_prec, train_rec, train_f1 = train_epoch(
        model, train_dataloader, criterion, optimizer, device
    )

    # Evaluate
    val_loss, val_acc, val_prec, val_rec, val_f1, _, _ = evaluate(
        model, test_dataloader, criterion, device
    )

    # Update scheduler
    scheduler.step(val_acc)

    # Store metrics
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    # Print results
    print(f"\nTrain - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")

    # Save best model
    if val_acc > best_accuracy:
        best_accuracy = val_acc
        torch.save(model.state_dict(), 'best_han_imdb_model.pth')
        patience_counter = 0
        print(f"✓ Saved best model (accuracy: {val_acc:.4f})")
    else:
        patience_counter += 1
        print(f"No improvement for {patience_counter}/{patience} epochs")

    # Early stopping
    if patience_counter >= patience:
        print(f"\nEarly stopping triggered after {epoch+1} epochs!")
        break

print("\n" + "=" * 60)
print(f"Training completed! Best validation accuracy: {best_accuracy:.4f}")
print("=" * 60)
============================================================
Starting Training
============================================================

Epoch 1/20
------------------------------------------------------------
Training: 100%|██████████| 782/782 [00:24<00:00, 32.33it/s]
Evaluating: 100%|██████████| 782/782 [00:10<00:00, 74.94it/s]
Train - Loss: 0.5511, Acc: 0.6979, F1: 0.7051
Val   - Loss: 0.4006, Acc: 0.8226, F1: 0.8048
✓ Saved best model (accuracy: 0.8226)

Epoch 2/20
------------------------------------------------------------
Training: 100%|██████████| 782/782 [00:22<00:00, 35.34it/s]
Evaluating: 100%|██████████| 782/782 [00:10<00:00, 75.12it/s]
Train - Loss: 0.3296, Acc: 0.8636, F1: 0.8642
Val   - Loss: 0.3076, Acc: 0.8676, F1: 0.8625
✓ Saved best model (accuracy: 0.8676)

Epoch 3/20
------------------------------------------------------------
Training: 100%|██████████| 782/782 [00:23<00:00, 33.98it/s]
Evaluating: 100%|██████████| 782/782 [00:10<00:00, 74.90it/s]
Train - Loss: 0.2797, Acc: 0.8871, F1: 0.8872
Val   - Loss: 0.2750, Acc: 0.8836, F1: 0.8826
✓ Saved best model (accuracy: 0.8836)

Epoch 4/20
------------------------------------------------------------
Training: 100%|██████████| 782/782 [00:22<00:00, 34.78it/s]
Evaluating: 100%|██████████| 782/782 [00:10<00:00, 74.01it/s]
Train - Loss: 0.2510, Acc: 0.9023, F1: 0.9023
Val   - Loss: 0.2602, Acc: 0.8921, F1: 0.8927
✓ Saved best model (accuracy: 0.8921)

Epoch 5/20
------------------------------------------------------------
Training: 100%|██████████| 782/782 [00:22<00:00, 35.11it/s]
Evaluating: 100%|██████████| 782/782 [00:10<00:00, 74.01it/s]
Train - Loss: 0.2241, Acc: 0.9147, F1: 0.9146
Val   - Loss: 0.2548, Acc: 0.8954, F1: 0.8933
✓ Saved best model (accuracy: 0.8954)

Epoch 6/20
------------------------------------------------------------
Training: 100%|██████████| 782/782 [00:22<00:00, 34.99it/s]
Evaluating: 100%|██████████| 782/782 [00:10<00:00, 74.29it/s]
Train - Loss: 0.2019, Acc: 0.9244, F1: 0.9243
Val   - Loss: 0.2472, Acc: 0.9002, F1: 0.9008
✓ Saved best model (accuracy: 0.9002)

Epoch 7/20
------------------------------------------------------------
Training: 100%|██████████| 782/782 [00:22<00:00, 34.48it/s]
Evaluating: 100%|██████████| 782/782 [00:10<00:00, 74.48it/s]
Train - Loss: 0.1777, Acc: 0.9345, F1: 0.9344
Val   - Loss: 0.2565, Acc: 0.8977, F1: 0.9004
No improvement for 1/3 epochs

Epoch 8/20
------------------------------------------------------------
Training: 100%|██████████| 782/782 [00:22<00:00, 34.79it/s]
Evaluating: 100%|██████████| 782/782 [00:10<00:00, 74.46it/s]
Train - Loss: 0.1535, Acc: 0.9456, F1: 0.9455
Val   - Loss: 0.2572, Acc: 0.9000, F1: 0.8988
No improvement for 2/3 epochs

Epoch 9/20
------------------------------------------------------------
Training: 100%|██████████| 782/782 [00:22<00:00, 35.15it/s]
Evaluating: 100%|██████████| 782/782 [00:10<00:00, 73.80it/s]
Train - Loss: 0.1309, Acc: 0.9551, F1: 0.9551
Val   - Loss: 0.2748, Acc: 0.8980, F1: 0.8995
No improvement for 3/3 epochs

Early stopping triggered after 9 epochs!

============================================================
Training completed! Best validation accuracy: 0.9002
============================================================