corrective_rag.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # -*- coding: utf-8 -*-
  2. """Corrective RAG - correct the answer by checking against retrieved documents."""
  3. from .base import BaseRAG
  4. from loguru import logger
  5. class CorrectiveRAG(BaseRAG):
  6. def __init__(self, **kwargs):
  7. super().__init__(**kwargs)
  8. def _verify_answer(self, query, answer, docs):
  9. context = "\n\n".join([doc.page_content[:200] for doc, _ in docs[:3]])
  10. prompt = (
  11. "验证以下答案是否与参考文档一致。如果不一致,请指出需要修正的部分。\n\n"
  12. "问题:{}\n\n"
  13. "参考文档(前3个):\n{}\n\n"
  14. "答案:\n{}\n\n"
  15. "验证结果(一致/不一致)及需要修正的部分:".format(query, context, answer[:500])
  16. )
  17. return self._call_llm(prompt)
  18. def _regenerate_answer(self, query, original_answer, verification_result, context):
  19. prompt = (
  20. "根据验证结果,修正以下答案。\n\n"
  21. "原始答案:\n{}\n\n"
  22. "验证结果:\n{}\n\n"
  23. "参考文档:\n{}\n\n"
  24. "问题:{}\n\n"
  25. "修正后的答案:".format(original_answer[:500], verification_result, context, query)
  26. )
  27. return self._call_llm(prompt)
  28. def retrieve(self, query, k=10):
  29. query_embedding = self.embedding_model.embed_query(query)
  30. return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
  31. def generate(self, query, context):
  32. answer = self._call_llm(
  33. "根据以下参考文档,回答问题。\n\n"
  34. "参考文档:\n{}\n\n"
  35. "问题:{}\n\n"
  36. "请详细回答。".format(context, query)
  37. )
  38. docs = [(type('Doc', (), {'page_content': c.split('\n\n---\n\n')[0]})(), 0.0) for c in context.split('\n\n---\n\n')]
  39. verification = self._verify_answer(query, answer, docs[:3])
  40. if "不一致" in verification:
  41. logger.info("CorrectiveRAG: answer verification failed, regenerating...")
  42. answer = self._regenerate_answer(query, answer, verification, context)
  43. return answer