""" BM25 RAG method standalone test Tests BM25 keyword retrieval performance on bidding documents """ import sys import os import time sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from bdirag.document_processor import Document from bdirag.rag_methods.bm25_rag import BM25RAG from bdirag.rag_methods.naive_rag import NaiveRAG from bdirag.rag_methods.bm25_html_tree_rag import BM25HTMLTreeRAG from examples.sample_data import SAMPLE_BIDDING_DOCS SAMPLE_HTML = """

XX市第一人民医院医疗设备招标公告

一、项目概况

项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目

项目编号:XX-ZB-2024-001

预算金额:500万元

采购内容:彩色多普勒超声诊断仪 1台

二、投标人资格要求

1. 具有独立承担民事责任的能力

2. 具有有效的医疗器械经营许可证

3. 近三年内无不良经营记录

4. 投标保证金:人民币5万元整

三、技术需求

1. 彩色多普勒超声诊断仪技术参数

- 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头

- 显示屏:≥19英寸高清液晶显示器

- 质保期:整机质保三年

2. 交货时间:合同签订后60天内交货

3. 交货地点:XX市第一人民医院设备科

四、评标方法

采用综合评分法:

- 技术部分:60分

- 商务部分:30分

- 价格部分:10分

五、付款方式

合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%

六、投标截止时间

投标截止时间:2024年12月31日上午9:30

开标时间:同投标截止时间

投标文件递交地点:XX市公共资源交易中心

""" def test_bm25_retrieval(): print("=" * 60) print("BM25 RAG - Standalone Test (Plain Text)") print("=" * 60) print("\n[1/2] Preparing documents...") documents = [ Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]}) for doc in SAMPLE_BIDDING_DOCS ] print(" Prepared {} documents".format(len(documents))) print("\n[2/2] Initializing BM25RAG...") bm25_rag = BM25RAG() bm25_rag.index_documents(documents) print(" BM25 index built successfully") test_queries = [ ("预算金额", ["budget", "Budget", "预算"]), ("投标保证金", ["bid bond", "Bid Bond", "保证金"]), ("资质要求", ["qualification", "Qualification", "资质"]), ("评标方法", ["evaluation", "Evaluation", "评标"]), ("质保期", ["warranty", "Warranty", "质保"]), ("付款方式", ["payment", "Payment", "付款"]), ("项目编号 XX-ZB", ["XX-ZB", "Project Code"]), ("交货时间", ["delivery", "Delivery", "交货"]), ] print("\n" + "=" * 60) print("BM25 Retrieval Test Results") print("=" * 60) for query, keywords in test_queries: print("\nQuery: {}".format(query)) print("-" * 60) start = time.time() results = bm25_rag.retrieve(query, k=3) elapsed = time.time() - start print(" Retrieved {} documents in {:.4f}s".format(len(results), elapsed)) # 评估相关性 relevant_count = 0 for i, (doc, score) in enumerate(results, 1): title = doc.metadata.get("title", "Unknown") preview = doc.page_content[:80].replace("\n", " ") # 检查是否包含关键词 is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords) if is_relevant: relevant_count += 1 marker = "[OK]" else: marker = "[ --]" print(" [{}] {} {} (Score: {:.4f})".format(i, marker, title, score)) print(" Preview: {}...".format(preview)) precision = relevant_count / len(results) if results else 0 print(" Precision@3: {:.1%}".format(precision)) print("\n\nBM25 plain text test complete!") def test_bm25_html_tree(): print("\n\n" + "=" * 60) print("BM25 HTML Tree RAG - Standalone Test") print("=" * 60) print("\n[1/2] Parsing HTML and building tree...") tree_rag = BM25HTMLTreeRAG() tree_rag.build_index(SAMPLE_HTML) print(" HTML tree index built successfully") test_queries = [ ("预算金额", ["预算", "Budget"]), ("投标保证金", ["保证金", "Bond"]), ("技术参数 探头", ["探头", "technical"]), ("评标方法 综合评分", ["评标", "综合评分"]), ("质保期", ["质保", "Warranty"]), ("付款方式", ["付款", "Payment"]), ("交货时间", ["交货", "Delivery"]), ] print("\n[2/2] Testing BM25 HTML Tree Retrieval...") print("=" * 60) for query, keywords in test_queries: print("\nQuery: {}".format(query)) print("-" * 60) start = time.time() results = tree_rag.query(query, k=3) elapsed = time.time() - start print(" Retrieved {} subtrees in {:.4f}s".format(len(results), elapsed)) relevant_count = 0 for i, (doc, score) in enumerate(results, 1): path = doc.metadata.get("path", "") title = doc.metadata.get("title", "") # 检查相关性 is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords) if is_relevant: relevant_count += 1 marker = "[OK]" else: marker = "[ --]" print(" [{}] {} Score: {:.4f}".format(i, marker, score)) print(" Path: {}".format(path)) print(" Content: {}...".format(doc.page_content[:120].replace("\n", " "))) precision = relevant_count / len(results) if results else 0 print(" Precision@3: {:.1%}".format(precision)) print("\n\nBM25 HTML Tree test complete!") def compare_bm25_vs_html_tree(): print("\n\n" + "=" * 60) print("BM25 Plain Text vs BM25 HTML Tree Comparison") print("=" * 60) documents = [ Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]}) for doc in SAMPLE_BIDDING_DOCS ] bm25_rag = BM25RAG() bm25_rag.index_documents(documents) tree_rag = BM25HTMLTreeRAG() tree_rag.build_index(SAMPLE_HTML) test_queries = [ ("预算金额", ["预算", "Budget"]), ("质保期", ["质保", "Warranty"]), ("评标方法", ["评标", "Evaluation"]), ] print("\n{:<15} | {:>15} | {:>15}".format("Query", "BM25 Docs", "HTML Tree Docs")) print("-" * 60) bm25_total_precision = 0 tree_total_precision = 0 num_queries = len(test_queries) for query, keywords in test_queries: t0 = time.time() bm25_results = bm25_rag.retrieve(query, k=3) bm25_time = time.time() - t0 # 计算BM25的precision bm25_relevant = sum(1 for doc, _ in bm25_results if any(kw.lower() in doc.page_content.lower() for kw in keywords)) bm25_precision = bm25_relevant / len(bm25_results) if bm25_results else 0 t1 = time.time() tree_results = tree_rag.query(query, k=3) tree_time = time.time() - t1 # 计算HTML Tree的precision - 考虑父节点上下文 def is_relevant_with_context(doc, keywords): """Check relevance considering parent context from path.""" content = doc.page_content.lower() path = doc.metadata.get("path", "").lower() # Check content if any(kw.lower() in content for kw in keywords): return True # Check path (which includes parent nodes) if any(kw.lower() in path for kw in keywords): return True return False tree_relevant = sum(1 for doc, _ in tree_results if is_relevant_with_context(doc, keywords)) tree_precision = tree_relevant / len(tree_results) if tree_results else 0 print("{:<15} | {:>10} ({:.4f}s, P@3:{:.1%}) | {:>10} ({:.4f}s, P@3:{:.1%})".format( query, len(bm25_results), bm25_time, bm25_precision, len(tree_results), tree_time, tree_precision)) bm25_total_precision += bm25_precision tree_total_precision += tree_precision print("-" * 60) print("{:<15} | {:>15} | {:>15}".format( "Average", "P@3: {:.1%}".format(bm25_total_precision / num_queries), "P@3: {:.1%}".format(tree_total_precision / num_queries))) print("\nComparison complete!") if __name__ == "__main__": test_bm25_retrieval() test_bm25_html_tree() # 如果需要对比,取消下面的注释 compare_bm25_vs_html_tree()