# -*- coding: utf-8 -*- """BidiRag 测试脚本""" import sys import os # 添加项目根目录到路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from bdirag.bidi_rag import BidiRag from loguru import logger def test_basic_retrieval(): """测试基础检索功能""" print("\n" + "=" * 80) print("测试 1: 基础检索功能") print("=" * 80) try: # 初始化 rag = BidiRag(rag_method='bm25') # 添加测试文档 docs = [ "招标公告:招标人为XX市财政局,项目预算100万元", "中标公告:中标人为XX建设有限公司,中标金额98万元", "采购公告:采购单位XX医院,采购医疗设备", ] rag.add_texts(docs) assert rag.get_document_count() == 3, "文档数量不正确" print("✓ 文档添加成功") # 检索 results = rag.retrieve("招标人", top_k=2) assert len(results) > 0, "检索结果为空" print(f"✓ 检索成功,返回 {len(results)} 个结果") # 验证结果包含关键词 found = False for doc, score in results: if "招标人" in doc.page_content: found = True break assert found, "结果中未找到关键词" print("✓ 检索结果正确") print("✅ 测试 1 通过\n") return True except Exception as e: print(f"❌ 测试 1 失败: {e}\n") import traceback traceback.print_exc() return False def test_keyword_search(): """测试关键词搜索""" print("\n" + "=" * 80) print("测试 2: 关键词搜索") print("=" * 80) try: rag = BidiRag(rag_method='bm25') docs = [ "招标人:XX局,项目A", "中标人:XX公司,项目B", "招标人:XX委,中标人:XX企业,项目C", ] rag.add_texts(docs) # 搜索单个关键词 results = rag.search_keywords(["招标人"]) assert len(results) == 2, f"应找到2个文档,实际找到 {len(results)} 个" print("✓ 单关键词搜索成功") # 搜索多个关键词(AND 逻辑) results = rag.search_keywords(["招标人", "中标人"]) assert len(results) == 1, f"应找到1个文档,实际找到 {len(results)} 个" assert "项目C" in results[0].page_content print("✓ 多关键词搜索成功") print("✅ 测试 2 通过\n") return True except Exception as e: print(f"❌ 测试 2 失败: {e}\n") import traceback traceback.print_exc() return False def test_html_retrieval(): """测试 HTML 文档检索""" print("\n" + "=" * 80) print("测试 3: HTML 文档检索") print("=" * 80) try: # 使用 HTML Tree 方法 rag = BidiRag(rag_method='bm25_html_tree') html_docs = [ """
招标人:XX市政府采购中心
项目预算:200万元
""", """中标人:XX科技有限公司
中标金额:180万元
""" ] rag.add_texts(html_docs) print("✓ HTML 文档添加成功") # 检索 results = rag.retrieve( "招标人信息", top_k=2, keywords=["招标人", "中标人"] ) assert len(results) > 0, "检索结果为空" print(f"✓ HTML 检索成功,返回 {len(results)} 个结果") print("✅ 测试 3 通过\n") return True except Exception as e: print(f"❌ 测试 3 失败: {e}\n") import traceback traceback.print_exc() return False def test_different_methods(): """测试不同的 RAG 方法""" print("\n" + "=" * 80) print("测试 4: 不同 RAG 方法") print("=" * 80) test_docs = [ "招标人A,项目1", "中标人B,项目2", ] methods = ['bm25', 'tfidf', 'keyword'] success_count = 0 for method in methods: try: print(f"\n测试方法: {method}") rag = BidiRag(rag_method=method) rag.add_texts(test_docs) results = rag.retrieve("招标人", top_k=1) if len(results) > 0: print(f" ✓ {method} 方法正常工作") success_count += 1 else: print(f" ⚠ {method} 方法返回空结果") except Exception as e: print(f" ❌ {method} 方法失败: {e}") print(f"\n✅ {success_count}/{len(methods)} 个方法测试通过\n") return success_count == len(methods) def test_filter_by_keywords(): """测试关键词过滤""" print("\n" + "=" * 80) print("测试 5: 关键词过滤") print("=" * 80) try: rag = BidiRag(rag_method='bm25') docs = [ "招标人:A公司,项目1", "中标人:B公司,项目2", "招标人:C公司,中标人:D公司,项目3", "采购人:E单位,项目4", ] rag.add_texts(docs) # 使用关键词过滤 results = rag.retrieve( "公司信息", top_k=5, keywords=["招标人"] ) # 所有结果都应包含"招标人" for doc, score in results: assert "招标人" in doc.page_content, f"结果中应包含'招标人': {doc.page_content}" print(f"✓ 关键词过滤成功,返回 {len(results)} 个结果") print("✅ 测试 5 通过\n") return True except Exception as e: print(f"❌ 测试 5 失败: {e}\n") import traceback traceback.print_exc() return False def test_available_methods(): """测试可用方法列表""" print("\n" + "=" * 80) print("测试 6: 可用方法列表") print("=" * 80) try: rag = BidiRag(rag_method='bm25') methods = rag.list_available_methods() assert len(methods) > 0, "方法列表为空" print(f"✓ 可用方法数量: {len(methods)}") print(f"✓ 方法列表: {', '.join(methods[:10])}...") # 测试方法信息 info = rag.get_method_info() assert 'method_name' in info assert 'document_count' in info print(f"✓ 方法信息获取成功: {info['method_name']}") print("✅ 测试 6 通过\n") return True except Exception as e: print(f"❌ 测试 6 失败: {e}\n") import traceback traceback.print_exc() return False def main(): """运行所有测试""" print("\n" + " " * 20) print("BidiRag 测试套件") print(" " * 20 + "\n") tests = [ ("基础检索", test_basic_retrieval), ("关键词搜索", test_keyword_search), ("HTML 检索", test_html_retrieval), ("不同方法", test_different_methods), ("关键词过滤", test_filter_by_keywords), ("可用方法", test_available_methods), ] results = [] for name, test_func in tests: try: result = test_func() results.append((name, result)) except Exception as e: print(f"❌ {name} 测试异常: {e}") results.append((name, False)) # 汇总 print("\n" + "=" * 80) print("测试汇总") print("=" * 80) passed = sum(1 for _, r in results if r) total = len(results) for name, result in results: status = "✅ 通过" if result else "❌ 失败" print(f"{status} - {name}") print(f"\n总计: {passed}/{total} 测试通过") if passed == total: print("\n🎉 所有测试通过!") return 0 else: print(f"\n⚠️ {total - passed} 个测试失败") return 1 if __name__ == "__main__": exit(main())