test_all_rag_methods.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # -*- coding: utf-8 -*-
  2. """
  3. 测试 BidiRag 支持的所有 RAG 方法
  4. 这个脚本会逐一测试每种 RAG 方法的召回效果
  5. """
  6. import sys
  7. import os
  8. import time
  9. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  10. from bdirag.bidi_rag import BidiRag
  11. # 测试文档(纯文本格式)
  12. TEST_DOCS = [
  13. """XX市第一人民医院医疗设备招标公告
  14. 项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目
  15. 项目编号:XX-ZB-2024-001
  16. 预算金额:500万元
  17. 采购内容:彩色多普勒超声诊断仪 1台
  18. 投标人资格要求:
  19. 1. 具有独立承担民事责任的能力
  20. 2. 具有有效的医疗器械经营许可证
  21. 3. 近三年内无不良经营记录
  22. 4. 投标保证金:人民币5万元整
  23. 技术需求:
  24. 1. 彩色多普勒超声诊断仪技术参数
  25. - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头
  26. - 显示屏:≥19英寸高清液晶显示器
  27. - 质保期:整机质保三年
  28. 2. 交货时间:合同签订后60天内交货
  29. 3. 交货地点:XX市第一人民医院设备科
  30. 评标方法:采用综合评分法
  31. - 技术部分:60分
  32. - 商务部分:30分
  33. - 价格部分:10分
  34. 付款方式:合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%
  35. 投标截止时间:2024年12月31日上午9:30
  36. 开标时间:同投标截止时间
  37. 投标文件递交地点:XX市公共资源交易中心""",
  38. """XX市智慧交通系统建设项目招标公告
  39. 项目名称:XX市智慧交通系统建设项目
  40. 项目编号:XX-ZB-2024-002
  41. 招标人:XX市交通运输局
  42. 预算金额:5000万元
  43. 项目内容:
  44. 1. 交通信号控制系统
  45. 2. 视频监控系统
  46. 3. 交通流量监测系统
  47. 4. 数据分析平台
  48. 资质要求:
  49. 1. 电子与智能化工程专业承包二级以上资质
  50. 2. 近三年至少完成2个类似项目业绩
  51. 3. 项目经理须具备机电工程专业一级建造师证书
  52. 评标方法:综合评分法
  53. - 技术部分:60分
  54. - 商务部分:40分
  55. 交货时间:合同签订后180天内
  56. 质保期:3年
  57. 付款方式:30%预付款,40%到货验收,25%最终验收,5%质保金"""
  58. ]
  59. def test_rag_method(method_name, query="预算金额", keywords=None):
  60. """测试单个 RAG 方法"""
  61. if keywords is None:
  62. keywords = ["预算"]
  63. print("\n" + "=" * 80)
  64. print("测试方法: {}".format(method_name))
  65. print("=" * 80)
  66. try:
  67. # 初始化 RAG
  68. start_time = time.time()
  69. rag = BidiRag(rag_method=method_name)
  70. init_time = time.time() - start_time
  71. # 添加文档
  72. start_time = time.time()
  73. rag.add_texts(TEST_DOCS)
  74. add_time = time.time() - start_time
  75. # 执行检索
  76. start_time = time.time()
  77. results = rag.retrieve(query=query, top_k=3, keywords=keywords)
  78. retrieve_time = time.time() - start_time
  79. # 评估结果
  80. relevant_count = 0
  81. for doc, score in results:
  82. is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
  83. if is_relevant:
  84. relevant_count += 1
  85. precision = relevant_count / len(results) if results else 0
  86. # 输出结果
  87. print("✓ 初始化时间: {:.2f}s".format(init_time))
  88. print("✓ 索引构建时间: {:.2f}s".format(add_time))
  89. print("✓ 检索时间: {:.4f}s".format(retrieve_time))
  90. print("✓ 召回数量: {}".format(len(results)))
  91. print("✓ 精确度@3: {:.1%}".format(precision))
  92. if results:
  93. print("\n前3个结果:")
  94. for i, (doc, score) in enumerate(results[:3], 1):
  95. preview = doc.page_content[:100].replace("\n", " ")
  96. print(" [{}] Score: {:.4f}".format(i, score))
  97. print(" Preview: {}...".format(preview))
  98. return {
  99. 'method': method_name,
  100. 'success': True,
  101. 'results_count': len(results),
  102. 'precision': precision,
  103. 'init_time': init_time,
  104. 'add_time': add_time,
  105. 'retrieve_time': retrieve_time,
  106. }
  107. except Exception as e:
  108. print("✗ 测试失败: {}".format(str(e)))
  109. import traceback
  110. traceback.print_exc()
  111. return {
  112. 'method': method_name,
  113. 'success': False,
  114. 'error': str(e),
  115. }
  116. def main():
  117. """主测试函数"""
  118. print("=" * 80)
  119. print("BidiRag - 所有 RAG 方法测试")
  120. print("=" * 80)
  121. # 只测试不需要 embedding 模型的方法
  122. methods_to_test = ['bm25', 'tfidf', 'keyword', 'bm25_html_tree']
  123. print("\n测试方法数量: {}".format(len(methods_to_test)))
  124. print("方法列表: {}".format(', '.join(methods_to_test)))
  125. print("\n注意: 其他方法需要 embedding 模型,由于网络限制暂时跳过")
  126. # 测试结果汇总
  127. results_summary = []
  128. # 测试每个方法
  129. for i, method in enumerate(methods_to_test, 1):
  130. print("\n[{}/{}] ".format(i, len(methods_to_test)), end="")
  131. result = test_rag_method(method)
  132. results_summary.append(result)
  133. # 打印汇总报告
  134. print("\n\n" + "=" * 80)
  135. print("测试汇总报告")
  136. print("=" * 80)
  137. successful = [r for r in results_summary if r['success']]
  138. failed = [r for r in results_summary if not r['success']]
  139. print("\n总测试数: {}".format(len(results_summary)))
  140. print("成功: {}".format(len(successful)))
  141. print("失败: {}".format(len(failed)))
  142. if successful:
  143. print("\n{:<30} {:>8} {:>10} {:>12}".format("方法名称", "召回数", "精确度", "检索时间"))
  144. print("-" * 80)
  145. # 按精确度排序
  146. successful.sort(key=lambda x: x['precision'], reverse=True)
  147. for r in successful:
  148. print("{:<30} {:>8} {:>9.1%} {:>11.4f}s".format(
  149. r['method'], r['results_count'], r['precision'], r['retrieve_time']))
  150. if failed:
  151. print("\n失败的方法:")
  152. for r in failed:
  153. print(" - {}: {}".format(r['method'], r.get('error', 'Unknown error')))
  154. # 推荐方法
  155. print("\n" + "=" * 80)
  156. print("推荐使用方法")
  157. print("=" * 80)
  158. # 找出精确度最高的方法
  159. best_methods = [r for r in successful if r['precision'] == 1.0]
  160. if best_methods:
  161. print("\n精确度 100% 的方法:")
  162. for r in best_methods:
  163. print(" ✓ {} (召回 {} 个结果, 耗时 {:.4f}s)".format(
  164. r['method'], r['results_count'], r['retrieve_time']))
  165. # 快速方法(检索时间 < 0.01s)
  166. fast_methods = [r for r in successful if r['retrieve_time'] < 0.01]
  167. if fast_methods:
  168. print("\n快速检索方法 (< 10ms):")
  169. for r in fast_methods:
  170. print(" ✓ {} (精确度 {:.1%})".format(r['method'], r['precision']))
  171. print("\n" + "=" * 80)
  172. print("测试完成!")
  173. print("=" * 80)
  174. if __name__ == "__main__":
  175. main()