代码在最后

前言

在企业知识库和文档问答系统中,检索(Retrieval)是 RAG(Retrieval-Augmented Generation)架构的核心组件。本文将深入解析一个获得 RAG 挑战赛冠军的检索系统实现,该系统巧妙地结合了 BM25、向量检索和 LLM 重排序等多种技术。

系统架构概览

该检索系统采用了三层架构设计:

  1. BM25 检索器:基于词频统计的传统检索方法

  2. 向量检索器:基于语义相似度的现代检索方法

  3. 混合检索器:结合向量检索和 LLM 重排序的高级检索方法

核心组件详解

1. BM25 检索器(BM25Retriever)

BM25 是一种基于词频统计的经典检索算法,特别适合处理关键词匹配场景。

class BM25Retriever:
    def __init__(self, bm25_db_dir: Path, documents_dir: Path):
        # 初始化BM25检索器,指定BM25索引和文档目录
        self.bm25_db_dir = bm25_db_dir
        self.documents_dir = documents_dir

核心检索方法

def retrieve_by_company_name(self, company_name: str, query: str, top_n: int = 3, return_parent_pages: bool = False) -> List[Dict]:
    # 1. 根据公司名找到对应文档
    document_path = None
    for path in self.documents_dir.glob("*.json"):
        with open(path, 'r', encoding='utf-8') as f:
            doc = json.load(f)
            if doc["metainfo"]["company_name"] == company_name:
                document_path = path
                document = doc
                break
    
    # 2. 加载预训练的BM25索引
    bm25_path = self.bm25_db_dir / f"{document['metainfo']['sha1_name']}.pkl"
    with open(bm25_path, 'rb') as f:
        bm25_index = pickle.load(f)
    
    # 3. 计算BM25分数并排序
    tokenized_query = query.split()
    scores = bm25_index.get_scores(tokenized_query)
    top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
    
    # 4. 构建检索结果
    retrieval_results = []
    for index in top_indices:
        score = round(float(scores[index]), 4)
        chunk = chunks[index]
        result = {
            "distance": score,
            "page": chunk["page"],
            "text": chunk["text"]
        }
        retrieval_results.append(result)
    
    return retrieval_results

技术亮点

  • 支持按公司名精确检索

  • 预计算并持久化 BM25 索引,提高查询效率

  • 支持返回文档块或完整页面两种模式

2. 向量检索器(VectorRetriever)

向量检索器是现代 RAG 系统的核心,通过语义相似度进行文档检索。

多 Embedding 提供商支持

def _get_embedding(self, text: str):
    # 根据 embedding_provider 获取文本的向量表示
    if self.embedding_provider == "openai":
        embedding = self.llm.embeddings.create(
            input=text,
            model="text-embedding-3-large"
        )
        return embedding.data[0].embedding
    elif self.embedding_provider == "dashscope":
        import dashscope
        rsp = dashscope.TextEmbedding.call(
            model="text-embedding-v1",
            input=[text]
        )
        return rsp['output']['embeddings'][0]['embedding']

向量数据库管理

def _load_dbs(self):
    # 加载所有向量库和对应文档,建立映射
    all_dbs = []
    all_documents_paths = list(self.documents_dir.glob('*.json'))
    vector_db_files = {db_path.stem: db_path for db_path in self.vector_db_dir.glob('*.faiss')}
    
    for document_path in all_documents_paths:
        stem = document_path.stem
        if stem not in vector_db_files:
            continue
            
        # 加载文档和对应的FAISS向量库
        with open(document_path, 'r', encoding='utf-8') as f:
            document = json.load(f)
        vector_db = faiss.read_index(str(vector_db_files[stem]))
        
        report = {
            "name": stem,
            "vector_db": vector_db,
            "document": document
        }
        all_dbs.append(report)
    return all_dbs

核心检索逻辑

def retrieve_by_company_name(self, company_name: str, query: str, top_n: int = 3, return_parent_pages: bool = False) -> List[Dict]:
    # 1. 找到目标公司文档
    target_report = None
    for report in self.all_dbs:
        if report["document"]["metainfo"]["company_name"] == company_name:
            target_report = report
            break
    
    # 2. 获取查询向量
    embedding = self._get_embedding(query)
    embedding_array = np.array(embedding, dtype=np.float32).reshape(1, -1)
    
    # 3. FAISS向量检索
    distances, indices = vector_db.search(x=embedding_array, k=top_n)
    
    # 4. 构建结果
    retrieval_results = []
    for distance, index in zip(distances[0], indices[0]):
        chunk = chunks[index]
        result = {
            "distance": round(float(distance), 4),
            "page": chunk["page"],
            "text": chunk["text"]
        }
        retrieval_results.append(result)
    
    return retrieval_results

技术特色

  • 支持 OpenAI 和 DashScope 两种 Embedding 提供商

  • 使用 FAISS 进行高效向量检索

  • 动态加载多个公司的向量数据库

  • 提供余弦相似度计算工具方法

3. 混合检索器(HybridRetriever)

混合检索器是系统的最高级组件,结合了向量检索和 LLM 重排序技术。

class HybridRetriever:
    def __init__(self, vector_db_dir: Path, documents_dir: Path):
        self.vector_retriever = VectorRetriever(vector_db_dir, documents_dir)
        self.reranker = LLMReranker()
        
    def retrieve_by_company_name(
        self, 
        company_name: str, 
        query: str, 
        llm_reranking_sample_size: int = 28,
        documents_batch_size: int = 2,
        top_n: int = 6,
        llm_weight: float = 0.7,
        return_parent_pages: bool = False
    ) -> List[Dict]:
        # 1. 向量检索获取初始候选集
        vector_results = self.vector_retriever.retrieve_by_company_name(
            company_name=company_name,
            query=query,
            top_n=llm_reranking_sample_size,
            return_parent_pages=return_parent_pages
        )
        
        # 2. LLM重排序优化结果
        reranked_results = self.reranker.rerank_documents(
            query=query,
            documents=vector_results,
            documents_batch_size=documents_batch_size,
            llm_weight=llm_weight
        )
        
        return reranked_results[:top_n]

系统设计亮点

1. 多层检索策略

系统采用了从简单到复杂的多层检索策略:

  • BM25:快速关键词匹配,适合精确查询

  • 向量检索:语义理解,适合模糊查询

  • 混合检索:结合两者优势,通过 LLM 重排序进一步优化

2. 灵活的配置参数

# 支持丰富的参数配置
llm_reranking_sample_size: int = 28  # 初始检索数量
documents_batch_size: int = 2        # LLM批处理大小
top_n: int = 6                       # 最终返回数量
llm_weight: float = 0.7              # LLM权重
return_parent_pages: bool = False    # 返回模式

3. 企业级特性

  • 多公司支持:按公司名精确检索,适合多租户场景

  • 错误处理:完善的异常处理和日志记录

  • 性能优化:预计算索引、批量处理、缓存机制

  • 扩展性:支持多种 Embedding 提供商

实际应用场景

1. 企业文档问答

# 查询某公司的财务信息
retriever = HybridRetriever(vector_db_dir, documents_dir)
results = retriever.retrieve_by_company_name(
    company_name="Apple Inc.",
    query="2023年第四季度营收情况",
    top_n=5
)

2. 多公司对比分析

# 对比多家公司的相同指标
apple_results = retriever.retrieve_by_company_name("Apple Inc.", query)
google_results = retriever.retrieve_by_company_name("Google LLC", query)

3. 全文档检索

# 获取公司所有相关文档
all_docs = vector_retriever.retrieve_all("Apple Inc.")

性能优化技巧

1. 索引预计算

  • BM25 索引预先计算并持久化

  • FAISS 向量库离线构建

  • 文档元数据缓存

2. 批量处理

  • LLM 重排序支持批量处理

  • 向量检索一次性返回多个结果

  • 文档加载批量操作

3. 内存管理

  • 按需加载向量数据库

  • 结果集大小限制

  • 及时释放不用的资源

总结

这个检索系统展示了现代 RAG 架构的最佳实践:

  1. 多模态检索:BM25 + 向量检索覆盖不同查询类型

  2. 智能重排序:LLM 重排序提升结果质量

  3. 企业级设计:多租户、高性能、易扩展

  4. 工程化实现:完善的错误处理和配置管理

对于构建企业级知识库系统,这个实现提供了很好的参考价值。通过合理的架构设计和技术选型,可以在保证检索质量的同时,实现良好的系统性能和用户体验。

参考资源


本文基于 RAG-Challenge-2 获奖项目的源码分析,展示了工业级检索系统的实现细节。希望对正在构建类似系统的开发者有所帮助。

import json
import logging
from typing import List, Tuple, Dict, Union
from rank_bm25 import BM25Okapi
import pickle
from pathlib import Path
import faiss
from openai import OpenAI
from dotenv import load_dotenv
import os
import numpy as np
from src.reranking import LLMReranker

_log = logging.getLogger(__name__)

class BM25Retriever:
    def __init__(self, bm25_db_dir: Path, documents_dir: Path):
        # 初始化BM25检索器,指定BM25索引和文档目录
        self.bm25_db_dir = bm25_db_dir
        self.documents_dir = documents_dir
        
    def retrieve_by_company_name(self, company_name: str, query: str, top_n: int = 3, return_parent_pages: bool = False) -> List[Dict]:
        # 按公司名检索相关文本块,返回BM25分数最高的top_n个块
        document_path = None
        for path in self.documents_dir.glob("*.json"):
            with open(path, 'r', encoding='utf-8') as f:
                doc = json.load(f)
                if doc["metainfo"]["company_name"] == company_name:
                    document_path = path
                    document = doc
                    break
                    
        if document_path is None:
            raise ValueError(f"No report found with '{company_name}' company name.")
            
        # 加载对应的BM25索引
        bm25_path = self.bm25_db_dir / f"{document['metainfo']['sha1_name']}.pkl"
        with open(bm25_path, 'rb') as f:
            bm25_index = pickle.load(f)
            
        # 获取文档内容和BM25索引
        document = document
        chunks = document["content"]["chunks"]
        pages = document["content"]["pages"]
        
        # 计算BM25分数
        tokenized_query = query.split()
        scores = bm25_index.get_scores(tokenized_query)
        
        actual_top_n = min(top_n, len(scores))
        top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:actual_top_n]
        
        retrieval_results = []
        seen_pages = set()
        
        for index in top_indices:
            score = round(float(scores[index]), 4)
            chunk = chunks[index]
            parent_page = next(page for page in pages if page["page"] == chunk["page"])
            
            if return_parent_pages:
                if parent_page["page"] not in seen_pages:
                    seen_pages.add(parent_page["page"])
                    result = {
                        "distance": score,
                        "page": parent_page["page"],
                        "text": parent_page["text"]
                    }
                    retrieval_results.append(result)
            else:
                result = {
                    "distance": score,
                    "page": chunk["page"],
                    "text": chunk["text"]
                }
                retrieval_results.append(result)
        
        return retrieval_results



class VectorRetriever:
    def __init__(self, vector_db_dir: Path, documents_dir: Path, embedding_provider: str = "dashscope"):
        # 初始化向量检索器,加载所有向量库和文档
        self.vector_db_dir = vector_db_dir
        self.documents_dir = documents_dir
        self.all_dbs = self._load_dbs()
        # 默认使用 dashscope 作为 embedding provider
        self.embedding_provider = embedding_provider.lower()
        self.llm = self._set_up_llm()

    def _set_up_llm(self):
        # 根据 embedding_provider 初始化对应的 LLM 客户端
        load_dotenv()
        if self.embedding_provider == "openai":
            llm = OpenAI(
                api_key=os.getenv("OPENAI_API_KEY"),
                timeout=None,
                max_retries=2
            )
            return llm
        elif self.embedding_provider == "dashscope":
            import dashscope
            dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
            return None  # dashscope 不需要 client 对象
        else:
            raise ValueError(f"不支持的 embedding provider: {self.embedding_provider}")

    def _get_embedding(self, text: str):
        # 根据 embedding_provider 获取文本的向量表示
        if self.embedding_provider == "openai":
            embedding = self.llm.embeddings.create(
                input=text,
                model="text-embedding-3-large"
            )
            return embedding.data[0].embedding
        elif self.embedding_provider == "dashscope":
            import dashscope
            rsp = dashscope.TextEmbedding.call(
                model="text-embedding-v1",
                input=[text]
            )
            # 兼容 dashscope 返回格式,不能用 resp.output,需用 resp['output']
            if 'output' in rsp and 'embeddings' in rsp['output']:
                # 多条输入(本处只有一条)
                emb = rsp['output']['embeddings'][0]
                if emb['embedding'] is None or len(emb['embedding']) == 0:
                    raise RuntimeError(f"DashScope返回的embedding为空,text_index={emb.get('text_index', None)}")
                return emb['embedding']
            elif 'output' in rsp and 'embedding' in rsp['output']:
                # 兼容单条输入格式
                if rsp['output']['embedding'] is None or len(rsp['output']['embedding']) == 0:
                    raise RuntimeError("DashScope返回的embedding为空")
                return rsp['output']['embedding']
            else:
                raise RuntimeError(f"DashScope embedding API返回格式异常: {rsp}")
        else:
            raise ValueError(f"不支持的 embedding provider: {self.embedding_provider}")

    @staticmethod
    def set_up_llm():
        # 静态方法,初始化OpenAI LLM
        load_dotenv()
        llm = OpenAI(
            api_key=os.getenv("OPENAI_API_KEY"),
            timeout=None,
            max_retries=2
        )
        return llm

    def _load_dbs(self):
        # 加载所有向量库和对应文档,建立映射
        all_dbs = []
        # 获取所有JSON文档路径
        all_documents_paths = list(self.documents_dir.glob('*.json'))
        vector_db_files = {db_path.stem: db_path for db_path in self.vector_db_dir.glob('*.faiss')}
        
        for document_path in all_documents_paths:
            stem = document_path.stem
            if stem not in vector_db_files:
                _log.warning(f"No matching vector DB found for document {document_path.name}")
                continue
            try:
                with open(document_path, 'r', encoding='utf-8') as f:
                    document = json.load(f)
            except Exception as e:
                _log.error(f"Error loading JSON from {document_path.name}: {e}")
                continue
            
            # 校验文档结构
            if not (isinstance(document, dict) and "metainfo" in document and "content" in document):
                _log.warning(f"Skipping {document_path.name}: does not match the expected schema.")
                continue
            
            try:
                vector_db = faiss.read_index(str(vector_db_files[stem]))
            except Exception as e:
                _log.error(f"Error reading vector DB for {document_path.name}: {e}")
                continue
                
            report = {
                "name": stem,
                "vector_db": vector_db,
                "document": document
            }
            all_dbs.append(report)
        return all_dbs

    @staticmethod
    def get_strings_cosine_similarity(str1, str2):
        # 计算两个字符串的余弦相似度(通过嵌入)
        llm = VectorRetriever.set_up_llm()
        embeddings = llm.embeddings.create(input=[str1, str2], model="text-embedding-3-large")
        embedding1 = embeddings.data[0].embedding
        embedding2 = embeddings.data[1].embedding
        similarity_score = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
        similarity_score = round(similarity_score, 4)
        return similarity_score

    def retrieve_by_company_name(self, company_name: str, query: str, llm_reranking_sample_size: int = None, top_n: int = 3, return_parent_pages: bool = False) -> List[Tuple[str, float]]:
        # 按公司名检索相关文本块,返回向量距离最近的top_n个块
        target_report = None
        for report in self.all_dbs:
            document = report.get("document", {})
            metainfo = document.get("metainfo")
            if not metainfo:
                _log.error(f"Report '{report.get('name')}' is missing 'metainfo'!")
                raise ValueError(f"Report '{report.get('name')}' is missing 'metainfo'!")
            if metainfo.get("company_name") == company_name:
                target_report = report
                break
        
        if target_report is None:
            _log.error(f"No report found with '{company_name}' company name.")
            raise ValueError(f"No report found with '{company_name}' company name.")
        
        document = target_report["document"]
        vector_db = target_report["vector_db"]
        chunks = document["content"]["chunks"]
        pages = document["content"]["pages"]
        
        actual_top_n = min(top_n, len(chunks))
        
        # 获取 query 的 embedding,支持 openai/dashscope
        embedding = self._get_embedding(query)
        embedding_array = np.array(embedding, dtype=np.float32).reshape(1, -1)
        distances, indices = vector_db.search(x=embedding_array, k=actual_top_n)
    
        retrieval_results = []
        seen_pages = set()
        
        for distance, index in zip(distances[0], indices[0]):
            distance = round(float(distance), 4)
            chunk = chunks[index]
            parent_page = next(page for page in pages if page["page"] == chunk["page"])
            if return_parent_pages:
                if parent_page["page"] not in seen_pages:
                    seen_pages.add(parent_page["page"])
                    result = {
                        "distance": distance,
                        "page": parent_page["page"],
                        "text": parent_page["text"]
                    }
                    retrieval_results.append(result)
            else:
                result = {
                    "distance": distance,
                    "page": chunk["page"],
                    "text": chunk["text"]
                }
                retrieval_results.append(result)
            
        return retrieval_results

    def retrieve_all(self, company_name: str) -> List[Dict]:
        # 检索公司所有文本块,返回全部内容
        target_report = None
        for report in self.all_dbs:
            document = report.get("document", {})
            metainfo = document.get("metainfo")
            if not metainfo:
                continue
            if metainfo.get("company_name") == company_name:
                target_report = report
                break
        
        if target_report is None:
            _log.error(f"No report found with '{company_name}' company name.")
            raise ValueError(f"No report found with '{company_name}' company name.")
        
        document = target_report["document"]
        pages = document["content"]["pages"]
        
        all_pages = []
        for page in sorted(pages, key=lambda p: p["page"]):
            result = {
                "distance": 0.5,
                "page": page["page"],
                "text": page["text"]
            }
            all_pages.append(result)
            
        return all_pages


class HybridRetriever:
    def __init__(self, vector_db_dir: Path, documents_dir: Path):
        self.vector_retriever = VectorRetriever(vector_db_dir, documents_dir)
        self.reranker = LLMReranker()
        
    def retrieve_by_company_name(
        self, 
        company_name: str, 
        query: str, 
        llm_reranking_sample_size: int = 28,
        documents_batch_size: int = 2,
        top_n: int = 6,
        llm_weight: float = 0.7,
        return_parent_pages: bool = False
    ) -> List[Dict]:
        """
        Retrieve and rerank documents using hybrid approach.
        
        Args:
            company_name: Name of the company to search documents for
            query: Search query
            llm_reranking_sample_size: Number of initial results to retrieve from vector DB
            documents_batch_size: Number of documents to analyze in one LLM prompt
            top_n: Number of final results to return after reranking
            llm_weight: Weight given to LLM scores (0-1)
            return_parent_pages: Whether to return full pages instead of chunks
            
        Returns:
            List of reranked document dictionaries with scores
        """
        # Get initial results from vector retriever
        vector_results = self.vector_retriever.retrieve_by_company_name(
            company_name=company_name,
            query=query,
            top_n=llm_reranking_sample_size,
            return_parent_pages=return_parent_pages
        )
        
        # Rerank results using LLM
        reranked_results = self.reranker.rerank_documents(
            query=query,
            documents=vector_results,
            documents_batch_size=documents_batch_size,
            llm_weight=llm_weight
        )
        
        return reranked_results[:top_n]

Logo

电影级数字人,免显卡端渲染SDK,十行代码即可调用,工业级demo免费开源下载!

更多推荐