In [ ]:
from google.colab import drive
drive.mount('/content/drive')
base_dir = "/content/drive/MyDrive/huggingface-rag"
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
In [3]:
!pip install -U sentence-transformers accelerate einops
Requirement already satisfied: sentence-transformers in /usr/local/lib/python3.12/dist-packages (5.1.2)
Requirement already satisfied: accelerate in /usr/local/lib/python3.12/dist-packages (1.12.0)
Requirement already satisfied: einops in /usr/local/lib/python3.12/dist-packages (0.8.1)
Requirement already satisfied: transformers<5.0.0,>=4.41.0 in /usr/local/lib/python3.12/dist-packages (from sentence-transformers) (4.57.1)
Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from sentence-transformers) (4.67.1)
Requirement already satisfied: torch>=1.11.0 in /usr/local/lib/python3.12/dist-packages (from sentence-transformers) (2.9.0+cu126)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.12/dist-packages (from sentence-transformers) (1.6.1)
Requirement already satisfied: scipy in /usr/local/lib/python3.12/dist-packages (from sentence-transformers) (1.16.3)
Requirement already satisfied: huggingface-hub>=0.20.0 in /usr/local/lib/python3.12/dist-packages (from sentence-transformers) (0.36.0)
Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from sentence-transformers) (11.3.0)
Requirement already satisfied: typing_extensions>=4.5.0 in /usr/local/lib/python3.12/dist-packages (from sentence-transformers) (4.15.0)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from accelerate) (2.0.2)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from accelerate) (25.0)
Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from accelerate) (5.9.5)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from accelerate) (6.0.3)
Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from accelerate) (0.7.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.20.0->sentence-transformers) (3.20.0)
Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.20.0->sentence-transformers) (2025.3.0)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.20.0->sentence-transformers) (2.32.4)
Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.20.0->sentence-transformers) (1.2.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (75.2.0)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (1.14.0)
Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (3.5)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (3.1.6)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (12.6.77)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (12.6.77)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (12.6.80)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (12.6.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (11.3.0.4)
Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (10.3.7.77)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (11.7.1.2)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (12.5.4.2)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (2.27.5)
Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (3.3.20)
Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (12.6.77)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (12.6.85)
Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (1.11.1.6)
Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.11.0->sentence-transformers) (3.5.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers<5.0.0,>=4.41.0->sentence-transformers) (2025.11.3)
Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers<5.0.0,>=4.41.0->sentence-transformers) (0.22.1)
Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn->sentence-transformers) (1.5.2)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn->sentence-transformers) (3.6.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.11.0->sentence-transformers) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.11.0->sentence-transformers) (3.0.3)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.20.0->sentence-transformers) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.20.0->sentence-transformers) (3.11)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.20.0->sentence-transformers) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.20.0->sentence-transformers) (2025.11.12)
In [4]:
import json
import torch
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
In [6]:
model_name = "jinaai/jina-embeddings-v2-base-en"
max_seq_length = 512
batch_size = 8
num_epochs = 1
output_path = f"{base_dir}/ft-jina-transformers-v1"

model = SentenceTransformer(model_name, trust_remote_code=True)
model.max_seq_length = max_seq_length
In [7]:
train_examples = []
dataset_path = f"{base_dir}/train_dataset.jsonl"

with open(dataset_path, 'r', encoding='utf-8') as f_in:
  for line in f_in:
    item = json.loads(line)
    train_examples.append(InputExample(texts=[
      item['query'],
      item['pos'],
      item['neg']
    ]))

print(f"{len(train_examples)} train data is loaded")
15019 train data is loaded
In [6]:
train_dataloader = DataLoader(
  train_examples,
  shuffle=True,
  batch_size=batch_size
)

train_loss = losses.MultipleNegativesRankingLoss(model=model)

print("Start finetuning...\n")

model.fit(
  train_objectives=[(train_dataloader, train_loss)],
  epochs = num_epochs,
  warmup_steps=int(len(train_dataloader) * 10),
  output_path=output_path,
  show_progress_bar=True,
  use_amp=True
)

print(f"Finetune over! Finetuned model has been saved in {output_path}")
Start finetuning...

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]
/usr/local/lib/python3.12/dist-packages/notebook/notebookapp.py:191: SyntaxWarning: invalid escape sequence '\/'
  | |_| | '_ \/ _` / _` |  _/ -_)
wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:
 ··········
wandb: WARNING If you're specifying your api key in code, ensure this code is not shared publicly.
wandb: WARNING Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.
wandb: No netrc file found, creating one.
wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
wandb: Currently logged in as: fuy60703 (fuy60703-huazhong-university-of-science-and-technology) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
Tracking run with wandb version 0.23.0
Run data is saved locally in /content/wandb/run-20251126_065724-xz2rmj67
[1878/1878 14:21, Epoch 1/1]
Step Training Loss
500 0.980200
1000 0.278000
1500 0.202600

Finetune over! Finetuned model has been saved in /content/drive/MyDrive/huggingface-rag/ft-jina-transformers-v1
In [8]:
import json
import random
import numpy as np
from sentence_transformers import util

model = SentenceTransformer(output_path, trust_remote_code=True)
model.to('cpu')
model.eval()

# Load dataset
with open(dataset_path, 'r', encoding='utf-8') as f:
  all_data = [json.loads(line) for line in f]

# Evaluation on sample pairs
print("Starting fine-tuning evaluation...\n")
print("=" * 100)

test_samples = random.sample(all_data, min(40, len(all_data)))
correct_predictions = 0

for idx, sample in enumerate(test_samples, 1):
  query = sample['query']
  pos_doc = sample['pos']
  neg_doc = sample['neg']

  # All embeddings are CPU numpy
  query_embedding = model.encode(query, convert_to_tensor=False)
  pos_embedding = model.encode(pos_doc, convert_to_tensor=False)
  neg_embedding = model.encode(neg_doc, convert_to_tensor=False)

  # Convert to torch tensor for similarity
  query_tensor = torch.tensor(query_embedding)
  pos_tensor = torch.tensor(pos_embedding)
  neg_tensor = torch.tensor(neg_embedding)

  pos_score = util.cos_sim(query_tensor, pos_tensor).item()
  neg_score = util.cos_sim(query_tensor, neg_tensor).item()
  is_correct = pos_score > neg_score
  correct_predictions += is_correct

  print(f"\nSample {idx}/{len(test_samples)}")
  print(f"Positive: {pos_score:.4f}, Negative: {neg_score:.4f}")
  print(f"Result: {'✓ Correct' if is_correct else '✗ Wrong'}")
  print("-" * 90)

print("\n" + "=" * 100)
print(f"Pairwise Accuracy: {correct_predictions}/{len(test_samples)} = {100 * correct_predictions / len(test_samples):.2f}%")
print("=" * 100)

# Retrieval Testing
print("\nRetrieval Task Testing...")
retrieval_samples = random.sample(all_data, min(10, len(all_data)))

for idx, sample in enumerate(retrieval_samples, 1):
  query = sample['query']
  pos_doc = sample['pos']

  # Sample negative candidates
  random_negs = random.sample([s['neg'] for s in all_data if s != sample], min(4, len(all_data)-1))
  candidates = [pos_doc] + random_negs
  random.shuffle(candidates)

  correct_idx = candidates.index(pos_doc)

  query_emb = model.encode(query, convert_to_tensor=False)
  cand_embs = model.encode(candidates, convert_to_tensor=False)

  # Convert to tensors for cosine similarity
  query_tensor = torch.tensor(query_emb)
  cand_tensor = torch.tensor(cand_embs)

  similarities = util.cos_sim(query_tensor, cand_tensor)[0]
  ranked_indices = similarities.argsort(descending=True).numpy()
  correct_rank = np.where(ranked_indices == correct_idx)[0][0] + 1

  print(f"\nRetrieval Test {idx}:")
  print(f"Correct Document Rank: {correct_rank}/{len(candidates)}")
  print(f"Top-1: {'✓' if correct_rank == 1 else '✗'}")

  print("Top-3:")
  for rank, i in enumerate(ranked_indices[:3], 1):
    print(f" {rank}. sim={similarities[i]:.4f} {'← correct' if i == correct_idx else ''}")
  print("-" * 90)

print("\nTesting completed! 🚀")
Starting fine-tuning evaluation...

====================================================================================================

Sample 1/40
Positive: 0.8836, Negative: 0.0052
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 2/40
Positive: 0.7760, Negative: -0.0524
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 3/40
Positive: 0.8126, Negative: 0.2996
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 4/40
Positive: 0.7417, Negative: -0.1425
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 5/40
Positive: 0.7156, Negative: 0.3208
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 6/40
Positive: 0.4373, Negative: 0.0201
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 7/40
Positive: 0.6209, Negative: -0.0475
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 8/40
Positive: 0.6389, Negative: -0.0110
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 9/40
Positive: 0.6835, Negative: -0.0322
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 10/40
Positive: 0.7110, Negative: -0.0592
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 11/40
Positive: 0.8333, Negative: 0.3832
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 12/40
Positive: 0.8234, Negative: 0.3786
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 13/40
Positive: 0.6846, Negative: 0.3276
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 14/40
Positive: 0.7523, Negative: 0.1574
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 15/40
Positive: 0.4928, Negative: 0.2356
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 16/40
Positive: 0.8164, Negative: 0.0141
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 17/40
Positive: 0.5266, Negative: 0.0799
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 18/40
Positive: 0.8020, Negative: 0.2003
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 19/40
Positive: 0.7023, Negative: -0.0496
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 20/40
Positive: 0.7398, Negative: -0.0339
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 21/40
Positive: 0.6253, Negative: 0.3056
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 22/40
Positive: 0.4634, Negative: -0.0055
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 23/40
Positive: 0.4318, Negative: 0.3263
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 24/40
Positive: 0.7547, Negative: -0.0265
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 25/40
Positive: 0.6421, Negative: 0.4205
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 26/40
Positive: 0.6076, Negative: 0.2016
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 27/40
Positive: 0.7400, Negative: 0.6232
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 28/40
Positive: 0.6000, Negative: -0.0592
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 29/40
Positive: 0.6903, Negative: 0.5531
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 30/40
Positive: 0.4140, Negative: 0.2203
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 31/40
Positive: 0.6745, Negative: 0.6913
Result: ✗ Wrong
------------------------------------------------------------------------------------------

Sample 32/40
Positive: 0.8263, Negative: 0.4778
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 33/40
Positive: 0.7925, Negative: 0.3103
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 34/40
Positive: 0.7583, Negative: 0.1287
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 35/40
Positive: 0.8496, Negative: 0.1525
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 36/40
Positive: 0.7250, Negative: 0.5422
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 37/40
Positive: 0.8378, Negative: -0.0069
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 38/40
Positive: 0.5784, Negative: 0.2563
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 39/40
Positive: 0.6282, Negative: 0.3806
Result: ✓ Correct
------------------------------------------------------------------------------------------

Sample 40/40
Positive: 0.5763, Negative: 0.3881
Result: ✓ Correct
------------------------------------------------------------------------------------------

====================================================================================================
Pairwise Accuracy: 39/40 = 97.50%
====================================================================================================

Retrieval Task Testing...

Retrieval Test 1:
Correct Document Rank: 1/5
Top-1: ✓
Top-3:
 1. sim=0.7410 ← correct
 2. sim=0.1145 
 3. sim=0.0799 
------------------------------------------------------------------------------------------

Retrieval Test 2:
Correct Document Rank: 1/5
Top-1: ✓
Top-3:
 1. sim=0.6482 ← correct
 2. sim=0.0599 
 3. sim=0.0550 
------------------------------------------------------------------------------------------

Retrieval Test 3:
Correct Document Rank: 1/5
Top-1: ✓
Top-3:
 1. sim=0.6007 ← correct
 2. sim=0.1302 
 3. sim=0.0835 
------------------------------------------------------------------------------------------

Retrieval Test 4:
Correct Document Rank: 1/5
Top-1: ✓
Top-3:
 1. sim=0.8003 ← correct
 2. sim=0.1457 
 3. sim=0.0883 
------------------------------------------------------------------------------------------

Retrieval Test 5:
Correct Document Rank: 1/5
Top-1: ✓
Top-3:
 1. sim=0.7396 ← correct
 2. sim=0.1031 
 3. sim=0.0846 
------------------------------------------------------------------------------------------

Retrieval Test 6:
Correct Document Rank: 1/5
Top-1: ✓
Top-3:
 1. sim=0.6588 ← correct
 2. sim=0.1011 
 3. sim=0.0529 
------------------------------------------------------------------------------------------

Retrieval Test 7:
Correct Document Rank: 1/5
Top-1: ✓
Top-3:
 1. sim=0.8398 ← correct
 2. sim=0.1840 
 3. sim=0.1325 
------------------------------------------------------------------------------------------

Retrieval Test 8:
Correct Document Rank: 1/5
Top-1: ✓
Top-3:
 1. sim=0.5624 ← correct
 2. sim=0.1171 
 3. sim=0.1068 
------------------------------------------------------------------------------------------

Retrieval Test 9:
Correct Document Rank: 1/5
Top-1: ✓
Top-3:
 1. sim=0.5426 ← correct
 2. sim=0.1341 
 3. sim=0.1120 
------------------------------------------------------------------------------------------

Retrieval Test 10:
Correct Document Rank: 1/5
Top-1: ✓
Top-3:
 1. sim=0.7565 ← correct
 2. sim=0.1497 
 3. sim=0.1364 
------------------------------------------------------------------------------------------

Testing completed! 🚀