hybrid_search_rag.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # -*- coding: utf-8 -*-
  2. """Hybrid Search RAG - combines semantic and keyword (BM25) retrieval."""
  3. import numpy as np
  4. from .base import BaseRAG
  5. class HybridSearchRAG(BaseRAG):
  6. def __init__(self, bm25_weight=0.3, semantic_weight=0.7, **kwargs):
  7. super().__init__(**kwargs)
  8. self.bm25_weight = bm25_weight
  9. self.semantic_weight = semantic_weight
  10. self.bm25 = None
  11. self._all_texts = []
  12. self._all_documents = []
  13. def index_documents(self, documents):
  14. texts = [doc.page_content for doc in documents]
  15. embeddings = self.embedding_model.embed_documents(texts)
  16. self.vector_store.add_documents(documents, embeddings)
  17. self._all_documents = documents
  18. self._all_texts = [text.split() for text in texts]
  19. from rank_bm25 import BM25Okapi
  20. self.bm25 = BM25Okapi(self._all_texts)
  21. def retrieve(self, query, k=10):
  22. query_embedding = self.embedding_model.embed_query(query)
  23. semantic_results = self.vector_store.similarity_search(query_embedding, k * 2)
  24. bm25_scores = np.zeros(len(self._all_documents))
  25. query_tokens = query.split()
  26. bm25_scores = self.bm25.get_scores(query_tokens)
  27. semantic_scores = np.zeros(len(self._all_documents))
  28. for doc, score in semantic_results:
  29. for i, ref_doc in enumerate(self._all_documents):
  30. if doc.page_content == ref_doc.page_content:
  31. semantic_scores[i] = score
  32. break
  33. hybrid_scores = self.semantic_weight * semantic_scores + self.bm25_weight * bm25_scores
  34. top_indices = np.argsort(hybrid_scores)[::-1][:k]
  35. results = []
  36. for idx in top_indices:
  37. if hybrid_scores[idx] > 0:
  38. results.append((self._all_documents[idx], float(hybrid_scores[idx])))
  39. return self._deduplicate_results(results, k)
  40. def generate(self, query, context):
  41. prompt = (
  42. "根据以下混合检索(语义+关键词)的参考文档,回答问题。\n\n"
  43. "参考文档:\n{}\n\n"
  44. "问题:{}\n\n"
  45. "请详细回答。".format(context, query)
  46. )
  47. return self._call_llm(prompt)