In [1]:
!pip install tokenizers -q
In [2]:
!pip install datasets nltk transformers -q
In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
Using device: cuda
In [ ]:
dataset = load_dataset("SetFit/20_newsgroups")
train_data = dataset['train']
test_data = dataset['test']
print(f"Train data size: {len(train_data)}")
print(f"Train data example: {train_data[0]}")
print(f"Test data size: {len(test_data)}")
print(f"Test data example: {test_data[0]}")
In [ ]:
train_data = train_data.to_pandas()
train_data['text_length'] = train_data['text'].str.len()
train_data['word_count'] = train_data['text'].str.split().str.len()
print("\n[Text Length Statistics]")
print(train_data[['text_length', 'word_count']].describe())
def has_email_headers(text):
headers = ['From:', 'Subject:', 'Organization:', 'Lines:', 'NNTP-Posting-Host:']
return any(header in text for header in headers)
train_data['has_headers'] = train_data['text'].apply(has_email_headers)
print(f"\n[Email Header Detection]")
print(f"Documents with email headers: {train_data['has_headers'].sum()} / {len(train_data)}")
print(f"Percentage: {train_data['has_headers'].mean():.2%}")
import re
def count_urls(text):
return len(re.findall(r'http[s]?://\S+', text))
def count_emails(text):
return len(re.findall(r'\S+@\S+', text))
def count_quote_lines(text):
return len(re.findall(r'^{{>}}+', text, re.MULTILINE))
train_data['url_count'] = train_data['text'].apply(count_urls)
train_data['email_count'] = train_data['text'].apply(count_emails)
train_data['quote_count'] = train_data['text'].apply(count_quote_lines)
print(f"\n[Noise Statistics]")
print(f"Documents containing URLs: {(train_data['url_count'] > 0).sum()}")
print(f"Documents containing emails: {(train_data['email_count'] > 0).sum()}")
print(f"Documents containing quotes: {(train_data['quote_count'] > 0).sum()}")
# 5. Class distribution
print(f"\n[Class Distribution]")
label_counts = train_data['label'].value_counts().sort_index()
print(label_counts)
print(f"\nMax class samples / Min class samples = {label_counts.max() / label_counts.min():.2f}")
print("\n[Abnormal Sample Detection]")
print(f"Empty documents: {(train_data['text'].str.strip() == '').sum()}")
print(f"Very short documents (<50 chars): {(train_data['text_length'] < 50).sum()}")
print(f"Very long documents (>10000 chars): {(train_data['text_length'] > 10000).sum()}")
In [6]:
import re
class SimplePreprocessor:
def __init__(self, min_length=20, max_length=50000):
self.min_length = min_length
self.max_length = max_length
def clean_text(self, text):
text = re.sub(r'(\S)\1{10,}', r'\1\1\1', text)
text = re.sub(r'\s+', ' ', text)
text = text.strip()
if len(text) > self.max_length:
text = text[:self.max_length]
return text
def is_valid(self, text):
if not text or len(text.strip()) < self.min_length:
return False
return True
def process_dataset(self, texts, labels):
cleaned_texts = []
valid_labels = []
removed_empty = 0
removed_short = 0
kept = 0
for text, label in zip(texts, labels):
cleaned = self.clean_text(text)
if not cleaned:
removed_empty += 1
continue
if not self.is_valid(cleaned):
removed_short += 1
continue
cleaned_texts.append(cleaned)
valid_labels.append(label)
kept += 1
return cleaned_texts, valid_labels
preprocessor = SimplePreprocessor(min_length=20, max_length=50000)
train_texts_cleaned, train_labels_cleaned = preprocessor.process_dataset(
train_data['text'],
train_data['label']
)
test_texts_cleaned, test_labels_cleaned = preprocessor.process_dataset(
test_data['text'],
test_data['label']
)
In [ ]:
import nltk
nltk.download('punkt_tab')
nltk.download('punkt', quiet=True)
from nltk.tokenize import sent_tokenize
def quick_analyze_structure(texts, n_samples=500):
"""Quickly analyze sentence structure"""
import random
# Random sampling
sample_texts = random.sample(texts, min(n_samples, len(texts)))
sent_counts = []
sent_lengths = []
for text in sample_texts:
sentences = sent_tokenize(text)
sent_counts.append(len(sentences))
for sent in sentences:
sent_lengths.append(len(sent.split()))
sent_counts = np.array(sent_counts)
sent_lengths = np.array(sent_lengths)
print("Sentence Structure Analysis:")
print(f"\nSentences per document:")
print(f" Average: {sent_counts.mean():.1f}")
print(f" Median: {np.median(sent_counts):.0f}")
print(f" 75th percentile: {np.percentile(sent_counts, 75):.0f}")
print(f" 95th percentile: {np.percentile(sent_counts, 95):.0f}")
print(f"\nWords per sentence:")
print(f" Average: {sent_lengths.mean():.1f}")
print(f" Median: {np.median(sent_lengths):.0f}")
print(f" 75th percentile: {np.percentile(sent_lengths, 75):.0f}")
print(f" 95th percentile: {np.percentile(sent_lengths, 95):.0f}")
# Recommended hyperparameters (using 95th percentile)
recommended_max_sents = int(np.percentile(sent_counts, 95))
recommended_max_words = int(np.percentile(sent_lengths, 95))
print(f"\n[Recommended Hyperparameters]")
print(f" max_sentences: {recommended_max_sents}")
print(f" max_words: {recommended_max_words}")
return recommended_max_sents, recommended_max_words
max_sents, max_words = quick_analyze_structure(train_texts_cleaned)
In [ ]:
from collections import Counter
from nltk.tokenize import word_tokenize
# Build word-level vocabulary
word_counter = Counter()
for text in tqdm(train_texts_cleaned, desc="Building vocabulary"):
for sent in sent_tokenize(text):
tokens = [w.lower() for w in word_tokenize(sent)]
word_counter.update(tokens)
# Create vocab with frequency threshold
vocab = {'[PAD]': 0, '[UNK]': 1}
for word, count in word_counter.items():
if count >= 2:
vocab[word] = len(vocab)
print(f"Vocabulary size: {len(vocab)}")
print(f"Most common words: {word_counter.most_common(20)}")
class WordLevelTokenizer:
def __init__(self, vocab):
self.vocab = vocab
self.pad_id = vocab['[PAD]']
self.unk_id = vocab['[UNK]']
def encode(self, text):
tokens = [w.lower() for w in word_tokenize(text)]
ids = [self.vocab.get(token, self.unk_id) for token in tokens]
return ids
def get_vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return self.vocab
tokenizer = WordLevelTokenizer(vocab)
Building vocabulary: 100%|██████████| 10949/10949 [00:13<00:00, 797.00it/s]
Vocabulary size: 50758
Most common words: [(',', 108751), ('the', 105427), ('.', 101720), ('>', 71587), ('to', 52446), ("'ax", 52276), ('of', 46542), ('a', 43694), ('and', 42108), ('(', 38359), (')', 38147), ('i', 34404), ('is', 30765), ('in', 30619), ('that', 27825), (':', 26124), ('it', 23441), ('?', 21481), ("''", 20117), ('*', 19847)]
In [9]:
from nltk.tokenize import sent_tokenize
import torch.nn.utils.rnn as rnn_utils
class HierarchicalTextProcessor:
def __init__(self, tokenizer, max_sentences, max_words):
self.tokenizer = tokenizer
self.max_sentences = max_sentences
self.max_words = max_words
self.pad_id = tokenizer.pad_id
def encode_document(self, text):
"""Encode document into (list of lists of token ids, list of sentence lengths, num sentences)"""
sentences = sent_tokenize(text)
# Limit number of sentences
if len(sentences) > self.max_sentences:
sentences = sentences[:self.max_sentences]
encoded_sentences = []
sentence_lengths = []
for sent in sentences:
ids = self.tokenizer.encode(sent)
ids = ids[:self.max_words] # Truncate words
# Store actual length of sentence after truncation
sentence_lengths.append(len(ids))
encoded_sentences.append(ids)
return encoded_sentences, len(encoded_sentences), sentence_lengths
# Using recommended hyperparameters
text_processor_han = HierarchicalTextProcessor(
tokenizer,
max_sentences=max_sents,
max_words=max_words
)
In [10]:
def process_data_hierarchical(texts, labels, processor):
all_input_ids = [] # List of lists of lists (document -> sentence -> token_ids)
all_sentence_lengths = [] # List of lists (document -> sentence_lengths)
all_num_sentences = [] # List of ints (document -> num_sentences)
for text in tqdm(texts, desc="Processing hierarchical texts"):
# doc_ids: list of lists (sentences -> token_ids)
# num_sents: int
# sent_lens: list of ints
doc_ids, num_sents, sent_lens = processor.encode_document(text)
all_input_ids.append(doc_ids)
all_num_sentences.append(num_sents)
all_sentence_lengths.append(sent_lens)
return {
'input_ids': all_input_ids, # No tensor conversion yet
'num_sentences': torch.tensor(all_num_sentences, dtype=torch.long),
'sentence_lengths': all_sentence_lengths, # No tensor conversion yet
'labels': torch.tensor(labels, dtype=torch.long)
}
train_processed = process_data_hierarchical(
train_texts_cleaned,
train_labels_cleaned,
text_processor_han
)
test_processed = process_data_hierarchical(
test_texts_cleaned,
test_labels_cleaned,
text_processor_han
)
Processing hierarchical texts: 100%|██████████| 10949/10949 [00:10<00:00, 1027.36it/s] Processing hierarchical texts: 100%|██████████| 7267/7267 [00:06<00:00, 1058.10it/s]
In [ ]:
import torch.nn.utils.rnn as rnn_utils
class HierarchicalTextDataset(Dataset):
def __init__(self, input_ids, num_sentences, sentence_lengths, labels):
self.input_ids = input_ids # This is now a list of lists of lists
self.num_sentences = num_sentences # This is a tensor (batch_size,)
self.sentence_lengths = sentence_lengths # This is a list of lists (doc -> sentence lengths)
self.labels = labels # This is a tensor
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
# Returns raw data from processing, to be padded by collate_fn
return {
'input_ids': self.input_ids[idx], # list of lists (sentences -> token_ids)
'num_sentences': self.num_sentences[idx], # int (actual num sentences in this document)
'sentence_lengths': self.sentence_lengths[idx], # list of ints (actual lengths of sentences in this document)
'labels': self.labels[idx]
}
def collate_fn_hierarchical(batch):
# batch is a list of dictionaries, where each dict is an item from __getitem__
# Separate components from the batch
input_ids_docs_list = [item['input_ids'] for item in batch] # List of (list of lists of token_ids)
num_sentences_list = [item['num_sentences'] for item in batch] # List of ints
sentence_lengths_docs_list = [item['sentence_lengths'] for item in batch] # List of (list of ints)
labels = torch.stack([item['labels'] for item in batch]) # Tensor (batch_size,)
batch_size = len(batch)
max_doc_num_sentences_in_batch = max(num_sentences_list) if num_sentences_list else 0
# Step 1: Prepare word-level data for packing
# We need to flatten all sentences from all documents in the batch
# and collect their actual lengths.
all_sentences_in_batch = []
all_sentence_lengths_in_batch = []
doc_sentence_pointers = [] # Helps in reconstructing document structure after word GRU
current_total_sentences = 0
for i, doc_sentences in enumerate(input_ids_docs_list):
# doc_sentences is a list of lists of token_ids
# doc_sentence_lengths is a list of ints
doc_sentence_lengths = sentence_lengths_docs_list[i]
# Convert each sentence (list of token_ids) to a tensor
sentences_as_tensors = [torch.tensor(s, dtype=torch.long) for s in doc_sentences]
all_sentences_in_batch.extend(sentences_as_tensors)
all_sentence_lengths_in_batch.extend(doc_sentence_lengths)
doc_sentence_pointers.append((current_total_sentences, current_total_sentences + len(doc_sentences)))
current_total_sentences += len(doc_sentences)
padded_sentences_for_attention = rnn_utils.pad_sequence(all_sentences_in_batch, batch_first=True, padding_value=text_processor_han.pad_id)
# Create a tensor for the sentence lengths per document, padded to max_doc_num_sentences_in_batch
padded_sentence_lengths_per_doc = []
for doc_len_list in sentence_lengths_docs_list:
current_doc_lengths = torch.tensor(doc_len_list, dtype=torch.long)
if len(current_doc_lengths) < max_doc_num_sentences_in_batch:
pad_len = max_doc_num_sentences_in_batch - len(current_doc_lengths)
current_doc_lengths = torch.cat([current_doc_lengths, torch.zeros(pad_len, dtype=torch.long)], dim=0)
padded_sentence_lengths_per_doc.append(current_doc_lengths)
padded_sentence_lengths_per_doc_tensor = torch.stack(padded_sentence_lengths_per_doc) # (batch_size, max_sentences_in_batch)
# Create final input_ids tensor for the model. This will be (batch_size, max_sentences_in_batch, max_words_in_batch)
# It needs to be constructed from padded_sentences_for_attention
# Reconstruct documents from padded_sentences_for_attention
final_input_ids = []
for i, (start_idx, end_idx) in enumerate(doc_sentence_pointers):
current_doc_sentences = padded_sentences_for_attention[start_idx:end_idx]
if current_doc_sentences.size(0) < max_doc_num_sentences_in_batch:
sentence_pad_shape = (max_doc_num_sentences_in_batch - current_doc_sentences.size(0), current_doc_sentences.size(1))
sentence_pad = torch.full(sentence_pad_shape, text_processor_han.pad_id, dtype=torch.long)
current_doc_sentences = torch.cat([current_doc_sentences, sentence_pad], dim=0)
final_input_ids.append(current_doc_sentences)
final_input_ids_tensor = torch.stack(final_input_ids) # (batch_size, max_sentences_in_batch, max_words_in_batch)
return {
'input_ids': final_input_ids_tensor, # (batch_size, max_sentences_in_batch, max_words_in_batch)
'num_sentences': torch.tensor(num_sentences_list, dtype=torch.long), # (batch_size,) - actual counts
'sentence_lengths': padded_sentence_lengths_per_doc_tensor, # (batch_size, max_sentences_in_batch) - actual lengths of each sentence, padded for consistent shape
'labels': labels
}
# Recreate datasets
train_dataset = HierarchicalTextDataset(
train_processed['input_ids'],
train_processed['num_sentences'],
train_processed['sentence_lengths'],
train_processed['labels']
)
test_dataset = HierarchicalTextDataset(
test_processed['input_ids'],
test_processed['num_sentences'],
test_processed['sentence_lengths'],
test_processed['labels']
)
print(f"Train Dataset size: {len(train_dataset)}")
print(f"Test Dataset size: {len(test_dataset)}")
batch_size = 64
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True, collate_fn=collate_fn_hierarchical)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
pin_memory=True, collate_fn=collate_fn_hierarchical)
Train Dataset size: 10949 Test Dataset size: 7267
In [12]:
!wget http://nlp.stanford.edu/data/glove.6B.zip -P /content
!unzip -q glove.6B.zip
--2025-11-14 08:09:33-- http://nlp.stanford.edu/data/glove.6B.zip Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140 Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:80... connected. HTTP request sent, awaiting response... 302 Found Location: https://nlp.stanford.edu/data/glove.6B.zip [following] --2025-11-14 08:09:34-- https://nlp.stanford.edu/data/glove.6B.zip Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip [following] --2025-11-14 08:09:34-- https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22 Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 862182613 (822M) [application/zip] Saving to: ‘/content/glove.6B.zip’ glove.6B.zip 100%[===================>] 822.24M 4.95MB/s in 2m 55s 2025-11-14 08:12:30 (4.70 MB/s) - ‘/content/glove.6B.zip’ saved [862182613/862182613]
In [13]:
import numpy as np
from tqdm import tqdm
def load_glove_embeddings(glove_file, embedding_dim):
embeddings_index = {}
with open(glove_file, 'r', encoding='utf-8') as f:
for line in tqdm(f, desc="Loading GloVe"):
values = line.split()
word = values[0]
vector = np.asarray(values[1:], dtype='float32')
embeddings_index[word] = vector
print(f'Loaded {len(embeddings_index)} word vectors.')
return embeddings_index
glove_embeddings = load_glove_embeddings('/content/glove.6B.100d.txt', 100)
Loading GloVe: 400000it [00:08, 44774.83it/s]
Loaded 400000 word vectors.
In [14]:
embedding_size = 100
def create_embedding_matrix(vocab, glove_embeddings, embedding_size):
vocab_size = len(vocab)
embedding_matrix = np.zeros((vocab_size, embedding_size))
matched = 0
for word, idx in tqdm(vocab.items(), desc="Creating embedding matrix"):
if word in glove_embeddings:
embedding_matrix[idx] = glove_embeddings[word]
matched += 1
else:
embedding_matrix[idx] = np.random.normal(0, 0.1, embedding_size)
# Set PAD token to zeros
pad_id = vocab['[PAD]']
embedding_matrix[pad_id] = np.zeros(embedding_size)
print(f"Matched {matched}/{vocab_size} words with GloVe ({matched/vocab_size*100:.1f}%)")
return embedding_matrix
embedding_matrix = create_embedding_matrix(tokenizer.get_vocab(), glove_embeddings, embedding_size)
Creating embedding matrix: 100%|██████████| 50758/50758 [00:00<00:00, 300721.36it/s]
Matched 34519/50758 words with GloVe (68.0%)
In [ ]:
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
class Attention(nn.Module):
def __init__(self, hidden_size):
super(Attention, self).__init__()
self.attention = nn.Linear(hidden_size, hidden_size)
self.context_vector = nn.Linear(hidden_size, 1, bias=False)
def forward(self, hidden_state, mask):
u = torch.tanh(self.attention(hidden_state))
score = self.context_vector(u).squeeze(-1)
score.masked_fill_(mask == 0, -1e9) # Fill with a very small number for softmax
attention_weights = F.softmax(score, dim=-1)
context_vector = torch.bmm(
attention_weights.unsqueeze(1),
hidden_state
).squeeze(1)
return context_vector, attention_weights
class HAN(nn.Module):
def __init__(self, vocab_size, embedding_size, hidden_size, num_classes, dropout=0.5, pad_idx=0):
super(HAN, self).__init__()
self.pad_idx = pad_idx
self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx)
self.word_gru = nn.GRU(
embedding_size,
hidden_size,
bidirectional=True,
batch_first=True
)
self.word_attention = Attention(hidden_size * 2)
self.sentence_gru = nn.GRU(
hidden_size * 2, # Output of word attention is hidden_size * 2
hidden_size,
bidirectional=True,
batch_first=True
)
self.sentence_attention = Attention(hidden_size * 2)
self.fc = nn.Linear(hidden_size * 2, num_classes)
self.dropout = nn.Dropout(dropout)
def forward(self, documents, num_sentences, sentence_lengths):
batch_size, max_sentences, max_words = documents.size()
word_input = documents.view(batch_size * max_sentences, max_words)
flat_sentence_lengths = sentence_lengths.view(-1) # (batch_size * max_sentences)
# Filter out sentences that have length 0 (these are the padded sentences at the document level)
non_zero_lengths_mask = (flat_sentence_lengths > 0)
actual_sentences_input = word_input[non_zero_lengths_mask]
actual_sentence_lengths = flat_sentence_lengths[non_zero_lengths_mask]
embedded_words = self.embedding(actual_sentences_input) # (num_actual_sentences, max_words_in_batch, embedding_size)
# Pack padded sequence for word GRU
packed_embedded_words = rnn_utils.pack_padded_sequence(
embedded_words,
actual_sentence_lengths.cpu(),
batch_first=True,
enforce_sorted=False # Let pack_padded_sequence handle sorting
)
packed_word_gru_out, _ = self.word_gru(packed_embedded_words)
word_gru_out, _ = rnn_utils.pad_packed_sequence(packed_word_gru_out, batch_first=True, total_length=max_words)
# word_gru_out shape: (num_actual_sentences, max_words_in_batch, 2*hidden_size)
# Expand word_gru_out back to (batch_size * max_sentences, max_words, 2*hidden_size)
# for attention. Need to put zeros where the padded sentences were.
full_word_gru_out = torch.zeros(
batch_size * max_sentences, max_words, 2 * self.word_gru.hidden_size,
device=documents.device, dtype=word_gru_out.dtype
)
full_word_gru_out[non_zero_lengths_mask] = word_gru_out
# Create mask for word attention based on padding
word_attention_mask = (word_input != self.pad_idx) # (batch_size * max_sentences, max_words)
# Get sentence vectors using word attention
sentence_vectors, _ = self.word_attention(full_word_gru_out.contiguous(), word_attention_mask) # (batch_size * max_sentences, 2*hidden_size)
# --- Sentence-level GRU and Attention ---
# Reshape sentence_vectors to (batch_size, max_sentences, 2*hidden_size)
sentence_vectors_reshaped = sentence_vectors.view(batch_size, max_sentences, -1)
# num_sentences contains the actual number of sentences per document
non_zero_num_sentences_mask = (num_sentences > 0)
actual_num_sentences = num_sentences[non_zero_num_sentences_mask]
actual_sentence_vectors = sentence_vectors_reshaped[non_zero_num_sentences_mask]
if actual_sentence_vectors.size(0) == 0:
document_vectors = torch.zeros(batch_size, 2 * self.sentence_gru.hidden_size, device=documents.device)
output = self.fc(self.dropout(document_vectors))
return output, None, None
# Pack padded sequence for sentence GRU
packed_embedded_sentences = rnn_utils.pack_padded_sequence(
actual_sentence_vectors,
actual_num_sentences.cpu(),
batch_first=True,
enforce_sorted=False
)
packed_sentence_gru_out, _ = self.sentence_gru(packed_embedded_sentences)
sentence_gru_out, _ = rnn_utils.pad_packed_sequence(packed_sentence_gru_out, batch_first=True, total_length=max_sentences)
# sentence_gru_out shape: (num_actual_documents, max_sentences_in_batch, 2*hidden_size)
# Expand sentence_gru_out back to (batch_size, max_sentences, 2*hidden_size)
full_sentence_gru_out = torch.zeros(
batch_size, max_sentences, 2 * self.sentence_gru.hidden_size,
device=documents.device, dtype=sentence_gru_out.dtype
)
full_sentence_gru_out[non_zero_num_sentences_mask] = sentence_gru_out
# Create mask for sentence attention based on actual number of sentences per document
sentence_attention_mask = (torch.arange(max_sentences, device=documents.device).expand(batch_size, max_sentences) < num_sentences.unsqueeze(1))
# Get document vectors using sentence attention
document_vectors, sentence_att_weights = self.sentence_attention(full_sentence_gru_out, sentence_attention_mask) # (batch_size, 2*hidden_size)
output = self.fc(self.dropout(document_vectors))
return output, None, sentence_att_weights # Word attention weights are tricky to return in this setup.
# Instantiate and move model to device
vocab_size = tokenizer.get_vocab_size()
embedding_size = 100
hidden_size = 64
num_classes = len(set(train_labels_cleaned))
model = HAN(
vocab_size=vocab_size,
embedding_size=embedding_size,
hidden_size=hidden_size,
num_classes=num_classes,
dropout=0.5,
pad_idx=vocab['[PAD]']
)
model.embedding.weight.data.copy_(torch.FloatTensor(embedding_matrix))
model.to(device)
Out[ ]:
HAN(
(embedding): Embedding(50758, 100, padding_idx=0)
(word_gru): GRU(100, 64, batch_first=True, bidirectional=True)
(word_attention): Attention(
(attention): Linear(in_features=128, out_features=128, bias=True)
(context_vector): Linear(in_features=128, out_features=1, bias=False)
)
(sentence_gru): GRU(128, 64, batch_first=True, bidirectional=True)
(sentence_attention): Attention(
(attention): Linear(in_features=128, out_features=128, bias=True)
(context_vector): Linear(in_features=128, out_features=1, bias=False)
)
(fc): Linear(in_features=128, out_features=20, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
)
In [16]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import torch
import torch.nn as nn
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='max',
factor=0.5,
patience=2
)
def train_epoch(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
predictions = []
true_labels = []
for batch in tqdm(dataloader, desc="Training"):
input_ids = batch['input_ids'].to(device)
num_sentences = batch['num_sentences'].to(device)
sentence_lengths = batch['sentence_lengths'].to(device)
labels = batch['labels'].to(device)
optimizer.zero_grad()
outputs, _, _ = model(input_ids, num_sentences, sentence_lengths)
loss = criterion(outputs, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
total_loss += loss.item()
preds = outputs.argmax(dim=1).cpu().numpy()
predictions.extend(preds)
true_labels.extend(labels.cpu().numpy())
optimizer.step()
avg_loss = total_loss / len(dataloader)
accuracy = accuracy_score(true_labels, predictions)
return avg_loss, accuracy
def evaluate(model, dataloader, criterion, device):
model.eval()
total_loss = 0
predictions = []
true_labels = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
input_ids = batch['input_ids'].to(device)
num_sentences = batch['num_sentences'].to(device)
sentence_lengths = batch['sentence_lengths'].to(device)
labels = batch['labels'].to(device)
outputs, _, _ = model(input_ids, num_sentences, sentence_lengths)
loss = criterion(outputs, labels)
total_loss += loss.item()
preds = outputs.argmax(dim=1).cpu().numpy()
predictions.extend(preds)
true_labels.extend(labels.cpu().numpy())
avg_loss = total_loss / len(dataloader)
accuracy = accuracy_score(true_labels, predictions)
return avg_loss, accuracy, predictions, true_labels
# Training loop with early stopping
num_epochs = 50
best_accuracy = 0
patience = 3
patience_counter = 0
for epoch in range(num_epochs):
train_loss, train_acc = train_epoch(model, train_dataloader, criterion, optimizer, device)
val_loss, val_acc, _, _ = evaluate(model, test_dataloader, criterion, device)
scheduler.step(val_acc) # Moved scheduler.step() here with validation accuracy
print(f"Epoch {epoch+1}/{num_epochs}")
print(f" Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
print(f" Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
if val_acc > best_accuracy:
best_accuracy = val_acc
torch.save(model.state_dict(), 'best_han_bpe_model.pth')
patience_counter = 0
print(f" ✓ Saved best model (accuracy: {val_acc:.4f})")
else:
patience_counter += 1
print(f" No improvement for {patience_counter}/{patience} epochs")
if patience_counter >= patience:
print(f"\nEarly stopping triggered after {epoch+1} epochs!")
break
print()
print(f"✓ Training completed! Best validation accuracy: {best_accuracy:.4f}")
Training: 100%|██████████| 172/172 [00:07<00:00, 22.35it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 49.94it/s]
Epoch 1/50 Train Loss: 2.9844, Train Acc: 0.0695 Val Loss: 2.9632, Val Acc: 0.0936 ✓ Saved best model (accuracy: 0.0936)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.42it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.97it/s]
Epoch 2/50 Train Loss: 2.8932, Train Acc: 0.1460 Val Loss: 2.7265, Val Acc: 0.2071 ✓ Saved best model (accuracy: 0.2071)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.75it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.60it/s]
Epoch 3/50 Train Loss: 2.5067, Train Acc: 0.2148 Val Loss: 2.2809, Val Acc: 0.2952 ✓ Saved best model (accuracy: 0.2952)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.36it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.46it/s]
Epoch 4/50 Train Loss: 2.1404, Train Acc: 0.3105 Val Loss: 2.0013, Val Acc: 0.3916 ✓ Saved best model (accuracy: 0.3916)
Training: 100%|██████████| 172/172 [00:05<00:00, 28.71it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.89it/s]
Epoch 5/50 Train Loss: 1.9193, Train Acc: 0.3879 Val Loss: 1.8398, Val Acc: 0.4597 ✓ Saved best model (accuracy: 0.4597)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.84it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.78it/s]
Epoch 6/50 Train Loss: 1.7616, Train Acc: 0.4479 Val Loss: 1.7218, Val Acc: 0.4965 ✓ Saved best model (accuracy: 0.4965)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.36it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.49it/s]
Epoch 7/50 Train Loss: 1.6300, Train Acc: 0.4893 Val Loss: 1.6252, Val Acc: 0.5155 ✓ Saved best model (accuracy: 0.5155)
Training: 100%|██████████| 172/172 [00:05<00:00, 30.02it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 53.31it/s]
Epoch 8/50 Train Loss: 1.5276, Train Acc: 0.5230 Val Loss: 1.5531, Val Acc: 0.5320 ✓ Saved best model (accuracy: 0.5320)
Training: 100%|██████████| 172/172 [00:05<00:00, 30.33it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.89it/s]
Epoch 9/50 Train Loss: 1.4372, Train Acc: 0.5566 Val Loss: 1.4931, Val Acc: 0.5425 ✓ Saved best model (accuracy: 0.5425)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.34it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 53.23it/s]
Epoch 10/50 Train Loss: 1.3593, Train Acc: 0.5723 Val Loss: 1.4533, Val Acc: 0.5496 ✓ Saved best model (accuracy: 0.5496)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.17it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 53.87it/s]
Epoch 11/50 Train Loss: 1.2891, Train Acc: 0.5997 Val Loss: 1.4162, Val Acc: 0.5588 ✓ Saved best model (accuracy: 0.5588)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.85it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.87it/s]
Epoch 12/50 Train Loss: 1.2429, Train Acc: 0.6148 Val Loss: 1.3898, Val Acc: 0.5656 ✓ Saved best model (accuracy: 0.5656)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.02it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.43it/s]
Epoch 13/50 Train Loss: 1.1878, Train Acc: 0.6261 Val Loss: 1.3713, Val Acc: 0.5722 ✓ Saved best model (accuracy: 0.5722)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.05it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.93it/s]
Epoch 14/50 Train Loss: 1.1397, Train Acc: 0.6439 Val Loss: 1.3511, Val Acc: 0.5797 ✓ Saved best model (accuracy: 0.5797)
Training: 100%|██████████| 172/172 [00:05<00:00, 30.01it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.53it/s]
Epoch 15/50 Train Loss: 1.0991, Train Acc: 0.6582 Val Loss: 1.3357, Val Acc: 0.5835 ✓ Saved best model (accuracy: 0.5835)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.33it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 53.00it/s]
Epoch 16/50 Train Loss: 1.0537, Train Acc: 0.6652 Val Loss: 1.3275, Val Acc: 0.5916 ✓ Saved best model (accuracy: 0.5916)
Training: 100%|██████████| 172/172 [00:06<00:00, 28.42it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 46.00it/s]
Epoch 17/50 Train Loss: 1.0160, Train Acc: 0.6830 Val Loss: 1.3234, Val Acc: 0.5923 ✓ Saved best model (accuracy: 0.5923)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.45it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.11it/s]
Epoch 18/50 Train Loss: 0.9807, Train Acc: 0.6928 Val Loss: 1.3156, Val Acc: 0.5945 ✓ Saved best model (accuracy: 0.5945)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.65it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.56it/s]
Epoch 19/50 Train Loss: 0.9505, Train Acc: 0.7032 Val Loss: 1.3260, Val Acc: 0.5979 ✓ Saved best model (accuracy: 0.5979)
Training: 100%|██████████| 172/172 [00:05<00:00, 28.87it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.18it/s]
Epoch 20/50 Train Loss: 0.9164, Train Acc: 0.7181 Val Loss: 1.3103, Val Acc: 0.6024 ✓ Saved best model (accuracy: 0.6024)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.11it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.74it/s]
Epoch 21/50 Train Loss: 0.8766, Train Acc: 0.7297 Val Loss: 1.3111, Val Acc: 0.6066 ✓ Saved best model (accuracy: 0.6066)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.27it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 53.43it/s]
Epoch 22/50 Train Loss: 0.8431, Train Acc: 0.7419 Val Loss: 1.3191, Val Acc: 0.6037 No improvement for 1/3 epochs
Training: 100%|██████████| 172/172 [00:05<00:00, 28.97it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.71it/s]
Epoch 23/50 Train Loss: 0.8196, Train Acc: 0.7502 Val Loss: 1.3183, Val Acc: 0.6077 ✓ Saved best model (accuracy: 0.6077)
Training: 100%|██████████| 172/172 [00:05<00:00, 28.75it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.61it/s]
Epoch 24/50 Train Loss: 0.7874, Train Acc: 0.7628 Val Loss: 1.3199, Val Acc: 0.6070 No improvement for 1/3 epochs
Training: 100%|██████████| 172/172 [00:05<00:00, 29.67it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.84it/s]
Epoch 25/50 Train Loss: 0.7555, Train Acc: 0.7735 Val Loss: 1.3216, Val Acc: 0.6113 ✓ Saved best model (accuracy: 0.6113)
Training: 100%|██████████| 172/172 [00:05<00:00, 29.86it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 51.14it/s]
Epoch 26/50 Train Loss: 0.7309, Train Acc: 0.7824 Val Loss: 1.3284, Val Acc: 0.6095 No improvement for 1/3 epochs
Training: 100%|██████████| 172/172 [00:05<00:00, 29.73it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 52.87it/s]
Epoch 27/50 Train Loss: 0.7022, Train Acc: 0.7902 Val Loss: 1.3334, Val Acc: 0.6108 No improvement for 2/3 epochs
Training: 100%|██████████| 172/172 [00:05<00:00, 29.25it/s] Evaluating: 100%|██████████| 114/114 [00:02<00:00, 50.73it/s]
Epoch 28/50 Train Loss: 0.6787, Train Acc: 0.7996 Val Loss: 1.3473, Val Acc: 0.6108 No improvement for 3/3 epochs Early stopping triggered after 28 epochs! ✓ Training completed! Best validation accuracy: 0.6113