| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- # -*- coding: utf-8 -*-
- """BM25 RAG - probabilistic retrieval model."""
- from .base import BaseRAG, RAGResult
- from .bm25_backend import get_bm25_okapi
- from .tokenization import bm25_tokenize
- from loguru import logger
- class BM25RAG(BaseRAG):
- def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
- self.embedding_model = None
- self.vector_store = None
- self.llm_client = llm_client
- self.llm_model = llm_model
- self.name = self.__class__.__name__
- self.bm25 = None
- self._all_texts = []
- self._all_documents = []
- def _tokenize(self, text):
- """Tokenize mixed Chinese/English text for BM25."""
- return bm25_tokenize(text)
- def index_documents(self, documents):
- self._all_documents = list(documents or [])
- self._all_texts = [self._tokenize(doc.page_content) for doc in self._all_documents]
- if not self._all_documents:
- self.bm25 = None
- logger.info("BM25 index built with 0 documents")
- return
- BM25Okapi = get_bm25_okapi()
- self.bm25 = BM25Okapi(self._all_texts)
- logger.info("BM25 index built with {} documents".format(len(self._all_documents)))
- def retrieve(self, query, k=10):
- if self.bm25 is None or k <= 0:
- return []
-
- query_tokens = self._tokenize(query)
- if not query_tokens:
- return []
- scores = self.bm25.get_scores(query_tokens)
- scored_docs = []
- for i, score in enumerate(scores):
- # Include all documents with non-zero scores (BM25 can return negative scores)
- if score != 0:
- scored_docs.append((self._all_documents[i], float(score)))
- scored_docs.sort(key=lambda x: x[1], reverse=True)
- return self._deduplicate_results(scored_docs, k)
- def generate(self, query, context):
- prompt = (
- "根据以下参考文档(BM25关键词检索),回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
- return self._call_llm(prompt)
- def query(self, query, k=10):
- import time
- start_total = time.time()
-
- t0 = time.time()
- docs = self.retrieve(query, k)
- retrieval_time = time.time() - t0
- context = self._format_context(docs)
- t1 = time.time()
- answer = self.generate(query, context)
- generation_time = time.time() - t1
- total_time = time.time() - start_total
- return RAGResult(
- answer=answer,
- retrieved_docs=docs,
- latency_retrieval=retrieval_time,
- latency_generation=generation_time,
- latency_total=total_time,
- metadata={"method": self.name, "num_context_tokens": len(context)},
- )
|