手写 Huggingface RAG 系统(2)——Embedding 构建
1. 概述
完成分块之后,我们就需要对分块的数据进行 Embedding 了。为了让 Embedding 在 Huggingface Transformers 上的效果更好,我们使用我们 chunk 好的数据构造数据集,在 jinaai/jina-embeddings-v2-base-en 上进行微调。
微调后的模型上传到了 Finetune-jina-transformers-v1 中。
2. 数据集构造
数据集构造分为两个部分:正样本对生成和负样本对生成。
正样本对生成
正样本对生成比较简单,因为对于每个 chunk,它的正样本就是它自己,我们只需要构造一些问题就行。我们用 Deepseek API 来完成问题的构建:
text = chunk['text']
source = chunk['metadata']['source']
prompt = f"""
You are an expert AI developer working with the Hugging Face Transformers library.
Based strictly on the following documentation excerpt, generate 3 technical questions that a developer might ask.
Documentation Source: {source}
Context:
{text}
Requirements:
1. The questions must be answerable **solely** based on the provided context.
2. Include specific specific Python class names, function names, arguments, or technical terms found in the text (e.g., 'logits', 'AutoModel', 'training_args').
3. The questions should be concise and professional.
4. Output ONLY a valid JSON list of strings.
Output Format:
["Question 1?", "Question 2?", "Question 3?"]
"""
response = client.chat.completions.create(
model='deepseek-chat',
messages=[{
'role': 'user',
'content': prompt
}],
temperature=0.7
)
由于总共有 6000 多个 chunk,全部用 Deepseek API 生成有些小贵,因此我们只用 Deepseek API 生成 2000 个,剩下的用我们在 chunk 时记录的 header 信息来构建:
def generate_questions_from_headers(chunk):
headers = chunk.get('metadata', {}).get('headers')
if not headers:
return []
questions = []
last_header = headers[-1]
questions.append(f"What does the documentation say about '{last_header}'?")
if len(headers) > 1:
parent_headers = " > ".join(headers[:-1])
questions.append(f"Explain the '{last_header}' section within '{parent_headers}'.")
questions.append(f"How do I use or implement '{last_header}' according to the provided text?")
return questions
然后我们遍历之间创建的 jsonl 分块,对每个分块进行问题生成即可。最终生成的问题格式如下:
{
"query": "What is the backbone model used by ColQwen2 to capture visual elements in document pages?",
"pos": "Context: ColQwen2\n\n# ColQwen2\n\nColQwen2(https://huggingface.co/papers/2407.01449) is a variant of the ColPali model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColQwen2 treats each page as an image. It uses the Qwen2-VL backbone to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed multi-vector embeddings that can be used for retrieval by computing pairwise late interaction similarity scores. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.\n\nThis model was contributed by @tonywu71(https://huggingface.co/tonywu71) (ILLUIN Technology) and @yonigozlan(https://huggingface.co/yonigozlan) (HuggingFace).\n\nYou can find all the original ColPali checkpoints under Vidore's Hf-native ColVision Models(https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.\n\n> !TIP\n> Click on the ColQwen2 models in the right sidebar for more examples of how to use ColQwen2 for image retrieval."
}
负样本对生成
除了正样本,我们还需要构建一些和正样本非常接近的负样本,让模型学会识别那些看起来像但是实际不是的文档。我们用 BM25 算法对每个 chunk 检索它的 Top10 样本,然后把挖掉 pos_doc(就是我们构造正样本的文档,这个肯定在 Top10 里面)后的排名第一的样本作为负样本。
在将文档用 BM25 算法进行评分前,我们需要进行简单的 tokenize:
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.data.find('corpora/stopwords')
nltk.download('stopwords')
nltk.data.find('tokenizers/punkt')
stopwords = set(nltk.corpus.stopwords.words('english'))
stemmer = PorterStemmer()
def tokenizer(text: str) -> List[str]:
text = str(text or "").lower()
tokens = nltk.word_tokenize(text)
# 移除标点与非数字
tokens = [word for word in tokens if word.isalnum()]
tokens = [word for word in tokens if word not in stopwords]
tokens = [stemmer.stem(word) for word in tokens]
return [token for token in tokens if token]
然后就可以开始构建负样本了:
bm25 = BM25Okapi(tokenized_corpus)
......
item = json.loads(line)
query = item['query']
pos_doc = item['pos']
tokenized_query = tokenizer(query)
top_n = bm25.get_top_n(tokenized_query, corpus_texts, n=10)
negatives = [doc for doc in top_n if doc != pos_doc]
hard_negative = negatives[0]
3. Embedding 模型微调
在生成数据集后,我们使用数据集在 jinaai/jina-embeddings-v2-base-en 上进行微调。这个比较简单,将数据集按照 jinaai 需要的格式放好即可:
可以看到我们只经过 1 epoch 微调后的模型的效果非常好。