test_bidi_rag.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # -*- coding: utf-8 -*-
  2. """BidiRag 测试脚本"""
  3. import sys
  4. import os
  5. # 添加项目根目录到路径
  6. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  7. from bdirag.bidi_rag import BidiRag
  8. from loguru import logger
  9. def test_basic_retrieval():
  10. """测试基础检索功能"""
  11. print("\n" + "=" * 80)
  12. print("测试 1: 基础检索功能")
  13. print("=" * 80)
  14. try:
  15. # 初始化
  16. rag = BidiRag(rag_method='bm25')
  17. # 添加测试文档
  18. docs = [
  19. "招标公告:招标人为XX市财政局,项目预算100万元",
  20. "中标公告:中标人为XX建设有限公司,中标金额98万元",
  21. "采购公告:采购单位XX医院,采购医疗设备",
  22. ]
  23. rag.add_texts(docs)
  24. assert rag.get_document_count() == 3, "文档数量不正确"
  25. print("✓ 文档添加成功")
  26. # 检索
  27. results = rag.retrieve("招标人", top_k=2)
  28. assert len(results) > 0, "检索结果为空"
  29. print(f"✓ 检索成功,返回 {len(results)} 个结果")
  30. # 验证结果包含关键词
  31. found = False
  32. for doc, score in results:
  33. if "招标人" in doc.page_content:
  34. found = True
  35. break
  36. assert found, "结果中未找到关键词"
  37. print("✓ 检索结果正确")
  38. print("✅ 测试 1 通过\n")
  39. return True
  40. except Exception as e:
  41. print(f"❌ 测试 1 失败: {e}\n")
  42. import traceback
  43. traceback.print_exc()
  44. return False
  45. def test_keyword_search():
  46. """测试关键词搜索"""
  47. print("\n" + "=" * 80)
  48. print("测试 2: 关键词搜索")
  49. print("=" * 80)
  50. try:
  51. rag = BidiRag(rag_method='bm25')
  52. docs = [
  53. "招标人:XX局,项目A",
  54. "中标人:XX公司,项目B",
  55. "招标人:XX委,中标人:XX企业,项目C",
  56. ]
  57. rag.add_texts(docs)
  58. # 搜索单个关键词
  59. results = rag.search_keywords(["招标人"])
  60. assert len(results) == 2, f"应找到2个文档,实际找到 {len(results)} 个"
  61. print("✓ 单关键词搜索成功")
  62. # 搜索多个关键词(AND 逻辑)
  63. results = rag.search_keywords(["招标人", "中标人"])
  64. assert len(results) == 1, f"应找到1个文档,实际找到 {len(results)} 个"
  65. assert "项目C" in results[0].page_content
  66. print("✓ 多关键词搜索成功")
  67. print("✅ 测试 2 通过\n")
  68. return True
  69. except Exception as e:
  70. print(f"❌ 测试 2 失败: {e}\n")
  71. import traceback
  72. traceback.print_exc()
  73. return False
  74. def test_html_retrieval():
  75. """测试 HTML 文档检索"""
  76. print("\n" + "=" * 80)
  77. print("测试 3: HTML 文档检索")
  78. print("=" * 80)
  79. try:
  80. # 使用 HTML Tree 方法
  81. rag = BidiRag(rag_method='bm25_html_tree')
  82. html_docs = [
  83. """<html><body>
  84. <h1>招标公告</h1>
  85. <p>招标人:XX市政府采购中心</p>
  86. <p>项目预算:200万元</p>
  87. </body></html>""",
  88. """<html><body>
  89. <h1>中标公告</h1>
  90. <p>中标人:XX科技有限公司</p>
  91. <p>中标金额:180万元</p>
  92. </body></html>"""
  93. ]
  94. rag.add_texts(html_docs)
  95. print("✓ HTML 文档添加成功")
  96. # 检索
  97. results = rag.retrieve(
  98. "招标人信息",
  99. top_k=2,
  100. keywords=["招标人", "中标人"]
  101. )
  102. assert len(results) > 0, "检索结果为空"
  103. print(f"✓ HTML 检索成功,返回 {len(results)} 个结果")
  104. print("✅ 测试 3 通过\n")
  105. return True
  106. except Exception as e:
  107. print(f"❌ 测试 3 失败: {e}\n")
  108. import traceback
  109. traceback.print_exc()
  110. return False
  111. def test_different_methods():
  112. """测试不同的 RAG 方法"""
  113. print("\n" + "=" * 80)
  114. print("测试 4: 不同 RAG 方法")
  115. print("=" * 80)
  116. test_docs = [
  117. "招标人A,项目1",
  118. "中标人B,项目2",
  119. ]
  120. methods = ['bm25', 'tfidf', 'keyword']
  121. success_count = 0
  122. for method in methods:
  123. try:
  124. print(f"\n测试方法: {method}")
  125. rag = BidiRag(rag_method=method)
  126. rag.add_texts(test_docs)
  127. results = rag.retrieve("招标人", top_k=1)
  128. if len(results) > 0:
  129. print(f" ✓ {method} 方法正常工作")
  130. success_count += 1
  131. else:
  132. print(f" ⚠ {method} 方法返回空结果")
  133. except Exception as e:
  134. print(f" ❌ {method} 方法失败: {e}")
  135. print(f"\n✅ {success_count}/{len(methods)} 个方法测试通过\n")
  136. return success_count == len(methods)
  137. def test_filter_by_keywords():
  138. """测试关键词过滤"""
  139. print("\n" + "=" * 80)
  140. print("测试 5: 关键词过滤")
  141. print("=" * 80)
  142. try:
  143. rag = BidiRag(rag_method='bm25')
  144. docs = [
  145. "招标人:A公司,项目1",
  146. "中标人:B公司,项目2",
  147. "招标人:C公司,中标人:D公司,项目3",
  148. "采购人:E单位,项目4",
  149. ]
  150. rag.add_texts(docs)
  151. # 使用关键词过滤
  152. results = rag.retrieve(
  153. "公司信息",
  154. top_k=5,
  155. keywords=["招标人"]
  156. )
  157. # 所有结果都应包含"招标人"
  158. for doc, score in results:
  159. assert "招标人" in doc.page_content, f"结果中应包含'招标人': {doc.page_content}"
  160. print(f"✓ 关键词过滤成功,返回 {len(results)} 个结果")
  161. print("✅ 测试 5 通过\n")
  162. return True
  163. except Exception as e:
  164. print(f"❌ 测试 5 失败: {e}\n")
  165. import traceback
  166. traceback.print_exc()
  167. return False
  168. def test_available_methods():
  169. """测试可用方法列表"""
  170. print("\n" + "=" * 80)
  171. print("测试 6: 可用方法列表")
  172. print("=" * 80)
  173. try:
  174. rag = BidiRag(rag_method='bm25')
  175. methods = rag.list_available_methods()
  176. assert len(methods) > 0, "方法列表为空"
  177. print(f"✓ 可用方法数量: {len(methods)}")
  178. print(f"✓ 方法列表: {', '.join(methods[:10])}...")
  179. # 测试方法信息
  180. info = rag.get_method_info()
  181. assert 'method_name' in info
  182. assert 'document_count' in info
  183. print(f"✓ 方法信息获取成功: {info['method_name']}")
  184. print("✅ 测试 6 通过\n")
  185. return True
  186. except Exception as e:
  187. print(f"❌ 测试 6 失败: {e}\n")
  188. import traceback
  189. traceback.print_exc()
  190. return False
  191. def main():
  192. """运行所有测试"""
  193. print("\n" + " " * 20)
  194. print("BidiRag 测试套件")
  195. print(" " * 20 + "\n")
  196. tests = [
  197. ("基础检索", test_basic_retrieval),
  198. ("关键词搜索", test_keyword_search),
  199. ("HTML 检索", test_html_retrieval),
  200. ("不同方法", test_different_methods),
  201. ("关键词过滤", test_filter_by_keywords),
  202. ("可用方法", test_available_methods),
  203. ]
  204. results = []
  205. for name, test_func in tests:
  206. try:
  207. result = test_func()
  208. results.append((name, result))
  209. except Exception as e:
  210. print(f"❌ {name} 测试异常: {e}")
  211. results.append((name, False))
  212. # 汇总
  213. print("\n" + "=" * 80)
  214. print("测试汇总")
  215. print("=" * 80)
  216. passed = sum(1 for _, r in results if r)
  217. total = len(results)
  218. for name, result in results:
  219. status = "✅ 通过" if result else "❌ 失败"
  220. print(f"{status} - {name}")
  221. print(f"\n总计: {passed}/{total} 测试通过")
  222. if passed == total:
  223. print("\n🎉 所有测试通过!")
  224. return 0
  225. else:
  226. print(f"\n⚠️ {total - passed} 个测试失败")
  227. return 1
  228. if __name__ == "__main__":
  229. exit(main())