test_methods_direct.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # -*- coding: utf-8 -*-
  2. """
  3. 直接测试各种 RAG 方法(不通过 BidiRag 封装)
  4. """
  5. import sys
  6. import os
  7. import time
  8. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  9. from bdirag.document_processor import Document
  10. from bdirag.rag_methods.bm25_rag import BM25RAG
  11. from bdirag.rag_methods.tfidf_rag import TFIDFRAG
  12. from bdirag.rag_methods.keyword_rag import KeywordRAG
  13. from bdirag.rag_methods.bm25_html_tree_rag import BM25HTMLTreeRAG
  14. TEST_DOCS = [
  15. Document(page_content="""XX市第一人民医院医疗设备招标公告
  16. 项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目
  17. 项目编号:XX-ZB-2024-001
  18. 预算金额:500万元
  19. 采购内容:彩色多普勒超声诊断仪 1台
  20. 投标人资格要求:
  21. 1. 具有独立承担民事责任的能力
  22. 2. 具有有效的医疗器械经营许可证
  23. 3. 近三年内无不良经营记录
  24. 4. 投标保证金:人民币5万元整
  25. 技术需求:
  26. 1. 彩色多普勒超声诊断仪技术参数
  27. - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头
  28. - 显示屏:≥19英寸高清液晶显示器
  29. - 质保期:整机质保三年
  30. 2. 交货时间:合同签订后60天内交货
  31. 3. 交货地点:XX市第一人民医院设备科
  32. 评标方法:采用综合评分法
  33. - 技术部分:60分
  34. - 商务部分:30分
  35. - 价格部分:10分
  36. 付款方式:合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%"""),
  37. Document(page_content="""XX市智慧交通系统建设项目招标公告
  38. 项目名称:XX市智慧交通系统建设项目
  39. 项目编号:XX-ZB-2024-002
  40. 招标人:XX市交通运输局
  41. 预算金额:5000万元
  42. 项目内容:
  43. 1. 交通信号控制系统
  44. 2. 视频监控系统
  45. 3. 交通流量监测系统
  46. 4. 数据分析平台
  47. 资质要求:
  48. 1. 电子与智能化工程专业承包二级以上资质
  49. 2. 近三年至少完成2个类似项目业绩
  50. 评标方法:综合评分法
  51. - 技术部分:60分
  52. - 商务部分:40分
  53. 交货时间:合同签订后180天内
  54. 质保期:3年""")
  55. ]
  56. def test_bm25():
  57. """测试 BM25"""
  58. print("\n" + "=" * 80)
  59. print("测试 BM25")
  60. print("=" * 80)
  61. rag = BM25RAG()
  62. rag.index_documents(TEST_DOCS)
  63. t0 = time.time()
  64. results = rag.retrieve("预算金额", k=3)
  65. retrieve_time = time.time() - t0
  66. relevant = sum(1 for doc, _ in results if "预算" in doc.page_content)
  67. precision = relevant / len(results) if results else 0
  68. print("召回数量: {}".format(len(results)))
  69. print("检索时间: {:.4f}s".format(retrieve_time))
  70. print("精确度: {:.1%}".format(precision))
  71. if results:
  72. print("\n结果预览:")
  73. for i, (doc, score) in enumerate(results[:2], 1):
  74. preview = doc.page_content[:80].replace("\n", " ")
  75. print(" [{}] Score={:.4f} | {}".format(i, score, preview))
  76. def test_tfidf():
  77. """测试 TF-IDF"""
  78. print("\n" + "=" * 80)
  79. print("测试 TF-IDF")
  80. print("=" * 80)
  81. rag = TFIDFRAG()
  82. rag.index_documents(TEST_DOCS)
  83. t0 = time.time()
  84. results = rag.retrieve("预算金额", k=3)
  85. retrieve_time = time.time() - t0
  86. relevant = sum(1 for doc, _ in results if "预算" in doc.page_content)
  87. precision = relevant / len(results) if results else 0
  88. print("召回数量: {}".format(len(results)))
  89. print("检索时间: {:.4f}s".format(retrieve_time))
  90. print("精确度: {:.1%}".format(precision))
  91. if results:
  92. print("\n结果预览:")
  93. for i, (doc, score) in enumerate(results[:2], 1):
  94. preview = doc.page_content[:80].replace("\n", " ")
  95. print(" [{}] Score={:.4f} | {}".format(i, score, preview))
  96. def test_keyword():
  97. """测试 Keyword"""
  98. print("\n" + "=" * 80)
  99. print("测试 Keyword")
  100. print("=" * 80)
  101. rag = KeywordRAG()
  102. rag.index_documents(TEST_DOCS)
  103. t0 = time.time()
  104. results = rag.retrieve("预算", k=3)
  105. retrieve_time = time.time() - t0
  106. relevant = sum(1 for doc, _ in results if "预算" in doc.page_content)
  107. precision = relevant / len(results) if results else 0
  108. print("召回数量: {}".format(len(results)))
  109. print("检索时间: {:.4f}s".format(retrieve_time))
  110. print("精确度: {:.1%}".format(precision))
  111. if results:
  112. print("\n结果预览:")
  113. for i, (doc, score) in enumerate(results[:2], 1):
  114. preview = doc.page_content[:80].replace("\n", " ")
  115. print(" [{}] Score={:.4f} | {}".format(i, score, preview))
  116. def test_bm25_html_tree():
  117. """测试 BM25 HTML Tree"""
  118. print("\n" + "=" * 80)
  119. print("测试 BM25 HTML Tree")
  120. print("=" * 80)
  121. html_content = "\n".join([doc.page_content for doc in TEST_DOCS])
  122. rag = BM25HTMLTreeRAG()
  123. rag.build_index(html_content)
  124. t0 = time.time()
  125. results = rag.query("预算金额", k=3)
  126. retrieve_time = time.time() - t0
  127. relevant = sum(1 for doc, _ in results if "预算" in doc.page_content)
  128. precision = relevant / len(results) if results else 0
  129. print("召回数量: {}".format(len(results)))
  130. print("检索时间: {:.4f}s".format(retrieve_time))
  131. print("精确度: {:.1%}".format(precision))
  132. if results:
  133. print("\n结果预览:")
  134. for i, (doc, score) in enumerate(results[:2], 1):
  135. path = doc.metadata.get("path", "")
  136. preview = doc.page_content[:80].replace("\n", " ")
  137. print(" [{}] Score={:.4f} | Path: {}".format(i, score, path[:60]))
  138. print(" Content: {}...".format(preview))
  139. if __name__ == "__main__":
  140. print("=" * 80)
  141. print("RAG 方法直接测试")
  142. print("=" * 80)
  143. try:
  144. test_bm25()
  145. except Exception as e:
  146. print("BM25 失败: " + str(e))
  147. try:
  148. test_tfidf()
  149. except Exception as e:
  150. print("TF-IDF 失败: " + str(e))
  151. try:
  152. test_keyword()
  153. except Exception as e:
  154. print("Keyword 失败: " + str(e))
  155. try:
  156. test_bm25_html_tree()
  157. except Exception as e:
  158. print("BM25 HTML Tree 失败: " + str(e))
  159. print("\n\n" + "=" * 80)
  160. print("测试完成!")
  161. print("=" * 80)