# -*- coding: utf-8 -*- """Adaptive RAG - adaptively choose retrieval strategy based on query complexity.""" from .base import BaseRAG from loguru import logger class AdaptiveRAG(BaseRAG): def __init__(self, **kwargs): super().__init__(**kwargs) self.max_iterations = 3 def _analyze_query(self, query): prompt = ( "分析以下问题的复杂度。请回答simple或complex或needs_multi_step。\n\n" "问题:{}\n\n" "回答:".format(query) ) response = self._call_llm(prompt).strip().lower() if "complex" in response: return "complex" elif "multi" in response: return "multi_step" return "simple" def retrieve(self, query, k=10): complexity = self._analyze_query(query) logger.info("AdaptiveRAG: complexity={}".format(complexity)) if complexity == "simple": return self._simple_retrieve(query, k) elif complexity == "complex": return self._complex_retrieve(query, k) else: return self._multi_step_retrieve(query, k) def _simple_retrieve(self, query, k): query_embedding = self.embedding_model.embed_query(query) return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k) def _complex_retrieve(self, query, k): sub_queries = self._generate_sub_queries(query) all_results = {} for sq in sub_queries: sq_embedding = self.embedding_model.embed_query(sq) results = self.vector_store.similarity_search(sq_embedding, k // len(sub_queries) + 1) for doc, score in results: key = self._dedup_key(doc) if key not in all_results or score > all_results[key][1]: all_results[key] = (doc, score) results = list(all_results.values()) results.sort(key=lambda x: x[1], reverse=True) return self._deduplicate_results(results, k) def _multi_step_retrieve(self, query, k): all_results = [] current_query = query for i in range(self.max_iterations): q_embedding = self.embedding_model.embed_query(current_query) step_results = self.vector_store.similarity_search(q_embedding, 5) all_results.extend(step_results) if len(all_results) >= k: break refinement_prompt = ( "根据之前的查询,提出一个新的查询问题以获取更多信息。\n\n" "之前查询:{}\n\n" "新查询:".format(current_query) ) current_query = self._call_llm(refinement_prompt) all_results.sort(key=lambda x: x[1], reverse=True) return self._deduplicate_results(all_results, k) def _generate_sub_queries(self, query): prompt = ( "将以下问题分解为2-3个子问题,用逗号分隔。只返回子问题。\n\n" "原始问题:{}\n\n" "子问题:".format(query) ) response = self._call_llm(prompt) return [q.strip() for q in response.split(",") if q.strip()] def generate(self, query, context): prompt = ( "根据以下经过自适应检索的参考文档,回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) return self._call_llm(prompt)