# -*- coding: utf-8 -*- """BM25 RAG - probabilistic retrieval model.""" from .base import BaseRAG, RAGResult from .bm25_backend import get_bm25_okapi from .tokenization import bm25_tokenize from loguru import logger class BM25RAG(BaseRAG): def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs): self.embedding_model = None self.vector_store = None self.llm_client = llm_client self.llm_model = llm_model self.name = self.__class__.__name__ self.bm25 = None self._all_texts = [] self._all_documents = [] def _tokenize(self, text): """Tokenize mixed Chinese/English text for BM25.""" return bm25_tokenize(text) def index_documents(self, documents): self._all_documents = list(documents or []) self._all_texts = [self._tokenize(doc.page_content) for doc in self._all_documents] if not self._all_documents: self.bm25 = None logger.info("BM25 index built with 0 documents") return BM25Okapi = get_bm25_okapi() self.bm25 = BM25Okapi(self._all_texts) logger.info("BM25 index built with {} documents".format(len(self._all_documents))) def retrieve(self, query, k=10): if self.bm25 is None or k <= 0: return [] query_tokens = self._tokenize(query) if not query_tokens: return [] scores = self.bm25.get_scores(query_tokens) scored_docs = [] for i, score in enumerate(scores): # Include all documents with non-zero scores (BM25 can return negative scores) if score != 0: scored_docs.append((self._all_documents[i], float(score))) scored_docs.sort(key=lambda x: x[1], reverse=True) return self._deduplicate_results(scored_docs, k) def generate(self, query, context): prompt = ( "根据以下参考文档(BM25关键词检索),回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) return self._call_llm(prompt) def query(self, query, k=10): import time start_total = time.time() t0 = time.time() docs = self.retrieve(query, k) retrieval_time = time.time() - t0 context = self._format_context(docs) t1 = time.time() answer = self.generate(query, context) generation_time = time.time() - t1 total_time = time.time() - start_total return RAGResult( answer=answer, retrieved_docs=docs, latency_retrieval=retrieval_time, latency_generation=generation_time, latency_total=total_time, metadata={"method": self.name, "num_context_tokens": len(context)}, )