# -*- coding: utf-8 -*- """Corrective RAG - correct the answer by checking against retrieved documents.""" from .base import BaseRAG from loguru import logger class CorrectiveRAG(BaseRAG): def __init__(self, **kwargs): super().__init__(**kwargs) def _verify_answer(self, query, answer, docs): context = "\n\n".join([doc.page_content[:200] for doc, _ in docs[:3]]) prompt = ( "验证以下答案是否与参考文档一致。如果不一致,请指出需要修正的部分。\n\n" "问题:{}\n\n" "参考文档(前3个):\n{}\n\n" "答案:\n{}\n\n" "验证结果(一致/不一致)及需要修正的部分:".format(query, context, answer[:500]) ) return self._call_llm(prompt) def _regenerate_answer(self, query, original_answer, verification_result, context): prompt = ( "根据验证结果,修正以下答案。\n\n" "原始答案:\n{}\n\n" "验证结果:\n{}\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "修正后的答案:".format(original_answer[:500], verification_result, context, query) ) return self._call_llm(prompt) def retrieve(self, query, k=10): query_embedding = self.embedding_model.embed_query(query) return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k) def generate(self, query, context): answer = self._call_llm( "根据以下参考文档,回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) docs = [(type('Doc', (), {'page_content': c.split('\n\n---\n\n')[0]})(), 0.0) for c in context.split('\n\n---\n\n')] verification = self._verify_answer(query, answer, docs[:3]) if "不一致" in verification: logger.info("CorrectiveRAG: answer verification failed, regenerating...") answer = self._regenerate_answer(query, answer, verification, context) return answer