multi_query_rag.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # -*- coding: utf-8 -*-
  2. """Multi-Query RAG - generate multiple query variants for retrieval."""
  3. from .base import BaseRAG
  4. class MultiQueryRAG(BaseRAG):
  5. def __init__(self, num_queries=3, **kwargs):
  6. super().__init__(**kwargs)
  7. self.num_queries = num_queries
  8. def _generate_queries(self, query):
  9. prompt = (
  10. "请为以下问题生成{}个不同的查询变体,用于提高检索效果。\n\n"
  11. "原始问题:{}\n\n"
  12. "查询变体(用换行分隔):".format(self.num_queries, query)
  13. )
  14. response = self._call_llm(prompt)
  15. return [q.strip() for q in response.strip().split("\n") if q.strip()][:self.num_queries]
  16. def retrieve(self, query, k=10):
  17. queries = self._generate_queries(query)
  18. all_queries = [query] + queries
  19. all_results = {}
  20. per_query_k = max(k // len(all_queries), 3)
  21. for q in all_queries:
  22. q_embedding = self.embedding_model.embed_query(q)
  23. results = self.vector_store.similarity_search(q_embedding, per_query_k)
  24. for doc, score in results:
  25. key = self._dedup_key(doc)
  26. if key not in all_results or score > all_results[key][1]:
  27. all_results[key] = (doc, score)
  28. results = list(all_results.values())
  29. results.sort(key=lambda x: x[1], reverse=True)
  30. return self._deduplicate_results(results, k)
  31. def generate(self, query, context):
  32. prompt = (
  33. "根据以下多角度查询检索的参考文档,回答问题。\n\n"
  34. "参考文档:\n{}\n\n"
  35. "问题:{}\n\n"
  36. "请详细回答。".format(context, query)
  37. )
  38. return self._call_llm(prompt)