adaptive_rag.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # -*- coding: utf-8 -*-
  2. """Adaptive RAG - adaptively choose retrieval strategy based on query complexity."""
  3. from .base import BaseRAG
  4. from loguru import logger
  5. class AdaptiveRAG(BaseRAG):
  6. def __init__(self, **kwargs):
  7. super().__init__(**kwargs)
  8. self.max_iterations = 3
  9. def _analyze_query(self, query):
  10. prompt = (
  11. "分析以下问题的复杂度。请回答simple或complex或needs_multi_step。\n\n"
  12. "问题:{}\n\n"
  13. "回答:".format(query)
  14. )
  15. response = self._call_llm(prompt).strip().lower()
  16. if "complex" in response:
  17. return "complex"
  18. elif "multi" in response:
  19. return "multi_step"
  20. return "simple"
  21. def retrieve(self, query, k=10):
  22. complexity = self._analyze_query(query)
  23. logger.info("AdaptiveRAG: complexity={}".format(complexity))
  24. if complexity == "simple":
  25. return self._simple_retrieve(query, k)
  26. elif complexity == "complex":
  27. return self._complex_retrieve(query, k)
  28. else:
  29. return self._multi_step_retrieve(query, k)
  30. def _simple_retrieve(self, query, k):
  31. query_embedding = self.embedding_model.embed_query(query)
  32. return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
  33. def _complex_retrieve(self, query, k):
  34. sub_queries = self._generate_sub_queries(query)
  35. all_results = {}
  36. for sq in sub_queries:
  37. sq_embedding = self.embedding_model.embed_query(sq)
  38. results = self.vector_store.similarity_search(sq_embedding, k // len(sub_queries) + 1)
  39. for doc, score in results:
  40. key = self._dedup_key(doc)
  41. if key not in all_results or score > all_results[key][1]:
  42. all_results[key] = (doc, score)
  43. results = list(all_results.values())
  44. results.sort(key=lambda x: x[1], reverse=True)
  45. return self._deduplicate_results(results, k)
  46. def _multi_step_retrieve(self, query, k):
  47. all_results = []
  48. current_query = query
  49. for i in range(self.max_iterations):
  50. q_embedding = self.embedding_model.embed_query(current_query)
  51. step_results = self.vector_store.similarity_search(q_embedding, 5)
  52. all_results.extend(step_results)
  53. if len(all_results) >= k:
  54. break
  55. refinement_prompt = (
  56. "根据之前的查询,提出一个新的查询问题以获取更多信息。\n\n"
  57. "之前查询:{}\n\n"
  58. "新查询:".format(current_query)
  59. )
  60. current_query = self._call_llm(refinement_prompt)
  61. all_results.sort(key=lambda x: x[1], reverse=True)
  62. return self._deduplicate_results(all_results, k)
  63. def _generate_sub_queries(self, query):
  64. prompt = (
  65. "将以下问题分解为2-3个子问题,用逗号分隔。只返回子问题。\n\n"
  66. "原始问题:{}\n\n"
  67. "子问题:".format(query)
  68. )
  69. response = self._call_llm(prompt)
  70. return [q.strip() for q in response.split(",") if q.strip()]
  71. def generate(self, query, context):
  72. prompt = (
  73. "根据以下经过自适应检索的参考文档,回答问题。\n\n"
  74. "参考文档:\n{}\n\n"
  75. "问题:{}\n\n"
  76. "请详细回答。".format(context, query)
  77. )
  78. return self._call_llm(prompt)