| 1234567891011121314151617181920212223242526272829303132333435363738 |
- # -*- coding: utf-8 -*-
- """Rerank RAG - semantic retrieval with LLM reranking."""
- from .base import BaseRAG
- class RerankRAG(BaseRAG):
- def __init__(self, rerank_model=None, rerank_top_k=5, **kwargs):
- super().__init__(**kwargs)
- self.rerank_model = rerank_model
- self.rerank_top_k = rerank_top_k
- self.initial_k = 20
- def retrieve(self, query, k=10):
- query_embedding = self.embedding_model.embed_query(query)
- initial_docs = self.vector_store.similarity_search(query_embedding, self.initial_k)
- if self.rerank_model and len(initial_docs) > 0:
- texts = [doc.page_content for doc, _ in initial_docs]
- pairs = [(query, text) for text in texts]
- scores = self.rerank_model.compute_score(pairs)
- if isinstance(scores, (int, float)):
- scores = [scores]
- reranked = list(zip(initial_docs, scores))
- reranked.sort(key=lambda x: x[1], reverse=True)
- return self._deduplicate_results([(doc, float(score)) for (doc, _), score in reranked], k)
- return self._deduplicate_results(initial_docs, k)
- def generate(self, query, context):
- prompt = (
- "根据以下经过重排序的参考文档,回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
- return self._call_llm(prompt)
|