# -*- coding: utf-8 -*- """Ensemble RAG - combine multiple retrieval strategies for robust results.""" from .base import BaseRAG from loguru import logger class EnsembleRAG(BaseRAG): def __init__(self, **kwargs): super().__init__(**kwargs) def retrieve(self, query, k=10): strategies = [ ("semantic", self._semantic_retrieve), ("keyword", self._keyword_retrieve), ] all_results = {} per_strategy_k = k for name, strategy in strategies: try: results = strategy(query, per_strategy_k) for doc, score in results: key = self._dedup_key(doc) if key not in all_results: all_results[key] = (doc, 0.0) all_results[key] = (doc, all_results[key][1] + score) logger.info("Ensemble strategy '{}' returned {} results".format(name, len(results))) except Exception as e: logger.warning("Ensemble strategy '{}' failed: {}".format(name, e)) results = list(all_results.values()) results.sort(key=lambda x: x[1], reverse=True) return self._deduplicate_results(results, k) def _semantic_retrieve(self, query, k): query_embedding = self.embedding_model.embed_query(query) return self.vector_store.similarity_search(query_embedding, k) def _keyword_retrieve(self, query, k): query_embedding = self.embedding_model.embed_query(query) results = self.vector_store.similarity_search(query_embedding, k * 5) keyword_results = [] for doc, score in results: query_words = set(query.split()) doc_words = set(doc.page_content.split()) overlap = len(query_words & doc_words) / len(query_words) if query_words else 0 if overlap > 0.3: keyword_results.append((doc, score * overlap)) keyword_results.sort(key=lambda x: x[1], reverse=True) return keyword_results[:k] def generate(self, query, context): prompt = ( "根据以下多策略集成的参考文档,回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) return self._call_llm(prompt)