step_back_rag.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. # -*- coding: utf-8 -*-
  2. """Step-Back RAG - ask a more general question for context, then specific."""
  3. from .base import BaseRAG
  4. class StepBackRAG(BaseRAG):
  5. def __init__(self, **kwargs):
  6. super().__init__(**kwargs)
  7. def _generate_step_back_query(self, query):
  8. prompt = (
  9. "针对以下具体问题,提出一个更宏观、更通用的背景问题。\n\n"
  10. "具体问题:{}\n\n"
  11. "宏观问题:".format(query)
  12. )
  13. return self._call_llm(prompt)
  14. def retrieve(self, query, k=10):
  15. step_back_query = self._generate_step_back_query(query)
  16. q1_embedding = self.embedding_model.embed_query(query)
  17. specific_results = self.vector_store.similarity_search(q1_embedding, k // 2)
  18. q2_embedding = self.embedding_model.embed_query(step_back_query)
  19. general_results = self.vector_store.similarity_search(q2_embedding, k // 2)
  20. combined = specific_results + general_results
  21. return self._deduplicate_results(combined, k)
  22. def generate(self, query, context):
  23. prompt = (
  24. "根据以下参考文档(宏观+具体检索),回答问题。\n\n"
  25. "参考文档:\n{}\n\n"
  26. "问题:{}\n\n"
  27. "请详细回答。".format(context, query)
  28. )
  29. return self._call_llm(prompt)