| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- # -*- coding: utf-8 -*-
- """Ensemble RAG - combine multiple retrieval strategies for robust results."""
- from .base import BaseRAG
- from loguru import logger
- class EnsembleRAG(BaseRAG):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- def retrieve(self, query, k=10):
- strategies = [
- ("semantic", self._semantic_retrieve),
- ("keyword", self._keyword_retrieve),
- ]
-
- all_results = {}
- per_strategy_k = k
-
- for name, strategy in strategies:
- try:
- results = strategy(query, per_strategy_k)
- for doc, score in results:
- key = self._dedup_key(doc)
- if key not in all_results:
- all_results[key] = (doc, 0.0)
- all_results[key] = (doc, all_results[key][1] + score)
- logger.info("Ensemble strategy '{}' returned {} results".format(name, len(results)))
- except Exception as e:
- logger.warning("Ensemble strategy '{}' failed: {}".format(name, e))
-
- results = list(all_results.values())
- results.sort(key=lambda x: x[1], reverse=True)
- return self._deduplicate_results(results, k)
- def _semantic_retrieve(self, query, k):
- query_embedding = self.embedding_model.embed_query(query)
- return self.vector_store.similarity_search(query_embedding, k)
- def _keyword_retrieve(self, query, k):
- query_embedding = self.embedding_model.embed_query(query)
- results = self.vector_store.similarity_search(query_embedding, k * 5)
- keyword_results = []
- for doc, score in results:
- query_words = set(query.split())
- doc_words = set(doc.page_content.split())
- overlap = len(query_words & doc_words) / len(query_words) if query_words else 0
- if overlap > 0.3:
- keyword_results.append((doc, score * overlap))
-
- keyword_results.sort(key=lambda x: x[1], reverse=True)
- return keyword_results[:k]
- def generate(self, query, context):
- prompt = (
- "根据以下多策略集成的参考文档,回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
- return self._call_llm(prompt)
|