# -*- coding: utf-8 -*- """Keyword RAG - unified interface for keyword-based retrieval (BM25 or TF-IDF).""" from .base import BaseRAG, RAGResult from loguru import logger class KeywordRAG(BaseRAG): def __init__(self, search_method="bm25", llm_client=None, llm_model="gpt-4o", **kwargs): self.embedding_model = None self.vector_store = None self.llm_client = llm_client self.llm_model = llm_model self.name = self.__class__.__name__ self.search_method = search_method self._all_documents = [] self._all_texts = [] self.bm25 = None self.tfidf_matrix = None self.vectorizer = None def index_documents(self, documents): self._all_documents = documents self._all_texts = [doc.page_content for doc in documents] if self.search_method == "bm25": from rank_bm25 import BM25Okapi tokenized = [t.split() for t in self._all_texts] self.bm25 = BM25Okapi(tokenized) elif self.search_method == "tfidf": from sklearn.feature_extraction.text import TfidfVectorizer self.vectorizer = TfidfVectorizer() self.tfidf_matrix = self.vectorizer.fit_transform(self._all_texts) logger.info("KeywordRAG ({}) index built with {} documents".format(self.search_method, len(documents))) def retrieve(self, query, k=10): if self.search_method == "bm25": query_tokens = query.split() scores = self.bm25.get_scores(query_tokens) elif self.search_method == "tfidf": from sklearn.metrics.pairwise import cosine_similarity query_vec = self.vectorizer.transform([query]) scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten() else: return [] import numpy as np top_indices = np.argsort(scores)[::-1] results = [] for idx in top_indices: if scores[idx] > 0: results.append((self._all_documents[idx], float(scores[idx]))) return self._deduplicate_results(results, k) def generate(self, query, context): prompt = ( "根据以下参考文档(关键词检索:{}),回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(self.search_method, context, query) ) return self._call_llm(prompt) def query(self, query, k=10): import time start_total = time.time() t0 = time.time() docs = self.retrieve(query, k) retrieval_time = time.time() - t0 context = self._format_context(docs) t1 = time.time() answer = self.generate(query, context) generation_time = time.time() - t1 total_time = time.time() - start_total return RAGResult( answer=answer, retrieved_docs=docs, latency_retrieval=retrieval_time, latency_generation=generation_time, latency_total=total_time, metadata={"method": self.name, "num_context_tokens": len(context)}, )