| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- # -*- coding: utf-8 -*-
- """TF-IDF RAG - term frequency-inverse document frequency retrieval."""
- from .base import BaseRAG, RAGResult
- from loguru import logger
- class TFIDFRAG(BaseRAG):
- def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
- self.embedding_model = None
- self.vector_store = None
- self.llm_client = llm_client
- self.llm_model = llm_model
- self.name = self.__class__.__name__
- self.vectorizer = None
- self.tfidf_matrix = None
- self._all_documents = []
- def index_documents(self, documents):
- self._all_documents = documents
- texts = [doc.page_content for doc in documents]
- from sklearn.feature_extraction.text import TfidfVectorizer
- self.vectorizer = TfidfVectorizer()
- self.tfidf_matrix = self.vectorizer.fit_transform(texts)
- logger.info("TF-IDF index built with {} documents, vocab size: {}".format(len(documents), len(self.vectorizer.vocabulary_)))
- def retrieve(self, query, k=10):
- if self.tfidf_matrix is None:
- return []
-
- from sklearn.metrics.pairwise import cosine_similarity
- query_vec = self.vectorizer.transform([query])
- scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
- import numpy as np
- top_indices = np.argsort(scores)[::-1]
- results = []
- for idx in top_indices:
- if scores[idx] > 0:
- results.append((self._all_documents[idx], float(scores[idx])))
- return self._deduplicate_results(results, k)
- def generate(self, query, context):
- prompt = (
- "根据以下参考文档(TF-IDF关键词检索),回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
- return self._call_llm(prompt)
- def query(self, query, k=10):
- import time
- start_total = time.time()
-
- t0 = time.time()
- docs = self.retrieve(query, k)
- retrieval_time = time.time() - t0
- context = self._format_context(docs)
- t1 = time.time()
- answer = self.generate(query, context)
- generation_time = time.time() - t1
- total_time = time.time() - start_total
- return RAGResult(
- answer=answer,
- retrieved_docs=docs,
- latency_retrieval=retrieval_time,
- latency_generation=generation_time,
- latency_total=total_time,
- metadata={"method": self.name, "num_context_tokens": len(context)},
- )
|