| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267 |
- """
- BM25 RAG method standalone test
- Tests BM25 keyword retrieval performance on bidding documents
- """
- import sys
- import os
- import time
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- from bdirag.document_processor import Document
- from bdirag.rag_methods.bm25_rag import BM25RAG
- from bdirag.rag_methods.naive_rag import NaiveRAG
- from bdirag.rag_methods.bm25_html_tree_rag import BM25HTMLTreeRAG
- from examples.sample_data import SAMPLE_BIDDING_DOCS
- SAMPLE_HTML = """
- <html>
- <body>
- <h1>XX市第一人民医院医疗设备招标公告</h1>
- <div>
- <h2>一、项目概况</h2>
- <p>项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目</p>
- <p>项目编号:XX-ZB-2024-001</p>
- <p>预算金额:500万元</p>
- <p>采购内容:彩色多普勒超声诊断仪 1台</p>
- </div>
- <div>
- <h2>二、投标人资格要求</h2>
- <p>1. 具有独立承担民事责任的能力</p>
- <p>2. 具有有效的医疗器械经营许可证</p>
- <p>3. 近三年内无不良经营记录</p>
- <p>4. 投标保证金:人民币5万元整</p>
- </div>
- <div>
- <h2>三、技术需求</h2>
- <p>1. 彩色多普勒超声诊断仪技术参数</p>
- <p> - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头</p>
- <p> - 显示屏:≥19英寸高清液晶显示器</p>
- <p> - 质保期:整机质保三年</p>
- <p>2. 交货时间:合同签订后60天内交货</p>
- <p>3. 交货地点:XX市第一人民医院设备科</p>
- </div>
- <div>
- <h2>四、评标方法</h2>
- <p>采用综合评分法:</p>
- <p> - 技术部分:60分</p>
- <p> - 商务部分:30分</p>
- <p> - 价格部分:10分</p>
- </div>
- <div>
- <h2>五、付款方式</h2>
- <p>合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%</p>
- </div>
- <div>
- <h2>六、投标截止时间</h2>
- <p>投标截止时间:2024年12月31日上午9:30</p>
- <p>开标时间:同投标截止时间</p>
- <p>投标文件递交地点:XX市公共资源交易中心</p>
- </div>
- </body>
- </html>
- """
- def test_bm25_retrieval():
- print("=" * 60)
- print("BM25 RAG - Standalone Test (Plain Text)")
- print("=" * 60)
- print("\n[1/2] Preparing documents...")
- documents = [
- Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
- for doc in SAMPLE_BIDDING_DOCS
- ]
- print(" Prepared {} documents".format(len(documents)))
- print("\n[2/2] Initializing BM25RAG...")
- bm25_rag = BM25RAG()
- bm25_rag.index_documents(documents)
- print(" BM25 index built successfully")
- test_queries = [
- ("预算金额", ["budget", "Budget", "预算"]),
- ("投标保证金", ["bid bond", "Bid Bond", "保证金"]),
- ("资质要求", ["qualification", "Qualification", "资质"]),
- ("评标方法", ["evaluation", "Evaluation", "评标"]),
- ("质保期", ["warranty", "Warranty", "质保"]),
- ("付款方式", ["payment", "Payment", "付款"]),
- ("项目编号 XX-ZB", ["XX-ZB", "Project Code"]),
- ("交货时间", ["delivery", "Delivery", "交货"]),
- ]
- print("\n" + "=" * 60)
- print("BM25 Retrieval Test Results")
- print("=" * 60)
- for query, keywords in test_queries:
- print("\nQuery: {}".format(query))
- print("-" * 60)
-
- start = time.time()
- results = bm25_rag.retrieve(query, k=3)
- elapsed = time.time() - start
-
- print(" Retrieved {} documents in {:.4f}s".format(len(results), elapsed))
-
- # 评估相关性
- relevant_count = 0
- for i, (doc, score) in enumerate(results, 1):
- title = doc.metadata.get("title", "Unknown")
- preview = doc.page_content[:80].replace("\n", " ")
-
- # 检查是否包含关键词
- is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
- if is_relevant:
- relevant_count += 1
- marker = "[OK]"
- else:
- marker = "[ --]"
-
- print(" [{}] {} {} (Score: {:.4f})".format(i, marker, title, score))
- print(" Preview: {}...".format(preview))
-
- precision = relevant_count / len(results) if results else 0
- print(" Precision@3: {:.1%}".format(precision))
-
- print("\n\nBM25 plain text test complete!")
- def test_bm25_html_tree():
- print("\n\n" + "=" * 60)
- print("BM25 HTML Tree RAG - Standalone Test")
- print("=" * 60)
- print("\n[1/2] Parsing HTML and building tree...")
- tree_rag = BM25HTMLTreeRAG()
- tree_rag.build_index(SAMPLE_HTML)
- print(" HTML tree index built successfully")
- test_queries = [
- ("预算金额", ["预算", "Budget"]),
- ("投标保证金", ["保证金", "Bond"]),
- ("技术参数 探头", ["探头", "technical"]),
- ("评标方法 综合评分", ["评标", "综合评分"]),
- ("质保期", ["质保", "Warranty"]),
- ("付款方式", ["付款", "Payment"]),
- ("交货时间", ["交货", "Delivery"]),
- ]
- print("\n[2/2] Testing BM25 HTML Tree Retrieval...")
- print("=" * 60)
- for query, keywords in test_queries:
- print("\nQuery: {}".format(query))
- print("-" * 60)
-
- start = time.time()
- results = tree_rag.query(query, k=3)
- elapsed = time.time() - start
-
- print(" Retrieved {} subtrees in {:.4f}s".format(len(results), elapsed))
-
- relevant_count = 0
- for i, (doc, score) in enumerate(results, 1):
- path = doc.metadata.get("path", "")
- title = doc.metadata.get("title", "")
-
- # 检查相关性
- is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
- if is_relevant:
- relevant_count += 1
- marker = "[OK]"
- else:
- marker = "[ --]"
-
- print(" [{}] {} Score: {:.4f}".format(i, marker, score))
- print(" Path: {}".format(path))
- print(" Content: {}...".format(doc.page_content[:120].replace("\n", " ")))
-
- precision = relevant_count / len(results) if results else 0
- print(" Precision@3: {:.1%}".format(precision))
-
- print("\n\nBM25 HTML Tree test complete!")
- def compare_bm25_vs_html_tree():
- print("\n\n" + "=" * 60)
- print("BM25 Plain Text vs BM25 HTML Tree Comparison")
- print("=" * 60)
- documents = [
- Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
- for doc in SAMPLE_BIDDING_DOCS
- ]
- bm25_rag = BM25RAG()
- bm25_rag.index_documents(documents)
- tree_rag = BM25HTMLTreeRAG()
- tree_rag.build_index(SAMPLE_HTML)
- test_queries = [
- ("预算金额", ["预算", "Budget"]),
- ("质保期", ["质保", "Warranty"]),
- ("评标方法", ["评标", "Evaluation"]),
- ]
- print("\n{:<15} | {:>15} | {:>15}".format("Query", "BM25 Docs", "HTML Tree Docs"))
- print("-" * 60)
- bm25_total_precision = 0
- tree_total_precision = 0
- num_queries = len(test_queries)
- for query, keywords in test_queries:
- t0 = time.time()
- bm25_results = bm25_rag.retrieve(query, k=3)
- bm25_time = time.time() - t0
- # 计算BM25的precision
- bm25_relevant = sum(1 for doc, _ in bm25_results if any(kw.lower() in doc.page_content.lower() for kw in keywords))
- bm25_precision = bm25_relevant / len(bm25_results) if bm25_results else 0
- t1 = time.time()
- tree_results = tree_rag.query(query, k=3)
- tree_time = time.time() - t1
- # 计算HTML Tree的precision - 考虑父节点上下文
- def is_relevant_with_context(doc, keywords):
- """Check relevance considering parent context from path."""
- content = doc.page_content.lower()
- path = doc.metadata.get("path", "").lower()
-
- # Check content
- if any(kw.lower() in content for kw in keywords):
- return True
-
- # Check path (which includes parent nodes)
- if any(kw.lower() in path for kw in keywords):
- return True
-
- return False
-
- tree_relevant = sum(1 for doc, _ in tree_results if is_relevant_with_context(doc, keywords))
- tree_precision = tree_relevant / len(tree_results) if tree_results else 0
- print("{:<15} | {:>10} ({:.4f}s, P@3:{:.1%}) | {:>10} ({:.4f}s, P@3:{:.1%})".format(
- query, len(bm25_results), bm25_time, bm25_precision,
- len(tree_results), tree_time, tree_precision))
-
- bm25_total_precision += bm25_precision
- tree_total_precision += tree_precision
- print("-" * 60)
- print("{:<15} | {:>15} | {:>15}".format(
- "Average",
- "P@3: {:.1%}".format(bm25_total_precision / num_queries),
- "P@3: {:.1%}".format(tree_total_precision / num_queries)))
- print("\nComparison complete!")
- if __name__ == "__main__":
- test_bm25_retrieval()
- test_bm25_html_tree()
- # 如果需要对比,取消下面的注释
- compare_bm25_vs_html_tree()
|