In [1]:
from google.colab import drive
drive.mount('/content/drive')
base_dir = "/content/drive/MyDrive/huggingface-rag"
In [2]:
!pip install -U sentence-transformers openai qdrant_client fastembed
In [3]:
import os
import httpx
from openai import OpenAI
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
from fastembed import SparseTextEmbedding
qdrant_path = f"{base_dir}/qdrant_hybrid_db"
collection_name = 'huggingface_transformers_docs'
dense_model = SentenceTransformer("fyerfyer/finetune-jina-transformers-v1", trust_remote_code=True)
sparse_model = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
lock_file = os.path.join(qdrant_path, ".lock")
if os.path.exists(lock_file):
try:
os.remove(lock_file)
print(f"Removed stale lock file: {lock_file}")
except Exception as e:
print(f"Warning: Could not remove lock file: {e}")
In [4]:
from google.colab import userdata
class HFRAG:
def __init__(self):
self.dense_model = dense_model
self.sparse_model = sparse_model
self.db_client = QdrantClient(path=qdrant_path)
if not self.db_client.collection_exists(collection_name):
raise ValueError(f"Cannot find collection {collection_name}, please check qdrant path")
print(f"Successfully connected to qdrant: {qdrant_path}")
self.llm_client = OpenAI(
api_key=userdata.get('DEEPSEEK_API_KEY'),
base_url="https://api.deepseek.com",
http_client=httpx.Client(proxy=None, trust_env=False) # 开了代理的话要加这个,不然会报错
)
def retrieve(self, query: str, top_k: int = 5):
# Generate dense vector
query_dense_vec = self.dense_model.encode(query).tolist()
# Generate sparse vector
query_sparse_gen = list(self.sparse_model.embed([query]))[0]
query_sparse_vec = models.SparseVector(
indices=query_sparse_gen.indices.tolist(),
values=query_sparse_gen.values.tolist()
)
# Create prefetch for dense retrieval
prefetch_dense = models.Prefetch(
query=query_dense_vec,
using="text-dense",
limit=20,
)
# Create prefetch for sparse retrieval
prefetch_sparse = models.Prefetch(
query=query_sparse_vec,
using="text-sparse",
limit=20,
)
# Hybrid search with RRF fusion
results = self.db_client.query_points(
collection_name=collection_name,
prefetch=[prefetch_dense, prefetch_sparse],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=top_k,
with_payload=True
).points
return results
def generate(self, query: str, search_results):
if not search_results:
return "I'm sorry, but I couldn't find any relevant information in the knowledge base regarding your query."
context_pieces = []
for idx, hit in enumerate(search_results, 1):
source = hit.payload.get('source', 'unknown')
filename = source.split('/')[-1] if '/' in source else source
text = hit.payload['text']
piece = f"""<doc id="{idx}" source="{filename}">
{text}
</doc>"""
context_pieces.append(piece)
context_str = "\n\n".join(context_pieces)
system_prompt = """You are an expert AI assistant specializing in the Hugging Face Transformers library and NLP technology.
YOUR MISSION:
Answer the user's question using ONLY the provided "Retrieved Context". Do not rely on your internal knowledge base unless it is to explain syntax or general programming concepts not covered in the documents.
GUIDELINES:
1. **Grounding**: Base your answer strictly on the provided context chunks.
2. **Code First**: If the context contains code examples, prioritize showing them in your answer using Python markdown blocks.
3. **Citation**: When referencing specific information, cite the source file name (e.g., `[model_doc.md]`).
4. **Honesty**: If the provided context does not contain enough information to answer the question, state: "The provided documents do not contain the answer to this question." Do not hallucinate or make up parameters.
5. **Clarity**: Keep explanations concise and technical.
Output Format:
- Use Markdown for formatting.
- Use `code blocks` for function names and parameters.
"""
# 4. User Prompt (The "Input")
user_prompt = f"""
### User Query
{query}
### Retrieved Context
Please use the following documents to answer the query above:
{context_str}
### Answer
"""
print(f"\nThinking (Processing {len(search_results)} context chunks)...")
try:
response = self.llm_client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.1,
max_tokens=4096,
stream=True
)
full_response = ""
print("-" * 60)
for chunk in response:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
print(content, end="", flush=True)
full_response += content
print("\n" + "-" * 60)
return full_response
except Exception as e:
return f"Error calling LLM: {e}"
def chat(self, query: str):
print(f"\nUser: {query}")
results = self.retrieve(query)
self.generate(query, results)
In [ ]:
if __name__ == "__main__":
rag = HFRAG()
print("\nHuggingFace RAG assitant is started! Input 'quit' to exit")
while True:
user_input = input("\nPlease input your question: ")
if user_input.lower() in ['quit', 'exit']:
break
rag.chat(user_input)