| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- # -*- coding: utf-8 -*-
- """Multi-Query RAG - generate multiple query variants for retrieval."""
- from .base import BaseRAG
- class MultiQueryRAG(BaseRAG):
- def __init__(self, num_queries=3, **kwargs):
- super().__init__(**kwargs)
- self.num_queries = num_queries
- def _generate_queries(self, query):
- prompt = (
- "请为以下问题生成{}个不同的查询变体,用于提高检索效果。\n\n"
- "原始问题:{}\n\n"
- "查询变体(用换行分隔):".format(self.num_queries, query)
- )
- response = self._call_llm(prompt)
- return [q.strip() for q in response.strip().split("\n") if q.strip()][:self.num_queries]
- def retrieve(self, query, k=10):
- queries = self._generate_queries(query)
- all_queries = [query] + queries
-
- all_results = {}
- per_query_k = max(k // len(all_queries), 3)
-
- for q in all_queries:
- q_embedding = self.embedding_model.embed_query(q)
- results = self.vector_store.similarity_search(q_embedding, per_query_k)
- for doc, score in results:
- key = self._dedup_key(doc)
- if key not in all_results or score > all_results[key][1]:
- all_results[key] = (doc, score)
-
- results = list(all_results.values())
- results.sort(key=lambda x: x[1], reverse=True)
- return self._deduplicate_results(results, k)
- def generate(self, query, context):
- prompt = (
- "根据以下多角度查询检索的参考文档,回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
- return self._call_llm(prompt)
|