test_retrieval_dedup.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # -*- coding: utf-8 -*-
  2. """Focused tests for content-level retrieval deduplication."""
  3. import os
  4. import sys
  5. ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  6. if ROOT_DIR not in sys.path:
  7. sys.path.insert(0, ROOT_DIR)
  8. from bdirag.document_processor import Document
  9. from bdirag.rag_methods.bm25_html_tree_rag import BM25HTMLTreeRAG
  10. from bdirag.rag_methods.bm25_rag import BM25RAG
  11. from bdirag.rag_methods.dedup import deduplicate_ranked_results
  12. from bdirag.rag_methods.ensemble_rag import EnsembleRAG
  13. from examples.rag_test_utils import FakeEmbedding, SimpleVectorStore, install_rank_bm25_fallback
  14. def test_deduplicate_ranked_results_keeps_highest_score():
  15. low = Document(page_content=" Duplicate content ", metadata={"source": "low"})
  16. high = Document(page_content="Duplicate content", metadata={"source": "high"})
  17. other = Document(page_content="Other content", metadata={"source": "other"})
  18. results = deduplicate_ranked_results([(low, 0.1), (other, 0.2), (high, 0.9)], k=10)
  19. assert len(results) == 2
  20. assert results[0][0].metadata["source"] == "high"
  21. assert results[0][1] == 0.9
  22. def test_deduplicate_ranked_results_keeps_first_on_score_tie():
  23. first = Document(page_content="Same content", metadata={"source": "first"})
  24. second = Document(page_content=" Same content ", metadata={"source": "second"})
  25. results = deduplicate_ranked_results([(first, 0.5), (second, 0.5)], k=10)
  26. assert len(results) == 1
  27. assert results[0][0].metadata["source"] == "first"
  28. def test_bm25_retrieve_deduplicates_equal_content_documents():
  29. install_rank_bm25_fallback()
  30. rag = BM25RAG()
  31. docs = [
  32. Document(page_content="alpha beta project budget", metadata={"source": "a"}),
  33. Document(page_content="alpha beta project budget", metadata={"source": "b"}),
  34. Document(page_content="alpha delivery schedule", metadata={"source": "c"}),
  35. Document(page_content="gamma warranty terms", metadata={"source": "d"}),
  36. Document(page_content="delta payment terms", metadata={"source": "e"}),
  37. ]
  38. rag.index_documents(docs)
  39. results = rag.retrieve("alpha beta", k=3)
  40. contents = [doc.page_content for doc, _ in results]
  41. assert contents.count("alpha beta project budget") == 1
  42. assert len(contents) == len(set(contents))
  43. def test_ensemble_retrieve_merges_duplicate_content_from_distinct_objects():
  44. docs = [
  45. Document(page_content="alpha beta project budget", metadata={"source": "a"}),
  46. Document(page_content="alpha beta project budget", metadata={"source": "b"}),
  47. Document(page_content="alpha delivery schedule", metadata={"source": "c"}),
  48. ]
  49. rag = EnsembleRAG(embedding_model=FakeEmbedding(), vector_store=SimpleVectorStore())
  50. embeddings = rag.embedding_model.embed_documents([doc.page_content for doc in docs])
  51. rag.vector_store.add_documents(docs, embeddings)
  52. results = rag.retrieve("alpha beta", k=3)
  53. contents = [doc.page_content for doc, _ in results]
  54. assert contents.count("alpha beta project budget") == 1
  55. assert len(contents) == len(set(contents))
  56. def test_html_tree_query_deduplicates_formatted_documents():
  57. rag = BM25HTMLTreeRAG()
  58. node_a = {"type": "p", "sentence_title_text": "A"}
  59. node_b = {"type": "p", "sentence_title_text": "B"}
  60. rag.retrieve_subtrees = lambda query, k: [
  61. (node_a, 0.7, "Repeated subtree text"),
  62. (node_b, 0.9, " Repeated subtree text "),
  63. ]
  64. rag.get_node_path = lambda node: node["sentence_title_text"]
  65. results = rag.query("repeated", k=5)
  66. assert len(results) == 1
  67. assert results[0][0].metadata["title"] == "B"
  68. assert results[0][1] == 0.9