self_rag.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # -*- coding: utf-8 -*-
  2. """Self-RAG - self-reflection on retrieval and generation quality."""
  3. from .base import BaseRAG
  4. from loguru import logger
  5. class SelfRAG(BaseRAG):
  6. def __init__(self, reflection_threshold=0.5, max_reflections=2, **kwargs):
  7. super().__init__(**kwargs)
  8. self.reflection_threshold = reflection_threshold
  9. self.max_reflections = max_reflections
  10. def _assess_retrieval(self, query, docs):
  11. context = "\n\n".join([doc.page_content[:200] for doc, _ in docs[:3]])
  12. prompt = (
  13. "评估以下检索到的文档对于回答问题的相关性和完整性。给出0-1的分数。\n\n"
  14. "问题:{}\n\n"
  15. "检索到的文档(前3个):\n{}\n\n"
  16. "相关性分数:".format(query, context)
  17. )
  18. try:
  19. response = self._call_llm(prompt).strip()
  20. return float(response)
  21. except:
  22. return 0.5
  23. def _assess_generation(self, query, answer):
  24. prompt = (
  25. "评估以下答案对于问题的准确性和完整性。给出0-1的分数。\n\n"
  26. "问题:{}\n\n"
  27. "答案:\n{}\n\n"
  28. "质量分数:".format(query, answer[:300])
  29. )
  30. try:
  31. response = self._call_llm(prompt).strip()
  32. return float(response)
  33. except:
  34. return 0.5
  35. def _generate_critique(self, query, answer):
  36. prompt = (
  37. "以下答案有哪些不足之处?请指出需要改进的地方。\n\n"
  38. "问题:{}\n\n"
  39. "答案:\n{}\n\n"
  40. "不足:".format(query, answer[:300])
  41. )
  42. return self._call_llm(prompt)
  43. def retrieve(self, query, k=10):
  44. query_embedding = self.embedding_model.embed_query(query)
  45. initial_results = self.vector_store.similarity_search(query_embedding, k)
  46. relevance_score = self._assess_retrieval(query, initial_results)
  47. logger.info("SelfRAG retrieval relevance: {:.3f}".format(relevance_score))
  48. return self._deduplicate_results(initial_results, k)
  49. def generate(self, query, context):
  50. answer = self._call_llm(
  51. "根据以下参考文档,回答问题。\n\n"
  52. "参考文档:\n{}\n\n"
  53. "问题:{}\n\n"
  54. "请详细回答。".format(context, query)
  55. )
  56. quality_score = self._assess_generation(query, answer)
  57. logger.info("SelfRAG generation quality: {:.3f}".format(quality_score))
  58. if quality_score < self.reflection_threshold:
  59. critique = self._generate_critique(query, answer)
  60. refined_prompt = (
  61. "之前的答案有以下不足:{}\n\n"
  62. "请重新回答以下问题,弥补这些不足。\n\n"
  63. "参考文档:\n{}\n\n"
  64. "问题:{}\n\n"
  65. "改进后的答案:".format(critique, context, query)
  66. )
  67. answer = self._call_llm(refined_prompt)
  68. return answer