| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- # -*- coding: utf-8 -*-
- """BidiRag 测试脚本"""
- import sys
- import os
- # 添加项目根目录到路径
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- from bdirag.bidi_rag import BidiRag
- from loguru import logger
- def test_basic_retrieval():
- """测试基础检索功能"""
- print("\n" + "=" * 80)
- print("测试 1: 基础检索功能")
- print("=" * 80)
-
- try:
- # 初始化
- rag = BidiRag(rag_method='bm25')
-
- # 添加测试文档
- docs = [
- "招标公告:招标人为XX市财政局,项目预算100万元",
- "中标公告:中标人为XX建设有限公司,中标金额98万元",
- "采购公告:采购单位XX医院,采购医疗设备",
- ]
- rag.add_texts(docs)
-
- assert rag.get_document_count() == 3, "文档数量不正确"
- print("✓ 文档添加成功")
-
- # 检索
- results = rag.retrieve("招标人", top_k=2)
- assert len(results) > 0, "检索结果为空"
- print(f"✓ 检索成功,返回 {len(results)} 个结果")
-
- # 验证结果包含关键词
- found = False
- for doc, score in results:
- if "招标人" in doc.page_content:
- found = True
- break
- assert found, "结果中未找到关键词"
- print("✓ 检索结果正确")
-
- print("✅ 测试 1 通过\n")
- return True
-
- except Exception as e:
- print(f"❌ 测试 1 失败: {e}\n")
- import traceback
- traceback.print_exc()
- return False
- def test_keyword_search():
- """测试关键词搜索"""
- print("\n" + "=" * 80)
- print("测试 2: 关键词搜索")
- print("=" * 80)
-
- try:
- rag = BidiRag(rag_method='bm25')
-
- docs = [
- "招标人:XX局,项目A",
- "中标人:XX公司,项目B",
- "招标人:XX委,中标人:XX企业,项目C",
- ]
- rag.add_texts(docs)
-
- # 搜索单个关键词
- results = rag.search_keywords(["招标人"])
- assert len(results) == 2, f"应找到2个文档,实际找到 {len(results)} 个"
- print("✓ 单关键词搜索成功")
-
- # 搜索多个关键词(AND 逻辑)
- results = rag.search_keywords(["招标人", "中标人"])
- assert len(results) == 1, f"应找到1个文档,实际找到 {len(results)} 个"
- assert "项目C" in results[0].page_content
- print("✓ 多关键词搜索成功")
-
- print("✅ 测试 2 通过\n")
- return True
-
- except Exception as e:
- print(f"❌ 测试 2 失败: {e}\n")
- import traceback
- traceback.print_exc()
- return False
- def test_html_retrieval():
- """测试 HTML 文档检索"""
- print("\n" + "=" * 80)
- print("测试 3: HTML 文档检索")
- print("=" * 80)
-
- try:
- # 使用 HTML Tree 方法
- rag = BidiRag(rag_method='bm25_html_tree')
-
- html_docs = [
- """<html><body>
- <h1>招标公告</h1>
- <p>招标人:XX市政府采购中心</p>
- <p>项目预算:200万元</p>
- </body></html>""",
- """<html><body>
- <h1>中标公告</h1>
- <p>中标人:XX科技有限公司</p>
- <p>中标金额:180万元</p>
- </body></html>"""
- ]
-
- rag.add_texts(html_docs)
- print("✓ HTML 文档添加成功")
-
- # 检索
- results = rag.retrieve(
- "招标人信息",
- top_k=2,
- keywords=["招标人", "中标人"]
- )
-
- assert len(results) > 0, "检索结果为空"
- print(f"✓ HTML 检索成功,返回 {len(results)} 个结果")
-
- print("✅ 测试 3 通过\n")
- return True
-
- except Exception as e:
- print(f"❌ 测试 3 失败: {e}\n")
- import traceback
- traceback.print_exc()
- return False
- def test_different_methods():
- """测试不同的 RAG 方法"""
- print("\n" + "=" * 80)
- print("测试 4: 不同 RAG 方法")
- print("=" * 80)
-
- test_docs = [
- "招标人A,项目1",
- "中标人B,项目2",
- ]
-
- methods = ['bm25', 'tfidf', 'keyword']
- success_count = 0
-
- for method in methods:
- try:
- print(f"\n测试方法: {method}")
- rag = BidiRag(rag_method=method)
- rag.add_texts(test_docs)
-
- results = rag.retrieve("招标人", top_k=1)
- if len(results) > 0:
- print(f" ✓ {method} 方法正常工作")
- success_count += 1
- else:
- print(f" ⚠ {method} 方法返回空结果")
-
- except Exception as e:
- print(f" ❌ {method} 方法失败: {e}")
-
- print(f"\n✅ {success_count}/{len(methods)} 个方法测试通过\n")
- return success_count == len(methods)
- def test_filter_by_keywords():
- """测试关键词过滤"""
- print("\n" + "=" * 80)
- print("测试 5: 关键词过滤")
- print("=" * 80)
-
- try:
- rag = BidiRag(rag_method='bm25')
-
- docs = [
- "招标人:A公司,项目1",
- "中标人:B公司,项目2",
- "招标人:C公司,中标人:D公司,项目3",
- "采购人:E单位,项目4",
- ]
- rag.add_texts(docs)
-
- # 使用关键词过滤
- results = rag.retrieve(
- "公司信息",
- top_k=5,
- keywords=["招标人"]
- )
-
- # 所有结果都应包含"招标人"
- for doc, score in results:
- assert "招标人" in doc.page_content, f"结果中应包含'招标人': {doc.page_content}"
-
- print(f"✓ 关键词过滤成功,返回 {len(results)} 个结果")
- print("✅ 测试 5 通过\n")
- return True
-
- except Exception as e:
- print(f"❌ 测试 5 失败: {e}\n")
- import traceback
- traceback.print_exc()
- return False
- def test_available_methods():
- """测试可用方法列表"""
- print("\n" + "=" * 80)
- print("测试 6: 可用方法列表")
- print("=" * 80)
-
- try:
- rag = BidiRag(rag_method='bm25')
- methods = rag.list_available_methods()
-
- assert len(methods) > 0, "方法列表为空"
- print(f"✓ 可用方法数量: {len(methods)}")
- print(f"✓ 方法列表: {', '.join(methods[:10])}...")
-
- # 测试方法信息
- info = rag.get_method_info()
- assert 'method_name' in info
- assert 'document_count' in info
- print(f"✓ 方法信息获取成功: {info['method_name']}")
-
- print("✅ 测试 6 通过\n")
- return True
-
- except Exception as e:
- print(f"❌ 测试 6 失败: {e}\n")
- import traceback
- traceback.print_exc()
- return False
- def main():
- """运行所有测试"""
- print("\n" + " " * 20)
- print("BidiRag 测试套件")
- print(" " * 20 + "\n")
-
- tests = [
- ("基础检索", test_basic_retrieval),
- ("关键词搜索", test_keyword_search),
- ("HTML 检索", test_html_retrieval),
- ("不同方法", test_different_methods),
- ("关键词过滤", test_filter_by_keywords),
- ("可用方法", test_available_methods),
- ]
-
- results = []
- for name, test_func in tests:
- try:
- result = test_func()
- results.append((name, result))
- except Exception as e:
- print(f"❌ {name} 测试异常: {e}")
- results.append((name, False))
-
- # 汇总
- print("\n" + "=" * 80)
- print("测试汇总")
- print("=" * 80)
-
- passed = sum(1 for _, r in results if r)
- total = len(results)
-
- for name, result in results:
- status = "✅ 通过" if result else "❌ 失败"
- print(f"{status} - {name}")
-
- print(f"\n总计: {passed}/{total} 测试通过")
-
- if passed == total:
- print("\n🎉 所有测试通过!")
- return 0
- else:
- print(f"\n⚠️ {total - passed} 个测试失败")
- return 1
- if __name__ == "__main__":
- exit(main())
|