keyword_rag.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # -*- coding: utf-8 -*-
  2. """Keyword RAG - unified interface for keyword-based retrieval (BM25 or TF-IDF)."""
  3. from .base import BaseRAG, RAGResult
  4. from loguru import logger
  5. class KeywordRAG(BaseRAG):
  6. def __init__(self, search_method="bm25", llm_client=None, llm_model="gpt-4o", **kwargs):
  7. self.embedding_model = None
  8. self.vector_store = None
  9. self.llm_client = llm_client
  10. self.llm_model = llm_model
  11. self.name = self.__class__.__name__
  12. self.search_method = search_method
  13. self._all_documents = []
  14. self._all_texts = []
  15. self.bm25 = None
  16. self.tfidf_matrix = None
  17. self.vectorizer = None
  18. def index_documents(self, documents):
  19. self._all_documents = documents
  20. self._all_texts = [doc.page_content for doc in documents]
  21. if self.search_method == "bm25":
  22. from rank_bm25 import BM25Okapi
  23. tokenized = [t.split() for t in self._all_texts]
  24. self.bm25 = BM25Okapi(tokenized)
  25. elif self.search_method == "tfidf":
  26. from sklearn.feature_extraction.text import TfidfVectorizer
  27. self.vectorizer = TfidfVectorizer()
  28. self.tfidf_matrix = self.vectorizer.fit_transform(self._all_texts)
  29. logger.info("KeywordRAG ({}) index built with {} documents".format(self.search_method, len(documents)))
  30. def retrieve(self, query, k=10):
  31. if self.search_method == "bm25":
  32. query_tokens = query.split()
  33. scores = self.bm25.get_scores(query_tokens)
  34. elif self.search_method == "tfidf":
  35. from sklearn.metrics.pairwise import cosine_similarity
  36. query_vec = self.vectorizer.transform([query])
  37. scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
  38. else:
  39. return []
  40. import numpy as np
  41. top_indices = np.argsort(scores)[::-1]
  42. results = []
  43. for idx in top_indices:
  44. if scores[idx] > 0:
  45. results.append((self._all_documents[idx], float(scores[idx])))
  46. return self._deduplicate_results(results, k)
  47. def generate(self, query, context):
  48. prompt = (
  49. "根据以下参考文档(关键词检索:{}),回答问题。\n\n"
  50. "参考文档:\n{}\n\n"
  51. "问题:{}\n\n"
  52. "请详细回答。".format(self.search_method, context, query)
  53. )
  54. return self._call_llm(prompt)
  55. def query(self, query, k=10):
  56. import time
  57. start_total = time.time()
  58. t0 = time.time()
  59. docs = self.retrieve(query, k)
  60. retrieval_time = time.time() - t0
  61. context = self._format_context(docs)
  62. t1 = time.time()
  63. answer = self.generate(query, context)
  64. generation_time = time.time() - t1
  65. total_time = time.time() - start_total
  66. return RAGResult(
  67. answer=answer,
  68. retrieved_docs=docs,
  69. latency_retrieval=retrieval_time,
  70. latency_generation=generation_time,
  71. latency_total=total_time,
  72. metadata={"method": self.name, "num_context_tokens": len(context)},
  73. )