ensemble_rag.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # -*- coding: utf-8 -*-
  2. """Ensemble RAG - combine multiple retrieval strategies for robust results."""
  3. from .base import BaseRAG
  4. from loguru import logger
  5. class EnsembleRAG(BaseRAG):
  6. def __init__(self, **kwargs):
  7. super().__init__(**kwargs)
  8. def retrieve(self, query, k=10):
  9. strategies = [
  10. ("semantic", self._semantic_retrieve),
  11. ("keyword", self._keyword_retrieve),
  12. ]
  13. all_results = {}
  14. per_strategy_k = k
  15. for name, strategy in strategies:
  16. try:
  17. results = strategy(query, per_strategy_k)
  18. for doc, score in results:
  19. key = self._dedup_key(doc)
  20. if key not in all_results:
  21. all_results[key] = (doc, 0.0)
  22. all_results[key] = (doc, all_results[key][1] + score)
  23. logger.info("Ensemble strategy '{}' returned {} results".format(name, len(results)))
  24. except Exception as e:
  25. logger.warning("Ensemble strategy '{}' failed: {}".format(name, e))
  26. results = list(all_results.values())
  27. results.sort(key=lambda x: x[1], reverse=True)
  28. return self._deduplicate_results(results, k)
  29. def _semantic_retrieve(self, query, k):
  30. query_embedding = self.embedding_model.embed_query(query)
  31. return self.vector_store.similarity_search(query_embedding, k)
  32. def _keyword_retrieve(self, query, k):
  33. query_embedding = self.embedding_model.embed_query(query)
  34. results = self.vector_store.similarity_search(query_embedding, k * 5)
  35. keyword_results = []
  36. for doc, score in results:
  37. query_words = set(query.split())
  38. doc_words = set(doc.page_content.split())
  39. overlap = len(query_words & doc_words) / len(query_words) if query_words else 0
  40. if overlap > 0.3:
  41. keyword_results.append((doc, score * overlap))
  42. keyword_results.sort(key=lambda x: x[1], reverse=True)
  43. return keyword_results[:k]
  44. def generate(self, query, context):
  45. prompt = (
  46. "根据以下多策略集成的参考文档,回答问题。\n\n"
  47. "参考文档:\n{}\n\n"
  48. "问题:{}\n\n"
  49. "请详细回答。".format(context, query)
  50. )
  51. return self._call_llm(prompt)