rerank_rag.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # -*- coding: utf-8 -*-
  2. """Rerank RAG - semantic retrieval with LLM reranking."""
  3. from .base import BaseRAG
  4. class RerankRAG(BaseRAG):
  5. def __init__(self, rerank_model=None, rerank_top_k=5, **kwargs):
  6. super().__init__(**kwargs)
  7. self.rerank_model = rerank_model
  8. self.rerank_top_k = rerank_top_k
  9. self.initial_k = 20
  10. def retrieve(self, query, k=10):
  11. query_embedding = self.embedding_model.embed_query(query)
  12. initial_docs = self.vector_store.similarity_search(query_embedding, self.initial_k)
  13. if self.rerank_model and len(initial_docs) > 0:
  14. texts = [doc.page_content for doc, _ in initial_docs]
  15. pairs = [(query, text) for text in texts]
  16. scores = self.rerank_model.compute_score(pairs)
  17. if isinstance(scores, (int, float)):
  18. scores = [scores]
  19. reranked = list(zip(initial_docs, scores))
  20. reranked.sort(key=lambda x: x[1], reverse=True)
  21. return self._deduplicate_results([(doc, float(score)) for (doc, _), score in reranked], k)
  22. return self._deduplicate_results(initial_docs, k)
  23. def generate(self, query, context):
  24. prompt = (
  25. "根据以下经过重排序的参考文档,回答问题。\n\n"
  26. "参考文档:\n{}\n\n"
  27. "问题:{}\n\n"
  28. "请详细回答。".format(context, query)
  29. )
  30. return self._call_llm(prompt)