| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- # -*- coding: utf-8 -*-
- """Keyword RAG - unified interface for keyword-based retrieval (BM25 or TF-IDF)."""
- from .base import BaseRAG, RAGResult
- from loguru import logger
- class KeywordRAG(BaseRAG):
- def __init__(self, search_method="bm25", llm_client=None, llm_model="gpt-4o", **kwargs):
- self.embedding_model = None
- self.vector_store = None
- self.llm_client = llm_client
- self.llm_model = llm_model
- self.name = self.__class__.__name__
- self.search_method = search_method
- self._all_documents = []
- self._all_texts = []
- self.bm25 = None
- self.tfidf_matrix = None
- self.vectorizer = None
- def index_documents(self, documents):
- self._all_documents = documents
- self._all_texts = [doc.page_content for doc in documents]
- if self.search_method == "bm25":
- from rank_bm25 import BM25Okapi
- tokenized = [t.split() for t in self._all_texts]
- self.bm25 = BM25Okapi(tokenized)
- elif self.search_method == "tfidf":
- from sklearn.feature_extraction.text import TfidfVectorizer
- self.vectorizer = TfidfVectorizer()
- self.tfidf_matrix = self.vectorizer.fit_transform(self._all_texts)
- logger.info("KeywordRAG ({}) index built with {} documents".format(self.search_method, len(documents)))
- def retrieve(self, query, k=10):
- if self.search_method == "bm25":
- query_tokens = query.split()
- scores = self.bm25.get_scores(query_tokens)
- elif self.search_method == "tfidf":
- from sklearn.metrics.pairwise import cosine_similarity
- query_vec = self.vectorizer.transform([query])
- scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
- else:
- return []
- import numpy as np
- top_indices = np.argsort(scores)[::-1]
- results = []
- for idx in top_indices:
- if scores[idx] > 0:
- results.append((self._all_documents[idx], float(scores[idx])))
- return self._deduplicate_results(results, k)
- def generate(self, query, context):
- prompt = (
- "根据以下参考文档(关键词检索:{}),回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(self.search_method, context, query)
- )
- return self._call_llm(prompt)
- def query(self, query, k=10):
- import time
- start_total = time.time()
-
- t0 = time.time()
- docs = self.retrieve(query, k)
- retrieval_time = time.time() - t0
- context = self._format_context(docs)
- t1 = time.time()
- answer = self.generate(query, context)
- generation_time = time.time() - t1
- total_time = time.time() - start_total
- return RAGResult(
- answer=answer,
- retrieved_docs=docs,
- latency_retrieval=retrieval_time,
- latency_generation=generation_time,
- latency_total=total_time,
- metadata={"method": self.name, "num_context_tokens": len(context)},
- )
|