# -*- coding: utf-8 -*- """Self-RAG - self-reflection on retrieval and generation quality.""" from .base import BaseRAG from loguru import logger class SelfRAG(BaseRAG): def __init__(self, reflection_threshold=0.5, max_reflections=2, **kwargs): super().__init__(**kwargs) self.reflection_threshold = reflection_threshold self.max_reflections = max_reflections def _assess_retrieval(self, query, docs): context = "\n\n".join([doc.page_content[:200] for doc, _ in docs[:3]]) prompt = ( "评估以下检索到的文档对于回答问题的相关性和完整性。给出0-1的分数。\n\n" "问题:{}\n\n" "检索到的文档(前3个):\n{}\n\n" "相关性分数:".format(query, context) ) try: response = self._call_llm(prompt).strip() return float(response) except: return 0.5 def _assess_generation(self, query, answer): prompt = ( "评估以下答案对于问题的准确性和完整性。给出0-1的分数。\n\n" "问题:{}\n\n" "答案:\n{}\n\n" "质量分数:".format(query, answer[:300]) ) try: response = self._call_llm(prompt).strip() return float(response) except: return 0.5 def _generate_critique(self, query, answer): prompt = ( "以下答案有哪些不足之处?请指出需要改进的地方。\n\n" "问题:{}\n\n" "答案:\n{}\n\n" "不足:".format(query, answer[:300]) ) return self._call_llm(prompt) def retrieve(self, query, k=10): query_embedding = self.embedding_model.embed_query(query) initial_results = self.vector_store.similarity_search(query_embedding, k) relevance_score = self._assess_retrieval(query, initial_results) logger.info("SelfRAG retrieval relevance: {:.3f}".format(relevance_score)) return self._deduplicate_results(initial_results, k) def generate(self, query, context): answer = self._call_llm( "根据以下参考文档,回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) quality_score = self._assess_generation(query, answer) logger.info("SelfRAG generation quality: {:.3f}".format(quality_score)) if quality_score < self.reflection_threshold: critique = self._generate_critique(query, answer) refined_prompt = ( "之前的答案有以下不足:{}\n\n" "请重新回答以下问题,弥补这些不足。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "改进后的答案:".format(critique, context, query) ) answer = self._call_llm(refined_prompt) return answer