| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- # -*- coding: utf-8 -*-
- """Corrective RAG - correct the answer by checking against retrieved documents."""
- from .base import BaseRAG
- from loguru import logger
- class CorrectiveRAG(BaseRAG):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- def _verify_answer(self, query, answer, docs):
- context = "\n\n".join([doc.page_content[:200] for doc, _ in docs[:3]])
- prompt = (
- "验证以下答案是否与参考文档一致。如果不一致,请指出需要修正的部分。\n\n"
- "问题:{}\n\n"
- "参考文档(前3个):\n{}\n\n"
- "答案:\n{}\n\n"
- "验证结果(一致/不一致)及需要修正的部分:".format(query, context, answer[:500])
- )
- return self._call_llm(prompt)
- def _regenerate_answer(self, query, original_answer, verification_result, context):
- prompt = (
- "根据验证结果,修正以下答案。\n\n"
- "原始答案:\n{}\n\n"
- "验证结果:\n{}\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "修正后的答案:".format(original_answer[:500], verification_result, context, query)
- )
- return self._call_llm(prompt)
- def retrieve(self, query, k=10):
- query_embedding = self.embedding_model.embed_query(query)
- return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
- def generate(self, query, context):
- answer = self._call_llm(
- "根据以下参考文档,回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
-
- docs = [(type('Doc', (), {'page_content': c.split('\n\n---\n\n')[0]})(), 0.0) for c in context.split('\n\n---\n\n')]
-
- verification = self._verify_answer(query, answer, docs[:3])
-
- if "不一致" in verification:
- logger.info("CorrectiveRAG: answer verification failed, regenerating...")
- answer = self._regenerate_answer(query, answer, verification, context)
-
- return answer
|