bm25_rag.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # -*- coding: utf-8 -*-
  2. """BM25 RAG - probabilistic retrieval model."""
  3. from .base import BaseRAG, RAGResult
  4. from .bm25_backend import get_bm25_okapi
  5. from .tokenization import bm25_tokenize
  6. from loguru import logger
  7. class BM25RAG(BaseRAG):
  8. def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
  9. self.embedding_model = None
  10. self.vector_store = None
  11. self.llm_client = llm_client
  12. self.llm_model = llm_model
  13. self.name = self.__class__.__name__
  14. self.bm25 = None
  15. self._all_texts = []
  16. self._all_documents = []
  17. def _tokenize(self, text):
  18. """Tokenize mixed Chinese/English text for BM25."""
  19. return bm25_tokenize(text)
  20. def index_documents(self, documents):
  21. self._all_documents = list(documents or [])
  22. self._all_texts = [self._tokenize(doc.page_content) for doc in self._all_documents]
  23. if not self._all_documents:
  24. self.bm25 = None
  25. logger.info("BM25 index built with 0 documents")
  26. return
  27. BM25Okapi = get_bm25_okapi()
  28. self.bm25 = BM25Okapi(self._all_texts)
  29. logger.info("BM25 index built with {} documents".format(len(self._all_documents)))
  30. def retrieve(self, query, k=10):
  31. if self.bm25 is None or k <= 0:
  32. return []
  33. query_tokens = self._tokenize(query)
  34. if not query_tokens:
  35. return []
  36. scores = self.bm25.get_scores(query_tokens)
  37. scored_docs = []
  38. for i, score in enumerate(scores):
  39. # Include all documents with non-zero scores (BM25 can return negative scores)
  40. if score != 0:
  41. scored_docs.append((self._all_documents[i], float(score)))
  42. scored_docs.sort(key=lambda x: x[1], reverse=True)
  43. return self._deduplicate_results(scored_docs, k)
  44. def generate(self, query, context):
  45. prompt = (
  46. "根据以下参考文档(BM25关键词检索),回答问题。\n\n"
  47. "参考文档:\n{}\n\n"
  48. "问题:{}\n\n"
  49. "请详细回答。".format(context, query)
  50. )
  51. return self._call_llm(prompt)
  52. def query(self, query, k=10):
  53. import time
  54. start_total = time.time()
  55. t0 = time.time()
  56. docs = self.retrieve(query, k)
  57. retrieval_time = time.time() - t0
  58. context = self._format_context(docs)
  59. t1 = time.time()
  60. answer = self.generate(query, context)
  61. generation_time = time.time() - t1
  62. total_time = time.time() - start_total
  63. return RAGResult(
  64. answer=answer,
  65. retrieved_docs=docs,
  66. latency_retrieval=retrieval_time,
  67. latency_generation=generation_time,
  68. latency_total=total_time,
  69. metadata={"method": self.name, "num_context_tokens": len(context)},
  70. )