"""
BM25 RAG method standalone test
Tests BM25 keyword retrieval performance on bidding documents
"""
import sys
import os
import time
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from bdirag.document_processor import Document
from bdirag.rag_methods.bm25_rag import BM25RAG
from bdirag.rag_methods.naive_rag import NaiveRAG
from bdirag.rag_methods.bm25_html_tree_rag import BM25HTMLTreeRAG
from examples.sample_data import SAMPLE_BIDDING_DOCS
SAMPLE_HTML = """
XX市第一人民医院医疗设备招标公告
一、项目概况
项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目
项目编号:XX-ZB-2024-001
预算金额:500万元
采购内容:彩色多普勒超声诊断仪 1台
二、投标人资格要求
1. 具有独立承担民事责任的能力
2. 具有有效的医疗器械经营许可证
3. 近三年内无不良经营记录
4. 投标保证金:人民币5万元整
三、技术需求
1. 彩色多普勒超声诊断仪技术参数
- 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头
- 显示屏:≥19英寸高清液晶显示器
- 质保期:整机质保三年
2. 交货时间:合同签订后60天内交货
3. 交货地点:XX市第一人民医院设备科
四、评标方法
采用综合评分法:
- 技术部分:60分
- 商务部分:30分
- 价格部分:10分
五、付款方式
合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%
六、投标截止时间
投标截止时间:2024年12月31日上午9:30
开标时间:同投标截止时间
投标文件递交地点:XX市公共资源交易中心
"""
def test_bm25_retrieval():
print("=" * 60)
print("BM25 RAG - Standalone Test (Plain Text)")
print("=" * 60)
print("\n[1/2] Preparing documents...")
documents = [
Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
for doc in SAMPLE_BIDDING_DOCS
]
print(" Prepared {} documents".format(len(documents)))
print("\n[2/2] Initializing BM25RAG...")
bm25_rag = BM25RAG()
bm25_rag.index_documents(documents)
print(" BM25 index built successfully")
test_queries = [
("预算金额", ["budget", "Budget", "预算"]),
("投标保证金", ["bid bond", "Bid Bond", "保证金"]),
("资质要求", ["qualification", "Qualification", "资质"]),
("评标方法", ["evaluation", "Evaluation", "评标"]),
("质保期", ["warranty", "Warranty", "质保"]),
("付款方式", ["payment", "Payment", "付款"]),
("项目编号 XX-ZB", ["XX-ZB", "Project Code"]),
("交货时间", ["delivery", "Delivery", "交货"]),
]
print("\n" + "=" * 60)
print("BM25 Retrieval Test Results")
print("=" * 60)
for query, keywords in test_queries:
print("\nQuery: {}".format(query))
print("-" * 60)
start = time.time()
results = bm25_rag.retrieve(query, k=3)
elapsed = time.time() - start
print(" Retrieved {} documents in {:.4f}s".format(len(results), elapsed))
# 评估相关性
relevant_count = 0
for i, (doc, score) in enumerate(results, 1):
title = doc.metadata.get("title", "Unknown")
preview = doc.page_content[:80].replace("\n", " ")
# 检查是否包含关键词
is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
if is_relevant:
relevant_count += 1
marker = "[OK]"
else:
marker = "[ --]"
print(" [{}] {} {} (Score: {:.4f})".format(i, marker, title, score))
print(" Preview: {}...".format(preview))
precision = relevant_count / len(results) if results else 0
print(" Precision@3: {:.1%}".format(precision))
print("\n\nBM25 plain text test complete!")
def test_bm25_html_tree():
print("\n\n" + "=" * 60)
print("BM25 HTML Tree RAG - Standalone Test")
print("=" * 60)
print("\n[1/2] Parsing HTML and building tree...")
tree_rag = BM25HTMLTreeRAG()
tree_rag.build_index(SAMPLE_HTML)
print(" HTML tree index built successfully")
test_queries = [
("预算金额", ["预算", "Budget"]),
("投标保证金", ["保证金", "Bond"]),
("技术参数 探头", ["探头", "technical"]),
("评标方法 综合评分", ["评标", "综合评分"]),
("质保期", ["质保", "Warranty"]),
("付款方式", ["付款", "Payment"]),
("交货时间", ["交货", "Delivery"]),
]
print("\n[2/2] Testing BM25 HTML Tree Retrieval...")
print("=" * 60)
for query, keywords in test_queries:
print("\nQuery: {}".format(query))
print("-" * 60)
start = time.time()
results = tree_rag.query(query, k=3)
elapsed = time.time() - start
print(" Retrieved {} subtrees in {:.4f}s".format(len(results), elapsed))
relevant_count = 0
for i, (doc, score) in enumerate(results, 1):
path = doc.metadata.get("path", "")
title = doc.metadata.get("title", "")
# 检查相关性
is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
if is_relevant:
relevant_count += 1
marker = "[OK]"
else:
marker = "[ --]"
print(" [{}] {} Score: {:.4f}".format(i, marker, score))
print(" Path: {}".format(path))
print(" Content: {}...".format(doc.page_content[:120].replace("\n", " ")))
precision = relevant_count / len(results) if results else 0
print(" Precision@3: {:.1%}".format(precision))
print("\n\nBM25 HTML Tree test complete!")
def compare_bm25_vs_html_tree():
print("\n\n" + "=" * 60)
print("BM25 Plain Text vs BM25 HTML Tree Comparison")
print("=" * 60)
documents = [
Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
for doc in SAMPLE_BIDDING_DOCS
]
bm25_rag = BM25RAG()
bm25_rag.index_documents(documents)
tree_rag = BM25HTMLTreeRAG()
tree_rag.build_index(SAMPLE_HTML)
test_queries = [
("预算金额", ["预算", "Budget"]),
("质保期", ["质保", "Warranty"]),
("评标方法", ["评标", "Evaluation"]),
]
print("\n{:<15} | {:>15} | {:>15}".format("Query", "BM25 Docs", "HTML Tree Docs"))
print("-" * 60)
bm25_total_precision = 0
tree_total_precision = 0
num_queries = len(test_queries)
for query, keywords in test_queries:
t0 = time.time()
bm25_results = bm25_rag.retrieve(query, k=3)
bm25_time = time.time() - t0
# 计算BM25的precision
bm25_relevant = sum(1 for doc, _ in bm25_results if any(kw.lower() in doc.page_content.lower() for kw in keywords))
bm25_precision = bm25_relevant / len(bm25_results) if bm25_results else 0
t1 = time.time()
tree_results = tree_rag.query(query, k=3)
tree_time = time.time() - t1
# 计算HTML Tree的precision - 考虑父节点上下文
def is_relevant_with_context(doc, keywords):
"""Check relevance considering parent context from path."""
content = doc.page_content.lower()
path = doc.metadata.get("path", "").lower()
# Check content
if any(kw.lower() in content for kw in keywords):
return True
# Check path (which includes parent nodes)
if any(kw.lower() in path for kw in keywords):
return True
return False
tree_relevant = sum(1 for doc, _ in tree_results if is_relevant_with_context(doc, keywords))
tree_precision = tree_relevant / len(tree_results) if tree_results else 0
print("{:<15} | {:>10} ({:.4f}s, P@3:{:.1%}) | {:>10} ({:.4f}s, P@3:{:.1%})".format(
query, len(bm25_results), bm25_time, bm25_precision,
len(tree_results), tree_time, tree_precision))
bm25_total_precision += bm25_precision
tree_total_precision += tree_precision
print("-" * 60)
print("{:<15} | {:>15} | {:>15}".format(
"Average",
"P@3: {:.1%}".format(bm25_total_precision / num_queries),
"P@3: {:.1%}".format(tree_total_precision / num_queries)))
print("\nComparison complete!")
if __name__ == "__main__":
test_bm25_retrieval()
test_bm25_html_tree()
# 如果需要对比,取消下面的注释
compare_bm25_vs_html_tree()