In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import pandas as pd

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
Using device: cuda
In [2]:
!wget https://www.manythings.org/anki/cmn-eng.zip -P /content
!unzip cmn-eng.zip
--2025-11-07 06:52:27--  https://www.manythings.org/anki/cmn-eng.zip
Resolving www.manythings.org (www.manythings.org)... 173.254.30.110
Connecting to www.manythings.org (www.manythings.org)|173.254.30.110|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1337935 (1.3M) [application/zip]
Saving to: ‘/content/cmn-eng.zip’

cmn-eng.zip         100%[===================>]   1.28M  --.-KB/s    in 0.1s    

2025-11-07 06:52:28 (13.0 MB/s) - ‘/content/cmn-eng.zip’ saved [1337935/1337935]

Archive:  cmn-eng.zip
  inflating: cmn.txt                 
  inflating: _about.txt              
In [3]:
train_data_pd = pd.read_csv('/content/cmn.txt', sep='\t', header=None, usecols=[0, 1])
train_data_pd.columns = ['english', 'chinese']
train_data = list(zip(train_data_pd['english'], train_data_pd['chinese']))

print(f"Total train_data: {len(train_data)}")
print(f"First 10 data: {train_data[:10]}")
Total train_data: 30979
First 10 data: [('Hi.', '嗨。'), ('Hi.', '你好。'), ('Run.', '你用跑的。'), ('Stay.', '待著。'), ('Stay.', '且慢。'), ('Stop!', '住手!'), ('Wait!', '等等!'), ('Wait!', '等一下!'), ('Begin.', '开始!'), ('Fight.', '開打。')]
In [4]:
!pip install zhconv
Collecting zhconv
  Downloading zhconv-1.4.3.tar.gz (211 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/211.6 kB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.6/211.6 kB 8.7 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Building wheels for collected packages: zhconv
  Building wheel for zhconv (setup.py) ... done
  Created wheel for zhconv: filename=zhconv-1.4.3-py2.py3-none-any.whl size=208852 sha256=081f6b61541c6d343a95cbe5926522f05aacfd6fab9ba38dac8fd8dfd9bd024b
  Stored in directory: /root/.cache/pip/wheels/61/90/d7/3604f0bf1943607b954e1de11c9ffd6911ef844b81ce9e5320
Successfully built zhconv
Installing collected packages: zhconv
Successfully installed zhconv-1.4.3
In [5]:
import re
import zhconv

def normalize_english(text):
  text = text.lower()
  text = re.sub(r"([.!?])", r" \1", text)
  text = re.sub(r"\s+", " ", text)
  text = re.sub(r"[^a-zA-Z.!?]+", " ", text)
  return text.strip()

def normalize_chinese(text):
  text = zhconv.convert(text, 'zh-cn')
  text = text.replace(',', ',').replace('。', '.').replace('!', '!')
  text = text.replace('?', '?').replace(':', ':')
  return text.strip()

cleaned_train_data = []

for en, cn in train_data:
  en = normalize_english(en)
  cn = normalize_chinese(cn)
  if en and cn:
    cleaned_train_data.append((en, cn))


print(f"Cleaned data: {cleaned_train_data[:10]}")
Cleaned data: [('hi .', '嗨.'), ('hi .', '你好.'), ('run .', '你用跑的.'), ('stay .', '待着.'), ('stay .', '且慢.'), ('stop !', '住手!'), ('wait !', '等等!'), ('wait !', '等一下!'), ('begin .', '开始!'), ('fight .', '开打.')]
In [6]:
import jieba

def tokenize_english(text):
  return text.split()

def tokenize_chinese(text):
  return list(jieba.cut(text))

tokenized_data = []
for en, cn in cleaned_train_data:
  en_tokens = tokenize_english(en)
  cn_tokens = tokenize_chinese(cn)
  tokenized_data.append((en_tokens, cn_tokens))

print(f"Tokenized data: {tokenized_data[20000:20030]}")
/usr/local/lib/python3.12/dist-packages/jieba/__init__.py:44: SyntaxWarning: invalid escape sequence '\.'
  re_han_default = re.compile("([\u4E00-\u9FD5a-zA-Z0-9+#&\._%\-]+)", re.U)
/usr/local/lib/python3.12/dist-packages/jieba/__init__.py:46: SyntaxWarning: invalid escape sequence '\s'
  re_skip_default = re.compile("(\r\n|\s)", re.U)
/usr/local/lib/python3.12/dist-packages/jieba/finalseg/__init__.py:78: SyntaxWarning: invalid escape sequence '\.'
  re_skip = re.compile("([a-zA-Z0-9]+(?:\.\d+)?%?)")
Building prefix dict from the default dictionary ...
DEBUG:jieba:Building prefix dict from the default dictionary ...
Dumping model to file cache /tmp/jieba.cache
DEBUG:jieba:Dumping model to file cache /tmp/jieba.cache
Loading model cost 0.663 seconds.
DEBUG:jieba:Loading model cost 0.663 seconds.
Prefix dict has been built successfully.
DEBUG:jieba:Prefix dict has been built successfully.
Tokenized data: [(['i', 'm', 'just', 'worried', 'about', 'my', 'weight', '.'], ['我', '只是', '担心', '我', '的', '体重', '.']), (['i', 'm', 'looking', 'for', 'a', 'small', 'suitcase', '.'], ['我', '正在', '找', '一个', '小', '手提箱', '.']), (['i', 'm', 'never', 'coming', 'back', 'here', 'again', '.'], ['我', '不会', '再', '回来', '了', '.']), (['i', 'm', 'not', 'as', 'optimistic', 'as', 'you', 'are', '.'], ['我', '不', '像', '你', '那么', '乐观', '.']), (['i', 'm', 'not', 'as', 'optimistic', 'as', 'you', 'are', '.'], ['我', '没有', '你', '那么', '乐观', '.']), (['i', 'm', 'not', 'picky', '.', 'i', 'll', 'eat', 'anything', '.'], ['我', '不', '挑剔', ',', '我', '什么', '都', '吃', '.']), (['i', 'm', 'not', 'sure', 'what', 'i', 'was', 'thinking', '.'], ['我', '不', '确定', '当时', '我', '正在', '想', '什么', '.']), (['i', 'm', 'old', 'enough', 'to', 'live', 'by', 'myself', '.'], ['我', '年纪', '够', '大', '了', '可以', '自己', '一个', '人住', '.']), (['i', 'm', 'old', 'enough', 'to', 'support', 'myself', '.'], ['我', '年纪', '够', '大', '可以', '养活', '我', '自己', '.']), (['i', 'm', 'quite', 'satisfied', 'with', 'my', 'life', '.'], ['我', '对', '我', '的', '人生', '很', '满意', '.']), (['i', 'm', 'reading', 'a', 'book', 'about', 'animals', '.'], ['我', '正在', '读', '一本', '关于', '动物', '的', '书', '.']), (['i', 'm', 'ready', 'to', 'do', 'anything', 'for', 'you', '.'], ['我', '甘心', '为', '你', '做', '任何', '事', '.']), (['i', 'm', 'really', 'looking', 'forward', 'to', 'it', '.'], ['我', '很', '期待', '哦', '.']), (['i', 'm', 'sick', 'and', 'tired', 'of', 'hamburgers', '.'], ['我', '对', '汉堡', '感到', '厌烦', '了', '.']), (['i', 'm', 'sick', 'and', 'tired', 'of', 'hamburgers', '.'], ['我', '吃腻', '了', '汉堡', '.']), (['i', 'm', 'sick', 'and', 'tired', 'of', 'hamburgers', '.'], ['汉堡', '我', '都', '吃腻', '了', '.']), (['i', 'm', 'sorry', 'to', 'bother', 'you', 'so', 'often', '.'], ['一直', '打扰', '你', '不好意思', '.']), (['i', 'm', 'sure', 'i', 'won', 't', 'be', 'of', 'much', 'help', '.'], ['我', '肯定', '帮不上', '什么', '忙', '.']), (['i', 'm', 'sure', 'tom', 'came', 'here', 'yesterday', '.'], ['我', '很', '肯定', '汤姆', '昨天', '来过', '这里', '.']), (['i', 'm', 'sure', 'that', 'he', 'll', 'come', 'on', 'time', '.'], ['我', '确定', '他会', '准时', '来', '.']), (['i', 'm', 'the', 'only', 'one', 'who', 'can', 'do', 'that', '.'], ['我', '是', '唯一', '能', '做到', '那个', '的', '人', '.']), (['i', 'm', 'the', 'tallest', 'one', 'in', 'the', 'class', '.'], ['我', '在', '班里', '是', '最高', '的', '.']), (['i', 'm', 'three', 'years', 'younger', 'than', 'you', '.'], ['我', '比', '你', '小', '三岁', '.']), (['i', 'm', 'tired', 'of', 'watching', 'television', '.'], ['我', '厌倦', '了', '看电视', '.']), (['i', 'm', 'tired', 'of', 'watching', 'television', '.'], ['我', '看电视', '看到', '厌烦', '了', '.']), (['i', 'm', 'too', 'sleepy', 'to', 'do', 'my', 'homework', '.'], ['我太累', '了', ',', '做不了', '功课', '.']), (['i', 'm', 'very', 'interested', 'in', 'languages', '.'], ['我', '对', '语言', '很感兴趣', '.']), (['i', 'm', 'very', 'worried', 'about', 'my', 'weight', '.'], ['我', '很', '担心', '我', '的', '体重', '.']), (['i', 've', 'been', 'here', 'many', 'times', 'before', '.'], ['我', '以前', '来', '过', '很', '多次', '了', '.']), (['i', 've', 'changed', 'my', 'website', 's', 'layout', '.'], ['我', '改', '了', '一下', '我', '网站', '的', '版面设计', '.'])]
In [7]:
def prepare_lang_tagged_data(tokenized_data, max_len=30):
  """
  [(en_text, cn_text), ...] =>
  [
      (['<en>', 'hello'], ['<cn>', '你好']),  # 英→中
      (['<cn>', '你好'], ['<en>', 'hello']),  # 中→英
  ]
  """

  processed_data = []
  for en_tokens, cn_tokens in tokenized_data:
    # Leave space for eos lang tag
    if len(en_tokens) > max_len - 3 or len(cn_tokens) > max_len - 3:
      continue

    src_en = ['<en>'] + en_tokens
    tgt_cn = ['<cn>'] + cn_tokens
    processed_data.append((src_en, tgt_cn))

    src_cn = ['<cn>'] + cn_tokens
    tgt_en = ['<en>'] + en_tokens
    processed_data.append((src_cn, tgt_en))

  return processed_data

lang_tagged_data = prepare_lang_tagged_data(tokenized_data, 30)
print(f"Lang tagged data: {lang_tagged_data[:10]}")
Lang tagged data: [(['<en>', 'hi', '.'], ['<cn>', '嗨', '.']), (['<cn>', '嗨', '.'], ['<en>', 'hi', '.']), (['<en>', 'hi', '.'], ['<cn>', '你好', '.']), (['<cn>', '你好', '.'], ['<en>', 'hi', '.']), (['<en>', 'run', '.'], ['<cn>', '你', '用', '跑', '的', '.']), (['<cn>', '你', '用', '跑', '的', '.'], ['<en>', 'run', '.']), (['<en>', 'stay', '.'], ['<cn>', '待', '着', '.']), (['<cn>', '待', '着', '.'], ['<en>', 'stay', '.']), (['<en>', 'stay', '.'], ['<cn>', '且慢', '.']), (['<cn>', '且慢', '.'], ['<en>', 'stay', '.'])]
In [ ]:
from collections import Counter

class Vocabulary:
  def __init__(self, lang_tags = ['en', 'cn']):
    self.word2idx = {
        '<pad>': 0,
        '<sos>': 1,
        '<eos>': 2,
        '<unk>': 3
    }
    self.special_tokens = ['<pad>', '<sos>', '<eos>', '<unk>'] # Initialize with common special tokens

    for lang in lang_tags:
      tag = '<' + lang + '>'
      self.word2idx[tag] = len(self.word2idx)
      self.special_tokens.append(tag) # Add language tags to special tokens

    self.idx2word = {i: w for w, i in self.word2idx.items()}
    self.word_count = Counter()
    self.n_words = len(self.word2idx)

  def add_sentence(self, tokens):
    for word in tokens:
      if word not in self.word2idx:
        self.word_count[word] += 1

  def build(self, min_count=2):
    for word, count in self.word_count.items():
      if count > min_count:
        self.word2idx[word] = self.n_words
        self.idx2word[self.n_words] = word
        self.n_words += 1

  def encode(self, tokens, add_sos=True, add_eos=True):
    indices = []
    if add_sos:
      indices.append(self.word2idx['<sos>'])

    for word in tokens:
      idx = self.word2idx.get(word, self.word2idx['<unk>'])
      indices.append(idx)

    if add_eos:
      indices.append(self.word2idx['<eos>'])

    return indices

  def decode(self, indices, skip_special=True):
    words = []
    for idx in indices:
      word = self.idx2word.get(idx, '<unk>')
      if skip_special and word in self.special_tokens:
        continue
      words.append(word)

    return words

vocab = Vocabulary()
for src_tokens, tgt_tokens in lang_tagged_data:
  vocab.add_sentence(src_tokens)
  vocab.add_sentence(tgt_tokens)

vocab.build(min_count=1)

indexed_pairs = []
for src_tokens, tgt_tokens in lang_tagged_data:
  src_indices = vocab.encode(src_tokens)
  tgt_indices = vocab.encode(tgt_tokens)
  indexed_pairs.append((src_indices, tgt_indices))

print(f"Indexed pairs: {indexed_pairs[:10]}")
Indexed pairs: [([1, 4, 6, 7, 2], [1, 5, 8, 7, 2]), ([1, 5, 8, 7, 2], [1, 4, 6, 7, 2]), ([1, 4, 6, 7, 2], [1, 5, 9, 7, 2]), ([1, 5, 9, 7, 2], [1, 4, 6, 7, 2]), ([1, 4, 10, 7, 2], [1, 5, 11, 12, 13, 14, 7, 2]), ([1, 5, 11, 12, 13, 14, 7, 2], [1, 4, 10, 7, 2]), ([1, 4, 15, 7, 2], [1, 5, 16, 17, 7, 2]), ([1, 5, 16, 17, 7, 2], [1, 4, 15, 7, 2]), ([1, 4, 15, 7, 2], [1, 5, 18, 7, 2]), ([1, 5, 18, 7, 2], [1, 4, 15, 7, 2])]
In [9]:
from torch.nn.utils.rnn import pad_sequence

class TranslationDataset(Dataset):
  def __init__(self, indexed_pairs):
    self.pairs = indexed_pairs

  def __len__(self):
    return len(self.pairs)

  def __getitem__(self, idx):
    return self.pairs[idx]

def collate_batch(batch):
  src_batch = [torch.LongTensor(pair[0]) for pair in batch]
  tgt_batch = [torch.LongTensor(pair[1]) for pair in batch]

  src_lengths = torch.LongTensor([len(s) for s in src_batch])
  tgt_lengths = torch.LongTensor([len(t) for t in tgt_batch])

  src_padded = pad_sequence(src_batch, batch_first=True, padding_value=0)
  tgt_padded = pad_sequence(tgt_batch, batch_first=True, padding_value=0)

  return src_padded, tgt_padded, src_lengths, tgt_lengths

batch_size = 128

train_size = int(0.9 * len(indexed_pairs))
train_data = indexed_pairs[:train_size]
val_data = indexed_pairs[train_size:]

train_dataset = TranslationDataset(train_data)
val_dataset = TranslationDataset(val_data)
train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_dataset, batch_size, shuffle=False, collate_fn=collate_batch)
In [ ]:
import random

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F

class Encoder(nn.Module):
  def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=2, dropout=0.3):
    super(Encoder, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
    self.dropout = nn.Dropout(dropout)

    self.lstm = nn.LSTM(
      embed_dim,
      hidden_dim,
      num_layers,
      batch_first=True,
      dropout=dropout if num_layers > 1 else 0,
      bidirectional=True
    )

    self.fc_hidden = nn.Linear(hidden_dim * 2, hidden_dim)
    self.fc_cell = nn.Linear(hidden_dim * 2, hidden_dim)

  def forward(self, src, src_lengths):
    embedded = self.dropout(self.embedding(src))

    packed_embedded = pack_padded_sequence(
      embedded, src_lengths.cpu(), batch_first=True, enforce_sorted=False
    )

    packed_outputs, (hidden, cell) = self.lstm(packed_embedded)

    outputs, _ = pad_packed_sequence(packed_outputs, batch_first=True)
    hidden = self._combine_bidirectional(hidden)
    cell = self._combine_bidirectional(cell)

    return outputs, hidden, cell

  def _combine_bidirectional(self, state):
    num_layers = state.shape[0] // 2
    batch_size = state.shape[1]
    hidden_dim = state.shape[2]

    state = state.reshape(num_layers, 2, batch_size, hidden_dim)

    state = torch.cat([state[:, 0, :, :], state[:, 1, :, :]], dim=2)

    state = torch.tanh(self.fc_hidden(state))

    return state

class BahdanauAttention(nn.Module):
  def __init__(self, hidden_dim, encoder_dim):
    super(BahdanauAttention, self).__init__()
    self.hidden_dim = hidden_dim
    self.encoder_dim = encoder_dim

    # Attention layers
    self.attn_hidden = nn.Linear(hidden_dim, hidden_dim)
    self.attn_encoder = nn.Linear(encoder_dim, hidden_dim)
    self.attn_combine = nn.Linear(hidden_dim, 1, bias=False)

  def forward(self, hidden, encoder_outputs, mask=None):
    hidden_proj = self.attn_hidden(hidden).unsqueeze(1)
    encoder_proj = self.attn_encoder(encoder_outputs)
    energy = torch.tanh(hidden_proj + encoder_proj)
    attention_scores = self.attn_combine(energy).squeeze(2)
    attention_scores = attention_scores.masked_fill(mask, -1e10)

    attention_weights = F.softmax(attention_scores, dim=1)
    context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)

    return context, attention_weights

class Decoder(nn.Module):
  def __init__(self, vocab_size, embed_dim, hidden_dim, encoder_dim, num_layers=2, dropout=0.3):
    super(Decoder, self).__init__()
    self.vocab_size = vocab_size
    self.hidden_dim = hidden_dim

    self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
    self.dropout = nn.Dropout(dropout)

    self.attention = BahdanauAttention(hidden_dim, encoder_dim)

    self.lstm = nn.LSTM(
      embed_dim + encoder_dim,
      hidden_dim,
      num_layers,
      batch_first=True,
      dropout=dropout if num_layers > 1 else 0
    )

    self.fc = nn.Linear(hidden_dim + encoder_dim + embed_dim, vocab_size)

  def forward(self, tgt, hidden, cell, encoder_outputs, src_mask=None):
    embedded = self.dropout(self.embedding(tgt))  # (batch_size, 1, embed_dim)

    context, attention_weights = self.attention(
      hidden[-1], encoder_outputs, src_mask
    )

    lstm_input = torch.cat([embedded, context.unsqueeze(1)], dim=2)

    output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))

    prediction_input = torch.cat([
      output.squeeze(1),      # (batch_size, hidden_dim)
      context,                # (batch_size, encoder_dim)
      embedded.squeeze(1)     # (batch_size, embed_dim)
    ], dim=1)

    prediction = self.fc(prediction_input)

    return prediction, hidden, cell, attention_weights

class Seq2Seq(nn.Module):
  def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout):
    super().__init__()
    encoder_dim = hidden_dim * 2  # Bidirectional encoder

    self.encoder = Encoder(vocab_size, embed_dim, hidden_dim, num_layers, dropout)
    self.decoder = Decoder(vocab_size, embed_dim, hidden_dim, encoder_dim, num_layers, dropout)
    self.vocab_size = vocab_size

  def create_mask(self, src, pad_idx):
    return (src == pad_idx)

  def forward(self, src, src_lengths, tgt, teacher_forcing_ratio=0.5):
    batch_size = src.shape[0]
    tgt_len = tgt.shape[1]

    encoder_outputs, hidden, cell = self.encoder(src, src_lengths)
    src_mask = self.create_mask(src, pad_idx=0)
    outputs = torch.zeros(batch_size, tgt_len - 1, self.vocab_size).to(src.device)

    # First decoder input is <sos>
    decoder_input = tgt[:, 0].unsqueeze(1)

    # Decode step by step
    for t in range(1, tgt_len):
      output, hidden, cell, _ = self.decoder(
        decoder_input, hidden, cell, encoder_outputs, src_mask
      )

      outputs[:, t - 1] = output

      # Teacher forcing
      use_teacher_forcing = random.random() < teacher_forcing_ratio
      if use_teacher_forcing:
        decoder_input = tgt[:, t].unsqueeze(1)
      else:
        decoder_input = output.argmax(1).unsqueeze(1)

    return outputs

  def inference(self, src, src_lengths, sos_idx, eos_idx, max_len, device, pad_idx=0):
    self.eval()
    batch_size = src.shape[0]

    with torch.no_grad():
      # Encode
      encoder_outputs, hidden, cell = self.encoder(src, src_lengths)

      # Create source mask
      src_mask = self.create_mask(src, pad_idx)

      # Start with <sos>
      decoder_input = torch.full((batch_size, 1), sos_idx, dtype=torch.long, device=device)

      generated_tokens = []

      for _ in range(max_len):
        output, hidden, cell, _ = self.decoder(
          decoder_input, hidden, cell, encoder_outputs, src_mask
        )

        predicted_token = output.argmax(1)

        if batch_size == 1 and predicted_token.item() == eos_idx:
          break

        generated_tokens.append(predicted_token.item() if batch_size == 1 else predicted_token)
        decoder_input = predicted_token.unsqueeze(1)

    return generated_tokens
In [11]:
!pip install nltk
Requirement already satisfied: nltk in /usr/local/lib/python3.12/dist-packages (3.9.1)
Requirement already satisfied: click in /usr/local/lib/python3.12/dist-packages (from nltk) (8.3.0)
Requirement already satisfied: joblib in /usr/local/lib/python3.12/dist-packages (from nltk) (1.5.2)
Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.12/dist-packages (from nltk) (2024.11.6)
Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from nltk) (4.67.1)
In [12]:
import nltk
nltk.download('punkt')
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
Out[12]:
True
In [ ]:
class TeacherForcingScheduler:
  def __init__(self, initial_ratio=1.0, final_ratio=0.5, decay_epochs=10):
    self.initial_ratio = initial_ratio
    self.final_ratio = final_ratio
    self.decay_epochs = decay_epochs

  def get_ratio(self, epoch):
    if epoch >= self.decay_epochs:
      return self.final_ratio

    # Linear decay
    ratio = self.initial_ratio - (self.initial_ratio - self.final_ratio) * (epoch / self.decay_epochs)
    return ratio
In [ ]:
from tqdm import tqdm
import math

def train_epoch(model, dataloader, optimizer, criterion, device, teacher_forcing_ratio=0.5, clip=1.0):
  model.train()
  epoch_loss = 0

  for src, tgt, src_lengths, _ in tqdm(dataloader, desc="Training", leave=False):
    src, tgt, src_lengths = src.to(device), tgt.to(device), src_lengths.to(device)

    optimizer.zero_grad()

    output = model(src, src_lengths, tgt, teacher_forcing_ratio=teacher_forcing_ratio)

    # Reshape for loss calculation
    output_dim = output.shape[-1]
    output = output.reshape(-1, output_dim)
    target = tgt[:, 1:].reshape(-1)

    loss = criterion(output, target)
    loss.backward()

    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

    optimizer.step()
    epoch_loss += loss.item()

  return epoch_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device):
  model.eval()
  epoch_loss = 0

  with torch.no_grad():
    for src, tgt, src_lengths, _ in tqdm(dataloader, desc="Evaluating", leave=False):
      src, tgt, src_lengths = src.to(device), tgt.to(device), src_lengths.to(device)

      # Use teacher forcing for evaluation 
      output = model(src, src_lengths, tgt, teacher_forcing_ratio=1.0)

      output_dim = output.shape[-1]
      output_flat = output.reshape(-1, output_dim)
      target_flat = tgt[:, 1:].reshape(-1)

      loss = criterion(output_flat, target_flat)
      epoch_loss += loss.item()

  return epoch_loss / len(dataloader)

def calculate_bleu(model, test_pairs, vocab, device, max_samples=500):
  model.eval()
  all_targets = []
  all_predictions = []

  # Use only a subset for BLEU to save time
  test_pairs = test_pairs[:max_samples]

  for src_indices, tgt_indices in tqdm(test_pairs, desc="Calculating BLEU", leave=False):
    src = torch.LongTensor(src_indices).unsqueeze(0).to(device)
    src_lengths = torch.LongTensor([len(src_indices)]).to(device)

    predicted_indices = model.inference(
      src, src_lengths,
      vocab.word2idx['<sos>'],
      vocab.word2idx['<eos>'],
      max_len=50,
      device=device,
      pad_idx=vocab.word2idx['<pad>']
    )

    predicted_tokens = vocab.decode(predicted_indices, skip_special=True)
    target_tokens = vocab.decode(tgt_indices, skip_special=True)

    if predicted_tokens:  # Only add non-empty predictions
      all_targets.append([target_tokens])
      all_predictions.append(predicted_tokens)

  bleu_score = nltk.translate.bleu_score.corpus_bleu(
    all_targets, all_predictions,
    smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method1
  )
  return bleu_score
In [ ]:
def translate(model, sentence, vocab, device, max_len=50, direction='en2cn'):
  model.eval()

  if direction == 'en2cn':
    tokens = ['<en>'] + tokenize_english(normalize_english(sentence))
    join_char = ''
  else:  # cn2en
    tokens = ['<cn>'] + tokenize_chinese(normalize_chinese(sentence))
    join_char = ' '

  # Encode input
  indices = vocab.encode(tokens)
  src = torch.LongTensor(indices).unsqueeze(0).to(device)
  src_lengths = torch.LongTensor([len(indices)]).to(device)

  # Generate translation
  output_indices = model.inference(
    src, src_lengths,
    vocab.word2idx['<sos>'],
    vocab.word2idx['<eos>'],
    max_len=max_len,
    device=device,
    pad_idx=vocab.word2idx['<pad>']
  )

  # Decode output
  output_tokens = vocab.decode(output_indices, skip_special=True)

  return join_char.join(output_tokens)
In [ ]:
vocab_size = vocab.n_words
embed_dim = 256
hidden_dim = 256 
num_layers = 2
dropout = 0.5

model = Seq2Seq(vocab_size, embed_dim, hidden_dim, num_layers, dropout).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx['<pad>'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  optimizer, mode='min', factor=0.5, patience=2
)
tf_scheduler = TeacherForcingScheduler(initial_ratio=1.0, final_ratio=0.5, decay_epochs=10)

num_epochs = 30
best_val_loss = float('inf')
patience = 7
patience_counter = 0

train_losses = []
val_losses = []

print(f"\nModel Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\n")

for epoch in range(num_epochs):
  tf_ratio = tf_scheduler.get_ratio(epoch)
  train_loss = train_epoch(
    model, train_loader, optimizer, criterion, device,
    teacher_forcing_ratio=tf_ratio, clip=1.0
  )

  val_loss = evaluate(model, val_loader, criterion, device)
  scheduler.step(val_loss)
  train_losses.append(train_loss)
  val_losses.append(val_loss)

  bleu_score = 0.0
  if (epoch + 1) % 5 == 0:
    bleu_score = calculate_bleu(model, val_data, vocab, device, max_samples=500)

  current_lr = optimizer.param_groups[0]['lr']
  print(f"Epoch {epoch+1}/{num_epochs}")
  print(f"  Train Loss: {train_loss:.4f} | Train PPL: {math.exp(train_loss):.2f}")
  print(f"  Val Loss:   {val_loss:.4f} | Val PPL:   {math.exp(val_loss):.2f}", end="")
  if bleu_score > 0:
    print(f" | BLEU: {bleu_score:.4f}")
  else:
    print()
  print(f"  TF Ratio: {tf_ratio:.3f} | LR: {current_lr:.6f}")

  # Early stopping
  if val_loss < best_val_loss:
    best_val_loss = val_loss
    patience_counter = 0
    torch.save(model.state_dict(), 'best_model.pt')
    print("  ✓ New best model saved!")
  else:
    patience_counter += 1
    print(f"  ⏳ No improvement ({patience_counter}/{patience})")

  if patience_counter >= patience:
    print(f"\n⚠️ Early stopping triggered after {epoch+1} epochs")
    break

  print()

model.load_state_dict(torch.load('best_model.pt'))

print("\n" + "="*50)
print("TRANSLATION EXAMPLES")
print("="*50)

test_sentences_en = [
  "Hello",
  "How are you?",
  "I love you",
  "Good morning",
  "Thank you very much"
]

test_sentences_cn = [
  "你好",
  "谢谢",
  "我爱你",
  "早上好",
  "再见"
]

print("\nEnglish → Chinese:")
for sent in test_sentences_en:
  translation = translate(model, sent, vocab, device, direction='en2cn')
  print(f"  {sent:30s} → {translation}")

print("\nChinese → English:")
for sent in test_sentences_cn:
  translation = translate(model, sent, vocab, device, direction='cn2en')
  print(f"  {sent:30s} → {translation}")
Model Parameters: 35,239,091
Trainable Parameters: 35,239,091


Epoch 1/30
  Train Loss: 3.8726 | Train PPL: 48.07
  Val Loss:   4.5393 | Val PPL:   93.63
  TF Ratio: 1.000 | LR: 0.001000
  ✓ New best model saved!


Epoch 2/30
  Train Loss: 2.9843 | Train PPL: 19.77
  Val Loss:   4.2783 | Val PPL:   72.12
  TF Ratio: 0.950 | LR: 0.001000
  ✓ New best model saved!


Epoch 3/30
  Train Loss: 2.5886 | Train PPL: 13.31
  Val Loss:   4.0913 | Val PPL:   59.82
  TF Ratio: 0.900 | LR: 0.001000
  ✓ New best model saved!


Epoch 4/30
  Train Loss: 2.2901 | Train PPL: 9.88
  Val Loss:   3.9491 | Val PPL:   51.89
  TF Ratio: 0.850 | LR: 0.001000
  ✓ New best model saved!


Epoch 5/30
  Train Loss: 2.0648 | Train PPL: 7.88
  Val Loss:   3.8748 | Val PPL:   48.17 | BLEU: 0.0642
  TF Ratio: 0.800 | LR: 0.001000
  ✓ New best model saved!


Epoch 6/30
  Train Loss: 1.8832 | Train PPL: 6.57
  Val Loss:   3.8804 | Val PPL:   48.44
  TF Ratio: 0.750 | LR: 0.001000
  ⏳ No improvement (1/7)


Epoch 7/30
  Train Loss: 1.7208 | Train PPL: 5.59
  Val Loss:   3.8123 | Val PPL:   45.25
  TF Ratio: 0.700 | LR: 0.001000
  ✓ New best model saved!


Epoch 8/30
  Train Loss: 1.6359 | Train PPL: 5.13
  Val Loss:   3.7864 | Val PPL:   44.10
  TF Ratio: 0.650 | LR: 0.001000
  ✓ New best model saved!


Epoch 9/30
  Train Loss: 1.5199 | Train PPL: 4.57
  Val Loss:   3.7514 | Val PPL:   42.58
  TF Ratio: 0.600 | LR: 0.001000
  ✓ New best model saved!


Epoch 10/30
  Train Loss: 1.4580 | Train PPL: 4.30
  Val Loss:   3.8479 | Val PPL:   46.89 | BLEU: 0.1046
  TF Ratio: 0.550 | LR: 0.001000
  ⏳ No improvement (1/7)


Epoch 11/30
  Train Loss: 1.3919 | Train PPL: 4.02
  Val Loss:   3.8795 | Val PPL:   48.40
  TF Ratio: 0.500 | LR: 0.001000
  ⏳ No improvement (2/7)


Epoch 12/30
  Train Loss: 1.2947 | Train PPL: 3.65
  Val Loss:   3.8504 | Val PPL:   47.01
  TF Ratio: 0.500 | LR: 0.000500
  ⏳ No improvement (3/7)


Epoch 13/30
  Train Loss: 1.1110 | Train PPL: 3.04
  Val Loss:   3.8161 | Val PPL:   45.43
  TF Ratio: 0.500 | LR: 0.000500
  ⏳ No improvement (4/7)


Epoch 14/30
  Train Loss: 1.0525 | Train PPL: 2.86
  Val Loss:   3.8253 | Val PPL:   45.85
  TF Ratio: 0.500 | LR: 0.000500
  ⏳ No improvement (5/7)


Epoch 15/30
  Train Loss: 1.0181 | Train PPL: 2.77
  Val Loss:   3.7974 | Val PPL:   44.58 | BLEU: 0.1311
  TF Ratio: 0.500 | LR: 0.000250
  ⏳ No improvement (6/7)


Epoch 16/30
  Train Loss: 0.9238 | Train PPL: 2.52
  Val Loss:   3.8191 | Val PPL:   45.56
  TF Ratio: 0.500 | LR: 0.000250
  ⏳ No improvement (7/7)

⚠️ Early stopping triggered after 16 epochs

==================================================
TRANSLATION EXAMPLES
==================================================

English → Chinese:
  Hello                          → 你好!
  How are you?                   → 你怎么啊?
  I love you                     → 我爱你.
  Good morning                   → 早上好!
  Thank you very much            → 非常感谢,谢谢你.

Chinese → English:
  你好                             → hi !
  谢谢                             → thank you .
  我爱你                            → i m a nice .
  早上好                            → get up .
  再见                             → see you .
In [30]:
test_sentences_en = [
  "Hello",
  "How are you?",
  "I love you",
  "Good morning",
  "Thank you very much",
  "Can you give me a cup of tea?",
  "What is your name?",
  "Are you a software engineer?",
  "Are you an engineer?",
  "Can you speak Chinese?"
]

test_sentences_cn = [
  "你好",
  "谢谢",
  "我爱你",
  "早上好",
  "再见"
]

print("\nEnglish → Chinese:")
for sent in test_sentences_en:
  translation = translate(model, sent, vocab, device, direction='en2cn')
  print(f"  {sent:30s} → {translation}")

print("\nChinese → English:")
for sent in test_sentences_cn:
  translation = translate(model, sent, vocab, device, direction='cn2en')
  print(f"  {sent:30s} → {translation}")
English → Chinese:
  Hello                          → 你好!
  How are you?                   → 你怎么啊?
  I love you                     → 我爱你.
  Good morning                   → 早上好!
  Thank you very much            → 非常感谢,谢谢你.
  Can you give me a cup of tea?  → 你能给我一杯茶吗?
  What is your name?             → 你叫什么名字?
  Are you a software engineer?   → 你是个术士吗?
  Are you an engineer?           → 你是个术士吗?
  Can you speak Chinese?         → 你会讲日语吗?

Chinese → English:
  你好                             → hi !
  谢谢                             → thank you .
  我爱你                            → i m a nice .
  早上好                            → get up .
  再见                             → see you .