| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- # -*- coding: utf-8 -*-
- """Adaptive RAG - adaptively choose retrieval strategy based on query complexity."""
- from .base import BaseRAG
- from loguru import logger
- class AdaptiveRAG(BaseRAG):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.max_iterations = 3
- def _analyze_query(self, query):
- prompt = (
- "分析以下问题的复杂度。请回答simple或complex或needs_multi_step。\n\n"
- "问题:{}\n\n"
- "回答:".format(query)
- )
- response = self._call_llm(prompt).strip().lower()
- if "complex" in response:
- return "complex"
- elif "multi" in response:
- return "multi_step"
- return "simple"
- def retrieve(self, query, k=10):
- complexity = self._analyze_query(query)
- logger.info("AdaptiveRAG: complexity={}".format(complexity))
-
- if complexity == "simple":
- return self._simple_retrieve(query, k)
- elif complexity == "complex":
- return self._complex_retrieve(query, k)
- else:
- return self._multi_step_retrieve(query, k)
- def _simple_retrieve(self, query, k):
- query_embedding = self.embedding_model.embed_query(query)
- return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
- def _complex_retrieve(self, query, k):
- sub_queries = self._generate_sub_queries(query)
- all_results = {}
-
- for sq in sub_queries:
- sq_embedding = self.embedding_model.embed_query(sq)
- results = self.vector_store.similarity_search(sq_embedding, k // len(sub_queries) + 1)
- for doc, score in results:
- key = self._dedup_key(doc)
- if key not in all_results or score > all_results[key][1]:
- all_results[key] = (doc, score)
-
- results = list(all_results.values())
- results.sort(key=lambda x: x[1], reverse=True)
- return self._deduplicate_results(results, k)
- def _multi_step_retrieve(self, query, k):
- all_results = []
- current_query = query
-
- for i in range(self.max_iterations):
- q_embedding = self.embedding_model.embed_query(current_query)
- step_results = self.vector_store.similarity_search(q_embedding, 5)
- all_results.extend(step_results)
-
- if len(all_results) >= k:
- break
-
- refinement_prompt = (
- "根据之前的查询,提出一个新的查询问题以获取更多信息。\n\n"
- "之前查询:{}\n\n"
- "新查询:".format(current_query)
- )
- current_query = self._call_llm(refinement_prompt)
-
- all_results.sort(key=lambda x: x[1], reverse=True)
- return self._deduplicate_results(all_results, k)
- def _generate_sub_queries(self, query):
- prompt = (
- "将以下问题分解为2-3个子问题,用逗号分隔。只返回子问题。\n\n"
- "原始问题:{}\n\n"
- "子问题:".format(query)
- )
- response = self._call_llm(prompt)
- return [q.strip() for q in response.split(",") if q.strip()]
- def generate(self, query, context):
- prompt = (
- "根据以下经过自适应检索的参考文档,回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
- return self._call_llm(prompt)
|