| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- # -*- coding: utf-8 -*-
- """
- 测试 BidiRag 支持的所有 RAG 方法
- 这个脚本会逐一测试每种 RAG 方法的召回效果
- """
- import sys
- import os
- import time
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- from bdirag.bidi_rag import BidiRag
- # 测试文档(纯文本格式)
- TEST_DOCS = [
- """XX市第一人民医院医疗设备招标公告
- 项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目
- 项目编号:XX-ZB-2024-001
- 预算金额:500万元
- 采购内容:彩色多普勒超声诊断仪 1台
- 投标人资格要求:
- 1. 具有独立承担民事责任的能力
- 2. 具有有效的医疗器械经营许可证
- 3. 近三年内无不良经营记录
- 4. 投标保证金:人民币5万元整
- 技术需求:
- 1. 彩色多普勒超声诊断仪技术参数
- - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头
- - 显示屏:≥19英寸高清液晶显示器
- - 质保期:整机质保三年
- 2. 交货时间:合同签订后60天内交货
- 3. 交货地点:XX市第一人民医院设备科
- 评标方法:采用综合评分法
- - 技术部分:60分
- - 商务部分:30分
- - 价格部分:10分
- 付款方式:合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%
- 投标截止时间:2024年12月31日上午9:30
- 开标时间:同投标截止时间
- 投标文件递交地点:XX市公共资源交易中心""",
- """XX市智慧交通系统建设项目招标公告
- 项目名称:XX市智慧交通系统建设项目
- 项目编号:XX-ZB-2024-002
- 招标人:XX市交通运输局
- 预算金额:5000万元
- 项目内容:
- 1. 交通信号控制系统
- 2. 视频监控系统
- 3. 交通流量监测系统
- 4. 数据分析平台
- 资质要求:
- 1. 电子与智能化工程专业承包二级以上资质
- 2. 近三年至少完成2个类似项目业绩
- 3. 项目经理须具备机电工程专业一级建造师证书
- 评标方法:综合评分法
- - 技术部分:60分
- - 商务部分:40分
- 交货时间:合同签订后180天内
- 质保期:3年
- 付款方式:30%预付款,40%到货验收,25%最终验收,5%质保金"""
- ]
- def test_rag_method(method_name, query="预算金额", keywords=None):
- """测试单个 RAG 方法"""
-
- if keywords is None:
- keywords = ["预算"]
-
- print("\n" + "=" * 80)
- print("测试方法: {}".format(method_name))
- print("=" * 80)
-
- try:
- # 初始化 RAG
- start_time = time.time()
- rag = BidiRag(rag_method=method_name)
- init_time = time.time() - start_time
-
- # 添加文档
- start_time = time.time()
- rag.add_texts(TEST_DOCS)
- add_time = time.time() - start_time
-
- # 执行检索
- start_time = time.time()
- results = rag.retrieve(query=query, top_k=3, keywords=keywords)
- retrieve_time = time.time() - start_time
-
- # 评估结果
- relevant_count = 0
- for doc, score in results:
- is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
- if is_relevant:
- relevant_count += 1
-
- precision = relevant_count / len(results) if results else 0
-
- # 输出结果
- print("✓ 初始化时间: {:.2f}s".format(init_time))
- print("✓ 索引构建时间: {:.2f}s".format(add_time))
- print("✓ 检索时间: {:.4f}s".format(retrieve_time))
- print("✓ 召回数量: {}".format(len(results)))
- print("✓ 精确度@3: {:.1%}".format(precision))
-
- if results:
- print("\n前3个结果:")
- for i, (doc, score) in enumerate(results[:3], 1):
- preview = doc.page_content[:100].replace("\n", " ")
- print(" [{}] Score: {:.4f}".format(i, score))
- print(" Preview: {}...".format(preview))
-
- return {
- 'method': method_name,
- 'success': True,
- 'results_count': len(results),
- 'precision': precision,
- 'init_time': init_time,
- 'add_time': add_time,
- 'retrieve_time': retrieve_time,
- }
-
- except Exception as e:
- print("✗ 测试失败: {}".format(str(e)))
- import traceback
- traceback.print_exc()
- return {
- 'method': method_name,
- 'success': False,
- 'error': str(e),
- }
- def main():
- """主测试函数"""
-
- print("=" * 80)
- print("BidiRag - 所有 RAG 方法测试")
- print("=" * 80)
-
- # 只测试不需要 embedding 模型的方法
- methods_to_test = ['bm25', 'tfidf', 'keyword', 'bm25_html_tree']
-
- print("\n测试方法数量: {}".format(len(methods_to_test)))
- print("方法列表: {}".format(', '.join(methods_to_test)))
- print("\n注意: 其他方法需要 embedding 模型,由于网络限制暂时跳过")
-
- # 测试结果汇总
- results_summary = []
-
- # 测试每个方法
- for i, method in enumerate(methods_to_test, 1):
- print("\n[{}/{}] ".format(i, len(methods_to_test)), end="")
- result = test_rag_method(method)
- results_summary.append(result)
-
- # 打印汇总报告
- print("\n\n" + "=" * 80)
- print("测试汇总报告")
- print("=" * 80)
-
- successful = [r for r in results_summary if r['success']]
- failed = [r for r in results_summary if not r['success']]
-
- print("\n总测试数: {}".format(len(results_summary)))
- print("成功: {}".format(len(successful)))
- print("失败: {}".format(len(failed)))
-
- if successful:
- print("\n{:<30} {:>8} {:>10} {:>12}".format("方法名称", "召回数", "精确度", "检索时间"))
- print("-" * 80)
-
- # 按精确度排序
- successful.sort(key=lambda x: x['precision'], reverse=True)
-
- for r in successful:
- print("{:<30} {:>8} {:>9.1%} {:>11.4f}s".format(
- r['method'], r['results_count'], r['precision'], r['retrieve_time']))
-
- if failed:
- print("\n失败的方法:")
- for r in failed:
- print(" - {}: {}".format(r['method'], r.get('error', 'Unknown error')))
-
- # 推荐方法
- print("\n" + "=" * 80)
- print("推荐使用方法")
- print("=" * 80)
-
- # 找出精确度最高的方法
- best_methods = [r for r in successful if r['precision'] == 1.0]
- if best_methods:
- print("\n精确度 100% 的方法:")
- for r in best_methods:
- print(" ✓ {} (召回 {} 个结果, 耗时 {:.4f}s)".format(
- r['method'], r['results_count'], r['retrieve_time']))
-
- # 快速方法(检索时间 < 0.01s)
- fast_methods = [r for r in successful if r['retrieve_time'] < 0.01]
- if fast_methods:
- print("\n快速检索方法 (< 10ms):")
- for r in fast_methods:
- print(" ✓ {} (精确度 {:.1%})".format(r['method'], r['precision']))
-
- print("\n" + "=" * 80)
- print("测试完成!")
- print("=" * 80)
- if __name__ == "__main__":
- main()
|