# -*- coding: utf-8 -*- """Multi-Query RAG - generate multiple query variants for retrieval.""" from .base import BaseRAG class MultiQueryRAG(BaseRAG): def __init__(self, num_queries=3, **kwargs): super().__init__(**kwargs) self.num_queries = num_queries def _generate_queries(self, query): prompt = ( "请为以下问题生成{}个不同的查询变体,用于提高检索效果。\n\n" "原始问题:{}\n\n" "查询变体(用换行分隔):".format(self.num_queries, query) ) response = self._call_llm(prompt) return [q.strip() for q in response.strip().split("\n") if q.strip()][:self.num_queries] def retrieve(self, query, k=10): queries = self._generate_queries(query) all_queries = [query] + queries all_results = {} per_query_k = max(k // len(all_queries), 3) for q in all_queries: q_embedding = self.embedding_model.embed_query(q) results = self.vector_store.similarity_search(q_embedding, per_query_k) for doc, score in results: key = self._dedup_key(doc) if key not in all_results or score > all_results[key][1]: all_results[key] = (doc, score) results = list(all_results.values()) results.sort(key=lambda x: x[1], reverse=True) return self._deduplicate_results(results, k) def generate(self, query, context): prompt = ( "根据以下多角度查询检索的参考文档,回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) return self._call_llm(prompt)