# -*- coding: utf-8 -*- """TF-IDF RAG - term frequency-inverse document frequency retrieval.""" from .base import BaseRAG, RAGResult from loguru import logger class TFIDFRAG(BaseRAG): def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs): self.embedding_model = None self.vector_store = None self.llm_client = llm_client self.llm_model = llm_model self.name = self.__class__.__name__ self.vectorizer = None self.tfidf_matrix = None self._all_documents = [] def index_documents(self, documents): self._all_documents = documents texts = [doc.page_content for doc in documents] from sklearn.feature_extraction.text import TfidfVectorizer self.vectorizer = TfidfVectorizer() self.tfidf_matrix = self.vectorizer.fit_transform(texts) logger.info("TF-IDF index built with {} documents, vocab size: {}".format(len(documents), len(self.vectorizer.vocabulary_))) def retrieve(self, query, k=10): if self.tfidf_matrix is None: return [] from sklearn.metrics.pairwise import cosine_similarity query_vec = self.vectorizer.transform([query]) scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten() import numpy as np top_indices = np.argsort(scores)[::-1] results = [] for idx in top_indices: if scores[idx] > 0: results.append((self._all_documents[idx], float(scores[idx]))) return self._deduplicate_results(results, k) def generate(self, query, context): prompt = ( "根据以下参考文档(TF-IDF关键词检索),回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) return self._call_llm(prompt) def query(self, query, k=10): import time start_total = time.time() t0 = time.time() docs = self.retrieve(query, k) retrieval_time = time.time() - t0 context = self._format_context(docs) t1 = time.time() answer = self.generate(query, context) generation_time = time.time() - t1 total_time = time.time() - start_total return RAGResult( answer=answer, retrieved_docs=docs, latency_retrieval=retrieval_time, latency_generation=generation_time, latency_total=total_time, metadata={"method": self.name, "num_context_tokens": len(context)}, )