demo_bm25_retrieval.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # -*- coding: utf-8 -*-
  2. """
  3. 使用 BidiRag 的 BM25 方法召回内容片段示例
  4. 对比 bm25_html_tree 和 bm25 两种方法的效果
  5. """
  6. import sys
  7. import os
  8. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  9. from bdirag.bidi_rag import BidiRag
  10. def demo_bm25_retrieval():
  11. """演示使用 BM25 方法召回内容片段"""
  12. print("=" * 80)
  13. print("BidiRag - BM25 纯文本召回示例")
  14. print("=" * 80)
  15. # 1. 初始化 BidiRag(使用 bm25 方法)
  16. print("\n[步骤 1] 初始化 BidiRag (使用 bm25 方法)...")
  17. rag = BidiRag(rag_method='bm25')
  18. # 2. 准备纯文本文档(BM25 更适合纯文本)
  19. print("\n[步骤 2] 添加招投标公告(纯文本格式)...")
  20. sample_docs = [
  21. """XX市第一人民医院医疗设备招标公告
  22. 项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目
  23. 项目编号:XX-ZB-2024-001
  24. 预算金额:500万元
  25. 采购内容:彩色多普勒超声诊断仪 1台
  26. 投标人资格要求:
  27. 1. 具有独立承担民事责任的能力
  28. 2. 具有有效的医疗器械经营许可证
  29. 3. 近三年内无不良经营记录
  30. 4. 投标保证金:人民币5万元整
  31. 技术需求:
  32. 1. 彩色多普勒超声诊断仪技术参数
  33. - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头
  34. - 显示屏:≥19英寸高清液晶显示器
  35. - 质保期:整机质保三年
  36. 2. 交货时间:合同签订后60天内交货
  37. 3. 交货地点:XX市第一人民医院设备科
  38. 评标方法:采用综合评分法
  39. - 技术部分:60分
  40. - 商务部分:30分
  41. - 价格部分:10分
  42. 付款方式:合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%
  43. 投标截止时间:2024年12月31日上午9:30
  44. 开标时间:同投标截止时间
  45. 投标文件递交地点:XX市公共资源交易中心"""
  46. ]
  47. rag.add_texts(sample_docs)
  48. print("✓ 文档已添加")
  49. # 3. 测试不同查询,召回内容片段
  50. test_queries = [
  51. ("预算金额", ["预算", "Budget"]),
  52. ("投标保证金", ["保证金", "Bond"]),
  53. ("技术参数 探头", ["探头", "technical"]),
  54. ("评标方法 综合评分", ["评标", "综合评分"]),
  55. ("质保期", ["质保", "Warranty"]),
  56. ("付款方式", ["付款", "Payment"]),
  57. ("交货时间", ["交货", "Delivery"]),
  58. ]
  59. print("\n" + "=" * 80)
  60. print("开始召回测试")
  61. print("=" * 80)
  62. for query, keywords in test_queries:
  63. print(f"\n{'=' * 80}")
  64. print(f"查询: {query}")
  65. print(f"{'-' * 80}")
  66. # 使用 retrieve 方法召回内容片段
  67. results = rag.retrieve(
  68. query=query,
  69. top_k=3,
  70. keywords=keywords
  71. )
  72. print(f"召回 {len(results)} 个内容片段:\n")
  73. # 评估相关性
  74. relevant_count = 0
  75. for i, (doc, score) in enumerate(results, 1):
  76. # 检查相关性
  77. is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
  78. if is_relevant:
  79. relevant_count += 1
  80. marker = "[✓]"
  81. else:
  82. marker = "[ ]"
  83. print(f" 片段 {i} {marker} (分数: {score:.4f})")
  84. print(f" 内容:")
  85. # 显示内容片段(前200字符)
  86. content_preview = doc.page_content[:200].replace("\n", " ")
  87. print(f" {content_preview}...")
  88. print()
  89. # 计算精确度
  90. precision = relevant_count / len(results) if results else 0
  91. print(f" 精确度@3: {precision:.1%}")
  92. print("\n" + "=" * 80)
  93. print("示例完成!")
  94. print("=" * 80)
  95. print("\n说明:")
  96. print("1. 使用 bm25 方法召回的是纯文本片段(非树节点)")
  97. print("2. 片段基于文档分块,不包含层级结构信息")
  98. print("3. BM25 算法擅长关键词匹配")
  99. print("4. 适合快速检索和简单场景")
  100. if __name__ == "__main__":
  101. demo_bm25_retrieval()