hyde_rag.py 1.1 KB

123456789101112131415161718192021222324252627282930
  1. # -*- coding: utf-8 -*-
  2. """HyDE RAG - generate hypothetical document to improve embedding."""
  3. from .base import BaseRAG
  4. class HyDERAG(BaseRAG):
  5. def __init__(self, **kwargs):
  6. super().__init__(**kwargs)
  7. def _generate_hypothetical_doc(self, query):
  8. prompt = (
  9. "请根据以下问题,生成一个假设的、包含相关信息的答案文档。\n\n"
  10. "问题:{}\n\n"
  11. "假设文档:".format(query)
  12. )
  13. return self._call_llm(prompt)
  14. def retrieve(self, query, k=10):
  15. hypothetical_doc = self._generate_hypothetical_doc(query)
  16. hypo_embedding = self.embedding_model.embed_query(hypothetical_doc)
  17. return self._deduplicate_results(self.vector_store.similarity_search(hypo_embedding, k), k)
  18. def generate(self, query, context):
  19. prompt = (
  20. "根据以下参考文档(基于假设文档检索),回答问题。\n\n"
  21. "参考文档:\n{}\n\n"
  22. "问题:{}\n\n"
  23. "请详细回答。".format(context, query)
  24. )
  25. return self._call_llm(prompt)