tfidf_rag.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # -*- coding: utf-8 -*-
  2. """TF-IDF RAG - term frequency-inverse document frequency retrieval."""
  3. from .base import BaseRAG, RAGResult
  4. from loguru import logger
  5. class TFIDFRAG(BaseRAG):
  6. def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
  7. self.embedding_model = None
  8. self.vector_store = None
  9. self.llm_client = llm_client
  10. self.llm_model = llm_model
  11. self.name = self.__class__.__name__
  12. self.vectorizer = None
  13. self.tfidf_matrix = None
  14. self._all_documents = []
  15. def index_documents(self, documents):
  16. self._all_documents = documents
  17. texts = [doc.page_content for doc in documents]
  18. from sklearn.feature_extraction.text import TfidfVectorizer
  19. self.vectorizer = TfidfVectorizer()
  20. self.tfidf_matrix = self.vectorizer.fit_transform(texts)
  21. logger.info("TF-IDF index built with {} documents, vocab size: {}".format(len(documents), len(self.vectorizer.vocabulary_)))
  22. def retrieve(self, query, k=10):
  23. if self.tfidf_matrix is None:
  24. return []
  25. from sklearn.metrics.pairwise import cosine_similarity
  26. query_vec = self.vectorizer.transform([query])
  27. scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
  28. import numpy as np
  29. top_indices = np.argsort(scores)[::-1]
  30. results = []
  31. for idx in top_indices:
  32. if scores[idx] > 0:
  33. results.append((self._all_documents[idx], float(scores[idx])))
  34. return self._deduplicate_results(results, k)
  35. def generate(self, query, context):
  36. prompt = (
  37. "根据以下参考文档(TF-IDF关键词检索),回答问题。\n\n"
  38. "参考文档:\n{}\n\n"
  39. "问题:{}\n\n"
  40. "请详细回答。".format(context, query)
  41. )
  42. return self._call_llm(prompt)
  43. def query(self, query, k=10):
  44. import time
  45. start_total = time.time()
  46. t0 = time.time()
  47. docs = self.retrieve(query, k)
  48. retrieval_time = time.time() - t0
  49. context = self._format_context(docs)
  50. t1 = time.time()
  51. answer = self.generate(query, context)
  52. generation_time = time.time() - t1
  53. total_time = time.time() - start_total
  54. return RAGResult(
  55. answer=answer,
  56. retrieved_docs=docs,
  57. latency_retrieval=retrieval_time,
  58. latency_generation=generation_time,
  59. latency_total=total_time,
  60. metadata={"method": self.name, "num_context_tokens": len(context)},
  61. )