| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- # -*- coding: utf-8 -*-
- """Hybrid Search RAG - combines semantic and keyword (BM25) retrieval."""
- import numpy as np
- from .base import BaseRAG
- class HybridSearchRAG(BaseRAG):
- def __init__(self, bm25_weight=0.3, semantic_weight=0.7, **kwargs):
- super().__init__(**kwargs)
- self.bm25_weight = bm25_weight
- self.semantic_weight = semantic_weight
- self.bm25 = None
- self._all_texts = []
- self._all_documents = []
- def index_documents(self, documents):
- texts = [doc.page_content for doc in documents]
- embeddings = self.embedding_model.embed_documents(texts)
- self.vector_store.add_documents(documents, embeddings)
-
- self._all_documents = documents
- self._all_texts = [text.split() for text in texts]
-
- from rank_bm25 import BM25Okapi
- self.bm25 = BM25Okapi(self._all_texts)
- def retrieve(self, query, k=10):
- query_embedding = self.embedding_model.embed_query(query)
- semantic_results = self.vector_store.similarity_search(query_embedding, k * 2)
-
- bm25_scores = np.zeros(len(self._all_documents))
- query_tokens = query.split()
- bm25_scores = self.bm25.get_scores(query_tokens)
-
- semantic_scores = np.zeros(len(self._all_documents))
- for doc, score in semantic_results:
- for i, ref_doc in enumerate(self._all_documents):
- if doc.page_content == ref_doc.page_content:
- semantic_scores[i] = score
- break
-
- hybrid_scores = self.semantic_weight * semantic_scores + self.bm25_weight * bm25_scores
-
- top_indices = np.argsort(hybrid_scores)[::-1][:k]
- results = []
- for idx in top_indices:
- if hybrid_scores[idx] > 0:
- results.append((self._all_documents[idx], float(hybrid_scores[idx])))
-
- return self._deduplicate_results(results, k)
- def generate(self, query, context):
- prompt = (
- "根据以下混合检索(语义+关键词)的参考文档,回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
- return self._call_llm(prompt)
|