test_bm25.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. """
  2. BM25 RAG method standalone test
  3. Tests BM25 keyword retrieval performance on bidding documents
  4. """
  5. import sys
  6. import os
  7. import time
  8. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  9. from bdirag.document_processor import Document
  10. from bdirag.rag_methods.bm25_rag import BM25RAG
  11. from bdirag.rag_methods.naive_rag import NaiveRAG
  12. from bdirag.rag_methods.bm25_html_tree_rag import BM25HTMLTreeRAG
  13. from examples.sample_data import SAMPLE_BIDDING_DOCS
  14. SAMPLE_HTML = """
  15. <html>
  16. <body>
  17. <h1>XX市第一人民医院医疗设备招标公告</h1>
  18. <div>
  19. <h2>一、项目概况</h2>
  20. <p>项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目</p>
  21. <p>项目编号:XX-ZB-2024-001</p>
  22. <p>预算金额:500万元</p>
  23. <p>采购内容:彩色多普勒超声诊断仪 1台</p>
  24. </div>
  25. <div>
  26. <h2>二、投标人资格要求</h2>
  27. <p>1. 具有独立承担民事责任的能力</p>
  28. <p>2. 具有有效的医疗器械经营许可证</p>
  29. <p>3. 近三年内无不良经营记录</p>
  30. <p>4. 投标保证金:人民币5万元整</p>
  31. </div>
  32. <div>
  33. <h2>三、技术需求</h2>
  34. <p>1. 彩色多普勒超声诊断仪技术参数</p>
  35. <p> - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头</p>
  36. <p> - 显示屏:≥19英寸高清液晶显示器</p>
  37. <p> - 质保期:整机质保三年</p>
  38. <p>2. 交货时间:合同签订后60天内交货</p>
  39. <p>3. 交货地点:XX市第一人民医院设备科</p>
  40. </div>
  41. <div>
  42. <h2>四、评标方法</h2>
  43. <p>采用综合评分法:</p>
  44. <p> - 技术部分:60分</p>
  45. <p> - 商务部分:30分</p>
  46. <p> - 价格部分:10分</p>
  47. </div>
  48. <div>
  49. <h2>五、付款方式</h2>
  50. <p>合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%</p>
  51. </div>
  52. <div>
  53. <h2>六、投标截止时间</h2>
  54. <p>投标截止时间:2024年12月31日上午9:30</p>
  55. <p>开标时间:同投标截止时间</p>
  56. <p>投标文件递交地点:XX市公共资源交易中心</p>
  57. </div>
  58. </body>
  59. </html>
  60. """
  61. def test_bm25_retrieval():
  62. print("=" * 60)
  63. print("BM25 RAG - Standalone Test (Plain Text)")
  64. print("=" * 60)
  65. print("\n[1/2] Preparing documents...")
  66. documents = [
  67. Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
  68. for doc in SAMPLE_BIDDING_DOCS
  69. ]
  70. print(" Prepared {} documents".format(len(documents)))
  71. print("\n[2/2] Initializing BM25RAG...")
  72. bm25_rag = BM25RAG()
  73. bm25_rag.index_documents(documents)
  74. print(" BM25 index built successfully")
  75. test_queries = [
  76. ("预算金额", ["budget", "Budget", "预算"]),
  77. ("投标保证金", ["bid bond", "Bid Bond", "保证金"]),
  78. ("资质要求", ["qualification", "Qualification", "资质"]),
  79. ("评标方法", ["evaluation", "Evaluation", "评标"]),
  80. ("质保期", ["warranty", "Warranty", "质保"]),
  81. ("付款方式", ["payment", "Payment", "付款"]),
  82. ("项目编号 XX-ZB", ["XX-ZB", "Project Code"]),
  83. ("交货时间", ["delivery", "Delivery", "交货"]),
  84. ]
  85. print("\n" + "=" * 60)
  86. print("BM25 Retrieval Test Results")
  87. print("=" * 60)
  88. for query, keywords in test_queries:
  89. print("\nQuery: {}".format(query))
  90. print("-" * 60)
  91. start = time.time()
  92. results = bm25_rag.retrieve(query, k=3)
  93. elapsed = time.time() - start
  94. print(" Retrieved {} documents in {:.4f}s".format(len(results), elapsed))
  95. # 评估相关性
  96. relevant_count = 0
  97. for i, (doc, score) in enumerate(results, 1):
  98. title = doc.metadata.get("title", "Unknown")
  99. preview = doc.page_content[:80].replace("\n", " ")
  100. # 检查是否包含关键词
  101. is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
  102. if is_relevant:
  103. relevant_count += 1
  104. marker = "[OK]"
  105. else:
  106. marker = "[ --]"
  107. print(" [{}] {} {} (Score: {:.4f})".format(i, marker, title, score))
  108. print(" Preview: {}...".format(preview))
  109. precision = relevant_count / len(results) if results else 0
  110. print(" Precision@3: {:.1%}".format(precision))
  111. print("\n\nBM25 plain text test complete!")
  112. def test_bm25_html_tree():
  113. print("\n\n" + "=" * 60)
  114. print("BM25 HTML Tree RAG - Standalone Test")
  115. print("=" * 60)
  116. print("\n[1/2] Parsing HTML and building tree...")
  117. tree_rag = BM25HTMLTreeRAG()
  118. tree_rag.build_index(SAMPLE_HTML)
  119. print(" HTML tree index built successfully")
  120. test_queries = [
  121. ("预算金额", ["预算", "Budget"]),
  122. ("投标保证金", ["保证金", "Bond"]),
  123. ("技术参数 探头", ["探头", "technical"]),
  124. ("评标方法 综合评分", ["评标", "综合评分"]),
  125. ("质保期", ["质保", "Warranty"]),
  126. ("付款方式", ["付款", "Payment"]),
  127. ("交货时间", ["交货", "Delivery"]),
  128. ]
  129. print("\n[2/2] Testing BM25 HTML Tree Retrieval...")
  130. print("=" * 60)
  131. for query, keywords in test_queries:
  132. print("\nQuery: {}".format(query))
  133. print("-" * 60)
  134. start = time.time()
  135. results = tree_rag.query(query, k=3)
  136. elapsed = time.time() - start
  137. print(" Retrieved {} subtrees in {:.4f}s".format(len(results), elapsed))
  138. relevant_count = 0
  139. for i, (doc, score) in enumerate(results, 1):
  140. path = doc.metadata.get("path", "")
  141. title = doc.metadata.get("title", "")
  142. # 检查相关性
  143. is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
  144. if is_relevant:
  145. relevant_count += 1
  146. marker = "[OK]"
  147. else:
  148. marker = "[ --]"
  149. print(" [{}] {} Score: {:.4f}".format(i, marker, score))
  150. print(" Path: {}".format(path))
  151. print(" Content: {}...".format(doc.page_content[:120].replace("\n", " ")))
  152. precision = relevant_count / len(results) if results else 0
  153. print(" Precision@3: {:.1%}".format(precision))
  154. print("\n\nBM25 HTML Tree test complete!")
  155. def compare_bm25_vs_html_tree():
  156. print("\n\n" + "=" * 60)
  157. print("BM25 Plain Text vs BM25 HTML Tree Comparison")
  158. print("=" * 60)
  159. documents = [
  160. Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
  161. for doc in SAMPLE_BIDDING_DOCS
  162. ]
  163. bm25_rag = BM25RAG()
  164. bm25_rag.index_documents(documents)
  165. tree_rag = BM25HTMLTreeRAG()
  166. tree_rag.build_index(SAMPLE_HTML)
  167. test_queries = [
  168. ("预算金额", ["预算", "Budget"]),
  169. ("质保期", ["质保", "Warranty"]),
  170. ("评标方法", ["评标", "Evaluation"]),
  171. ]
  172. print("\n{:<15} | {:>15} | {:>15}".format("Query", "BM25 Docs", "HTML Tree Docs"))
  173. print("-" * 60)
  174. bm25_total_precision = 0
  175. tree_total_precision = 0
  176. num_queries = len(test_queries)
  177. for query, keywords in test_queries:
  178. t0 = time.time()
  179. bm25_results = bm25_rag.retrieve(query, k=3)
  180. bm25_time = time.time() - t0
  181. # 计算BM25的precision
  182. bm25_relevant = sum(1 for doc, _ in bm25_results if any(kw.lower() in doc.page_content.lower() for kw in keywords))
  183. bm25_precision = bm25_relevant / len(bm25_results) if bm25_results else 0
  184. t1 = time.time()
  185. tree_results = tree_rag.query(query, k=3)
  186. tree_time = time.time() - t1
  187. # 计算HTML Tree的precision - 考虑父节点上下文
  188. def is_relevant_with_context(doc, keywords):
  189. """Check relevance considering parent context from path."""
  190. content = doc.page_content.lower()
  191. path = doc.metadata.get("path", "").lower()
  192. # Check content
  193. if any(kw.lower() in content for kw in keywords):
  194. return True
  195. # Check path (which includes parent nodes)
  196. if any(kw.lower() in path for kw in keywords):
  197. return True
  198. return False
  199. tree_relevant = sum(1 for doc, _ in tree_results if is_relevant_with_context(doc, keywords))
  200. tree_precision = tree_relevant / len(tree_results) if tree_results else 0
  201. print("{:<15} | {:>10} ({:.4f}s, P@3:{:.1%}) | {:>10} ({:.4f}s, P@3:{:.1%})".format(
  202. query, len(bm25_results), bm25_time, bm25_precision,
  203. len(tree_results), tree_time, tree_precision))
  204. bm25_total_precision += bm25_precision
  205. tree_total_precision += tree_precision
  206. print("-" * 60)
  207. print("{:<15} | {:>15} | {:>15}".format(
  208. "Average",
  209. "P@3: {:.1%}".format(bm25_total_precision / num_queries),
  210. "P@3: {:.1%}".format(tree_total_precision / num_queries)))
  211. print("\nComparison complete!")
  212. if __name__ == "__main__":
  213. test_bm25_retrieval()
  214. test_bm25_html_tree()
  215. # 如果需要对比,取消下面的注释
  216. compare_bm25_vs_html_tree()