RAG 检索系统核心实现解析:多模态检索与重排序技术retrieval.py详解
多模态检索:BM25 + 向量检索覆盖不同查询类型智能重排序:LLM 重排序提升结果质量企业级设计:多租户、高性能、易扩展工程化实现:完善的错误处理和配置管理对于构建企业级知识库系统,这个实现提供了很好的参考价值。通过合理的架构设计和技术选型,可以在保证检索质量的同时,实现良好的系统性能和用户体验。
代码在最后
前言
在企业知识库和文档问答系统中,检索(Retrieval)是 RAG(Retrieval-Augmented Generation)架构的核心组件。本文将深入解析一个获得 RAG 挑战赛冠军的检索系统实现,该系统巧妙地结合了 BM25、向量检索和 LLM 重排序等多种技术。
系统架构概览
该检索系统采用了三层架构设计:
-
BM25 检索器:基于词频统计的传统检索方法
-
向量检索器:基于语义相似度的现代检索方法
-
混合检索器:结合向量检索和 LLM 重排序的高级检索方法
核心组件详解
1. BM25 检索器(BM25Retriever)
BM25 是一种基于词频统计的经典检索算法,特别适合处理关键词匹配场景。
class BM25Retriever: def __init__(self, bm25_db_dir: Path, documents_dir: Path): # 初始化BM25检索器,指定BM25索引和文档目录 self.bm25_db_dir = bm25_db_dir self.documents_dir = documents_dir
核心检索方法:
def retrieve_by_company_name(self, company_name: str, query: str, top_n: int = 3, return_parent_pages: bool = False) -> List[Dict]:
# 1. 根据公司名找到对应文档
document_path = None
for path in self.documents_dir.glob("*.json"):
with open(path, 'r', encoding='utf-8') as f:
doc = json.load(f)
if doc["metainfo"]["company_name"] == company_name:
document_path = path
document = doc
break
# 2. 加载预训练的BM25索引
bm25_path = self.bm25_db_dir / f"{document['metainfo']['sha1_name']}.pkl"
with open(bm25_path, 'rb') as f:
bm25_index = pickle.load(f)
# 3. 计算BM25分数并排序
tokenized_query = query.split()
scores = bm25_index.get_scores(tokenized_query)
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
# 4. 构建检索结果
retrieval_results = []
for index in top_indices:
score = round(float(scores[index]), 4)
chunk = chunks[index]
result = {
"distance": score,
"page": chunk["page"],
"text": chunk["text"]
}
retrieval_results.append(result)
return retrieval_results
技术亮点:
-
支持按公司名精确检索
-
预计算并持久化 BM25 索引,提高查询效率
-
支持返回文档块或完整页面两种模式
2. 向量检索器(VectorRetriever)
向量检索器是现代 RAG 系统的核心,通过语义相似度进行文档检索。
多 Embedding 提供商支持:
def _get_embedding(self, text: str): # 根据 embedding_provider 获取文本的向量表示 if self.embedding_provider == "openai": embedding = self.llm.embeddings.create( input=text, model="text-embedding-3-large" ) return embedding.data[0].embedding elif self.embedding_provider == "dashscope": import dashscope rsp = dashscope.TextEmbedding.call( model="text-embedding-v1", input=[text] ) return rsp['output']['embeddings'][0]['embedding']
向量数据库管理:
def _load_dbs(self):
# 加载所有向量库和对应文档,建立映射
all_dbs = []
all_documents_paths = list(self.documents_dir.glob('*.json'))
vector_db_files = {db_path.stem: db_path for db_path in self.vector_db_dir.glob('*.faiss')}
for document_path in all_documents_paths:
stem = document_path.stem
if stem not in vector_db_files:
continue
# 加载文档和对应的FAISS向量库
with open(document_path, 'r', encoding='utf-8') as f:
document = json.load(f)
vector_db = faiss.read_index(str(vector_db_files[stem]))
report = {
"name": stem,
"vector_db": vector_db,
"document": document
}
all_dbs.append(report)
return all_dbs
核心检索逻辑:
def retrieve_by_company_name(self, company_name: str, query: str, top_n: int = 3, return_parent_pages: bool = False) -> List[Dict]:
# 1. 找到目标公司文档
target_report = None
for report in self.all_dbs:
if report["document"]["metainfo"]["company_name"] == company_name:
target_report = report
break
# 2. 获取查询向量
embedding = self._get_embedding(query)
embedding_array = np.array(embedding, dtype=np.float32).reshape(1, -1)
# 3. FAISS向量检索
distances, indices = vector_db.search(x=embedding_array, k=top_n)
# 4. 构建结果
retrieval_results = []
for distance, index in zip(distances[0], indices[0]):
chunk = chunks[index]
result = {
"distance": round(float(distance), 4),
"page": chunk["page"],
"text": chunk["text"]
}
retrieval_results.append(result)
return retrieval_results
技术特色:
-
支持 OpenAI 和 DashScope 两种 Embedding 提供商
-
使用 FAISS 进行高效向量检索
-
动态加载多个公司的向量数据库
-
提供余弦相似度计算工具方法
3. 混合检索器(HybridRetriever)
混合检索器是系统的最高级组件,结合了向量检索和 LLM 重排序技术。
class HybridRetriever: def __init__(self, vector_db_dir: Path, documents_dir: Path): self.vector_retriever = VectorRetriever(vector_db_dir, documents_dir) self.reranker = LLMReranker() def retrieve_by_company_name( self, company_name: str, query: str, llm_reranking_sample_size: int = 28, documents_batch_size: int = 2, top_n: int = 6, llm_weight: float = 0.7, return_parent_pages: bool = False ) -> List[Dict]: # 1. 向量检索获取初始候选集 vector_results = self.vector_retriever.retrieve_by_company_name( company_name=company_name, query=query, top_n=llm_reranking_sample_size, return_parent_pages=return_parent_pages ) # 2. LLM重排序优化结果 reranked_results = self.reranker.rerank_documents( query=query, documents=vector_results, documents_batch_size=documents_batch_size, llm_weight=llm_weight ) return reranked_results[:top_n]
系统设计亮点
1. 多层检索策略
系统采用了从简单到复杂的多层检索策略:
-
BM25:快速关键词匹配,适合精确查询
-
向量检索:语义理解,适合模糊查询
-
混合检索:结合两者优势,通过 LLM 重排序进一步优化
2. 灵活的配置参数
# 支持丰富的参数配置 llm_reranking_sample_size: int = 28 # 初始检索数量 documents_batch_size: int = 2 # LLM批处理大小 top_n: int = 6 # 最终返回数量 llm_weight: float = 0.7 # LLM权重 return_parent_pages: bool = False # 返回模式
3. 企业级特性
-
多公司支持:按公司名精确检索,适合多租户场景
-
错误处理:完善的异常处理和日志记录
-
性能优化:预计算索引、批量处理、缓存机制
-
扩展性:支持多种 Embedding 提供商
实际应用场景
1. 企业文档问答
# 查询某公司的财务信息 retriever = HybridRetriever(vector_db_dir, documents_dir) results = retriever.retrieve_by_company_name( company_name="Apple Inc.", query="2023年第四季度营收情况", top_n=5 )
2. 多公司对比分析
# 对比多家公司的相同指标
apple_results = retriever.retrieve_by_company_name("Apple Inc.", query)
google_results = retriever.retrieve_by_company_name("Google LLC", query)
3. 全文档检索
# 获取公司所有相关文档
all_docs = vector_retriever.retrieve_all("Apple Inc.")
性能优化技巧
1. 索引预计算
-
BM25 索引预先计算并持久化
-
FAISS 向量库离线构建
-
文档元数据缓存
2. 批量处理
-
LLM 重排序支持批量处理
-
向量检索一次性返回多个结果
-
文档加载批量操作
3. 内存管理
-
按需加载向量数据库
-
结果集大小限制
-
及时释放不用的资源
总结
这个检索系统展示了现代 RAG 架构的最佳实践:
-
多模态检索:BM25 + 向量检索覆盖不同查询类型
-
智能重排序:LLM 重排序提升结果质量
-
企业级设计:多租户、高性能、易扩展
-
工程化实现:完善的错误处理和配置管理
对于构建企业级知识库系统,这个实现提供了很好的参考价值。通过合理的架构设计和技术选型,可以在保证检索质量的同时,实现良好的系统性能和用户体验。
参考资源
本文基于 RAG-Challenge-2 获奖项目的源码分析,展示了工业级检索系统的实现细节。希望对正在构建类似系统的开发者有所帮助。
import json
import logging
from typing import List, Tuple, Dict, Union
from rank_bm25 import BM25Okapi
import pickle
from pathlib import Path
import faiss
from openai import OpenAI
from dotenv import load_dotenv
import os
import numpy as np
from src.reranking import LLMReranker
_log = logging.getLogger(__name__)
class BM25Retriever:
def __init__(self, bm25_db_dir: Path, documents_dir: Path):
# 初始化BM25检索器,指定BM25索引和文档目录
self.bm25_db_dir = bm25_db_dir
self.documents_dir = documents_dir
def retrieve_by_company_name(self, company_name: str, query: str, top_n: int = 3, return_parent_pages: bool = False) -> List[Dict]:
# 按公司名检索相关文本块,返回BM25分数最高的top_n个块
document_path = None
for path in self.documents_dir.glob("*.json"):
with open(path, 'r', encoding='utf-8') as f:
doc = json.load(f)
if doc["metainfo"]["company_name"] == company_name:
document_path = path
document = doc
break
if document_path is None:
raise ValueError(f"No report found with '{company_name}' company name.")
# 加载对应的BM25索引
bm25_path = self.bm25_db_dir / f"{document['metainfo']['sha1_name']}.pkl"
with open(bm25_path, 'rb') as f:
bm25_index = pickle.load(f)
# 获取文档内容和BM25索引
document = document
chunks = document["content"]["chunks"]
pages = document["content"]["pages"]
# 计算BM25分数
tokenized_query = query.split()
scores = bm25_index.get_scores(tokenized_query)
actual_top_n = min(top_n, len(scores))
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:actual_top_n]
retrieval_results = []
seen_pages = set()
for index in top_indices:
score = round(float(scores[index]), 4)
chunk = chunks[index]
parent_page = next(page for page in pages if page["page"] == chunk["page"])
if return_parent_pages:
if parent_page["page"] not in seen_pages:
seen_pages.add(parent_page["page"])
result = {
"distance": score,
"page": parent_page["page"],
"text": parent_page["text"]
}
retrieval_results.append(result)
else:
result = {
"distance": score,
"page": chunk["page"],
"text": chunk["text"]
}
retrieval_results.append(result)
return retrieval_results
class VectorRetriever:
def __init__(self, vector_db_dir: Path, documents_dir: Path, embedding_provider: str = "dashscope"):
# 初始化向量检索器,加载所有向量库和文档
self.vector_db_dir = vector_db_dir
self.documents_dir = documents_dir
self.all_dbs = self._load_dbs()
# 默认使用 dashscope 作为 embedding provider
self.embedding_provider = embedding_provider.lower()
self.llm = self._set_up_llm()
def _set_up_llm(self):
# 根据 embedding_provider 初始化对应的 LLM 客户端
load_dotenv()
if self.embedding_provider == "openai":
llm = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
timeout=None,
max_retries=2
)
return llm
elif self.embedding_provider == "dashscope":
import dashscope
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
return None # dashscope 不需要 client 对象
else:
raise ValueError(f"不支持的 embedding provider: {self.embedding_provider}")
def _get_embedding(self, text: str):
# 根据 embedding_provider 获取文本的向量表示
if self.embedding_provider == "openai":
embedding = self.llm.embeddings.create(
input=text,
model="text-embedding-3-large"
)
return embedding.data[0].embedding
elif self.embedding_provider == "dashscope":
import dashscope
rsp = dashscope.TextEmbedding.call(
model="text-embedding-v1",
input=[text]
)
# 兼容 dashscope 返回格式,不能用 resp.output,需用 resp['output']
if 'output' in rsp and 'embeddings' in rsp['output']:
# 多条输入(本处只有一条)
emb = rsp['output']['embeddings'][0]
if emb['embedding'] is None or len(emb['embedding']) == 0:
raise RuntimeError(f"DashScope返回的embedding为空,text_index={emb.get('text_index', None)}")
return emb['embedding']
elif 'output' in rsp and 'embedding' in rsp['output']:
# 兼容单条输入格式
if rsp['output']['embedding'] is None or len(rsp['output']['embedding']) == 0:
raise RuntimeError("DashScope返回的embedding为空")
return rsp['output']['embedding']
else:
raise RuntimeError(f"DashScope embedding API返回格式异常: {rsp}")
else:
raise ValueError(f"不支持的 embedding provider: {self.embedding_provider}")
@staticmethod
def set_up_llm():
# 静态方法,初始化OpenAI LLM
load_dotenv()
llm = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
timeout=None,
max_retries=2
)
return llm
def _load_dbs(self):
# 加载所有向量库和对应文档,建立映射
all_dbs = []
# 获取所有JSON文档路径
all_documents_paths = list(self.documents_dir.glob('*.json'))
vector_db_files = {db_path.stem: db_path for db_path in self.vector_db_dir.glob('*.faiss')}
for document_path in all_documents_paths:
stem = document_path.stem
if stem not in vector_db_files:
_log.warning(f"No matching vector DB found for document {document_path.name}")
continue
try:
with open(document_path, 'r', encoding='utf-8') as f:
document = json.load(f)
except Exception as e:
_log.error(f"Error loading JSON from {document_path.name}: {e}")
continue
# 校验文档结构
if not (isinstance(document, dict) and "metainfo" in document and "content" in document):
_log.warning(f"Skipping {document_path.name}: does not match the expected schema.")
continue
try:
vector_db = faiss.read_index(str(vector_db_files[stem]))
except Exception as e:
_log.error(f"Error reading vector DB for {document_path.name}: {e}")
continue
report = {
"name": stem,
"vector_db": vector_db,
"document": document
}
all_dbs.append(report)
return all_dbs
@staticmethod
def get_strings_cosine_similarity(str1, str2):
# 计算两个字符串的余弦相似度(通过嵌入)
llm = VectorRetriever.set_up_llm()
embeddings = llm.embeddings.create(input=[str1, str2], model="text-embedding-3-large")
embedding1 = embeddings.data[0].embedding
embedding2 = embeddings.data[1].embedding
similarity_score = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
similarity_score = round(similarity_score, 4)
return similarity_score
def retrieve_by_company_name(self, company_name: str, query: str, llm_reranking_sample_size: int = None, top_n: int = 3, return_parent_pages: bool = False) -> List[Tuple[str, float]]:
# 按公司名检索相关文本块,返回向量距离最近的top_n个块
target_report = None
for report in self.all_dbs:
document = report.get("document", {})
metainfo = document.get("metainfo")
if not metainfo:
_log.error(f"Report '{report.get('name')}' is missing 'metainfo'!")
raise ValueError(f"Report '{report.get('name')}' is missing 'metainfo'!")
if metainfo.get("company_name") == company_name:
target_report = report
break
if target_report is None:
_log.error(f"No report found with '{company_name}' company name.")
raise ValueError(f"No report found with '{company_name}' company name.")
document = target_report["document"]
vector_db = target_report["vector_db"]
chunks = document["content"]["chunks"]
pages = document["content"]["pages"]
actual_top_n = min(top_n, len(chunks))
# 获取 query 的 embedding,支持 openai/dashscope
embedding = self._get_embedding(query)
embedding_array = np.array(embedding, dtype=np.float32).reshape(1, -1)
distances, indices = vector_db.search(x=embedding_array, k=actual_top_n)
retrieval_results = []
seen_pages = set()
for distance, index in zip(distances[0], indices[0]):
distance = round(float(distance), 4)
chunk = chunks[index]
parent_page = next(page for page in pages if page["page"] == chunk["page"])
if return_parent_pages:
if parent_page["page"] not in seen_pages:
seen_pages.add(parent_page["page"])
result = {
"distance": distance,
"page": parent_page["page"],
"text": parent_page["text"]
}
retrieval_results.append(result)
else:
result = {
"distance": distance,
"page": chunk["page"],
"text": chunk["text"]
}
retrieval_results.append(result)
return retrieval_results
def retrieve_all(self, company_name: str) -> List[Dict]:
# 检索公司所有文本块,返回全部内容
target_report = None
for report in self.all_dbs:
document = report.get("document", {})
metainfo = document.get("metainfo")
if not metainfo:
continue
if metainfo.get("company_name") == company_name:
target_report = report
break
if target_report is None:
_log.error(f"No report found with '{company_name}' company name.")
raise ValueError(f"No report found with '{company_name}' company name.")
document = target_report["document"]
pages = document["content"]["pages"]
all_pages = []
for page in sorted(pages, key=lambda p: p["page"]):
result = {
"distance": 0.5,
"page": page["page"],
"text": page["text"]
}
all_pages.append(result)
return all_pages
class HybridRetriever:
def __init__(self, vector_db_dir: Path, documents_dir: Path):
self.vector_retriever = VectorRetriever(vector_db_dir, documents_dir)
self.reranker = LLMReranker()
def retrieve_by_company_name(
self,
company_name: str,
query: str,
llm_reranking_sample_size: int = 28,
documents_batch_size: int = 2,
top_n: int = 6,
llm_weight: float = 0.7,
return_parent_pages: bool = False
) -> List[Dict]:
"""
Retrieve and rerank documents using hybrid approach.
Args:
company_name: Name of the company to search documents for
query: Search query
llm_reranking_sample_size: Number of initial results to retrieve from vector DB
documents_batch_size: Number of documents to analyze in one LLM prompt
top_n: Number of final results to return after reranking
llm_weight: Weight given to LLM scores (0-1)
return_parent_pages: Whether to return full pages instead of chunks
Returns:
List of reranked document dictionaries with scores
"""
# Get initial results from vector retriever
vector_results = self.vector_retriever.retrieve_by_company_name(
company_name=company_name,
query=query,
top_n=llm_reranking_sample_size,
return_parent_pages=return_parent_pages
)
# Rerank results using LLM
reranked_results = self.reranker.rerank_documents(
query=query,
documents=vector_results,
documents_batch_size=documents_batch_size,
llm_weight=llm_weight
)
return reranked_results[:top_n]
更多推荐


所有评论(0)