检索增强生成(RAG)实战指南
提示:下面的流程图使用 Mermaid 绘制,您可以在支持的 Markdown 渲染器中直接查看。
graph LR
Docs[文档集合] --> Embedding[向量化 (FAISS)]
Query[用户查询] --> EmbedQ[查询向量]
EmbedQ --> Search[相似度检索]
Search --> Retrieved[检索到的文档]
Retrieved --> Prompt[构造 RAG Prompt]
Prompt --> LLM[LLM 生成]
LLM --> Answer[最终答案]
关键步骤
- 文档准备:将原始文本切分成段落或句子。
- 向量化:使用预训练的嵌入模型(如
sentence-transformers/all-MiniLM-L6-v2)将每段文本转为向量,并使用 FAISS 建立索引。 - 检索:对用户查询进行同样的向量化,使用 FAISS 找到最相似的 N 条文档。
- 构造 Prompt:把检索到的文档拼接到 Prompt 中,交给 LLM 生成答案。
环境准备
_10pip install faiss-cpu sentence-transformers transformers torch
完整代码示例
_45import os_45import torch_45from sentence_transformers import SentenceTransformer_45from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer_45import faiss_45_45# 1. 加载嵌入模型_45embed_model = SentenceTransformer("all-MiniLM-L6-v2")_45_45# 2. 准备文档(这里使用本地 txt 文件示例)_45corpus_path = "./docs"_45texts = []_45for fname in os.listdir(corpus_path):_45 if fname.endswith('.txt'):_45 with open(os.path.join(corpus_path, fname), 'r', encoding='utf-8') as f:_45 texts.extend([para.strip() for para in f.read().split('\n\n') if para.strip()])_45_45# 3. 向量化文档_45embeddings = embed_model.encode(texts, batch_size=32, show_progress_bar=True, convert_to_numpy=True)_45_45# 4. 建立 FAISS 索引(使用 L2 距离)_45dimension = embeddings.shape[1]_45index = faiss.IndexFlatL2(dimension)_45index.add(embeddings)_45_45# 5. 加载 LLM(这里使用小模型示例)_45llm_name = "gpt2"_45tokenizer = AutoTokenizer.from_pretrained(llm_name)_45model = AutoModelForCausalLM.from_pretrained(llm_name)_45generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)_45_45def rag_query(query, top_k=3):_45 # 向量化查询_45 q_vec = embed_model.encode([query], convert_to_numpy=True)_45 # 检索相似文档_45 distances, indices = index.search(q_vec, top_k)_45 retrieved = "\n\n".join([texts[i] for i in indices[0]])_45 # 构造 Prompt_45 prompt = f"以下是与问题相关的文档:\n{retrieved}\n\n请根据上述文档,用中文回答以下问题:{query}\n答案:"_45 # 生成答案_45 result = generator(prompt, max_new_tokens=150, do_sample=True, temperature=0.7)_45 return result[0]["generated_text"].split("答案:")[-1].strip()_45_45# 示例调用_45print(rag_query("什么是注意力机制?"))
说明:
faiss.IndexFlatL2适用于小规模数据(几千条),如果数据量更大请使用IndexIVFFlat或HNSW。- 代码中使用了
gpt2作为演示模型,实际项目可替换为更强大的 LLM(如meta-llama/Llama-2-7b-chat-hf)。
小结
- RAG 将检索与生成相结合,显著提升答案的事实准确性。
- 只需几行代码即可在本地实现完整的检索增强生成工作流。
- 通过更换嵌入模型和 LLM,可轻松适配不同领域的知识库。