# -*- coding: utf-8 -*- """Rerank RAG - semantic retrieval with LLM reranking.""" from .base import BaseRAG class RerankRAG(BaseRAG): def __init__(self, rerank_model=None, rerank_top_k=5, **kwargs): super().__init__(**kwargs) self.rerank_model = rerank_model self.rerank_top_k = rerank_top_k self.initial_k = 20 def retrieve(self, query, k=10): query_embedding = self.embedding_model.embed_query(query) initial_docs = self.vector_store.similarity_search(query_embedding, self.initial_k) if self.rerank_model and len(initial_docs) > 0: texts = [doc.page_content for doc, _ in initial_docs] pairs = [(query, text) for text in texts] scores = self.rerank_model.compute_score(pairs) if isinstance(scores, (int, float)): scores = [scores] reranked = list(zip(initial_docs, scores)) reranked.sort(key=lambda x: x[1], reverse=True) return self._deduplicate_results([(doc, float(score)) for (doc, _), score in reranked], k) return self._deduplicate_results(initial_docs, k) def generate(self, query, context): prompt = ( "根据以下经过重排序的参考文档,回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) return self._call_llm(prompt)