| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- # -*- coding: utf-8 -*-
- """Self-RAG - self-reflection on retrieval and generation quality."""
- from .base import BaseRAG
- from loguru import logger
- class SelfRAG(BaseRAG):
- def __init__(self, reflection_threshold=0.5, max_reflections=2, **kwargs):
- super().__init__(**kwargs)
- self.reflection_threshold = reflection_threshold
- self.max_reflections = max_reflections
- def _assess_retrieval(self, query, docs):
- context = "\n\n".join([doc.page_content[:200] for doc, _ in docs[:3]])
- prompt = (
- "评估以下检索到的文档对于回答问题的相关性和完整性。给出0-1的分数。\n\n"
- "问题:{}\n\n"
- "检索到的文档(前3个):\n{}\n\n"
- "相关性分数:".format(query, context)
- )
- try:
- response = self._call_llm(prompt).strip()
- return float(response)
- except:
- return 0.5
- def _assess_generation(self, query, answer):
- prompt = (
- "评估以下答案对于问题的准确性和完整性。给出0-1的分数。\n\n"
- "问题:{}\n\n"
- "答案:\n{}\n\n"
- "质量分数:".format(query, answer[:300])
- )
- try:
- response = self._call_llm(prompt).strip()
- return float(response)
- except:
- return 0.5
- def _generate_critique(self, query, answer):
- prompt = (
- "以下答案有哪些不足之处?请指出需要改进的地方。\n\n"
- "问题:{}\n\n"
- "答案:\n{}\n\n"
- "不足:".format(query, answer[:300])
- )
- return self._call_llm(prompt)
- def retrieve(self, query, k=10):
- query_embedding = self.embedding_model.embed_query(query)
- initial_results = self.vector_store.similarity_search(query_embedding, k)
-
- relevance_score = self._assess_retrieval(query, initial_results)
- logger.info("SelfRAG retrieval relevance: {:.3f}".format(relevance_score))
-
- return self._deduplicate_results(initial_results, k)
- def generate(self, query, context):
- answer = self._call_llm(
- "根据以下参考文档,回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
-
- quality_score = self._assess_generation(query, answer)
- logger.info("SelfRAG generation quality: {:.3f}".format(quality_score))
-
- if quality_score < self.reflection_threshold:
- critique = self._generate_critique(query, answer)
- refined_prompt = (
- "之前的答案有以下不足:{}\n\n"
- "请重新回答以下问题,弥补这些不足。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "改进后的答案:".format(critique, context, query)
- )
- answer = self._call_llm(refined_prompt)
-
- return answer
|