# -*- coding: utf-8 -*- """ 测试 BidiRag 支持的所有 RAG 方法 这个脚本会逐一测试每种 RAG 方法的召回效果 """ import sys import os import time sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from bdirag.bidi_rag import BidiRag # 测试文档(纯文本格式) TEST_DOCS = [ """XX市第一人民医院医疗设备招标公告 项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目 项目编号:XX-ZB-2024-001 预算金额:500万元 采购内容:彩色多普勒超声诊断仪 1台 投标人资格要求: 1. 具有独立承担民事责任的能力 2. 具有有效的医疗器械经营许可证 3. 近三年内无不良经营记录 4. 投标保证金:人民币5万元整 技术需求: 1. 彩色多普勒超声诊断仪技术参数 - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头 - 显示屏:≥19英寸高清液晶显示器 - 质保期:整机质保三年 2. 交货时间:合同签订后60天内交货 3. 交货地点:XX市第一人民医院设备科 评标方法:采用综合评分法 - 技术部分:60分 - 商务部分:30分 - 价格部分:10分 付款方式:合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5% 投标截止时间:2024年12月31日上午9:30 开标时间:同投标截止时间 投标文件递交地点:XX市公共资源交易中心""", """XX市智慧交通系统建设项目招标公告 项目名称:XX市智慧交通系统建设项目 项目编号:XX-ZB-2024-002 招标人:XX市交通运输局 预算金额:5000万元 项目内容: 1. 交通信号控制系统 2. 视频监控系统 3. 交通流量监测系统 4. 数据分析平台 资质要求: 1. 电子与智能化工程专业承包二级以上资质 2. 近三年至少完成2个类似项目业绩 3. 项目经理须具备机电工程专业一级建造师证书 评标方法:综合评分法 - 技术部分:60分 - 商务部分:40分 交货时间:合同签订后180天内 质保期:3年 付款方式:30%预付款,40%到货验收,25%最终验收,5%质保金""" ] def test_rag_method(method_name, query="预算金额", keywords=None): """测试单个 RAG 方法""" if keywords is None: keywords = ["预算"] print("\n" + "=" * 80) print("测试方法: {}".format(method_name)) print("=" * 80) try: # 初始化 RAG start_time = time.time() rag = BidiRag(rag_method=method_name) init_time = time.time() - start_time # 添加文档 start_time = time.time() rag.add_texts(TEST_DOCS) add_time = time.time() - start_time # 执行检索 start_time = time.time() results = rag.retrieve(query=query, top_k=3, keywords=keywords) retrieve_time = time.time() - start_time # 评估结果 relevant_count = 0 for doc, score in results: is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords) if is_relevant: relevant_count += 1 precision = relevant_count / len(results) if results else 0 # 输出结果 print("✓ 初始化时间: {:.2f}s".format(init_time)) print("✓ 索引构建时间: {:.2f}s".format(add_time)) print("✓ 检索时间: {:.4f}s".format(retrieve_time)) print("✓ 召回数量: {}".format(len(results))) print("✓ 精确度@3: {:.1%}".format(precision)) if results: print("\n前3个结果:") for i, (doc, score) in enumerate(results[:3], 1): preview = doc.page_content[:100].replace("\n", " ") print(" [{}] Score: {:.4f}".format(i, score)) print(" Preview: {}...".format(preview)) return { 'method': method_name, 'success': True, 'results_count': len(results), 'precision': precision, 'init_time': init_time, 'add_time': add_time, 'retrieve_time': retrieve_time, } except Exception as e: print("✗ 测试失败: {}".format(str(e))) import traceback traceback.print_exc() return { 'method': method_name, 'success': False, 'error': str(e), } def main(): """主测试函数""" print("=" * 80) print("BidiRag - 所有 RAG 方法测试") print("=" * 80) # 只测试不需要 embedding 模型的方法 methods_to_test = ['bm25', 'tfidf', 'keyword', 'bm25_html_tree'] print("\n测试方法数量: {}".format(len(methods_to_test))) print("方法列表: {}".format(', '.join(methods_to_test))) print("\n注意: 其他方法需要 embedding 模型,由于网络限制暂时跳过") # 测试结果汇总 results_summary = [] # 测试每个方法 for i, method in enumerate(methods_to_test, 1): print("\n[{}/{}] ".format(i, len(methods_to_test)), end="") result = test_rag_method(method) results_summary.append(result) # 打印汇总报告 print("\n\n" + "=" * 80) print("测试汇总报告") print("=" * 80) successful = [r for r in results_summary if r['success']] failed = [r for r in results_summary if not r['success']] print("\n总测试数: {}".format(len(results_summary))) print("成功: {}".format(len(successful))) print("失败: {}".format(len(failed))) if successful: print("\n{:<30} {:>8} {:>10} {:>12}".format("方法名称", "召回数", "精确度", "检索时间")) print("-" * 80) # 按精确度排序 successful.sort(key=lambda x: x['precision'], reverse=True) for r in successful: print("{:<30} {:>8} {:>9.1%} {:>11.4f}s".format( r['method'], r['results_count'], r['precision'], r['retrieve_time'])) if failed: print("\n失败的方法:") for r in failed: print(" - {}: {}".format(r['method'], r.get('error', 'Unknown error'))) # 推荐方法 print("\n" + "=" * 80) print("推荐使用方法") print("=" * 80) # 找出精确度最高的方法 best_methods = [r for r in successful if r['precision'] == 1.0] if best_methods: print("\n精确度 100% 的方法:") for r in best_methods: print(" ✓ {} (召回 {} 个结果, 耗时 {:.4f}s)".format( r['method'], r['results_count'], r['retrieve_time'])) # 快速方法(检索时间 < 0.01s) fast_methods = [r for r in successful if r['retrieve_time'] < 0.01] if fast_methods: print("\n快速检索方法 (< 10ms):") for r in fast_methods: print(" ✓ {} (精确度 {:.1%})".format(r['method'], r['precision'])) print("\n" + "=" * 80) print("测试完成!") print("=" * 80) if __name__ == "__main__": main()