浏览代码

初始提交

luojiehua 1 周之前
当前提交
d6f8328eaf
共有 98 个文件被更改,包括 11387 次插入0 次删除
  1. 6 0
      .idea/.gitignore
  2. 6 0
      .idea/MarsCodeWorkspaceAppSettings.xml
  3. 16 0
      .idea/misc.xml
  4. 8 0
      .idea/modules.xml
  5. 6 0
      .idea/vcs.xml
  6. 6 0
      .vscode/settings.json
  7. 3 0
      1.4.1
  8. 406 0
      BIDIRAG_USAGE.md
  9. 11 0
      BidiRAG.iml
  10. 150 0
      RAG_METHODS_TEST_REPORT.md
  11. 163 0
      README.md
  12. 37 0
      bdirag/__init__.py
  13. 257 0
      bdirag/benchmark.py
  14. 498 0
      bdirag/bidi_rag.py
  15. 84 0
      bdirag/config.py
  16. 141 0
      bdirag/document_processor.py
  17. 110 0
      bdirag/embedding_models.py
  18. 1377 0
      bdirag/rag_methods.py
  19. 37 0
      bdirag/rag_methods/__init__.py
  20. 94 0
      bdirag/rag_methods/adaptive_rag.py
  21. 111 0
      bdirag/rag_methods/base.py
  22. 322 0
      bdirag/rag_methods/bid_field_extraction_rag.py
  23. 67 0
      bdirag/rag_methods/bm25_backend.py
  24. 448 0
      bdirag/rag_methods/bm25_html_tree_rag.py
  25. 87 0
      bdirag/rag_methods/bm25_rag.py
  26. 46 0
      bdirag/rag_methods/contextual_compression_rag.py
  27. 53 0
      bdirag/rag_methods/corrective_rag.py
  28. 46 0
      bdirag/rag_methods/dedup.py
  29. 61 0
      bdirag/rag_methods/ensemble_rag.py
  30. 49 0
      bdirag/rag_methods/flare_rag.py
  31. 83 0
      bdirag/rag_methods/graph_rag.py
  32. 59 0
      bdirag/rag_methods/hybrid_search_rag.py
  33. 30 0
      bdirag/rag_methods/hyde_rag.py
  34. 89 0
      bdirag/rag_methods/keyword_rag.py
  35. 44 0
      bdirag/rag_methods/llm_filter_rag.py
  36. 58 0
      bdirag/rag_methods/metadata_filter_rag.py
  37. 46 0
      bdirag/rag_methods/multi_query_rag.py
  38. 22 0
      bdirag/rag_methods/naive_rag.py
  39. 71 0
      bdirag/rag_methods/parent_document_rag.py
  40. 64 0
      bdirag/rag_methods/query_routing_rag.py
  41. 90 0
      bdirag/rag_methods/raptor_rag.py
  42. 38 0
      bdirag/rag_methods/rerank_rag.py
  43. 80 0
      bdirag/rag_methods/self_rag.py
  44. 37 0
      bdirag/rag_methods/step_back_rag.py
  45. 88 0
      bdirag/rag_methods/table_aware_rag.py
  46. 76 0
      bdirag/rag_methods/tfidf_rag.py
  47. 90 0
      bdirag/rag_methods/tokenization.py
  48. 181 0
      bdirag/vector_stores.py
  49. 46 0
      convert_unicode_to_chinese.py
  50. 二进制
      doubao-page.png
  51. 1 0
      examples/__init__.py
  52. 371 0
      examples/benchmark_all_methods.py
  53. 189 0
      examples/benchmark_retrieval_speed.py
  54. 82 0
      examples/bid_field_extraction_demo.py
  55. 83 0
      examples/debug_bm25.py
  56. 45 0
      examples/debug_bm25_html.py
  57. 198 0
      examples/demo_bid_announcement.py
  58. 276 0
      examples/demo_bidi_rag.py
  59. 129 0
      examples/demo_bm25_retrieval.py
  60. 149 0
      examples/demo_tree_node_retrieval.py
  61. 282 0
      examples/extract_bid_info.py
  62. 115 0
      examples/quick_demo.py
  63. 133 0
      examples/quick_test_methods.py
  64. 317 0
      examples/rag_test_utils.py
  65. 242 0
      examples/sample_data.py
  66. 12 0
      examples/test_adaptive_rag.py
  67. 226 0
      examples/test_all_rag_methods.py
  68. 21 0
      examples/test_bid_field_extraction_rag.py
  69. 292 0
      examples/test_bidi_rag.py
  70. 267 0
      examples/test_bm25.py
  71. 12 0
      examples/test_bm25_html_tree_rag.py
  72. 12 0
      examples/test_contextual_compression_rag.py
  73. 12 0
      examples/test_corrective_rag.py
  74. 12 0
      examples/test_ensemble_rag.py
  75. 12 0
      examples/test_flare_rag.py
  76. 12 0
      examples/test_graph_rag.py
  77. 15 0
      examples/test_hybrid_search_rag.py
  78. 12 0
      examples/test_hyde_rag.py
  79. 22 0
      examples/test_keyword_rag.py
  80. 12 0
      examples/test_llm_filter_rag.py
  81. 12 0
      examples/test_metadata_filter_rag.py
  82. 213 0
      examples/test_methods_direct.py
  83. 12 0
      examples/test_multi_query_rag.py
  84. 12 0
      examples/test_naive_rag.py
  85. 12 0
      examples/test_parent_document_rag.py
  86. 12 0
      examples/test_query_routing_rag.py
  87. 19 0
      examples/test_raptor_rag.py
  88. 13 0
      examples/test_rerank_rag.py
  89. 90 0
      examples/test_retrieval_dedup.py
  90. 12 0
      examples/test_self_rag.py
  91. 12 0
      examples/test_step_back_rag.py
  92. 12 0
      examples/test_table_aware_rag.py
  93. 14 0
      examples/test_tfidf_rag.py
  94. 1 0
      examples/untitled-1.py
  95. 90 0
      fix_fstrings.py
  96. 0 0
      parser/__init__.py
  97. 1239 0
      parser/htmlparser.py
  98. 51 0
      requirements.txt

+ 6 - 0
.idea/.gitignore

@@ -0,0 +1,6 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml

+ 6 - 0
.idea/MarsCodeWorkspaceAppSettings.xml

@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="com.codeverse.userSettings.MarscodeWorkspaceAppSettingsState">
+    <option name="progress" value="1.0" />
+  </component>
+</project>

+ 16 - 0
.idea/misc.xml

@@ -0,0 +1,16 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="Black">
+    <option name="sdkName" value="Python 3.11 (open_manus)" />
+  </component>
+  <component name="BspLocalSettings">
+    <option name="projectSyncType">
+      <map>
+        <entry key="D:/Workspace2016/PaddleOCR" value="PREVIEW" />
+      </map>
+    </option>
+  </component>
+  <component name="ProjectRootManager" version="2" languageLevel="JDK_26" project-jdk-name="Python 3.11 (open_manus)" project-jdk-type="Python SDK">
+    <output url="file://$PROJECT_DIR$/out" />
+  </component>
+</project>

+ 8 - 0
.idea/modules.xml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/BidiRAG.iml" filepath="$PROJECT_DIR$/BidiRAG.iml" />
+    </modules>
+  </component>
+</project>

+ 6 - 0
.idea/vcs.xml

@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="$PROJECT_DIR$" vcs="Git" />
+  </component>
+</project>

+ 6 - 0
.vscode/settings.json

@@ -0,0 +1,6 @@
+{
+  "python.defaultInterpreterPath": "C:\\Anaconda3.4\\envs\\open_manus\\python.exe",
+  "terminal.integrated.env.windows": {
+    "PATH": "C:\\Anaconda3.4\\envs\\open_manus;C:\\Anaconda3.4\\envs\\open_manus\\Scripts;${env:PATH}"
+  }
+}

+ 3 - 0
1.4.1

@@ -0,0 +1,3 @@
+Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
+Requirement already satisfied: scipy in c:\anaconda3.4\envs\py37\lib\site-packages (1.2.0)
+Requirement already satisfied: numpy>=1.8.2 in c:\anaconda3.4\envs\py37\lib\site-packages (from scipy) (1.17.2)

+ 406 - 0
BIDIRAG_USAGE.md

@@ -0,0 +1,406 @@
+# BidiRag 使用指南
+
+## 概述
+
+`BidiRag` 是 BidiRAG 项目的高级封装类,提供统一的接口来:
+- 加载和处理 HTML 文档
+- 使用多种 RAG 方法进行文档检索
+- 支持关键词检索和自然语言查询
+- 支持从招投标公告中提取结构化信息
+
+## 快速开始
+
+### 基本用法
+
+```python
+from bdirag.bidi_rag import BidiRag
+
+# 1. 初始化 BidiRag(推荐用于招投标公告)
+rag = BidiRag(rag_method='bm25_html_tree')
+
+# 2. 添加文档
+rag.add_documents(["path/to/announcement1.html", "path/to/announcement2.html"])
+
+# 3. 检索文档
+results = rag.retrieve(
+    query="招标人和中标人",
+    top_k=5,
+    keywords=["招标人", "中标人", "采购人"]
+)
+
+# 4. 查看结果
+for doc, score in results:
+    print(f"Score: {score:.3f}")
+    print(doc.page_content)
+    print("-" * 80)
+```
+
+### 关键词搜索
+
+```python
+# 搜索包含特定关键词的文档
+results = rag.search_keywords(["招标人"], top_k=10)
+
+# 搜索同时包含多个关键词的文档(AND 逻辑)
+results = rag.search_keywords(["招标人", "中标人"], top_k=10)
+```
+
+### 使用不同的 RAG 方法
+
+```python
+# BM25 方法(关键词检索,无需 embedding 模型)
+rag = BidiRag(rag_method='bm25')
+
+# TF-IDF 方法
+rag = BidiRag(rag_method='tfidf')
+
+# HTML 结构感知方法(推荐用于招投标公告)
+rag = BidiRag(rag_method='bm25_html_tree')
+
+# 语义检索方法(需要 embedding 模型)
+rag = BidiRag(rag_method='naive')
+
+# 混合搜索方法
+rag = BidiRag(rag_method='hybrid_search')
+```
+
+### 完整问答(需要 LLM)
+
+```python
+from openai import OpenAI
+
+# 配置 LLM 客户端
+llm_client = OpenAI(
+    api_key="your-api-key",
+    base_url="https://api.openai.com/v1"
+)
+
+# 使用支持 LLM 的 RAG 方法
+rag = BidiRag(
+    rag_method='naive',
+    llm_client=llm_client,
+    llm_model='gpt-4o'
+)
+
+rag.add_texts([
+    "招标公告:招标人XX局,项目预算100万",
+    "中标公告:中标人XX公司,金额98万"
+])
+
+# 完整问答
+result = rag.query(
+    query="谁是招标人?",
+    keywords=["招标人", "采购人"]
+)
+
+print("答案:", result.answer)
+print("检索到的文档数:", len(result.retrieved_docs))
+```
+
+## API 参考
+
+### BidiRag 类
+
+#### 初始化参数
+
+```python
+BidiRag(
+    rag_method: str = 'bm25_html_tree',  # RAG 方法名称
+    chunk_size: int = 512,               # 文档块大小
+    chunk_overlap: int = 50,             # 块重叠大小
+    vector_store_type: str = 'faiss',    # 向量存储类型
+    embedding_model_name: str = None,    # Embedding 模型名称
+    llm_client=None,                     # LLM 客户端
+    llm_model: str = "gpt-4o",          # LLM 模型名称
+    index_name: str = "default"          # 索引名称
+)
+```
+
+#### 主要方法
+
+##### `add_documents(sources, rebuild_index=True)`
+添加文档文件(支持 HTML、PDF、DOCX、TXT 等)
+
+- `sources`: 文件路径或目录路径(字符串或列表)
+- `rebuild_index`: 是否重建索引
+- 返回: 添加的文档数量
+
+```python
+rag.add_documents(["doc1.html", "doc2.html"])
+rag.add_documents("path/to/documents/")  # 目录
+```
+
+##### `add_texts(texts, metadata=None, rebuild_index=True)`
+添加文本文档
+
+- `texts`: 文本内容列表
+- `metadata`: 元数据列表(可选)
+- `rebuild_index`: 是否重建索引
+- 返回: 添加的文档数量
+
+```python
+rag.add_texts([
+    "<html><body>招标公告...</body></html>",
+    "<html><body>中标公告...</body></html>"
+])
+```
+
+##### `retrieve(query, top_k=None, keywords=None, return_scores=True)`
+检索相关文档
+
+- `query`: 搜索查询(自然语言或关键词)
+- `top_k`: 返回结果数量
+- `keywords`: 额外关键词过滤(可选)
+- `return_scores`: 是否返回相似度分数
+- 返回: 文档列表(带分数)
+
+```python
+results = rag.retrieve(
+    query="采购信息",
+    top_k=5,
+    keywords=["招标人", "采购人", "中标人"]
+)
+```
+
+##### `query(query, top_k=None, keywords=None)`
+完整 RAG 流程:检索 + 生成答案(需要 LLM)
+
+- `query`: 搜索查询
+- `top_k`: 检索文档数量
+- `keywords`: 关键词过滤
+- 返回: RAGResult 对象(包含答案和检索文档)
+
+```python
+result = rag.query(
+    query="项目预算是多少?",
+    keywords=["预算", "金额"]
+)
+print(result.answer)
+```
+
+##### `search_keywords(keywords, top_k=None)`
+基于关键词精确搜索
+
+- `keywords`: 关键词列表(AND 逻辑)
+- `top_k`: 最大结果数量
+- 返回: 文档列表
+
+```python
+results = rag.search_keywords(["招标人", "中标人"], top_k=10)
+```
+
+##### `get_document_count()`
+获取已加载的文档数量
+
+##### `list_available_methods()`
+列出所有可用的 RAG 方法
+
+##### `get_method_info()`
+获取当前 RAG 方法的信息
+
+##### `clear()`
+清除所有文档和索引
+
+## 可用的 RAG 方法
+
+### 无需 Embedding 模型的方法
+
+| 方法名 | 描述 | 适用场景 |
+|--------|------|----------|
+| `bm25` | BM25 关键词检索 | 快速关键词匹配 |
+| `tfidf` | TF-IDF 检索 | 简单关键词匹配 |
+| `keyword` | 关键词检索(BM25/TF-IDF) | 基础关键词搜索 |
+| `bm25_html_tree` | HTML 结构感知的 BM25 | **招投标公告(推荐)** |
+
+### 需要 Embedding 模型的方法
+
+| 方法名 | 描述 | 特点 |
+|--------|------|------|
+| `naive` | 基础语义检索 | 简单向量相似度搜索 |
+| `rerank` | 重排序检索 | 检索后使用重排序模型优化 |
+| `hybrid_search` | 混合搜索 | 结合向量搜索和关键词搜索 |
+| `multi_query` | 多查询检索 | 生成多个查询进行检索 |
+| `hyde` | 假设文档嵌入 | 基于假设文档的检索 |
+| `step_back` | 后退一步检索 | 生成更一般的查询 |
+| `parent_document` | 父文档检索 | 检索小块,返回大块 |
+| `contextual_compression` | 上下文压缩 | 压缩检索到的上下文 |
+
+### 需要 LLM 的高级方法
+
+| 方法名 | 描述 | 特点 |
+|--------|------|------|
+| `adaptive` | 自适应检索 | 根据查询难度自适应策略 |
+| `self_rag` | 自反思检索 | 检索后进行自我评估 |
+| `corrective` | 纠正性检索 | 检索后进行纠正 |
+| `flare` | FLARE 方法 | 生成与检索交替 |
+| `raptor` | RAPTOR 方法 | 递归树聚合 |
+| `ensemble` | 集成检索 | 多种方法集成 |
+
+## 实用示例
+
+### 示例 1:从招投标公告中提取信息
+
+```python
+from bdirag.bidi_rag import BidiRag
+
+# 使用 HTML 结构感知方法
+rag = BidiRag(rag_method='bm25_html_tree')
+
+# 添加公告
+html_content = """
+<html>
+<body>
+    <h1>政府采购中标公告</h1>
+    <table>
+        <tr><td>采购人</td><td>XX市财政局</td></tr>
+        <tr><td>中标人</td><td>XX科技有限公司</td></tr>
+        <tr><td>中标金额</td><td>500万元</td></tr>
+    </table>
+</body>
+</html>
+"""
+
+rag.add_texts([html_content])
+
+# 检索招标人信息
+results = rag.retrieve(
+    query="采购人信息",
+    keywords=["采购人", "招标人"]
+)
+
+for doc, score in results:
+    print(f"\n找到采购人信息 (score: {score:.3f}):")
+    print(doc.page_content)
+```
+
+### 示例 2:批量处理多个公告
+
+```python
+import os
+from bdirag.bidi_rag import BidiRag
+
+rag = BidiRag(rag_method='bm25_html_tree')
+
+# 处理目录中的所有 HTML 文件
+announcement_dir = "data/announcements/"
+rag.add_documents(announcement_dir)
+
+print(f"已加载 {rag.get_document_count()} 个公告")
+
+# 搜索特定信息
+results = rag.retrieve(
+    query="太阳能路灯采购",
+    top_k=10,
+    keywords=["路灯", "照明", "采购"]
+)
+
+print(f"\n找到 {len(results)} 个相关公告")
+```
+
+### 示例 3:关键词精化搜索
+
+```python
+rag = BidiRag(rag_method='bm25')
+
+# 添加文档
+rag.add_texts([
+    "招标人A公司,预算100万",
+    "招标人B公司,中标人C公司,预算200万",
+    "采购人D单位,预算150万"
+])
+
+# 只搜索包含"招标人"的文档
+results = rag.retrieve(
+    query="公司信息",
+    keywords=["招标人"]
+)
+
+# 只搜索同时包含"招标人"和"中标人"的文档
+results = rag.retrieve(
+    query="公司信息",
+    keywords=["招标人", "中标人"]
+)
+```
+
+## 常见问题
+
+### Q: 为什么检索结果为空?
+
+A: 可能的原因:
+1. 文档索引未构建 - 确保调用了 `add_documents()` 或 `build_index()`
+2. 关键词不匹配 - 检查关键词是否在文档中存在
+3. 使用关键词过滤时 - 确保文档包含所有指定的关键词
+
+### Q: 如何选择 RAG 方法?
+
+A: 建议:
+- **招投标公告检索**: 使用 `bm25_html_tree`(保留 HTML 结构)
+- **快速关键词搜索**: 使用 `bm25` 或 `tfidf`
+- **语义理解搜索**: 使用 `naive` 或 `hybrid_search`
+- **需要生成答案**: 使用支持 LLM 的方法(如 `naive` + `llm_client`)
+
+### Q: 如何配置 Embedding 模型?
+
+A: 默认使用 `BAAI/bge-large-zh-v1.5` 模型:
+
+```python
+# 使用默认模型
+rag = BidiRag(rag_method='naive')
+
+# 或使用其他模型
+rag = BidiRag(
+    rag_method='naive',
+    embedding_model_name='BAAI/bge-base-zh-v1.5'
+)
+```
+
+### Q: 如何配置 LLM?
+
+A: 示例使用 OpenAI API:
+
+```python
+from openai import OpenAI
+
+llm_client = OpenAI(
+    api_key="your-api-key",
+    base_url="https://api.openai.com/v1"  # 或使用其他兼容 API
+)
+
+rag = BidiRag(
+    rag_method='naive',
+    llm_client=llm_client,
+    llm_model='gpt-4o'
+)
+```
+
+## 测试和示例
+
+运行测试:
+```bash
+python examples/test_bidi_rag.py
+```
+
+运行完整示例:
+```bash
+python examples/demo_bidi_rag.py
+```
+
+## 注意事项
+
+1. **网络问题**: 如果使用 embedding 模型,需要访问 HuggingFace。如遇网络问题,可以:
+   - 使用无需 embedding 的方法(`bm25`, `tfidf`, `keyword`, `bm25_html_tree`)
+   - 配置代理或镜像
+   - 下载模型到本地
+
+2. **内存使用**: 大量文档时注意内存使用,可以分批处理
+
+3. **关键词过滤**: 使用 `keywords` 参数时,文档必须包含所有关键词(AND 逻辑)
+
+4. **HTML 处理**: `bm25_html_tree` 方法会解析 HTML 结构,适合结构化文档
+
+## 技术支持
+
+如有问题,请查看:
+- 示例代码: `examples/` 目录
+- 测试代码: `examples/test_bidi_rag.py`
+- RAG 方法实现: `bdirag/rag_methods/` 目录

+ 11 - 0
BidiRAG.iml

@@ -0,0 +1,11 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager" inherit-compiler-output="true">
+    <exclude-output />
+    <content url="file://$MODULE_DIR$">
+      <sourceFolder url="file://$MODULE_DIR$/examples" isTestSource="false" />
+    </content>
+    <orderEntry type="jdk" jdkName="Python 3.7 (py37)" jdkType="Python SDK" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+</module>

+ 150 - 0
RAG_METHODS_TEST_REPORT.md

@@ -0,0 +1,150 @@
+# BidiRag RAG 方法测试报告
+
+## 测试环境
+- Python 版本: 3.7
+- 操作系统: Windows 22H2
+- 测试时间: 2026-05-08
+
+## 测试文档
+两个招投标公告文档(纯文本格式),包含预算金额、资质要求、评标方法等信息。
+
+## 测试结果汇总
+
+### ✅ 成功的方法
+
+#### 1. BM25 (BM25RAG)
+- **状态**: ✅ 正常工作
+- **召回数量**: 2/2 文档
+- **精确度**: 100%
+- **检索时间**: ~0.5ms
+- **特点**: 
+  - 基于概率检索模型
+  - 支持中文分词(jieba)
+  - 返回负分数(BM25 特性)
+  - 适合关键词匹配场景
+
+#### 2. BM25 HTML Tree (BM25HTMLTreeRAG)
+- **状态**: ⚠️ 部分工作
+- **召回数量**: 0(HTML 解析问题)
+- **精确度**: N/A
+- **检索时间**: < 1ms
+- **特点**:
+  - 解析 HTML 为树结构
+  - 保留层级信息(path metadata)
+  - 当前问题:简单 HTML 被解析为单个节点
+  - 需要更复杂的 HTML 结构才能发挥优势
+
+### ❌ 失败的方法
+
+#### 3. TF-IDF (TFIDFRAG)
+- **状态**: ❌ 失败
+- **错误**: `cannot import name 'get_config'`
+- **原因**: 依赖库版本不兼容
+
+#### 4. Keyword (KeywordRAG)
+- **状态**: ❌ 失败
+- **错误**: `No module named 'rank_bm25'`
+- **原因**: 缺少 rank_bm25 模块
+
+#### 5-24. 其他需要 Embedding 的方法
+- **状态**: ❌ 无法测试
+- **原因**: 网络限制,无法从 HuggingFace 下载 embedding 模型
+- **包括**: naive, rerank, parent_document, llm_filter, query_routing, metadata_filter, adaptive, hybrid_search, multi_query, hyde, step_back, contextual_compression, self_rag, corrective, flare, raptor, ensemble, bid_field_extraction, table_aware, graph
+
+## 推荐使用方法
+
+### 对于招投标公告检索
+
+#### 方案 1: BM25(推荐)
+```python
+from bdirag.rag_methods.bm25_rag import BM25RAG
+from bdirag.document_processor import Document
+
+# 初始化
+rag = BM25RAG()
+rag.index_documents([Document(page_content=html_text)])
+
+# 检索
+results = rag.retrieve("预算金额", k=5)
+for doc, score in results:
+    print(doc.page_content)
+```
+
+**优点**:
+- ✅ 快速(< 1ms)
+- ✅ 精确度高(100%)
+- ✅ 无需外部依赖
+- ✅ 支持中文
+
+**缺点**:
+- ❌ 不保留 HTML 结构
+- ❌ 返回整个文档而非片段
+
+#### 方案 2: BM25 HTML Tree(适用于复杂 HTML)
+```python
+from bdirag.rag_methods.bm25_html_tree_rag import BM25HTMLTreeRAG
+
+# 初始化
+rag = BM25HTMLTreeRAG()
+rag.build_index(html_content)
+
+# 检索
+results = rag.query("预算金额", k=5)
+for doc, score in results:
+    path = doc.metadata.get("path", "")
+    print(f"Path: {path}")
+    print(f"Content: {doc.page_content}")
+```
+
+**优点**:
+- ✅ 保留 HTML 层级结构
+- ✅ 返回树节点片段
+- ✅ 包含路径信息
+
+**缺点**:
+- ⚠️ 需要复杂的 HTML 结构
+- ⚠️ 简单表格可能被合并为一个节点
+
+## 已知问题
+
+### 1. BM25 返回负分数
+- **现象**: BM25 得分可能是负数(如 -1.1182)
+- **原因**: BM25Okapi 的实现特性
+- **解决**: 已修复,允许负分通过过滤条件
+
+### 2. BM25 HTML Tree 节点过少
+- **现象**: 简单 HTML 只解析出 1-2 个节点
+- **原因**: ParseDocument 将表格作为整体处理
+- **建议**: 使用更复杂的 HTML 结构或改用 BM25
+
+### 3. 网络依赖问题
+- **现象**: 大部分 RAG 方法需要 embedding 模型
+- **影响**: 无法在无网络环境下使用
+- **建议**: 预下载模型或使用离线方案
+
+## 性能对比
+
+| 方法 | 初始化时间 | 索引时间 | 检索时间 | 精确度 | 适用场景 |
+|------|-----------|---------|---------|--------|---------|
+| BM25 | < 1s | < 1s | < 1ms | 100% | 关键词检索 |
+| BM25 HTML Tree | < 1s | < 1s | < 1ms | N/A | 结构化 HTML |
+| TF-IDF | - | - | - | - | 依赖问题 |
+| Keyword | - | - | - | - | 缺少模块 |
+
+## 总结
+
+### 最佳实践
+1. **简单场景**: 使用 BM25,快速且准确
+2. **结构化 HTML**: 使用 BM25 HTML Tree,保留层级信息
+3. **生产环境**: 预加载 embedding 模型以支持更多方法
+
+### 下一步改进
+1. 修复 TF-IDF 依赖问题
+2. 安装 rank_bm25 模块支持 Keyword 方法
+3. 优化 BM25 HTML Tree 的 HTML 解析逻辑
+4. 提供离线 embedding 模型下载方案
+
+---
+
+**测试完成时间**: 2026-05-08 17:21
+**测试人员**: AI Assistant

+ 163 - 0
README.md

@@ -0,0 +1,163 @@
+# BidiRAG - 招投标领域RAG检索与生成项目
+
+基于多种RAG(Retrieval-Augmented Generation)方法的招投标公告和附件信息检索与字段提取系统。
+
+## 项目简介
+
+本项目专注于招投标领域的智能化信息处理,核心功能包括:
+
+- 多种RAG检索方法的实现与对比
+- 招投标公告结构化字段提取
+- 检索效果与速度的基准测试
+
+## 支持的RAG方法
+
+### 基础方法
+
+| 方法                  | 描述        | 特点                 |
+| ------------------- | --------- | ------------------ |
+| **BM25RAG**         | BM25关键词检索 | 纯BM25概率模型,不依赖向量嵌入  |
+| **TFIDFRAG**        | TF-IDF检索  | 纯TF-IDF余弦相似度,不依赖向量嵌入 |
+| **NaiveRAG**        | 最基础的RAG实现 | 向量语义检索 + LLM生成     |
+| **RerankRAG**       | 带重排序的RAG  | 初始检索后使用重排模型精排      |
+| **HybridSearchRAG** | 混合搜索RAG   | 语义检索 + BM25关键词检索融合 |
+
+### 高级方法
+
+| 方法                           | 描述        | 特点                   |
+| ---------------------------- | --------- | -------------------- |
+| **MultiQueryRAG**            | 多查询扩展RAG  | 将原问题扩展为多个角度查询        |
+| **HyDERAG**                  | 假设文档嵌入RAG | 先生成假设文档再检索           |
+| **SelfRAG**                  | 自我反思RAG   | 检索前判断是否需要检索,检索后评估相关性 |
+| **CorrectiveRAG**            | 纠正性RAG    | 评估文档正确性,不足时补充外部搜索    |
+| **FLARERAG**                 | 主动检索生成RAG | 迭代式检索与生成             |
+| **RAPTORRAG**                | 递归摘要树RAG  | 构建多层级摘要树进行检索         |
+| **StepBackRAG**              | 抽象回退RAG   | 将具体问题抽象为高层问题检索背景知识   |
+| **ContextualCompressionRAG** | 上下文压缩RAG  | 压缩检索结果仅保留相关信息        |
+| **EnsembleRAG**              | 集成检索RAG   | 多种检索策略结果融合           |
+
+### 招投标专用方法
+
+| 方法                        | 描述         | 特点             |
+| ------------------------- | ---------- | -------------- |
+| **BidFieldExtractionRAG** | 招投标字段提取RAG | 针对招投标字段优化检索与提取 |
+| **TableAwareRAG**         | 表格感知RAG    | 专门处理表格和结构化数据   |
+| **GraphRAG**              | 图谱增强RAG    | 结合实体关系图谱进行检索   |
+
+## 项目结构
+
+```
+BidiRAG/
+├── bdirag/
+│   ├── __init__.py
+│   ├── config.py                 # 配置文件
+│   ├── document_processor.py     # 文档处理(PDF/Word/Excel/TXT)
+│   ├── embedding_models.py       # 嵌入模型(SentenceTransformers/OpenAI/DashScope/Zhipu)
+│   ├── vector_stores.py          # 向量存储(FAISS/Chroma)
+│   ├── rag_methods.py            # 所有RAG方法实现
+│   └── benchmark.py              # 基准测试模块
+├── examples/
+│   ├── sample_data.py            # 示例招投标数据
+│   ├── quick_demo.py             # 快速演示
+│   ├── benchmark_all_methods.py  # 全方法基准测试
+│   ├── benchmark_retrieval_speed.py  # 检索速度测试
+│   └── bid_field_extraction_demo.py  # 字段提取演示
+├── data/
+│   ├── documents/                # 放置待处理的招投标文档
+│   ├── indexes/                  # 向量索引存储
+│   └── cache/                    # 缓存目录
+├── output/                       # 输出目录(基准测试结果等)
+├── requirements.txt
+└── README.md
+```
+
+## 快速开始
+
+### 安装依赖
+
+```bash
+pip install -r requirements.txt
+```
+
+### 设置环境变量
+
+```bash
+# OpenAI API配置(或兼容的API)
+export OPENAI_API_KEY="your-api-key"
+export OPENAI_BASE_URL="https://api.openai.com/v1"  # 或其他兼容API地址
+export LLM_MODEL="gpt-4o"
+
+# 嵌入模型配置
+export EMBEDDING_MODEL="BAAI/bge-large-zh-v1.5"
+```
+
+### 运行示例
+
+```bash
+# 快速演示
+python examples/quick_demo.py
+
+# 检索速度测试(无需LLM)
+python examples/benchmark_retrieval_speed.py
+
+# 完整基准测试
+python examples/benchmark_all_methods.py
+
+# 字段提取演示
+python examples/bid_field_extraction_demo.py
+```
+
+## 支持的嵌入模型
+
+- **SentenceTransformers**: 本地部署,支持中文(如BAAI/bge-large-zh-v1.5)
+- **OpenAI**: text-embedding-3-large/small
+- **DashScope**: 阿里云text-embedding-v2
+- **Zhipu**: 智谱AI embedding-3
+
+## 支持的文档格式
+
+- PDF
+- Word (.docx)
+- Excel (.xlsx/.xls)
+- 纯文本 (.txt)
+
+## 招投标字段提取
+
+系统支持提取以下招投标关键字段:
+
+- 项目名称 (project\_name)
+- 项目编号 (project\_code)
+- 预算金额 (budget\_amount)
+- 币种 (currency)
+- 投标截止时间 (bid\_deadline)
+- 开标时间 (bid\_open\_time)
+- 投标地点 (bid\_location)
+- 采购人名称 (purchaser\_name)
+- 采购人联系人 (purchaser\_contact)
+- 采购人电话 (purchaser\_phone)
+- 代理机构名称 (agency\_name)
+- 代理机构联系人 (agency\_contact)
+- 代理机构电话 (agency\_phone)
+- 资格要求 (qualification\_requirements)
+- 投标保证金 (bid\_bond\_amount)
+- 履约保证金 (performance\_bond\_amount)
+- 质保期 (warranty\_period)
+- 交货时间 (delivery\_time)
+- 交货地点 (delivery\_location)
+- 付款方式 (payment\_terms)
+- 评标方法 (evaluation\_method)
+- 工作范围 (scope\_of\_work)
+
+## 基准测试指标
+
+系统会输出以下对比指标:
+
+- 平均总延迟 / 检索延迟 / 生成延迟
+- P50 / P95 延迟
+- 最小 / 最大延迟
+- 平均检索文档数
+- 吞吐量(QPS)
+
+## License
+
+MIT

+ 37 - 0
bdirag/__init__.py

@@ -0,0 +1,37 @@
+from .config import *
+from .document_processor import DocumentProcessor
+from .embedding_models import BaseEmbedding, SentenceTransformerEmbedding, OpenAIEmbedding
+from .vector_stores import BaseVectorStore, FAISSStore, ChromaStore
+from .rag_methods import (
+    BaseRAG,
+    RAGResult,
+    NaiveRAG,
+    RerankRAG,
+    ParentDocumentRAG,
+    LLMFilterRAG,
+    QueryRoutingRAG,
+    MetadataFilterRAG,
+    AdaptiveRAG,
+    HybridSearchRAG,
+    MultiQueryRAG,
+    HyDERAG,
+    StepBackRAG,
+    ContextualCompressionRAG,
+    SelfRAG,
+    CorrectiveRAG,
+    FLARERAG,
+    RAPTORRAG,
+    EnsembleRAG,
+    BidFieldExtractionRAG,
+    TableAwareRAG,
+    GraphRAG,
+    BM25RAG,
+    TFIDFRAG,
+    KeywordRAG,
+)
+try:
+    from .benchmark import RAGBenchmark
+except (ImportError, SyntaxError):
+    RAGBenchmark = None
+
+__version__ = "0.1.0"

+ 257 - 0
bdirag/benchmark.py

@@ -0,0 +1,257 @@
+import time
+import json
+import os
+from typing import List, Dict, Any, Optional
+from dataclasses import dataclass, asdict
+import numpy as np
+from loguru import logger
+
+from .rag_methods import BaseRAG, RAGResult
+
+
+@dataclass
+class BenchmarkMetrics:
+    method_name: str
+    avg_latency_total: float
+    avg_latency_retrieval: float
+    avg_latency_generation: float
+    avg_docs_retrieved: float
+    total_queries: int
+    latency_std: float
+    retrieval_std: float
+    generation_std: float
+    min_latency: float
+    max_latency: float
+    p50_latency: float
+    p95_latency: float
+
+    def to_dict(self):
+        return asdict(self)
+
+
+@dataclass
+class BenchmarkResult:
+    metrics: List[BenchmarkMetrics]
+    detailed_results: Dict[str, List[RAGResult]]
+    timestamp: str
+
+    def to_dict(self):
+        return {
+            "metrics": [m.to_dict() for m in self.metrics],
+            "timestamp": self.timestamp,
+        }
+
+    def save(self, path):
+        os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
+        with open(path, "w", encoding="utf-8") as f:
+            json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
+        logger.info("Benchmark results saved to {}".format(path))
+
+
+class RAGBenchmark:
+    def __init__(self):
+        self.results = {}
+
+    def run_single_query(self, rag_method, query, k=10):
+        result = rag_method.query(query, k)
+        return result
+
+    def run_benchmark(self, rag_methods, queries, k=10, verbose=True):
+        self.results = {}
+
+        for method in rag_methods:
+            method_results = []
+            if verbose:
+                logger.info("Running benchmark for: {}".format(method.name))
+
+            for i, query in enumerate(queries):
+                if verbose:
+                    logger.info("  Query {}/{}: {}...".format(i + 1, len(queries), query[:50]))
+
+                try:
+                    result = self.run_single_query(method, query, k)
+                    method_results.append(result)
+                    if verbose:
+                        logger.info("    Answer: {}...".format(result.answer[:80]))
+                        logger.info("    Latency: {:.3f}s".format(result.latency_total))
+                except Exception as e:
+                    logger.error("    Error on query {} for {}: {}".format(i + 1, method.name, e))
+
+            self.results[method.name] = method_results
+
+        metrics = self._compute_metrics()
+        import datetime
+        benchmark_result = BenchmarkResult(
+            metrics=metrics,
+            detailed_results=self.results,
+            timestamp=datetime.datetime.now().isoformat(),
+        )
+
+        if verbose:
+            self._print_summary(metrics)
+
+        return benchmark_result
+
+    def _compute_metrics(self):
+        metrics = []
+
+        for method_name, results in self.results.items():
+            if not results:
+                continue
+
+            total_latencies = [r.latency_total for r in results]
+            retrieval_latencies = [r.latency_retrieval for r in results]
+            generation_latencies = [r.latency_generation for r in results]
+            num_docs = [len(r.retrieved_docs) for r in results]
+
+            metric = BenchmarkMetrics(
+                method_name=method_name,
+                avg_latency_total=np.mean(total_latencies),
+                avg_latency_retrieval=np.mean(retrieval_latencies),
+                avg_latency_generation=np.mean(generation_latencies),
+                avg_docs_retrieved=np.mean(num_docs),
+                total_queries=len(results),
+                latency_std=np.std(total_latencies),
+                retrieval_std=np.std(retrieval_latencies),
+                generation_std=np.std(generation_latencies),
+                min_latency=np.min(total_latencies),
+                max_latency=np.max(total_latencies),
+                p50_latency=np.percentile(total_latencies, 50),
+                p95_latency=np.percentile(total_latencies, 95),
+            )
+            metrics.append(metric)
+
+        return metrics
+
+    def _print_summary(self, metrics):
+        from rich.console import Console
+        from rich.table import Table
+
+        console = Console()
+        table = Table(title="RAG Methods Benchmark Results")
+
+        table.add_column("Method", style="cyan")
+        table.add_column("Avg Total(s)", justify="right", style="green")
+        table.add_column("Avg Retrieval(s)", justify="right", style="green")
+        table.add_column("Avg Generation(s)", justify="right", style="green")
+        table.add_column("P50(s)", justify="right", style="yellow")
+        table.add_column("P95(s)", justify="right", style="yellow")
+        table.add_column("Min(s)", justify="right", style="magenta")
+        table.add_column("Max(s)", justify="right", style="magenta")
+        table.add_column("Avg Docs", justify="right", style="blue")
+
+        for m in sorted(metrics, key=lambda x: x.avg_latency_total):
+            table.add_row(
+                m.method_name,
+                "{:.3f}".format(m.avg_latency_total),
+                "{:.3f}".format(m.avg_latency_retrieval),
+                "{:.3f}".format(m.avg_latency_generation),
+                "{:.3f}".format(m.p50_latency),
+                "{:.3f}".format(m.p95_latency),
+                "{:.3f}".format(m.min_latency),
+                "{:.3f}".format(m.max_latency),
+                "{:.1f}".format(m.avg_docs_retrieved),
+            )
+
+        console.print(table)
+
+    def plot_comparison(self, metrics, save_path=None, show=True):
+        import matplotlib.pyplot as plt
+        import matplotlib
+
+        matplotlib.rcParams["font.sans-serif"] = ["SimHei", "Arial Unicode MS"]
+        matplotlib.rcParams["axes.unicode_minus"] = False
+
+        methods = [m.method_name for m in metrics]
+        avg_total = [m.avg_latency_total for m in metrics]
+        avg_retrieval = [m.avg_latency_retrieval for m in metrics]
+        avg_generation = [m.avg_latency_generation for m in metrics]
+        p50 = [m.p50_latency for m in metrics]
+        p95 = [m.p95_latency for m in metrics]
+
+        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
+        fig.suptitle("RAG Methods Performance Comparison", fontsize=16, fontweight="bold")
+
+        colors = plt.cm.Set3(np.linspace(0, 1, len(methods)))
+        x = np.arange(len(methods))
+        width = 0.25
+
+        ax1 = axes[0, 0]
+        bars1 = ax1.bar(x - width, avg_retrieval, width, label="Retrieval", color="#4CAF50", alpha=0.8)
+        bars2 = ax1.bar(x, avg_generation, width, label="Generation", color="#2196F3", alpha=0.8)
+        bars3 = ax1.bar(x + width, avg_total, width, label="Total", color="#FF9800", alpha=0.8)
+        ax1.set_xlabel("Method")
+        ax1.set_ylabel("Time (seconds)")
+        ax1.set_title("Average Latency Comparison")
+        ax1.set_xticks(x)
+        ax1.set_xticklabels(methods, rotation=45, ha="right")
+        ax1.legend()
+        ax1.grid(True, alpha=0.3)
+
+        ax2 = axes[0, 1]
+        ax2.bar(x, p50, width, label="P50", color="#9C27B0", alpha=0.8)
+        ax2.bar(x + width * 0.5, p95, width, label="P95", color="#E91E63", alpha=0.8)
+        ax2.set_xlabel("Method")
+        ax2.set_ylabel("Time (seconds)")
+        ax2.set_title("Percentile Latency Comparison")
+        ax2.set_xticks(x)
+        ax2.set_xticklabels(methods, rotation=45, ha="right")
+        ax2.legend()
+        ax2.grid(True, alpha=0.3)
+
+        ax3 = axes[1, 0]
+        avg_docs = [m.avg_docs_retrieved for m in metrics]
+        ax3.barh(methods, avg_docs, color="#00BCD4", alpha=0.8)
+        ax3.set_xlabel("Average Number of Documents")
+        ax3.set_title("Average Retrieved Documents")
+        ax3.grid(True, alpha=0.3, axis="x")
+
+        ax4 = axes[1, 1]
+        speeds = [1.0 / m.avg_latency_total for m in metrics]
+        ax4.barh(methods, speeds, color="#8BC34A", alpha=0.8)
+        ax4.set_xlabel("Queries per Second")
+        ax4.set_title("Throughput Comparison")
+        ax4.grid(True, alpha=0.3, axis="x")
+
+        plt.tight_layout()
+
+        if save_path:
+            os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
+            plt.savefig(save_path, dpi=300, bbox_inches="tight")
+            logger.info("Plot saved to {}".format(save_path))
+
+        if show:
+            plt.show()
+
+        plt.close()
+
+    def generate_report(self, benchmark_result, save_path=None):
+        report = "# RAG Benchmark Report\n\n"
+        report += "**Timestamp**: {}\n\n".format(benchmark_result.timestamp)
+        report += "## Summary\n\n"
+        report += "| Method | Avg Total(s) | Avg Retrieval(s) | Avg Generation(s) | P50(s) | P95(s) |\n"
+        report += "|--------|-------------|-----------------|-------------------|--------|--------|\n"
+
+        for m in sorted(benchmark_result.metrics, key=lambda x: x.avg_latency_total):
+            report += "| {} | {:.3f} | {:.3f} | {:.3f} | {:.3f} | {:.3f} |\n".format(
+                m.method_name, m.avg_latency_total, m.avg_latency_retrieval,
+                m.avg_latency_generation, m.p50_latency, m.p95_latency)
+
+        report += "\n## Detailed Analysis\n\n"
+
+        fastest = min(benchmark_result.metrics, key=lambda x: x.avg_latency_total)
+        report += "- **Fastest Method**: {} ({:.3f}s average)\n".format(fastest.method_name, fastest.avg_latency_total)
+
+        most_docs = max(benchmark_result.metrics, key=lambda x: x.avg_docs_retrieved)
+        report += "- **Most Documents Retrieved**: {} ({:.1f} average)\n".format(most_docs.method_name, most_docs.avg_docs_retrieved)
+
+        most_stable = min(benchmark_result.metrics, key=lambda x: x.latency_std)
+        report += "- **Most Stable**: {} (std={:.3f})\n".format(most_stable.method_name, most_stable.latency_std)
+
+        if save_path:
+            os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
+            with open(save_path, "w", encoding="utf-8") as f:
+                f.write(report)
+            logger.info("Report saved to {}".format(save_path))
+
+        return report

+ 498 - 0
bdirag/bidi_rag.py

@@ -0,0 +1,498 @@
+# -*- coding: utf-8 -*-
+"""BidiRag - Unified interface for RAG-based document retrieval."""
+
+import os
+from typing import List, Dict, Optional, Union
+
+from loguru import logger
+
+from .config import (
+    EMBEDDING_MODEL_NAME,
+    CHUNK_SIZE,
+    CHUNK_OVERLAP,
+    VECTOR_STORE_TYPE,
+    RETRIEVAL_TOP_K,
+)
+from .document_processor import DocumentProcessor, Document
+from .embedding_models import SentenceTransformerEmbedding
+from .vector_stores import get_vector_store
+from .rag_methods import (
+    NaiveRAG, RerankRAG, ParentDocumentRAG, LLMFilterRAG,
+    QueryRoutingRAG, MetadataFilterRAG, AdaptiveRAG, HybridSearchRAG,
+    MultiQueryRAG, HyDERAG, StepBackRAG, ContextualCompressionRAG,
+    SelfRAG, CorrectiveRAG, FLARERAG, RAPTORRAG, EnsembleRAG,
+    BidFieldExtractionRAG, TableAwareRAG, GraphRAG,
+    BM25RAG, TFIDFRAG, KeywordRAG, BM25HTMLTreeRAG,
+)
+
+
+# Available RAG methods mapping
+RAG_METHODS = {
+    'naive': NaiveRAG,
+    'rerank': RerankRAG,
+    'parent_document': ParentDocumentRAG,
+    'llm_filter': LLMFilterRAG,
+    'query_routing': QueryRoutingRAG,
+    'metadata_filter': MetadataFilterRAG,
+    'adaptive': AdaptiveRAG,
+    'hybrid_search': HybridSearchRAG,
+    'multi_query': MultiQueryRAG,
+    'hyde': HyDERAG,
+    'step_back': StepBackRAG,
+    'contextual_compression': ContextualCompressionRAG,
+    'self_rag': SelfRAG,
+    'corrective': CorrectiveRAG,
+    'flare': FLARERAG,
+    'raptor': RAPTORRAG,
+    'ensemble': EnsembleRAG,
+    'bid_field_extraction': BidFieldExtractionRAG,
+    'table_aware': TableAwareRAG,
+    'graph': GraphRAG,
+    'bm25': BM25RAG,
+    'tfidf': TFIDFRAG,
+    'keyword': KeywordRAG,
+    'bm25_html_tree': BM25HTMLTreeRAG,
+}
+
+
+class BidiRag:
+    """
+    BidiRag - Unified interface for RAG-based document retrieval and extraction.
+    
+    This class provides a high-level API to:
+    1. Load and process HTML documents
+    2. Build index using embedding models
+    3. Retrieve relevant text using various RAG methods
+    4. Support keyword-based search
+    
+    Usage:
+        >>> rag = BidiRag(rag_method='bm25_html_tree')
+        >>> rag.add_documents(["path/to/document1.html", "path/to/document2.html"])
+        >>> results = rag.retrieve("招标人和中标人", top_k=5)
+        >>> for doc in results:
+        ...     print(doc.page_content)
+    """
+    
+    def __init__(
+        self,
+        rag_method: str = 'bm25_html_tree',
+        chunk_size: int = None,
+        chunk_overlap: int = None,
+        vector_store_type: str = None,
+        embedding_model_name: str = None,
+        llm_client=None,
+        llm_model: str = "gpt-4o",
+        index_name: str = "default",
+        **kwargs
+    ):
+        """
+        Initialize BidiRag.
+        
+        Args:
+            rag_method: RAG method to use (e.g., 'bm25_html_tree', 'naive', 'hybrid_search')
+            chunk_size: Document chunk size (default: from config)
+            chunk_overlap: Chunk overlap size (default: from config)
+            vector_store_type: Vector store type 'faiss' or 'chroma' (default: from config)
+            embedding_model_name: Embedding model name (default: from config)
+            llm_client: LLM client instance (optional, for methods requiring LLM)
+            llm_model: LLM model name
+            index_name: Index name for storage
+            **kwargs: Additional arguments for specific RAG methods
+        """
+        # Configuration
+        self.rag_method_name = rag_method
+        self.chunk_size = chunk_size or CHUNK_SIZE
+        self.chunk_overlap = chunk_overlap or CHUNK_OVERLAP
+        self.vector_store_type = vector_store_type or VECTOR_STORE_TYPE
+        self.embedding_model_name = embedding_model_name or EMBEDDING_MODEL_NAME
+        self.llm_client = llm_client
+        self.llm_model = llm_model
+        self.index_name = index_name
+        
+        # Validate RAG method
+        if rag_method not in RAG_METHODS:
+            available = ', '.join(RAG_METHODS.keys())
+            raise ValueError(
+                f"Unknown RAG method '{rag_method}'. "
+                f"Available methods: {available}"
+            )
+        
+        # Methods that don't require embedding model
+        self.embedding_free_methods = {'bm25', 'tfidf', 'keyword', 'bm25_html_tree'}
+        # Methods that don't require vector store (they have their own indexing)
+        self.vector_store_free_methods = {'bm25', 'tfidf', 'keyword', 'bm25_html_tree'}
+        
+        # Initialize components
+        self.document_processor = DocumentProcessor(
+            chunk_size=self.chunk_size,
+            chunk_overlap=self.chunk_overlap
+        )
+        
+        # Initialize embedding model (only if needed)
+        if rag_method not in self.embedding_free_methods:
+            logger.info(f"Loading embedding model: {self.embedding_model_name}")
+            self.embedding_model = SentenceTransformerEmbedding(model_name=self.embedding_model_name)
+        else:
+            logger.info(f"Method {rag_method} doesn't require embedding model")
+            self.embedding_model = None
+        
+        # Initialize vector store (only if needed)
+        if rag_method not in self.vector_store_free_methods:
+            logger.info(f"Initializing vector store: {self.vector_store_type}")
+            self.vector_store = get_vector_store(
+                store_type=self.vector_store_type,
+                embedding_model=self.embedding_model
+            )
+        else:
+            logger.info(f"Method {rag_method} doesn't require vector store")
+            self.vector_store = None
+        
+        # Initialize RAG method
+        self._init_rag_method(**kwargs)
+        
+        # Document storage
+        self.documents = []
+        self.indexed = False
+        
+        logger.info(f"BidiRag initialized with method: {rag_method}")
+    
+    def _init_rag_method(self, **kwargs):
+        """Initialize the specified RAG method."""
+        rag_class = RAG_METHODS[self.rag_method_name]
+        
+        # Check if this is a special method that doesn't follow BaseRAG interface
+        special_methods = {'bm25_html_tree'}
+        
+        if self.rag_method_name in special_methods:
+            # BM25HTMLTreeRAG has its own interface
+            try:
+                self.rag_method = rag_class()
+                self.is_special_method = True
+                logger.info(f"Special RAG method {self.rag_method_name} initialized")
+            except Exception as e:
+                logger.error(f"Failed to initialize special RAG method {self.rag_method_name}: {e}")
+                raise
+        else:
+            # Common parameters for BaseRAG methods
+            common_params = {
+                'embedding_model': self.embedding_model,
+                'vector_store': self.vector_store,
+            }
+            
+            # Add LLM parameters if provided
+            if self.llm_client is not None:
+                common_params['llm_client'] = self.llm_client
+                common_params['llm_model'] = self.llm_model
+            
+            # Merge with additional kwargs
+            common_params.update(kwargs)
+            
+            try:
+                self.rag_method = rag_class(**common_params)
+                self.is_special_method = False
+                logger.info(f"RAG method {self.rag_method_name} initialized successfully")
+            except Exception as e:
+                logger.error(f"Failed to initialize RAG method {self.rag_method_name}: {e}")
+                raise
+    
+    def add_documents(
+        self,
+        sources: Union[str, List[str]],
+        rebuild_index: bool = True
+    ) -> int:
+        """
+        Add documents from file paths or directories.
+        
+        Args:
+            sources: Single file path, directory path, or list of paths
+            rebuild_index: Whether to rebuild the index after adding documents
+            
+        Returns:
+            Number of documents added
+        """
+        if isinstance(sources, str):
+            sources = [sources]
+        
+        new_docs = []
+        for source in sources:
+            if not os.path.exists(source):
+                logger.warning(f"Source not found: {source}")
+                continue
+            
+            logger.info(f"Processing: {source}")
+            docs = self.document_processor.process(source)
+            new_docs.extend(docs)
+        
+        if not new_docs:
+            logger.warning("No documents were loaded")
+            return 0
+        
+        # Store documents
+        self.documents.extend(new_docs)
+        logger.info(f"Added {len(new_docs)} documents (total: {len(self.documents)})")
+        
+        # Build index
+        if rebuild_index:
+            self.build_index()
+        
+        return len(new_docs)
+    
+    def add_texts(
+        self,
+        texts: List[str],
+        metadata: List[Dict] = None,
+        rebuild_index: bool = True
+    ) -> int:
+        """
+        Add documents from text strings.
+        
+        Args:
+            texts: List of text content
+            metadata: List of metadata dictionaries (optional)
+            rebuild_index: Whether to rebuild the index
+            
+        Returns:
+            Number of documents added
+        """
+        if metadata is None:
+            metadata = [{} for _ in texts]
+        
+        if len(texts) != len(metadata):
+            raise ValueError("texts and metadata must have the same length")
+        
+        new_docs = [
+            Document(page_content=text, metadata=meta)
+            for text, meta in zip(texts, metadata)
+        ]
+        
+        self.documents.extend(new_docs)
+        logger.info(f"Added {len(new_docs)} text documents")
+        
+        if rebuild_index:
+            self.build_index()
+        
+        return len(new_docs)
+    
+    def build_index(self):
+        """Build or rebuild the document index."""
+        if not self.documents:
+            logger.warning("No documents to index")
+            return
+        
+        logger.info(f"Building index with {len(self.documents)} documents...")
+        
+        if self.is_special_method:
+            # For BM25HTMLTreeRAG, we need to combine all HTML content
+            html_content = "\n".join([doc.page_content for doc in self.documents])
+            self.rag_method.build_index(html_content)
+        elif self.vector_store is not None:
+            # Clear existing index
+            self.vector_store.clear()
+            
+            # Index documents using the RAG method
+            self.rag_method.index_documents(self.documents)
+        else:
+            # For methods like BM25, TFIDF that have their own indexing
+            # Call their index_documents method directly
+            if hasattr(self.rag_method, 'index_documents'):
+                self.rag_method.index_documents(self.documents)
+            else:
+                logger.info(f"Method {self.rag_method_name} handles indexing internally")
+        
+        self.indexed = True
+        logger.info("Index built successfully")
+    
+    def retrieve(
+        self,
+        query: str,
+        top_k: int = None,
+        keywords: List[str] = None,
+        return_scores: bool = True
+    ) -> List:
+        """
+        Retrieve relevant documents/fragments for a query.
+        
+        Args:
+            query: Search query (can be natural language or keywords)
+            top_k: Number of results to return (default: from config)
+            keywords: Additional keywords to filter results (optional)
+            return_scores: Whether to return documents with scores
+            
+        Returns:
+            List of retrieved documents/fragments (with scores if return_scores=True)
+        """
+        if not self.indexed:
+            logger.warning("Index not built. Call add_documents() first.")
+            return []
+        
+        top_k = top_k or RETRIEVAL_TOP_K
+        
+        # Build query with keywords if provided
+        if keywords:
+            enhanced_query = f"{query} {' '.join(keywords)}"
+            logger.info(f"Enhanced query with keywords: {enhanced_query}")
+        else:
+            enhanced_query = query
+        
+        # Retrieve using the RAG method
+        logger.info(f"Retrieving with query: {query}")
+        
+        if self.is_special_method:
+            # BM25HTMLTreeRAG returns tree node fragments directly
+            # This returns (Document, score) tuples with tree node content
+            results = self.rag_method.query(enhanced_query, k=top_k)
+            logger.info(f"Retrieved {len(results)} tree node fragments")
+        else:
+            # BaseRAG methods use retrieve() method
+            results = self.rag_method.retrieve(enhanced_query, k=top_k)
+        
+        # Filter by keywords if provided
+        # Only apply strict filtering if we have more results than needed
+        if keywords and len(results) > top_k:
+            results = self._filter_by_keywords(results, keywords)
+            logger.info(f"Filtered to {len(results)} results matching keywords")
+        
+        # If no results from retrieval but keywords provided, try search_keywords
+        if not results and keywords:
+            logger.info("No results from retrieval, trying search_keywords")
+            results = self.search_keywords(keywords, top_k)
+            # Add dummy scores for consistency
+            results = [(doc, 1.0) for doc in results]
+        
+        return results
+    
+    def _filter_by_keywords(
+        self,
+        results: List,
+        keywords: List[str]
+    ) -> List:
+        """Filter retrieval results by keywords."""
+        filtered = []
+        for item in results:
+            # Handle both (doc, score) tuples and plain documents
+            if isinstance(item, tuple):
+                doc, score = item
+            else:
+                doc = item
+                score = None
+            
+            # Check if any keyword is in the document content
+            content_lower = doc.page_content.lower()
+            if any(kw.lower() in content_lower for kw in keywords):
+                if score is not None:
+                    filtered.append((doc, score))
+                else:
+                    filtered.append(doc)
+        
+        return filtered
+    
+    def query(
+        self,
+        query: str,
+        top_k: int = None,
+        keywords: List[str] = None
+    ):
+        """
+        Full RAG pipeline: retrieve + generate answer.
+        
+        Args:
+            query: Search query
+            top_k: Number of documents to retrieve
+            keywords: Additional keywords for filtering
+            
+        Returns:
+            RAGResult object with answer and retrieved documents
+        """
+        if not self.indexed:
+            logger.warning("Index not built. Call add_documents() first.")
+            return None
+        
+        # Build query with keywords
+        if keywords:
+            enhanced_query = f"{query} {' '.join(keywords)}"
+        else:
+            enhanced_query = query
+        
+        logger.info(f"Querying: {query}")
+        
+        if self.is_special_method:
+            # BM25HTMLTreeRAG doesn't have generate() method, only retrieve
+            result = self.rag_method.query(enhanced_query, k=top_k or RETRIEVAL_TOP_K)
+            # Wrap in a simple result object
+            from .rag_methods.base import RAGResult
+            rag_result = RAGResult(
+                answer="",
+                retrieved_docs=result,
+                metadata={"method": self.rag_method_name}
+            )
+        else:
+            result = self.rag_method.query(enhanced_query, k=top_k or RETRIEVAL_TOP_K)
+            rag_result = result
+        
+        # Filter retrieved docs by keywords if provided
+        if keywords and rag_result.retrieved_docs:
+            filtered_docs = self._filter_by_keywords(rag_result.retrieved_docs, keywords)
+            rag_result.retrieved_docs = filtered_docs
+        
+        return rag_result
+    
+    def search_keywords(
+        self,
+        keywords: List[str],
+        top_k: int = None
+    ) -> List[Document]:
+        """
+        Search documents by exact keywords.
+        
+        Args:
+            keywords: List of keywords to search for
+            top_k: Maximum number of results
+            
+        Returns:
+            List of documents containing the keywords
+        """
+        if not self.documents:
+            logger.warning("No documents loaded")
+            return []
+        
+        top_k = top_k or len(self.documents)
+        results = []
+        
+        for doc in self.documents:
+            content_lower = doc.page_content.lower()
+            # Check if all keywords are present
+            if all(kw.lower() in content_lower for kw in keywords):
+                results.append(doc)
+            
+            if len(results) >= top_k:
+                break
+        
+        logger.info(f"Found {len(results)} documents matching keywords: {keywords}")
+        return results
+    
+    def get_document_count(self) -> int:
+        """Get the number of loaded documents."""
+        return len(self.documents)
+    
+    def clear(self):
+        """Clear all documents and index."""
+        self.documents = []
+        if self.vector_store is not None:
+            self.vector_store.clear()
+        self.indexed = False
+        logger.info("Cleared all documents and index")
+    
+    def list_available_methods(self) -> List[str]:
+        """List all available RAG methods."""
+        return list(RAG_METHODS.keys())
+    
+    def get_method_info(self) -> Dict:
+        """Get information about the current RAG method."""
+        return {
+            'method_name': self.rag_method_name,
+            'method_class': self.rag_method.__class__.__name__,
+            'embedding_model': self.embedding_model_name,
+            'vector_store': self.vector_store_type,
+            'document_count': len(self.documents),
+            'indexed': self.indexed,
+        }

+ 84 - 0
bdirag/config.py

@@ -0,0 +1,84 @@
+import os
+
+# Base paths
+BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+DATA_DIR = os.path.join(BASE_DIR, "data")
+DOCS_DIR = os.path.join(DATA_DIR, "documents")
+INDEX_DIR = os.path.join(DATA_DIR, "indexes")
+CACHE_DIR = os.path.join(DATA_DIR, "cache")
+OUTPUT_DIR = os.path.join(BASE_DIR, "output")
+
+# Create directories
+for d in [DATA_DIR, DOCS_DIR, INDEX_DIR, CACHE_DIR, OUTPUT_DIR]:
+    os.makedirs(d, exist_ok=True)
+
+# Embedding settings
+EMBEDDING_MODEL_NAME = "BAAI/bge-large-zh-v1.5"
+EMBEDDING_DIMENSION = 1024
+EMBEDDING_BATCH_SIZE = 32
+
+# LLM settings
+LLM_MODEL_NAME = "gpt-4o"
+LLM_TEMPERATURE = 0.1
+LLM_MAX_TOKENS = 4096
+
+# Chunk settings
+CHUNK_SIZE = 512
+CHUNK_OVERLAP = 50
+
+# Vector store settings
+VECTOR_STORE_TYPE = "faiss"
+FAISS_INDEX_PATH = os.path.join(INDEX_DIR, "faiss_index")
+CHROMA_PERSIST_PATH = os.path.join(INDEX_DIR, "chroma_db")
+
+# Reranking settings
+RERANK_MODEL_NAME = "BAAI/bge-reranker-large"
+RERANK_TOP_K = 5
+
+# Retrieval settings
+RETRIEVAL_TOP_K = 10
+HYBRID_SEARCH_WEIGHT = 0.5
+
+# HyDE settings
+HYDE_GENERATION_MODEL = "gpt-4o"
+HYDE_NUM_HYPOTHESES = 3
+
+# Self-RAG settings
+SELF_RAG_RELEVANCE_THRESHOLD = 0.7
+SELF_RAG_SUPPORT_THRESHOLD = 0.6
+SELF_RAG_USEFULNESS_THRESHOLD = 0.7
+
+# CRAG settings
+CRAG_CORRECTNESS_THRESHOLD = 0.7
+CRAG_MAX_WEB_RESULTS = 5
+
+# RAPTOR settings
+RAPTOR_MAX_CLUSTERS = 50
+RAPTOR_SUMMARY_LENGTH = 256
+
+# Bidding field extraction
+BID_FIELDS = [
+    "project_name", "project_code", "budget_amount", "currency",
+    "bid_deadline", "bid_open_time", "bid_location",
+    "purchaser_name", "purchaser_contact", "purchaser_phone",
+    "agency_name", "agency_contact", "agency_phone",
+    "qualification_requirements", "bid_bond_amount",
+    "performance_bond_amount", "warranty_period",
+    "delivery_time", "delivery_location", "payment_terms",
+    "evaluation_method", "scope_of_work"
+]
+
+# Bidding domain specific prompts
+BID_EXTRACTION_PROMPT = """你是一个招投标领域的专家。请根据提供的文档内容,提取以下字段信息:
+
+{fields}
+
+文档内容:
+{context}
+
+请以JSON格式返回提取结果。如果某个字段无法从文档中提取,请返回null。
+"""
+
+# Logging settings
+LOG_LEVEL = "INFO"
+LOG_FORMAT = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"

+ 141 - 0
bdirag/document_processor.py

@@ -0,0 +1,141 @@
+import os
+import re
+
+from loguru import logger
+
+
+class Document(object):
+    def __init__(self, page_content="", metadata=None):
+        self.page_content = page_content
+        self.metadata = metadata if metadata is not None else {}
+
+
+class DocumentProcessor:
+    def __init__(
+        self,
+        chunk_size=512,
+        chunk_overlap=50,
+        separators=None
+    ):
+        self.chunk_size = chunk_size
+        self.chunk_overlap = chunk_overlap
+        self.separators = separators or ["\n\n", "\n", " ", ""]
+        self.text_splitter = None
+
+    def _get_text_splitter(self):
+        if self.text_splitter is None:
+            from langchain_text_splitters import RecursiveCharacterTextSplitter
+
+            self.text_splitter = RecursiveCharacterTextSplitter(
+                chunk_size=self.chunk_size,
+                chunk_overlap=self.chunk_overlap,
+                separators=self.separators,
+                length_function=len,
+            )
+        return self.text_splitter
+
+    def _load_html_document(self, file_path):
+        with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
+            html = f.read()
+
+        from bs4 import BeautifulSoup
+
+        soup = BeautifulSoup(html, "lxml")
+        for tag in soup(["script", "style", "noscript"]):
+            tag.decompose()
+
+        body = soup.find("body") or soup
+        lines = []
+        for element in body.find_all(["h1", "h2", "h3", "h4", "h5", "p", "li", "tr"], recursive=True):
+            if element.name == "tr":
+                cells = [
+                    re.sub(r"\s+", " ", cell.get_text(" ", strip=True))
+                    for cell in element.find_all(["th", "td"], recursive=False)
+                ]
+                text = " | ".join(cell for cell in cells if cell)
+            else:
+                if element.find_parent("tr") is not None:
+                    continue
+                text = element.get_text(" ", strip=True)
+                text = re.sub(r"\s+", " ", text)
+            if text and (not lines or lines[-1] != text):
+                lines.append(text)
+
+        if not lines:
+            text = body.get_text("\n", strip=True)
+            lines = [line.strip() for line in text.splitlines() if line.strip()]
+
+        return [
+            Document(
+                page_content="\n".join(lines),
+                metadata={"source": file_path, "file_type": "html"},
+            )
+        ]
+
+    def load_document(self, file_path):
+        ext = os.path.splitext(file_path)[1].lower()
+        if ext in (".html", ".htm"):
+            return self._load_html_document(file_path)
+
+        from langchain_community.document_loaders import (
+            PyPDFLoader,
+            Docx2txtLoader,
+            TextLoader,
+            UnstructuredExcelLoader,
+        )
+
+        loader_map = {
+            ".pdf": PyPDFLoader,
+            ".docx": Docx2txtLoader,
+            ".txt": TextLoader,
+            ".xlsx": UnstructuredExcelLoader,
+            ".xls": UnstructuredExcelLoader,
+        }
+        loader_cls = loader_map.get(ext)
+        if not loader_cls:
+            raise ValueError("Unsupported file format: {}".format(ext))
+        
+        loader = loader_cls(file_path)
+        raw_docs = loader.load()
+        return [
+            Document(
+                page_content=doc.page_content,
+                metadata=dict(doc.metadata, source=file_path)
+            )
+            for doc in raw_docs
+        ]
+
+    def load_directory(self, dir_path):
+        all_docs = []
+        for root, _, files in os.walk(dir_path):
+            for file in files:
+                file_path = os.path.join(root, file)
+                try:
+                    docs = self.load_document(file_path)
+                    all_docs.extend(docs)
+                    logger.info("Loaded {} chunks from {}".format(len(docs), file_path))
+                except Exception as e:
+                    logger.warning("Failed to load {}: {}".format(file_path, e))
+        return all_docs
+
+    def split_documents(self, documents):
+        from langchain_core.documents import Document as LCDocument
+        text_splitter = self._get_text_splitter()
+        lc_docs = [
+            LCDocument(page_content=doc.page_content, metadata=doc.metadata)
+            for doc in documents
+        ]
+        split_docs = text_splitter.split_documents(lc_docs)
+        return [
+            Document(page_content=doc.page_content, metadata=doc.metadata)
+            for doc in split_docs
+        ]
+
+    def process(self, source_path):
+        if os.path.isfile(source_path):
+            docs = self.load_document(source_path)
+        elif os.path.isdir(source_path):
+            docs = self.load_directory(source_path)
+        else:
+            raise ValueError("Source path does not exist: {}".format(source_path))
+        return self.split_documents(docs)

+ 110 - 0
bdirag/embedding_models.py

@@ -0,0 +1,110 @@
+from abc import ABC, abstractmethod
+from typing import List
+import numpy as np
+from loguru import logger
+
+
+class BaseEmbedding(ABC):
+    @abstractmethod
+    def embed_documents(self, texts):
+        pass
+
+    @abstractmethod
+    def embed_query(self, text):
+        pass
+
+
+class SentenceTransformerEmbedding(BaseEmbedding):
+    def __init__(self, model_name="BAAI/bge-large-zh-v1.5", device="cpu"):
+        from sentence_transformers import SentenceTransformer
+        self.model = SentenceTransformer(model_name, device=device)
+        self.dimension = self.model.get_sentence_embedding_dimension()
+        logger.info("Loaded SentenceTransformer: {}, dim={}".format(model_name, self.dimension))
+
+    def embed_documents(self, texts):
+        embeddings = self.model.encode(texts, batch_size=32, show_progress_bar=False)
+        return embeddings.tolist()
+
+    def embed_query(self, text):
+        embedding = self.model.encode([text])[0]
+        return embedding.tolist()
+
+
+class OpenAIEmbedding(BaseEmbedding):
+    def __init__(self, model_name="text-embedding-3-large", api_key=None, base_url=None):
+        from openai import OpenAI
+        self.model_name = model_name
+        self.client = OpenAI(api_key=api_key, base_url=base_url)
+        self.dimension = {"text-embedding-3-large": 3072, "text-embedding-3-small": 1536}.get(model_name, 1536)
+        logger.info("Loaded OpenAI Embedding: {}, dim={}".format(model_name, self.dimension))
+
+    def embed_documents(self, texts):
+        embeddings = []
+        batch_size = 100
+        for i in range(0, len(texts), batch_size):
+            batch = texts[i:i + batch_size]
+            response = self.client.embeddings.create(model=self.model_name, input=batch)
+            embeddings.extend([d.embedding for d in response.data])
+        return embeddings
+
+    def embed_query(self, text):
+        response = self.client.embeddings.create(model=self.model_name, input=[text])
+        return response.data[0].embedding
+
+
+class DashScopeEmbedding(BaseEmbedding):
+    def __init__(self, model_name="text-embedding-v2", api_key=None):
+        import dashscope
+        self.model_name = model_name
+        dashscope.api_key = api_key
+        self.dimension = 1536
+        logger.info("Loaded DashScope Embedding: {}".format(model_name))
+
+    def embed_documents(self, texts):
+        from dashscope import TextEmbedding
+        embeddings = []
+        batch_size = 25
+        for i in range(0, len(texts), batch_size):
+            batch = texts[i:i + batch_size]
+            response = TextEmbedding.call(model=self.model_name, input=batch)
+            if response.status_code == 200:
+                for item in response.output["embeddings"]:
+                    embeddings.append(item["embedding"])
+            else:
+                raise RuntimeError("DashScope embedding failed: {}".format(response.message))
+        return embeddings
+
+    def embed_query(self, text):
+        return self.embed_documents([text])[0]
+
+
+class ZhipuEmbedding(BaseEmbedding):
+    def __init__(self, model_name="embedding-3", api_key=None):
+        from zhipuai import ZhipuAI
+        self.model_name = model_name
+        self.client = ZhipuAI(api_key=api_key)
+        self.dimension = 2048
+        logger.info("Loaded Zhipu Embedding: {}".format(model_name))
+
+    def embed_documents(self, texts):
+        embeddings = []
+        for text in texts:
+            response = self.client.embeddings.create(model=self.model_name, input=[text])
+            embeddings.append(response.data[0].embedding)
+        return embeddings
+
+    def embed_query(self, text):
+        return self.embed_documents([text])[0]
+
+
+def get_embedding(model_type="sentence_transformer", **kwargs):
+    model_map = {
+        "sentence_transformer": SentenceTransformerEmbedding,
+        "openai": OpenAIEmbedding,
+        "dashscope": DashScopeEmbedding,
+        "zhipu": ZhipuEmbedding,
+    }
+    cls = model_map.get(model_type)
+    if not cls:
+        raise ValueError("Unknown embedding model type: {}".format(model_type))
+    return cls(**kwargs)

+ 1377 - 0
bdirag/rag_methods.py

@@ -0,0 +1,1377 @@
+from abc import ABC, abstractmethod
+from typing import List, Dict, Any, Optional, Tuple
+import time
+import re
+import numpy as np
+from loguru import logger
+
+from .document_processor import Document
+from .embedding_models import BaseEmbedding
+from .vector_stores import BaseVectorStore
+
+
+class RAGResult(object):
+    def __init__(self, answer="", retrieved_docs=None, latency_retrieval=0.0,
+                 latency_generation=0.0, latency_total=0.0, metadata=None):
+        self.answer = answer
+        self.retrieved_docs = retrieved_docs if retrieved_docs is not None else []
+        self.latency_retrieval = latency_retrieval
+        self.latency_generation = latency_generation
+        self.latency_total = latency_total
+        self.metadata = metadata if metadata is not None else {}
+
+    def to_dict(self):
+        return {
+            "answer": self.answer,
+            "num_docs_retrieved": len(self.retrieved_docs),
+            "latency_retrieval": round(self.latency_retrieval, 3),
+            "latency_generation": round(self.latency_generation, 3),
+            "latency_total": round(self.latency_total, 3),
+            "metadata": self.metadata or {},
+        }
+
+
+class BaseRAG(ABC):
+    def __init__(self, embedding_model=None, vector_store=None, llm_client=None, llm_model="gpt-4o", **kwargs):
+        self.embedding_model = embedding_model
+        self.vector_store = vector_store
+        self.llm_client = llm_client
+        self.llm_model = llm_model
+        self.name = self.__class__.__name__
+
+    def index_documents(self, documents):
+        texts = [doc.page_content for doc in documents]
+        embeddings = self.embedding_model.embed_documents(texts)
+        self.vector_store.add_documents(documents, embeddings)
+
+    def _call_llm(self, prompt, system_prompt=None):
+        messages = []
+        if system_prompt:
+            messages.append({"role": "system", "content": system_prompt})
+        messages.append({"role": "user", "content": prompt})
+
+        response = self.llm_client.chat.completions.create(
+            model=self.llm_model,
+            messages=messages,
+            temperature=0.1,
+            max_tokens=2048,
+        )
+        return response.choices[0].message.content
+
+    def _format_context(self, docs):
+        context_parts = []
+        for i, (doc, score) in enumerate(docs, 1):
+            source = doc.metadata.get("source", "unknown")
+            context_parts.append("[{}] (Score: {:.3f}, Source: {})\n{}".format(i, score, source, doc.page_content))
+        return "\n\n---\n\n".join(context_parts)
+
+    @abstractmethod
+    def retrieve(self, query, k=10):
+        pass
+
+    @abstractmethod
+    def generate(self, query, context):
+        pass
+
+    def query(self, query, k=10):
+        start_total = time.time()
+        
+        t0 = time.time()
+        docs = self.retrieve(query, k)
+        retrieval_time = time.time() - t0
+
+        context = self._format_context(docs)
+
+        t1 = time.time()
+        answer = self.generate(query, context)
+        generation_time = time.time() - t1
+
+        total_time = time.time() - start_total
+
+        return RAGResult(
+            answer=answer,
+            retrieved_docs=docs,
+            latency_retrieval=retrieval_time,
+            latency_generation=generation_time,
+            latency_total=total_time,
+            metadata={"method": self.name, "num_context_tokens": len(context)},
+        )
+
+
+class NaiveRAG(BaseRAG):
+    def __init__(self, retrieval_prompt_template=None, **kwargs):
+        super().__init__(**kwargs)
+        self.retrieval_prompt_template = retrieval_prompt_template or (
+            "根据以下参考文档,回答问题。\n\n"
+            "参考文档:\n{context}\n\n"
+            "问题:{query}\n\n"
+            "请详细回答,如果参考文档中没有相关信息,请说明无法从文档中找到答案。"
+        )
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        return self.vector_store.similarity_search(query_embedding, k)
+
+    def generate(self, query, context):
+        prompt = self.retrieval_prompt_template.format(context=context, query=query)
+        return self._call_llm(prompt)
+
+
+class RerankRAG(BaseRAG):
+    def __init__(self, rerank_model=None, rerank_top_k=5, **kwargs):
+        super().__init__(**kwargs)
+        self.rerank_model = rerank_model
+        self.rerank_top_k = rerank_top_k
+        self.initial_k = 20
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        initial_docs = self.vector_store.similarity_search(query_embedding, self.initial_k)
+
+        if self.rerank_model and len(initial_docs) > 0:
+            texts = [doc.page_content for doc, _ in initial_docs]
+            pairs = [(query, text) for text in texts]
+            scores = self.rerank_model.compute_score(pairs)
+
+            if isinstance(scores, (int, float)):
+                scores = [scores]
+
+            reranked = list(zip(initial_docs, scores))
+            reranked.sort(key=lambda x: x[1], reverse=True)
+            return [(doc, float(score)) for (doc, _), score in reranked[:k]]
+
+        return initial_docs[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过重排序的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class ParentDocumentRAG(BaseRAG):
+    def __init__(self, parent_chunk_size=1500, **kwargs):
+        super().__init__(**kwargs)
+        self.parent_chunk_size = parent_chunk_size
+        self.parent_docs = []
+        self.child_to_parent = {}
+
+    def index_documents(self, documents):
+        self.parent_docs = []
+        self.child_to_parent = {}
+        
+        for i, doc in enumerate(documents):
+            words = doc.page_content.split()
+            for j in range(0, len(words), self.parent_chunk_size):
+                parent_text = " ".join(words[j:j + self.parent_chunk_size])
+                parent_doc = Document(
+                    page_content=parent_text,
+                    metadata=dict(doc.metadata, chunk_index=j, is_parent=True)
+                )
+                self.parent_docs.append(parent_doc)
+        
+        child_texts = []
+        for doc in documents:
+            words = doc.page_content.split()
+            for j in range(0, len(words), self.parent_chunk_size):
+                chunk_words = words[j:j + self.parent_chunk_size]
+                for k in range(0, len(chunk_words), 512):
+                    child_text = " ".join(chunk_words[k:k + 512])
+                    child_id = len(child_texts)
+                    child_texts.append(child_text)
+                    self.child_to_parent[child_id] = child_id // 3
+        
+        if child_texts:
+            embeddings = self.embedding_model.embed_documents(child_texts)
+            child_docs = [
+                Document(page_content=text, metadata={"is_parent": False})
+                for text in child_texts
+            ]
+            self.vector_store.add_documents(child_docs, embeddings)
+        
+        logger.info("ParentDocumentRAG: {} parents, {} children".format(len(self.parent_docs), len(child_texts)))
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        child_results = self.vector_store.similarity_search(query_embedding, k * 2)
+        
+        parent_map = {}
+        for child_doc, score in child_results:
+            for i, parent_doc in enumerate(self.parent_docs):
+                if child_doc.page_content[:50] in parent_doc.page_content:
+                    if i not in parent_map or score > parent_map[i]:
+                        parent_map[i] = (parent_doc, score)
+                    break
+        
+        parent_results = sorted(parent_map.values(), key=lambda x: x[1], reverse=True)
+        return parent_results[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(父子文档检索,包含完整上下文),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class LLMFilterRAG(BaseRAG):
+    def __init__(self, filter_threshold=0.5, **kwargs):
+        super().__init__(**kwargs)
+        self.filter_threshold = filter_threshold
+
+    def _score_relevance(self, query, doc):
+        prompt = (
+            "评估以下文档与问题的相关性,给出0-1之间的分数。只返回分数数字。\n\n"
+            "问题:{}\n\n"
+            "文档:{}\n\n"
+            "相关性分数:".format(query, doc.page_content[:300])
+        )
+        try:
+            response = self._call_llm(prompt).strip()
+            return float(response)
+        except:
+            return 0.5
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        initial_results = self.vector_store.similarity_search(query_embedding, k * 3)
+        
+        filtered_results = []
+        for doc, score in initial_results:
+            relevance = self._score_relevance(query, doc)
+            if relevance >= self.filter_threshold:
+                filtered_results.append((doc, score * relevance))
+        
+        filtered_results.sort(key=lambda x: x[1], reverse=True)
+        return filtered_results[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过LLM精选的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class QueryRoutingRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def _classify_query(self, query):
+        categories = [
+            "budget", "deadline", "qualification", "evaluation",
+            "payment", "warranty", "delivery", "contact", "scope"
+        ]
+        categories_str = ", ".join(categories)
+        prompt = (
+            "将以下问题分类到以下类别之一:{}\n\n"
+            "问题:{}\n\n"
+            "类别:".format(categories_str, query)
+        )
+        return self._call_llm(prompt).strip().lower()
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        semantic_results = self.vector_store.similarity_search(query_embedding, k)
+        
+        category = self._classify_query(query)
+        logger.info("QueryRouting: category={}".format(category))
+        
+        category_boost = {}
+        for doc, score in semantic_results:
+            content_lower = doc.page_content.lower()
+            category_keywords = {
+                "budget": ["预算", "金额", "价格", "费用", "报价"],
+                "deadline": ["截止", "时间", "日期", "开标"],
+                "qualification": ["资格", "要求", "证书", "业绩"],
+                "evaluation": ["评标", "评价", "分数", "方法"],
+                "payment": ["付款", "结算", "进度", "保证金"],
+                "warranty": ["质保", "维修", "售后", "服务"],
+                "delivery": ["交货", "工期", "交付", "地点"],
+                "contact": ["联系人", "电话", "邮箱"],
+                "scope": ["范围", "内容", "清单", "设备"],
+            }
+            keywords = category_keywords.get(category, [])
+            boost = sum(1 for kw in keywords if kw in content_lower) * 0.1
+            category_boost[id(doc)] = boost
+        
+        enhanced_results = []
+        for doc, score in semantic_results:
+            boost = category_boost.get(id(doc), 0)
+            enhanced_results.append((doc, score + boost))
+        
+        enhanced_results.sort(key=lambda x: x[1], reverse=True)
+        return enhanced_results[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过查询路由的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class MetadataFilterRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def retrieve(self, query, k=10, metadata_filter=None):
+        query_embedding = self.embedding_model.embed_query(query)
+        all_results = self.vector_store.similarity_search(query_embedding, k * 3)
+        
+        if metadata_filter:
+            filtered = []
+            for doc, score in all_results:
+                match = all(
+                    doc.metadata.get(key) == value
+                    for key, value in metadata_filter.items()
+                )
+                if match:
+                    filtered.append((doc, score))
+            return filtered[:k]
+        
+        return all_results[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过元数据筛选的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class AdaptiveRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def _determine_strategy(self, query):
+        prompt = (
+            "根据问题类型,选择最合适的检索策略。\n"
+            "可选策略:\n"
+            "- semantic: 语义检索,适合概念性问题\n"
+            "- keyword: 关键词检索,适合精确实体匹配\n"
+            "- multi: 多路检索,适合复杂综合问题\n\n"
+            "问题:{}\n\n"
+            "策略名称(只返回name):".format(query)
+        )
+        return self._call_llm(prompt).strip().lower()
+
+    def retrieve(self, query, k=10):
+        strategy = self._determine_strategy(query)
+        logger.info("AdaptiveRAG: strategy={}".format(strategy))
+        
+        if strategy in ["keyword", "exact"]:
+            try:
+                from rank_bm25 import BM25Okapi
+            except:
+                BM25Okapi = None
+            
+            if BM25Okapi:
+                all_docs = self.vector_store.documents if hasattr(self.vector_store, 'documents') else []
+                if all_docs:
+                    texts = [doc.page_content.split() for doc in all_docs]
+                    bm25 = BM25Okapi(texts)
+                    scores = bm25.get_scores(query.split())
+                    results = []
+                    for i, score in enumerate(scores):
+                        if score > 0:
+                            results.append((all_docs[i], float(score)))
+                    results.sort(key=lambda x: x[1], reverse=True)
+                    return results[:k]
+        
+        query_embedding = self.embedding_model.embed_query(query)
+        semantic_results = self.vector_store.similarity_search(query_embedding, k)
+        
+        if strategy in ["multi", "comprehensive"]:
+            all_results = {}
+            for doc, score in semantic_results:
+                doc_id = doc.page_content[:100]
+                all_results[doc_id] = (doc, score)
+            
+            try:
+                from rank_bm25 import BM25Okapi
+                all_docs = self.vector_store.documents if hasattr(self.vector_store, 'documents') else []
+                if all_docs:
+                    texts = [doc.page_content.split() for doc in all_docs]
+                    bm25 = BM25Okapi(texts)
+                    bm25_scores = bm25.get_scores(query.split())
+                    max_bm25 = max(bm25_scores) if max(bm25_scores) > 0 else 1.0
+                    for i, doc in enumerate(all_docs):
+                        doc_id = doc.page_content[:100]
+                        norm_bm25 = bm25_scores[i] / max_bm25
+                        if doc_id in all_results:
+                            all_results[doc_id] = (doc, all_results[doc_id][1] * 0.6 + norm_bm25 * 0.4)
+                        else:
+                            all_results[doc_id] = (doc, norm_bm25)
+            except:
+                pass
+            
+            sorted_results = sorted(all_results.values(), key=lambda x: x[1], reverse=True)
+            return sorted_results[:k]
+        
+        return semantic_results[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过自适应检索的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class HybridSearchRAG(BaseRAG):
+    def __init__(self, bm25_index=None, semantic_weight=0.5, **kwargs):
+        super().__init__(**kwargs)
+        self.bm25_index = bm25_index
+        self.semantic_weight = semantic_weight
+        self._all_texts = []
+
+    def index_documents(self, documents):
+        super().index_documents(documents)
+        if self.bm25_index is not None:
+            from rank_bm25 import BM25Okapi
+            self._all_texts = [doc.page_content.split() for doc in documents]
+            self.bm25_index = BM25Okapi(self._all_texts)
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        semantic_results = self.vector_store.similarity_search(query_embedding, k * 2)
+
+        if self.bm25_index and self._all_texts:
+            from rank_bm25 import BM25Okapi
+            query_tokens = query.split()
+            bm25_scores = self.bm25_index.get_scores(query_tokens)
+
+            max_bm25 = max(bm25_scores) if bm25_scores.max() > 0 else 1.0
+            normalized_bm25 = bm25_scores / max_bm25
+
+            combined = []
+            for i, (doc, sem_score) in enumerate(semantic_results):
+                bm25_score = normalized_bm25[i] if i < len(normalized_bm25) else 0.0
+                combined_score = (
+                    self.semantic_weight * sem_score +
+                    (1 - self.semantic_weight) * bm25_score
+                )
+                combined.append((doc, combined_score))
+
+            combined.sort(key=lambda x: x[1], reverse=True)
+            return combined[:k]
+
+        return semantic_results[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(结合语义搜索和BM25关键词搜索),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class MultiQueryRAG(BaseRAG):
+    def __init__(self, num_queries=3, **kwargs):
+        super().__init__(**kwargs)
+        self.num_queries = num_queries
+
+    def _generate_queries(self, query):
+        prompt = (
+            "请将以下问题扩展为{}个不同角度的问题,"
+            "用于检索更全面的信息。每行一个问题。\n\n"
+            "原问题:{}\n\n"
+            "扩展问题:".format(self.num_queries, query)
+        )
+        response = self._call_llm(prompt)
+        queries = [q.strip() for q in response.strip().split("\n") if q.strip()]
+        queries.insert(0, query)
+        return queries[:self.num_queries + 1]
+
+    def retrieve(self, query, k=10):
+        queries = self._generate_queries(query)
+        all_docs = {}
+
+        for q in queries:
+            q_embedding = self.embedding_model.embed_query(q)
+            results = self.vector_store.similarity_search(q_embedding, k)
+            for doc, score in results:
+                doc_id = doc.page_content[:100]
+                if doc_id not in all_docs or score > all_docs[doc_id][1]:
+                    all_docs[doc_id] = (doc, score)
+
+        sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
+        return sorted_docs[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请综合多个查询角度的检索结果,详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class HyDERAG(BaseRAG):
+    def __init__(self, num_hypotheses=3, **kwargs):
+        super().__init__(**kwargs)
+        self.num_hypotheses = num_hypotheses
+
+    def _generate_hypothetical_docs(self, query):
+        prompt = (
+            "假设你是一个招投标专家,请根据以下问题,"
+            "生成{}个可能包含答案的假想文档段落。"
+            "每个段落用'<doc>'和'</doc>'分隔。\n\n"
+            "问题:{}\n\n"
+            "假想文档:".format(self.num_hypotheses, query)
+        )
+        response = self._call_llm(prompt)
+        docs = re.findall(r'<doc>(.*?)</doc>', response, re.DOTALL)
+        if not docs:
+            docs = [response]
+        return docs[:self.num_hypotheses]
+
+    def retrieve(self, query, k=10):
+        hypothetical_docs = self._generate_hypothetical_docs(query)
+        all_results = {}
+
+        for hypo_doc in hypothetical_docs:
+            hypo_embedding = self.embedding_model.embed_query(hypo_doc)
+            results = self.vector_store.similarity_search(hypo_embedding, k)
+            for doc, score in results:
+                doc_id = doc.page_content[:100]
+                if doc_id not in all_results or score > all_results[doc_id][1]:
+                    all_results[doc_id] = (doc, score)
+
+        query_embedding = self.embedding_model.embed_query(query)
+        direct_results = self.vector_store.similarity_search(query_embedding, k)
+        for doc, score in direct_results:
+            doc_id = doc.page_content[:100]
+            if doc_id not in all_results or score > all_results[doc_id][1]:
+                all_results[doc_id] = (doc, score)
+
+        sorted_docs = sorted(all_results.values(), key=lambda x: x[1], reverse=True)
+        return sorted_docs[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class StepBackRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def _generate_step_back_query(self, query):
+        prompt = (
+            "请将以下具体问题抽象为一个更通用的高层次问题,"
+            "用于查找相关背景知识。只返回抽象后的问题。\n\n"
+            "具体问题:{}\n\n"
+            "高层次问题:".format(query)
+        )
+        return self._call_llm(prompt).strip()
+
+    def retrieve(self, query, k=10):
+        step_back_query = self._generate_step_back_query(query)
+        logger.info("StepBack query: {}".format(step_back_query))
+
+        query_embedding = self.embedding_model.embed_query(query)
+        direct_results = self.vector_store.similarity_search(query_embedding, k // 2)
+
+        step_back_embedding = self.embedding_model.embed_query(step_back_query)
+        step_back_results = self.vector_store.similarity_search(step_back_embedding, k // 2)
+
+        all_docs = {}
+        for doc, score in direct_results + step_back_results:
+            doc_id = doc.page_content[:100]
+            if doc_id not in all_docs or score > all_docs[doc_id][1]:
+                all_docs[doc_id] = (doc, score)
+
+        sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
+        return sorted_docs[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "结合以下参考文档(包含具体问题和抽象问题的检索结果),"
+            "回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class ContextualCompressionRAG(BaseRAG):
+    def __init__(self, compression_llm=None, **kwargs):
+        super().__init__(**kwargs)
+        self.compression_llm = compression_llm or llm_client
+
+    def _compress_doc(self, query, doc):
+        prompt = (
+            "请压缩以下文档片段,仅保留与问题相关的信息。\n\n"
+            "问题:{}\n\n"
+            "文档:{}\n\n"
+            "压缩后的内容:".format(query, doc.page_content)
+        )
+        messages = []
+        messages.append({"role": "user", "content": prompt})
+        response = self.compression_llm.chat.completions.create(
+            model=self.llm_model,
+            messages=messages,
+            temperature=0.0,
+            max_tokens=512,
+        )
+        return response.choices[0].message.content
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
+
+        compressed_results = []
+        for doc, score in initial_results:
+            compressed_text = self._compress_doc(query, doc)
+            compressed_doc = Document(page_content=compressed_text, metadata=doc.metadata)
+            compressed_results.append((compressed_doc, score))
+
+        return compressed_results[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过上下文压缩的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class SelfRAG(BaseRAG):
+    def __init__(self, relevance_threshold=0.5, support_threshold=0.5, retrieval_threshold=0.6, critic_model=None, **kwargs):
+        super().__init__(**kwargs)
+        self.relevance_threshold = relevance_threshold
+        self.support_threshold = support_threshold
+        self.retrieval_threshold = retrieval_threshold
+        self.critic_model = critic_model
+
+    def _is_retrieval_needed(self, query):
+        prompt = (
+            "判断以下问题是否需要检索外部知识才能回答。"
+            "只需回答'是'或'否'。\n\n"
+            "问题:{}\n\n"
+            "是否需要检索:".format(query)
+        )
+        response = self._call_llm(prompt).strip().lower()
+        return "是" in response or "yes" in response
+
+    def _evaluate_relevance(self, query, doc):
+        prompt = (
+            "评估以下文档片段与问题的相关性,给出0-1之间的分数。\n\n"
+            "问题:{}\n\n"
+            "文档:{}\n\n"
+            "相关性分数:".format(query, doc.page_content)
+        )
+        response = self._call_llm(prompt).strip()
+        try:
+            return float(response)
+        except:
+            return 0.5
+
+    def _evaluate_support(self, query, answer, doc):
+        prompt = (
+            "评估以下回答是否得到了文档的支持,给出0-1之间的分数。\n\n"
+            "文档:{}\n\n"
+            "回答:{}\n\n"
+            "支持分数:".format(doc.page_content, answer)
+        )
+        response = self._call_llm(prompt).strip()
+        try:
+            return float(response)
+        except:
+            return 0.5
+
+    def _evaluate_usefulness(self, query, answer):
+        prompt = (
+            "评估以下回答对于问题的有用性,给出0-1之间的分数。\n\n"
+            "问题:{}\n\n"
+            "回答:{}\n\n"
+            "有用性分数:".format(query, answer)
+        )
+        response = self._call_llm(prompt).strip()
+        try:
+            return float(response)
+        except:
+            return 0.5
+
+    def retrieve(self, query, k=10):
+        if not self._is_retrieval_needed(query):
+            return []
+
+        query_embedding = self.embedding_model.embed_query(query)
+        initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
+
+        filtered_results = []
+        for doc, score in initial_results:
+            relevance = self._evaluate_relevance(query, doc)
+            if relevance >= self.relevance_threshold:
+                filtered_results.append((doc, relevance * score))
+
+        filtered_results.sort(key=lambda x: x[1], reverse=True)
+        return filtered_results[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过自反思筛选的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请确保回答有充分的文档支持,详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class CorrectiveRAG(BaseRAG):
+    def __init__(self, correctness_threshold=0.6, **kwargs):
+        super().__init__(**kwargs)
+        self.correctness_threshold = correctness_threshold
+        self.web_search_results = []
+
+    def _evaluate_correctness(self, query, docs):
+        if not docs:
+            return 0.0
+
+        context = "\n\n".join([doc.page_content for doc, _ in docs[:5]])
+        prompt = (
+            "评估以下文档集合是否能够正确回答问题,给出0-1之间的分数。\n\n"
+            "问题:{}\n\n"
+            "文档集合:\n{}\n\n"
+            "正确性分数:".format(query, context)
+        )
+        response = self._call_llm(prompt).strip()
+        try:
+            return float(response)
+        except:
+            return 0.5
+
+    def _web_search(self, query):
+        return "[Web search results for: {}] - Simulated external knowledge".format(query)
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        initial_results = self.vector_store.similarity_search(query_embedding, k)
+
+        correctness_score = self._evaluate_correctness(query, initial_results)
+
+        if correctness_score < self.correctness_threshold:
+            logger.info("CRAG: Correctness score {} below threshold, adding web search".format(correctness_score))
+            web_result = self._web_search(query)
+            web_doc = Document(
+                page_content=web_result,
+                metadata={"source": "web_search"}
+            )
+            initial_results.insert(0, (web_doc, 0.8))
+
+        return initial_results
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(可能包含外部搜索结果),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。如果包含外部搜索结果,请注明。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class FLARERAG(BaseRAG):
+    def __init__(self, max_iterations=3, **kwargs):
+        super().__init__(**kwargs)
+        self.max_iterations = max_iterations
+
+    def _need_more_retrieval(self, query, current_answer):
+        prompt = (
+            "基于当前已有的信息,判断是否还需要更多检索才能完整回答问题。\n\n"
+            "问题:{}\n\n"
+            "当前已有信息的回答:{}\n\n"
+            "是否需要更多检索(只回答是/否):".format(query, current_answer)
+        )
+        response = self._call_llm(prompt).strip().lower()
+        return "是" in response or "yes" in response
+
+    def _generate_next_query(self, query, current_answer):
+        prompt = (
+            "基于当前回答的不足,生成一个新的查询来补充信息。\n\n"
+            "原问题:{}\n\n"
+            "当前回答:{}\n\n"
+            "新查询:".format(query, current_answer)
+        )
+        return self._call_llm(prompt).strip()
+
+    def retrieve(self, query, k=10):
+        all_docs = {}
+        current_query = query
+        current_answer = ""
+
+        for iteration in range(self.max_iterations):
+            query_embedding = self.embedding_model.embed_query(current_query)
+            results = self.vector_store.similarity_search(query_embedding, k)
+
+            for doc, score in results:
+                doc_id = doc.page_content[:100]
+                if doc_id not in all_docs or score > all_docs[doc_id][1]:
+                    all_docs[doc_id] = (doc, score)
+
+            context = self._format_context(list(all_docs.values()))
+            current_answer = self._generate_partial_answer(query, context)
+
+            if not self._need_more_retrieval(query, current_answer):
+                break
+
+            current_query = self._generate_next_query(query, current_answer)
+            logger.info("FLARE iteration {}, new query: {}".format(iteration + 1, current_query))
+
+        sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
+        return sorted_docs[:k]
+
+    def _generate_partial_answer(self, query, context):
+        prompt = (
+            "根据以下参考文档,给出问题的回答。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "回答:".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+    def generate(self, query, context):
+        return self._generate_partial_answer(query, context)
+
+
+class RAPTORRAG(BaseRAG):
+    def __init__(self, max_clusters=50, summary_length=256, num_tree_levels=2, **kwargs):
+        super().__init__(**kwargs)
+        self.max_clusters = max_clusters
+        self.summary_length = summary_length
+        self.num_tree_levels = num_tree_levels
+        self.hierarchical_docs = []
+
+    def _summarize_texts(self, texts):
+        summaries = []
+        batch_size = 5
+        for i in range(0, len(texts), batch_size):
+            batch = texts[i:i + batch_size]
+            combined = "\n\n".join(batch)
+            prompt = (
+                "请用一段话总结以下内容的核心要点({}字以内):\n\n"
+                "{}\n\n"
+                "总结:".format(self.summary_length, combined)
+            )
+            summary = self._call_llm(prompt)
+            summaries.append(summary)
+        return summaries
+
+    def _cluster_embeddings(self, embeddings):
+        from sklearn.cluster import AgglomerativeClustering
+        emb_np = np.array(embeddings)
+        n_clusters = min(self.max_clusters, len(embeddings))
+        if n_clusters < 2:
+            return [list(range(len(embeddings)))]
+
+        clustering = AgglomerativeClustering(n_clusters=n_clusters)
+        labels = clustering.fit_predict(emb_np)
+
+        clusters = {}
+        for i, label in enumerate(labels):
+            if label not in clusters:
+                clusters[label] = []
+            clusters[label].append(i)
+
+        return list(clusters.values())
+
+    def build_tree(self, documents):
+        self.index_documents(documents)
+        
+        current_level_docs = documents
+        for level in range(self.num_tree_levels):
+            texts = [doc.page_content for doc in current_level_docs]
+            embeddings = self.embedding_model.embed_documents(texts)
+            clusters = self._cluster_embeddings(embeddings)
+
+            summary_docs = []
+            for cluster in clusters:
+                cluster_texts = [texts[i] for i in cluster]
+                summaries = self._summarize_texts(cluster_texts)
+                for summary in summaries:
+                    summary_docs.append(Document(
+                        page_content=summary,
+                        metadata={"level": level + 1, "num_source_docs": len(cluster)}
+                    ))
+
+            self.hierarchical_docs.extend(summary_docs)
+            if summary_docs:
+                summary_embeddings = self.embedding_model.embed_documents([d.page_content for d in summary_docs])
+                self.vector_store.add_documents(summary_docs, summary_embeddings)
+
+            current_level_docs = summary_docs
+            logger.info("RAPTOR level {}: {} summaries created".format(level + 1, len(summary_docs)))
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        
+        all_results = []
+        all_results.extend(self.vector_store.similarity_search(query_embedding, k))
+        
+        sorted_results = sorted(all_results, key=lambda x: x[1], reverse=True)
+        return sorted_results[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(包含多层次摘要信息),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class EnsembleRAG(BaseRAG):
+    def __init__(self, methods=None, **kwargs):
+        super().__init__(**kwargs)
+        self.methods = methods or ["naive", "hybrid", "multi_query"]
+
+    def retrieve(self, query, k=10):
+        all_docs = {}
+        
+        query_embedding = self.embedding_model.embed_query(query)
+        base_results = self.vector_store.similarity_search(query_embedding, k * 2)
+        for doc, score in base_results:
+            doc_id = doc.page_content[:100]
+            all_docs[doc_id] = (doc, score)
+
+        try:
+            from rank_bm25 import BM25Okapi
+            all_texts = [doc.page_content.split() for doc, _ in base_results]
+            bm25 = BM25Okapi(all_texts)
+            query_tokens = query.split()
+            bm25_scores = bm25.get_scores(query_tokens)
+            max_bm25 = max(bm25_scores) if bm25_scores.max() > 0 else 1.0
+
+            for i, (doc, _) in enumerate(base_results):
+                doc_id = doc.page_content[:100]
+                bm25_score = bm25_scores[i] / max_bm25
+                all_docs[doc_id] = (doc, all_docs[doc_id][1] * 0.5 + bm25_score * 0.5)
+        except:
+            pass
+
+        sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
+        return sorted_docs[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(通过集成多种检索方法获取),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请综合多种检索策略的结果,详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class BidFieldExtractionRAG(BaseRAG):
+    def __init__(self, fields=None, **kwargs):
+        super().__init__(**kwargs)
+        self.fields = fields or [
+            "project_name", "project_code", "budget_amount", "currency",
+            "bid_deadline", "bid_open_time", "bid_location",
+            "purchaser_name", "purchaser_contact", "purchaser_phone",
+            "agency_name", "agency_contact", "agency_phone",
+            "qualification_requirements", "bid_bond_amount",
+            "performance_bond_amount", "warranty_period",
+            "delivery_time", "delivery_location", "payment_terms",
+            "evaluation_method", "scope_of_work"
+        ]
+
+    def retrieve(self, query, k=10):
+        all_docs = {}
+        
+        for field in self.fields:
+            field_query = "{} {}".format(query, field)
+            field_embedding = self.embedding_model.embed_query(field_query)
+            results = self.vector_store.similarity_search(field_embedding, k // 3)
+            for doc, score in results:
+                doc_id = doc.page_content[:100]
+                if doc_id not in all_docs or score > all_docs[doc_id][1]:
+                    all_docs[doc_id] = (doc, score)
+
+        query_embedding = self.embedding_model.embed_query(query)
+        direct_results = self.vector_store.similarity_search(query_embedding, k)
+        for doc, score in direct_results:
+            doc_id = doc.page_content[:100]
+            if doc_id not in all_docs or score > all_docs[doc_id][1]:
+                all_docs[doc_id] = (doc, score)
+
+        sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
+        return sorted_docs[:k]
+
+    def generate(self, query, context):
+        fields_str = "\n".join(["- {}".format(f) for f in self.fields])
+        prompt = """你是一个招投标领域的专家。请根据提供的文档内容,提取以下字段信息:
+
+{}
+
+文档内容:
+{}
+
+请以JSON格式返回提取结果。如果某个字段无法从文档中提取,请返回null。
+
+JSON格式示例:
+{{
+  "project_name": "项目名称",
+  "budget_amount": 1000000,
+  ...
+}}""".format(fields_str, context)
+        return self._call_llm(prompt)
+
+
+class TableAwareRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def _extract_table_info(self, doc):
+        content = doc.page_content
+        
+        table_pattern = r'(\|.*\|[\r\n]+)'
+        tables = re.findall(table_pattern, content)
+        
+        if tables:
+            return "[TABLE DETECTED]\n{}".format(''.join(tables))
+        
+        numbered_pattern = r'(\d+[\.、].*?(?=\d+[\.、]|$))'
+        numbered_items = re.findall(numbered_pattern, content, re.DOTALL)
+        
+        if numbered_items:
+            return "[STRUCTURED LIST DETECTED]\n{}".format(''.join(numbered_items))
+        
+        return content
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
+
+        enhanced_results = []
+        for doc, score in initial_results:
+            enhanced_content = self._extract_table_info(doc)
+            enhanced_doc = Document(
+                page_content=enhanced_content,
+                metadata=dict(doc.metadata, has_table="[TABLE" in enhanced_content)
+            )
+            table_bonus = 0.1 if "[TABLE" in enhanced_content else 0.0
+            enhanced_results.append((enhanced_doc, score + table_bonus))
+
+        enhanced_results.sort(key=lambda x: x[1], reverse=True)
+        return enhanced_results[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(包含表格和结构化数据),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请特别注意表格和结构化数据中的信息,详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class GraphRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self.graph = {}
+
+    def _extract_entities(self, text):
+        entities = re.findall(r'[一-龥]{2,10}(?:公司|单位|招标|投标|项目|金额|时间)', text)
+        return list(set(entities))
+
+    def build_graph(self, documents):
+        self.index_documents(documents)
+        
+        for doc in documents:
+            entities = self._extract_entities(doc.page_content)
+            for entity in entities:
+                if entity not in self.graph:
+                    self.graph[entity] = []
+                self.graph[entity].append(doc)
+
+        logger.info("Graph built with {} entities".format(len(self.graph)))
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        semantic_results = self.vector_store.similarity_search(query_embedding, k)
+
+        query_entities = self._extract_entities(query)
+        graph_docs = {}
+        for entity in query_entities:
+            if entity in self.graph:
+                for doc in self.graph[entity]:
+                    doc_id = doc.page_content[:100]
+                    if doc_id not in graph_docs:
+                        graph_docs[doc_id] = (doc, 0.7)
+                    else:
+                        graph_docs[doc_id] = (doc, graph_docs[doc_id][1] + 0.1)
+
+        all_docs = {}
+        for doc, score in semantic_results:
+            doc_id = doc.page_content[:100]
+            all_docs[doc_id] = (doc, score)
+
+        for doc_id, (doc, score) in graph_docs.items():
+            if doc_id in all_docs:
+                all_docs[doc_id] = (doc, all_docs[doc_id][1] * 0.6 + score * 0.4)
+            else:
+                all_docs[doc_id] = (doc, score)
+
+        sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
+        return sorted_docs[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(结合知识图谱和语义检索),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+
+class BM25RAG(BaseRAG):
+    def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
+        self.embedding_model = None
+        self.vector_store = None
+        self.llm_client = llm_client
+        self.llm_model = llm_model
+        self.name = self.__class__.__name__
+        from rank_bm25 import BM25Okapi
+        self.bm25 = None
+        self._all_texts = []
+        self._all_documents = []
+
+    def index_documents(self, documents):
+        self._all_documents = documents
+        self._all_texts = [doc.page_content.split() for doc in documents]
+        from rank_bm25 import BM25Okapi
+        self.bm25 = BM25Okapi(self._all_texts)
+        logger.info("BM25 index built with {} documents".format(len(documents)))
+
+    def retrieve(self, query, k=10):
+        if self.bm25 is None:
+            return []
+        
+        query_tokens = query.split()
+        scores = self.bm25.get_scores(query_tokens)
+
+        scored_docs = []
+        for i, score in enumerate(scores):
+            if score > 0:
+                scored_docs.append((self._all_documents[i], float(score)))
+
+        scored_docs.sort(key=lambda x: x[1], reverse=True)
+        return scored_docs[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(BM25关键词检索),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+    def query(self, query, k=10):
+        start_total = time.time()
+        
+        t0 = time.time()
+        docs = self.retrieve(query, k)
+        retrieval_time = time.time() - t0
+
+        context = self._format_context(docs)
+
+        t1 = time.time()
+        answer = self.generate(query, context)
+        generation_time = time.time() - t1
+
+        total_time = time.time() - start_total
+
+        return RAGResult(
+            answer=answer,
+            retrieved_docs=docs,
+            latency_retrieval=retrieval_time,
+            latency_generation=generation_time,
+            latency_total=total_time,
+            metadata={"method": self.name, "num_context_tokens": len(context)},
+        )
+
+
+class TFIDFRAG(BaseRAG):
+    def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
+        self.embedding_model = None
+        self.vector_store = None
+        self.llm_client = llm_client
+        self.llm_model = llm_model
+        self.name = self.__class__.__name__
+        self.vectorizer = None
+        self.tfidf_matrix = None
+        self._all_documents = []
+
+    def index_documents(self, documents):
+        self._all_documents = documents
+        texts = [doc.page_content for doc in documents]
+        from sklearn.feature_extraction.text import TfidfVectorizer
+        self.vectorizer = TfidfVectorizer()
+        self.tfidf_matrix = self.vectorizer.fit_transform(texts)
+        logger.info("TF-IDF index built with {} documents, vocab size: {}".format(len(documents), len(self.vectorizer.vocabulary_)))
+
+    def retrieve(self, query, k=10):
+        if self.tfidf_matrix is None:
+            return []
+        
+        from sklearn.metrics.pairwise import cosine_similarity
+        query_vec = self.vectorizer.transform([query])
+        scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
+
+        k = min(k, len(scores))
+        top_indices = np.argsort(scores)[::-1][:k]
+
+        results = []
+        for idx in top_indices:
+            if scores[idx] > 0:
+                results.append((self._all_documents[idx], float(scores[idx])))
+
+        return results
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(TF-IDF关键词检索),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+    def query(self, query, k=10):
+        start_total = time.time()
+        
+        t0 = time.time()
+        docs = self.retrieve(query, k)
+        retrieval_time = time.time() - t0
+
+        context = self._format_context(docs)
+
+        t1 = time.time()
+        answer = self.generate(query, context)
+        generation_time = time.time() - t1
+
+        total_time = time.time() - start_total
+
+        return RAGResult(
+            answer=answer,
+            retrieved_docs=docs,
+            latency_retrieval=retrieval_time,
+            latency_generation=generation_time,
+            latency_total=total_time,
+            metadata={"method": self.name, "num_context_tokens": len(context)},
+        )
+
+
+class KeywordRAG(BaseRAG):
+    def __init__(self, search_method="bm25", llm_client=None, llm_model="gpt-4o", **kwargs):
+        self.embedding_model = None
+        self.vector_store = None
+        self.llm_client = llm_client
+        self.llm_model = llm_model
+        self.name = self.__class__.__name__
+        self.search_method = search_method
+        self._all_documents = []
+        self._all_texts = []
+        self.bm25 = None
+        self.tfidf_matrix = None
+        self.vectorizer = None
+
+    def index_documents(self, documents):
+        self._all_documents = documents
+        self._all_texts = [doc.page_content for doc in documents]
+
+        if self.search_method == "bm25":
+            from rank_bm25 import BM25Okapi
+            tokenized = [t.split() for t in self._all_texts]
+            self.bm25 = BM25Okapi(tokenized)
+        elif self.search_method == "tfidf":
+            from sklearn.feature_extraction.text import TfidfVectorizer
+            self.vectorizer = TfidfVectorizer()
+            self.tfidf_matrix = self.vectorizer.fit_transform(self._all_texts)
+
+        logger.info("KeywordRAG ({}) index built with {} documents".format(self.search_method, len(documents)))
+
+    def retrieve(self, query, k=10):
+        if self.search_method == "bm25":
+            query_tokens = query.split()
+            scores = self.bm25.get_scores(query_tokens)
+        elif self.search_method == "tfidf":
+            from sklearn.metrics.pairwise import cosine_similarity
+            query_vec = self.vectorizer.transform([query])
+            scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
+        else:
+            return []
+
+        k = min(k, len(scores))
+        top_indices = np.argsort(scores)[::-1][:k]
+
+        results = []
+        for idx in top_indices:
+            if scores[idx] > 0:
+                results.append((self._all_documents[idx], float(scores[idx])))
+
+        return results
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(关键词检索:{}),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(self.search_method, context, query)
+        )
+        return self._call_llm(prompt)
+
+    def query(self, query, k=10):
+        start_total = time.time()
+        
+        t0 = time.time()
+        docs = self.retrieve(query, k)
+        retrieval_time = time.time() - t0
+
+        context = self._format_context(docs)
+
+        t1 = time.time()
+        answer = self.generate(query, context)
+        generation_time = time.time() - t1
+
+        total_time = time.time() - start_total
+
+        return RAGResult(
+            answer=answer,
+            retrieved_docs=docs,
+            latency_retrieval=retrieval_time,
+            latency_generation=generation_time,
+            latency_total=total_time,
+            metadata={"method": self.name, "num_context_tokens": len(context)},
+        )

+ 37 - 0
bdirag/rag_methods/__init__.py

@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+"""RAG methods package for BidiRAG."""
+from .base import RAGResult, BaseRAG
+from .naive_rag import NaiveRAG
+from .rerank_rag import RerankRAG
+from .parent_document_rag import ParentDocumentRAG
+from .llm_filter_rag import LLMFilterRAG
+from .query_routing_rag import QueryRoutingRAG
+from .metadata_filter_rag import MetadataFilterRAG
+from .adaptive_rag import AdaptiveRAG
+from .hybrid_search_rag import HybridSearchRAG
+from .multi_query_rag import MultiQueryRAG
+from .hyde_rag import HyDERAG
+from .step_back_rag import StepBackRAG
+from .contextual_compression_rag import ContextualCompressionRAG
+from .self_rag import SelfRAG
+from .corrective_rag import CorrectiveRAG
+from .flare_rag import FLARERAG
+from .raptor_rag import RAPTORRAG
+from .ensemble_rag import EnsembleRAG
+from .bid_field_extraction_rag import BidFieldExtractionRAG
+from .table_aware_rag import TableAwareRAG
+from .graph_rag import GraphRAG
+from .bm25_rag import BM25RAG
+from .tfidf_rag import TFIDFRAG
+from .keyword_rag import KeywordRAG
+from .bm25_html_tree_rag import BM25HTMLTreeRAG
+
+__all__ = [
+    'RAGResult', 'BaseRAG',
+    'NaiveRAG', 'RerankRAG', 'ParentDocumentRAG', 'LLMFilterRAG',
+    'QueryRoutingRAG', 'MetadataFilterRAG', 'AdaptiveRAG', 'HybridSearchRAG',
+    'MultiQueryRAG', 'HyDERAG', 'StepBackRAG', 'ContextualCompressionRAG',
+    'SelfRAG', 'CorrectiveRAG', 'FLARERAG', 'RAPTORRAG', 'EnsembleRAG',
+    'BidFieldExtractionRAG', 'TableAwareRAG', 'GraphRAG',
+    'BM25RAG', 'TFIDFRAG', 'KeywordRAG', 'BM25HTMLTreeRAG',
+]

+ 94 - 0
bdirag/rag_methods/adaptive_rag.py

@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+"""Adaptive RAG - adaptively choose retrieval strategy based on query complexity."""
+from .base import BaseRAG
+from loguru import logger
+
+
+class AdaptiveRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self.max_iterations = 3
+
+    def _analyze_query(self, query):
+        prompt = (
+            "分析以下问题的复杂度。请回答simple或complex或needs_multi_step。\n\n"
+            "问题:{}\n\n"
+            "回答:".format(query)
+        )
+        response = self._call_llm(prompt).strip().lower()
+        if "complex" in response:
+            return "complex"
+        elif "multi" in response:
+            return "multi_step"
+        return "simple"
+
+    def retrieve(self, query, k=10):
+        complexity = self._analyze_query(query)
+        logger.info("AdaptiveRAG: complexity={}".format(complexity))
+        
+        if complexity == "simple":
+            return self._simple_retrieve(query, k)
+        elif complexity == "complex":
+            return self._complex_retrieve(query, k)
+        else:
+            return self._multi_step_retrieve(query, k)
+
+    def _simple_retrieve(self, query, k):
+        query_embedding = self.embedding_model.embed_query(query)
+        return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
+
+    def _complex_retrieve(self, query, k):
+        sub_queries = self._generate_sub_queries(query)
+        all_results = {}
+        
+        for sq in sub_queries:
+            sq_embedding = self.embedding_model.embed_query(sq)
+            results = self.vector_store.similarity_search(sq_embedding, k // len(sub_queries) + 1)
+            for doc, score in results:
+                key = self._dedup_key(doc)
+                if key not in all_results or score > all_results[key][1]:
+                    all_results[key] = (doc, score)
+        
+        results = list(all_results.values())
+        results.sort(key=lambda x: x[1], reverse=True)
+        return self._deduplicate_results(results, k)
+
+    def _multi_step_retrieve(self, query, k):
+        all_results = []
+        current_query = query
+        
+        for i in range(self.max_iterations):
+            q_embedding = self.embedding_model.embed_query(current_query)
+            step_results = self.vector_store.similarity_search(q_embedding, 5)
+            all_results.extend(step_results)
+            
+            if len(all_results) >= k:
+                break
+            
+            refinement_prompt = (
+                "根据之前的查询,提出一个新的查询问题以获取更多信息。\n\n"
+                "之前查询:{}\n\n"
+                "新查询:".format(current_query)
+            )
+            current_query = self._call_llm(refinement_prompt)
+        
+        all_results.sort(key=lambda x: x[1], reverse=True)
+        return self._deduplicate_results(all_results, k)
+
+    def _generate_sub_queries(self, query):
+        prompt = (
+            "将以下问题分解为2-3个子问题,用逗号分隔。只返回子问题。\n\n"
+            "原始问题:{}\n\n"
+            "子问题:".format(query)
+        )
+        response = self._call_llm(prompt)
+        return [q.strip() for q in response.split(",") if q.strip()]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过自适应检索的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 111 - 0
bdirag/rag_methods/base.py

@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+"""Base RAG classes: RAGResult and BaseRAG."""
+from abc import ABC, abstractmethod
+import time
+
+from ..document_processor import Document
+from .dedup import content_dedup_key, deduplicate_ranked_results
+
+
+class RAGResult(object):
+    """Result of a RAG query."""
+    def __init__(self, answer="", retrieved_docs=None, latency_retrieval=0.0,
+                 latency_generation=0.0, latency_total=0.0, metadata=None):
+        self.answer = answer
+        self.retrieved_docs = retrieved_docs if retrieved_docs is not None else []
+        self.latency_retrieval = latency_retrieval
+        self.latency_generation = latency_generation
+        self.latency_total = latency_total
+        self.metadata = metadata if metadata is not None else {}
+
+    def to_dict(self):
+        return {
+            "answer": self.answer,
+            "num_docs_retrieved": len(self.retrieved_docs),
+            "latency_retrieval": round(self.latency_retrieval, 3),
+            "latency_generation": round(self.latency_generation, 3),
+            "latency_total": round(self.latency_total, 3),
+            "metadata": self.metadata or {},
+        }
+
+
+class BaseRAG(ABC):
+    """Abstract base class for all RAG methods."""
+    def __init__(self, embedding_model=None, vector_store=None, llm_client=None, llm_model="gpt-4o", **kwargs):
+        self.embedding_model = embedding_model
+        self.vector_store = vector_store
+        self.llm_client = llm_client
+        self.llm_model = llm_model
+        self.name = self.__class__.__name__
+
+    def index_documents(self, documents):
+        """Index documents using the embedding model and vector store."""
+        texts = [doc.page_content for doc in documents]
+        embeddings = self.embedding_model.embed_documents(texts)
+        self.vector_store.add_documents(documents, embeddings)
+
+    def _call_llm(self, prompt, system_prompt=None):
+        """Call the LLM to generate a response."""
+        messages = []
+        if system_prompt:
+            messages.append({"role": "system", "content": system_prompt})
+        messages.append({"role": "user", "content": prompt})
+
+        response = self.llm_client.chat.completions.create(
+            model=self.llm_model,
+            messages=messages,
+            temperature=0.1,
+            max_tokens=2048,
+        )
+        return response.choices[0].message.content
+
+    def _format_context(self, docs):
+        """Format retrieved documents into a context string."""
+        context_parts = []
+        for i, (doc, score) in enumerate(docs, 1):
+            source = doc.metadata.get("source", "unknown")
+            context_parts.append("[{}] (Score: {:.3f}, Source: {})\n{}".format(i, score, source, doc.page_content))
+        return "\n\n---\n\n".join(context_parts)
+
+    def _dedup_key(self, doc):
+        """Return a content-based key for retrieval result deduplication."""
+        return content_dedup_key(doc)
+
+    def _deduplicate_results(self, results, k=None):
+        """Deduplicate ranked retrieval results by normalized document content."""
+        return deduplicate_ranked_results(results, k)
+
+    @abstractmethod
+    def retrieve(self, query, k=10):
+        """Retrieve relevant documents for the query."""
+        pass
+
+    @abstractmethod
+    def generate(self, query, context):
+        """Generate an answer based on the context."""
+        pass
+
+    def query(self, query, k=10):
+        """Full RAG pipeline: retrieve + generate."""
+        start_total = time.time()
+        
+        t0 = time.time()
+        docs = self.retrieve(query, k)
+        retrieval_time = time.time() - t0
+
+        context = self._format_context(docs)
+
+        t1 = time.time()
+        answer = self.generate(query, context)
+        generation_time = time.time() - t1
+
+        total_time = time.time() - start_total
+
+        return RAGResult(
+            answer=answer,
+            retrieved_docs=docs,
+            latency_retrieval=retrieval_time,
+            latency_generation=generation_time,
+            latency_total=total_time,
+            metadata={"method": self.name, "num_context_tokens": len(context)},
+        )

+ 322 - 0
bdirag/rag_methods/bid_field_extraction_rag.py

@@ -0,0 +1,322 @@
+# -*- coding: utf-8 -*-
+"""Bid-field extraction RAG tuned for tender/bidding HTML text recall."""
+import json
+import re
+import time
+from collections import OrderedDict
+
+import numpy as np
+from loguru import logger
+
+from .base import BaseRAG, RAGResult
+from .bm25_backend import get_bm25_okapi
+from .dedup import content_dedup_key
+from .tokenization import bm25_tokenize
+
+
+FIELD_SPECS = OrderedDict([
+    ("project_name", {
+        "label": "项目名称",
+        "aliases": ["项目名称", "采购项目名称", "招标项目名称", "Project Name"],
+    }),
+    ("project_code", {
+        "label": "项目编号",
+        "aliases": ["项目编号", "采购编号", "招标编号", "Project Code", "Tender No"],
+    }),
+    ("budget_amount", {
+        "label": "预算金额",
+        "aliases": ["预算金额", "采购预算", "项目预算", "最高限价", "Budget Amount", "Project Budget"],
+    }),
+    ("currency", {
+        "label": "币种",
+        "aliases": ["币种", "货币", "人民币", "Currency", "RMB"],
+    }),
+    ("bid_deadline", {
+        "label": "投标截止时间",
+        "aliases": ["投标截止时间", "递交截止时间", "提交投标文件截止时间", "Bid Submission Deadline", "Bid Deadline"],
+    }),
+    ("bid_opening_time", {
+        "label": "开标时间",
+        "aliases": ["开标时间", "开启时间", "Bid Opening Time", "Bid Opening"],
+    }),
+    ("bid_location", {
+        "label": "投标地点",
+        "aliases": ["投标地点", "递交地点", "开标地点", "Bid Location", "Venue"],
+    }),
+    ("purchaser_name", {
+        "label": "采购人名称",
+        "aliases": ["采购人", "招标人", "采购单位", "Purchaser", "Tenderer"],
+    }),
+    ("purchaser_contact", {
+        "label": "采购人联系人",
+        "aliases": ["采购人联系人", "联系人", "Contact Person", "Purchaser Contact"],
+    }),
+    ("purchaser_phone", {
+        "label": "采购人电话",
+        "aliases": ["采购人电话", "联系电话", "Contact Phone", "Purchaser Phone"],
+    }),
+    ("agency_name", {
+        "label": "代理机构名称",
+        "aliases": ["代理机构", "采购代理机构", "招标代理", "Agency Name", "Bidding Agency"],
+    }),
+    ("agency_contact", {
+        "label": "代理机构联系人",
+        "aliases": ["代理机构联系人", "Agency Contact"],
+    }),
+    ("agency_phone", {
+        "label": "代理机构电话",
+        "aliases": ["代理机构电话", "Agency Phone"],
+    }),
+    ("qualification_requirements", {
+        "label": "资格要求",
+        "aliases": ["资格要求", "资质要求", "投标人资格", "Qualification Requirements"],
+    }),
+    ("bid_bond_amount", {
+        "label": "投标保证金",
+        "aliases": ["投标保证金", "Bid Bond", "Bid Bond Amount"],
+    }),
+    ("performance_bond_amount", {
+        "label": "履约保证金",
+        "aliases": ["履约保证金", "Performance Bond", "Performance Bond Amount"],
+    }),
+    ("warranty_period", {
+        "label": "质保期",
+        "aliases": ["质保期", "保修期", "免费保修", "Warranty Period", "Warranty"],
+    }),
+    ("delivery_time", {
+        "label": "交货时间",
+        "aliases": ["交货时间", "交付时间", "工期", "建设周期", "Delivery Time", "Construction Period"],
+    }),
+    ("delivery_location", {
+        "label": "交货地点",
+        "aliases": ["交货地点", "交付地点", "安装地点", "Delivery Location", "Installation Location"],
+    }),
+    ("payment_terms", {
+        "label": "付款方式",
+        "aliases": ["付款方式", "支付方式", "Payment Terms", "Payment"],
+    }),
+    ("evaluation_method", {
+        "label": "评标方法",
+        "aliases": ["评标方法", "评审方法", "评分办法", "Evaluation Method"],
+    }),
+    ("scope_of_work", {
+        "label": "工作范围",
+        "aliases": ["工作范围", "采购内容", "招标范围", "建设内容", "Scope of Work", "Scope"],
+    }),
+])
+
+
+class BidFieldExtractionRAG(BaseRAG):
+    """Retrieve field-specific context before asking the LLM to extract JSON."""
+
+    def __init__(self, extraction_prompt_template=None, bm25_weight=0.45, vector_weight=0.55, **kwargs):
+        super(BidFieldExtractionRAG, self).__init__(**kwargs)
+        self.extraction_prompt_template = extraction_prompt_template or (
+            "你是招投标领域的信息抽取助手。请只依据给定上下文抽取字段。\n\n"
+            "目标字段:\n{fields}\n\n"
+            "上下文:\n{context}\n\n"
+            "要求:\n"
+            "1. 仅返回一个 JSON 对象,不要输出解释。\n"
+            "2. JSON key 必须使用目标字段英文名。\n"
+            "3. 无法从上下文确认的字段返回 null。\n"
+            "4. 金额、时间、联系人、电话等值要保留原文表述。"
+        )
+        self.target_fields = list(FIELD_SPECS.keys())
+        self.field_specs = FIELD_SPECS
+        self.bm25_weight = float(bm25_weight)
+        self.vector_weight = float(vector_weight)
+        self._documents = []
+        self._bm25 = None
+        self._bm25_corpus = []
+
+    def index_documents(self, documents):
+        self._documents = list(documents or [])
+        self._build_bm25(self._documents)
+        if self.embedding_model is not None and self.vector_store is not None:
+            super(BidFieldExtractionRAG, self).index_documents(self._documents)
+
+    def _build_bm25(self, documents):
+        self._bm25_corpus = []
+        for doc in documents:
+            meta_text = " ".join(str(v) for v in (doc.metadata or {}).values())
+            self._bm25_corpus.append(bm25_tokenize("{}\n{}".format(meta_text, doc.page_content)))
+        BM25Okapi = get_bm25_okapi()
+        self._bm25 = BM25Okapi(self._bm25_corpus) if self._bm25_corpus else None
+
+    def retrieve(self, query, k=10):
+        """Hybrid retrieval for a free-form query."""
+        return self._hybrid_retrieve(query, k=k)
+
+    def retrieve_for_fields(self, fields=None, k_per_field=3, max_docs=12):
+        """Retrieve and deduplicate contexts for each requested field."""
+        selected_fields = [f for f in (fields or self.target_fields) if f in self.field_specs]
+        merged = OrderedDict()
+
+        for field in selected_fields:
+            field_query = self._field_query(field)
+            for doc, score in self._hybrid_retrieve(field_query, k=k_per_field):
+                key = self._dedup_key(doc)
+                if key not in merged:
+                    metadata = dict(doc.metadata or {})
+                    metadata["matched_fields"] = [field]
+                    metadata["retrieval_query"] = field_query
+                    merged[key] = [doc, float(score), metadata]
+                else:
+                    merged[key][1] = max(merged[key][1], float(score))
+                    merged[key][2]["matched_fields"].append(field)
+
+        results = []
+        for doc, score, metadata in merged.values():
+            doc.metadata.update(metadata)
+            doc.metadata["matched_fields"] = sorted(set(doc.metadata["matched_fields"]))
+            results.append((doc, score))
+
+        results.sort(key=lambda item: item[1], reverse=True)
+        return self._deduplicate_results(results, max_docs)
+
+    def _field_query(self, field):
+        spec = self.field_specs[field]
+        return "{} {} {}".format(field, spec["label"], " ".join(spec["aliases"]))
+
+    def _hybrid_retrieve(self, query, k=10):
+        k = max(0, int(k or 0))
+        if k == 0:
+            return []
+
+        vector_scores = self._vector_scores(query, k=max(k * 4, k))
+        bm25_scores = self._bm25_scores(query)
+
+        merged = {}
+        for doc, score in vector_scores:
+            merged[self._dedup_key(doc)] = [doc, self._normalize_score(score) * self.vector_weight]
+
+        for doc, score in bm25_scores:
+            key = self._dedup_key(doc)
+            if key not in merged:
+                merged[key] = [doc, 0.0]
+            merged[key][1] += self._normalize_score(score) * self.bm25_weight
+
+        results = [(doc, score) for doc, score in merged.values() if score > 0]
+        results.sort(key=lambda item: item[1], reverse=True)
+        return self._deduplicate_results(results, k)
+
+    def _vector_scores(self, query, k):
+        if self.embedding_model is None or self.vector_store is None:
+            return []
+        try:
+            query_embedding = self.embedding_model.embed_query(query)
+            return self.vector_store.similarity_search(query_embedding, k)
+        except Exception as exc:
+            logger.warning("Vector retrieval failed, falling back to BM25: {}".format(exc))
+            return []
+
+    def _bm25_scores(self, query):
+        if self._bm25 is None:
+            store_docs = getattr(self.vector_store, "documents", None)
+            if store_docs:
+                self._documents = list(store_docs)
+                self._build_bm25(self._documents)
+        if self._bm25 is None:
+            return []
+
+        query_tokens = bm25_tokenize(query)
+        if not query_tokens:
+            return []
+
+        scores = self._bm25.get_scores(query_tokens)
+        max_score = float(np.max(scores)) if len(scores) else 0.0
+        if max_score <= 0:
+            return []
+
+        scored = []
+        for doc, score in zip(self._documents, scores):
+            score = float(score)
+            if score > 0:
+                scored.append((doc, score / max_score))
+        scored.sort(key=lambda item: item[1], reverse=True)
+        return scored
+
+    @staticmethod
+    def _normalize_score(score):
+        try:
+            value = float(score)
+        except (TypeError, ValueError):
+            return 0.0
+        if value < 0:
+            return 0.0
+        return min(value, 1.0)
+
+    def generate(self, query, context):
+        fields_str = "\n".join(
+            "- {}: {} ({})".format(field, spec["label"], ", ".join(spec["aliases"][:4]))
+            for field, spec in self.field_specs.items()
+            if field in self.target_fields
+        )
+        prompt = self.extraction_prompt_template.format(
+            fields=fields_str,
+            context=context[:6000],
+        )
+        return self._call_llm(prompt)
+
+    def extract_fields(self, query=None, k=5, fields=None):
+        start_total = time.time()
+
+        t0 = time.time()
+        requested_fields = [f for f in (fields or self.target_fields) if f in self.field_specs]
+        docs = self.retrieve_for_fields(requested_fields, k_per_field=max(1, k), max_docs=max(k, len(requested_fields)))
+        if query:
+            docs = self._merge_ranked(docs, self.retrieve(query, k=k))
+        retrieval_time = time.time() - t0
+
+        context = self._format_context(docs)
+
+        t1 = time.time()
+        raw_response = self.generate(query or "", context)
+        generation_time = time.time() - t1
+
+        fields_result = self._parse_extraction(raw_response)
+        total_time = time.time() - start_total
+
+        return RAGResult(
+            answer=json.dumps(fields_result, indent=2, ensure_ascii=False),
+            retrieved_docs=docs,
+            latency_retrieval=retrieval_time,
+            latency_generation=generation_time,
+            latency_total=total_time,
+            metadata={
+                "method": self.name,
+                "fields": fields_result,
+                "requested_fields": requested_fields,
+                "num_context_chars": len(context),
+            },
+        )
+
+    @staticmethod
+    def _merge_ranked(primary, secondary):
+        merged = OrderedDict()
+        for doc, score in list(primary or []) + list(secondary or []):
+            key = content_dedup_key(doc)
+            if key not in merged or float(score) > merged[key][1]:
+                merged[key] = [doc, float(score)]
+        results = [(doc, score) for doc, score in merged.values()]
+        results.sort(key=lambda item: item[1], reverse=True)
+        return results
+
+    def _parse_extraction(self, raw_response):
+        if not raw_response:
+            return {field: None for field in self.target_fields}
+
+        try:
+            start = raw_response.find("{")
+            end = raw_response.rfind("}") + 1
+            if start >= 0 and end > start:
+                json_str = raw_response[start:end]
+                json_str = re.sub(r",\s*}", "}", json_str)
+                parsed = json.loads(json_str)
+                return {field: parsed.get(field) for field in self.target_fields}
+        except (TypeError, ValueError, json.JSONDecodeError):
+            logger.warning("Failed to parse JSON from LLM response")
+
+        result = {field: None for field in self.target_fields}
+        result["raw"] = raw_response
+        return result

+ 67 - 0
bdirag/rag_methods/bm25_backend.py

@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+"""BM25 backend selection with a small local fallback."""
+import math
+from collections import Counter
+
+import numpy as np
+
+
+class SimpleBM25Okapi(object):
+    """Small BM25Okapi-compatible fallback used when rank_bm25 is unavailable."""
+
+    def __init__(self, corpus, k1=1.5, b=0.75, epsilon=0.25):
+        self.corpus = [list(doc or []) for doc in corpus]
+        self.k1 = k1
+        self.b = b
+        self.epsilon = epsilon
+        self.corpus_size = len(self.corpus)
+        self.doc_len = [len(doc) for doc in self.corpus]
+        self.avgdl = float(sum(self.doc_len)) / self.corpus_size if self.corpus_size else 0.0
+        self.doc_freqs = [Counter(doc) for doc in self.corpus]
+        self.idf = self._calc_idf()
+
+    def _calc_idf(self):
+        nd = {}
+        for freqs in self.doc_freqs:
+            for word in freqs:
+                nd[word] = nd.get(word, 0) + 1
+
+        idf = {}
+        negative_idfs = []
+        for word, freq in nd.items():
+            value = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5)
+            idf[word] = value
+            if value < 0:
+                negative_idfs.append(value)
+
+        average_idf = sum(idf.values()) / len(idf) if idf else 0.0
+        eps = self.epsilon * average_idf
+        for word in idf:
+            if idf[word] < 0:
+                idf[word] = eps
+        return idf
+
+    def get_scores(self, query):
+        scores = np.zeros(self.corpus_size)
+        if not query or not self.corpus_size or self.avgdl <= 0:
+            return scores
+
+        for token in query:
+            token_idf = self.idf.get(token)
+            if token_idf is None:
+                continue
+            for i, freqs in enumerate(self.doc_freqs):
+                freq = freqs.get(token, 0)
+                if freq == 0:
+                    continue
+                denominator = freq + self.k1 * (1 - self.b + self.b * self.doc_len[i] / self.avgdl)
+                scores[i] += token_idf * freq * (self.k1 + 1) / denominator
+        return scores
+
+
+def get_bm25_okapi():
+    try:
+        from rank_bm25 import BM25Okapi
+        return BM25Okapi
+    except ImportError:
+        return SimpleBM25Okapi

+ 448 - 0
bdirag/rag_methods/bm25_html_tree_rag.py

@@ -0,0 +1,448 @@
+# -*- coding: utf-8 -*-
+"""BM25 HTML Tree RAG - BM25 retrieval on HTML hierarchical tree structure."""
+import importlib.util
+import os
+
+import numpy as np
+from bs4 import BeautifulSoup
+from loguru import logger
+
+from ..document_processor import Document
+from .bm25_backend import get_bm25_okapi
+from .dedup import deduplicate_ranked_results
+from .tokenization import bm25_tokenize
+
+
+class BM25HTMLTreeRAG:
+    """
+    BM25-based retrieval on HTML hierarchical tree.
+    
+    This class:
+    1. Parses HTML into a hierarchical tree using ParseDocument
+    2. Extracts all text nodes from the tree
+    3. Builds BM25 index on tree node texts
+    4. Retrieves relevant subtrees based on query
+    """
+    
+    def __init__(self):
+        self.tree = []
+        self.all_nodes = []
+        self.all_texts = []
+        self.bm25 = None
+        self.html_content = ""
+    
+    def _tokenize(self, text):
+        """Tokenize mixed Chinese/English text for BM25."""
+        return bm25_tokenize(text)
+        
+    def _get_node_depth(self, node, visited=None):
+        """Calculate the depth of a node in the tree."""
+        if visited is None:
+            visited = set()
+        
+        depth = 0
+        current = node
+        while current is not None:
+            current_id = id(current)
+            if current_id in visited:
+                break
+            visited.add(current_id)
+            
+            parent = current.get("parent_title")
+            if parent is None:
+                break
+            current = parent
+            depth += 1
+        
+        return depth
+    
+    def _extract_node_text(self, node):
+        """Extract clean text from a tree node."""
+        text = node.get("text", "")
+        if not text:
+            return ""
+        
+        # If it's HTML content, extract text
+        if text.startswith("<") and text.endswith(">"):
+            try:
+                soup = BeautifulSoup(text, "lxml")
+                text = soup.get_text(strip=True)
+            except:
+                pass
+        
+        return text.strip()
+    
+    def _get_node_full_text(self, node):
+        """Get full text including children for a node."""
+        texts = [self._extract_node_text(node)]
+        
+        childs = node.get("child_title", [])
+        for child in childs:
+            child_text = self._get_node_full_text(child)
+            if child_text:
+                texts.append(child_text)
+        
+        return " ".join(texts)
+    
+    def _collect_all_nodes(self, tree_nodes, visited=None):
+        """Recursively collect all nodes from the tree."""
+        if visited is None:
+            visited = set()
+        
+        nodes = []
+        for node in tree_nodes:
+            node_id = id(node)
+            if node_id in visited:
+                continue
+            visited.add(node_id)
+            
+            nodes.append(node)
+            childs = node.get("child_title", [])
+            if childs:
+                nodes.extend(self._collect_all_nodes(childs, visited))
+        
+        return nodes
+    
+    def build_index(self, html_content, auto_merge_table=True):
+        """
+        Build BM25 index from HTML content.
+        
+        Args:
+            html_content: HTML string to parse
+            auto_merge_table: Whether to auto-merge tables
+        """
+        self.html_content = html_content
+        
+        # Import the local parser/htmlparser.py explicitly.  The package name
+        # collides with Python's historical stdlib parser module on older
+        # interpreters, so a normal ``from parser...`` import is unreliable.
+        parser_path = os.path.join(
+            os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
+            "parser",
+            "htmlparser.py",
+        )
+        spec = importlib.util.spec_from_file_location("bidirag_htmlparser", parser_path)
+        htmlparser = importlib.util.module_from_spec(spec)
+        spec.loader.exec_module(htmlparser)
+        ParseDocument = htmlparser.ParseDocument
+        
+        # Parse HTML into tree structure
+        pd = ParseDocument(html_content, auto_merge_table=auto_merge_table)
+        self.tree = pd.tree
+        
+        # Collect all nodes
+        self.all_nodes = self._collect_all_nodes(self.tree)
+        
+        # Extract texts for BM25
+        self.all_texts = []
+        for node in self.all_nodes:
+            text = self._get_node_full_text(node)
+            self.all_texts.append(self._tokenize(text) if text else [])
+        
+        # Build BM25 index
+        BM25Okapi = get_bm25_okapi()
+        self.bm25 = BM25Okapi(self.all_texts) if self.all_nodes else None
+        
+        logger.info("BM25HTMLTreeRAG: indexed {} nodes from HTML tree".format(len(self.all_nodes)))
+    
+    def retrieve_subtrees(self, query, k=5, min_score=0.0):
+        """
+        Retrieve relevant subtrees using BM25.
+        
+        Args:
+            query: Search query
+            k: Number of top results to return
+            min_score: Minimum score threshold
+            
+        Returns:
+            List of (node, score, subtree_text) tuples
+        """
+        if self.bm25 is None or k <= 0:
+            return []
+        
+        query_tokens = self._tokenize(query)
+        if not query_tokens:
+            return []
+
+        scores = self.bm25.get_scores(query_tokens)
+        
+        # Build id-to-index mapping for efficient lookup
+        id_to_index = {}
+        for j, node in enumerate(self.all_nodes):
+            id_to_index[id(node)] = j
+        
+        # Compute enhanced scores for each node with multi-level boosting
+        enhanced_scores = []
+        max_score = float(np.max(scores)) if len(scores) > 0 and np.max(scores) > 0 else 1.0
+        
+        for i, node in enumerate(self.all_nodes):
+            base_score = scores[i]
+            normalized_score = base_score / max_score  # Normalize to [0, 1]
+            
+            # Optimization 1: Boost from children (max child score)
+            childs = node.get("child_title", [])
+            child_boost = 0.0
+            child_match_count = 0  # Count how many children have matches
+            for child in childs:
+                child_id = id(child)
+                if child_id in id_to_index:
+                    child_idx = id_to_index[child_id]
+                    child_norm_score = scores[child_idx] / max_score
+                    if child_norm_score > 0.1:  # Count significant matches
+                        child_match_count += 1
+                    child_boost = max(child_boost, child_norm_score * 0.4)
+            
+            # Optimization 2: Boost from parent (if parent matches, boost children)
+            parent = node.get("parent_title")
+            parent_boost = 0.0
+            if parent is not None:
+                parent_id = id(parent)
+                if parent_id in id_to_index:
+                    parent_idx = id_to_index[parent_id]
+                    parent_norm_score = scores[parent_idx] / max_score
+                    if parent_norm_score > 0.3:  # Only boost if parent has significant match
+                        parent_boost = parent_norm_score * 0.3
+            
+            # Optimization 3: Title boost (headers are more important)
+            has_match = normalized_score > 0 or child_boost > 0
+            title_boost = 0.0
+            if has_match:
+                # Check if this is a title/header node
+                if node.get("sentence_title") is not None:
+                    title_boost = 0.15
+                # Also check node type
+                node_type = node.get("type", "").lower()
+                if node_type in ["h1", "h2", "h3", "h4", "h5", "h6", "title", "header"]:
+                    title_boost = max(title_boost, 0.2)
+            
+            # Optimization 4: Depth penalty (prefer mid-level nodes over very deep or root)
+            depth = self._get_node_depth(node)
+            depth_factor = 1.0
+            if depth == 0:  # Root level
+                depth_factor = 0.7
+            elif depth > 5:  # Too deep
+                depth_factor = 0.8
+            elif 1 <= depth <= 3:  # Sweet spot
+                depth_factor = 1.1
+            
+            # Optimization 5: Child diversity bonus - reward nodes with multiple matching children
+            child_diversity_bonus = 0.0
+            if child_match_count >= 2:
+                child_diversity_bonus = 0.05 * min(child_match_count, 4)  # Cap at 4 children
+            
+            # Optimization 6: Structural coherence bonus
+            # If this node is a header/title and has matching content children, boost it
+            structural_bonus = 0.0
+            if node.get("sentence_title") is not None and child_match_count > 0:
+                # Header with matching content is very valuable
+                structural_bonus = 0.08 * min(child_match_count, 3)
+            
+            # Combined score with better weights
+            combined_before_depth = (
+                0.45 * normalized_score +      # Own content (reduced to give more weight to structure)
+                0.20 * child_boost +           # Children influence
+                0.15 * parent_boost +          # Parent context
+                0.10 * title_boost +           # Title importance
+                0.05 * child_diversity_bonus + # Child diversity
+                0.05 * structural_bonus        # NEW: Structural coherence
+            )
+            
+            enhanced_score = combined_before_depth * depth_factor
+            
+            enhanced_scores.append((node, enhanced_score, base_score))
+        
+        # Sort by enhanced score
+        enhanced_scores.sort(key=lambda x: x[1], reverse=True)
+        
+        # Filter by min_score and take top k with smart deduplication
+        results = []
+        section_counts = {}  # Track how many results per section
+        seen_paths = set()  # Track exact paths to avoid duplicates
+        
+        for node, enhanced_score, base_score in enhanced_scores:
+            if len(results) >= k:
+                break
+            if enhanced_score <= 0:
+                break
+            if enhanced_score < min_score:
+                continue
+            
+            # Get path for smarter deduplication
+            path = self.get_node_path(node)
+            
+            # Avoid exact duplicate paths
+            if path in seen_paths:
+                continue
+            seen_paths.add(path)
+            
+            # Extract the main section (first level)
+            section_parts = path.split(" > ")
+            main_section = section_parts[0] if section_parts else path
+            
+            # Dynamic max results per section based on node characteristics
+            # If this is a high-scoring parent node with multiple matching children, allow more results
+            childs = node.get("child_title", [])
+            matching_children = sum(1 for c in childs if id(c) in id_to_index and scores[id_to_index[id(c)]] / max_score > 0.1)
+            
+            # Check if this is a header node with strong structural bonus
+            is_structural_header = (
+                node.get("sentence_title") is not None and 
+                matching_children >= 2 and
+                enhanced_score > 0.3
+            )
+            
+            # Allow up to 3-4 results for rich sections, otherwise 2
+            if is_structural_header:
+                max_per_section = 4  # High-value structural nodes get more slots
+            elif matching_children >= 2:
+                max_per_section = 3  # Rich sub-structure
+            else:
+                max_per_section = 2  # Standard diversity
+            
+            # Count how many results we already have from this section
+            section_count = section_counts.get(main_section, 0)
+            
+            if section_count >= max_per_section:
+                continue
+            
+            # Get subtree text
+            subtree_text = self._get_subtree_text(node)
+            results.append((node, enhanced_score, subtree_text))
+            
+            # Update section count
+            section_counts[main_section] = section_count + 1
+        
+        return results
+    
+    def _get_subtree_text(self, node, indent=0):
+        """Get formatted text of a subtree with hierarchy."""
+        lines = []
+        
+        node_text = self._extract_node_text(node)
+        if node_text:
+            # Add title info if available
+            title_info = ""
+            if node.get("sentence_title") is not None:
+                title_text = node.get("sentence_title_text", "")
+                if title_text:
+                    title_info = " [标题: {}]".format(title_text)
+            
+            prefix = "  " * indent
+            lines.append("{}{}{}".format(prefix, node_text, title_info))
+        
+        # Recursively get children
+        childs = node.get("child_title", [])
+        for child in childs:
+            child_text = self._get_subtree_text(child, indent + 1)
+            if child_text:
+                lines.append(child_text)
+        
+        return "\n".join(lines)
+    
+    def _is_relevant_to_query(self, node, query_tokens, id_to_index, scores, max_score, threshold=0.1):
+        """
+        Check if a node or its subtree is relevant to the query.
+        This includes checking the node itself and all its children.
+        Also considers parent context for better relevance judgment.
+        """
+        # Check the node itself
+        node_id = id(node)
+        if node_id in id_to_index:
+            node_score = scores[id_to_index[node_id]] / max_score
+            if node_score > threshold:
+                return True
+        
+        # Check all children recursively
+        childs = node.get("child_title", [])
+        for child in childs:
+            if self._is_relevant_to_query(child, query_tokens, id_to_index, scores, max_score, threshold):
+                return True
+        
+        # Check parent context - if parent has strong match, child might be relevant too
+        parent = node.get("parent_title")
+        if parent is not None:
+            parent_id = id(parent)
+            if parent_id in id_to_index:
+                parent_score = scores[id_to_index[parent_id]] / max_score
+                if parent_score > 0.3:  # Parent has significant match
+                    return True
+        
+        return False
+    
+    def get_node_path(self, node):
+        """Get the path from root to this node."""
+        path = []
+        current = node
+        
+        while current is not None:
+            text = self._extract_node_text(current)
+            if text:
+                path.insert(0, text[:50])
+            current = current.get("parent_title")
+        
+        return " > ".join(path)
+    
+    def query(self, query, k=5):
+        """
+        Full query pipeline: retrieve subtrees and format results.
+        
+        Args:
+            query: Search query
+            k: Number of results
+            
+        Returns:
+            List of result dicts with node, score, path, and content
+        """
+        results = self.retrieve_subtrees(query, k)
+        
+        formatted_results = []
+        for node, score, subtree_text in results:
+            path = self.get_node_path(node)
+            
+            doc = Document(
+                page_content=subtree_text,
+                metadata={
+                    "node_type": node.get("type", "unknown"),
+                    "path": path,
+                    "score": score,
+                    "title": node.get("sentence_title_text", ""),
+                }
+            )
+            
+            formatted_results.append((doc, score))
+        
+        return deduplicate_ranked_results(formatted_results, k)
+
+
+# Example usage
+if __name__ == "__main__":
+    sample_html = """
+    <html>
+    <body>
+        <h1>招标公告</h1>
+        <h2>一、项目概况</h2>
+        <p>本项目预算金额为5000万元,招标编号为XX-ZB-2024-001。</p>
+        <h2>二、投标人资格要求</h2>
+        <p>1. 具有独立承担民事责任的能力</p>
+        <p>2. 具有相关资质证书</p>
+        <h2>三、评标方法</h2>
+        <p>采用综合评分法,技术分占比60%,商务分占比40%。</p>
+        <h2>四、投标文件递交</h2>
+        <p>截止时间:2024年12月31日</p>
+    </body>
+    </html>
+    """
+    
+    rag = BM25HTMLTreeRAG()
+    rag.build_index(sample_html)
+    
+    queries = ["预算金额", "资质要求", "评标方法"]
+    
+    for query in queries:
+        print("\n查询: {}".format(query))
+        print("-" * 60)
+        results = rag.query(query, k=3)
+        for i, (doc, score) in enumerate(results, 1):
+            print("  [{}] 分数: {:.4f}".format(i, score))
+            print("      路径: {}".format(doc.metadata.get("path", "")))
+            print("      内容: {}...".format(doc.page_content[:100]))

+ 87 - 0
bdirag/rag_methods/bm25_rag.py

@@ -0,0 +1,87 @@
+# -*- coding: utf-8 -*-
+"""BM25 RAG - probabilistic retrieval model."""
+from .base import BaseRAG, RAGResult
+from .bm25_backend import get_bm25_okapi
+from .tokenization import bm25_tokenize
+from loguru import logger
+
+
+class BM25RAG(BaseRAG):
+    def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
+        self.embedding_model = None
+        self.vector_store = None
+        self.llm_client = llm_client
+        self.llm_model = llm_model
+        self.name = self.__class__.__name__
+        self.bm25 = None
+        self._all_texts = []
+        self._all_documents = []
+
+    def _tokenize(self, text):
+        """Tokenize mixed Chinese/English text for BM25."""
+        return bm25_tokenize(text)
+
+    def index_documents(self, documents):
+        self._all_documents = list(documents or [])
+        self._all_texts = [self._tokenize(doc.page_content) for doc in self._all_documents]
+        if not self._all_documents:
+            self.bm25 = None
+            logger.info("BM25 index built with 0 documents")
+            return
+
+        BM25Okapi = get_bm25_okapi()
+        self.bm25 = BM25Okapi(self._all_texts)
+        logger.info("BM25 index built with {} documents".format(len(self._all_documents)))
+
+    def retrieve(self, query, k=10):
+        if self.bm25 is None or k <= 0:
+            return []
+        
+        query_tokens = self._tokenize(query)
+        if not query_tokens:
+            return []
+
+        scores = self.bm25.get_scores(query_tokens)
+
+        scored_docs = []
+        for i, score in enumerate(scores):
+            # Include all documents with non-zero scores (BM25 can return negative scores)
+            if score != 0:
+                scored_docs.append((self._all_documents[i], float(score)))
+
+        scored_docs.sort(key=lambda x: x[1], reverse=True)
+        return self._deduplicate_results(scored_docs, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(BM25关键词检索),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+    def query(self, query, k=10):
+        import time
+        start_total = time.time()
+        
+        t0 = time.time()
+        docs = self.retrieve(query, k)
+        retrieval_time = time.time() - t0
+
+        context = self._format_context(docs)
+
+        t1 = time.time()
+        answer = self.generate(query, context)
+        generation_time = time.time() - t1
+
+        total_time = time.time() - start_total
+
+        return RAGResult(
+            answer=answer,
+            retrieved_docs=docs,
+            latency_retrieval=retrieval_time,
+            latency_generation=generation_time,
+            latency_total=total_time,
+            metadata={"method": self.name, "num_context_tokens": len(context)},
+        )

+ 46 - 0
bdirag/rag_methods/contextual_compression_rag.py

@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+"""Contextual Compression RAG - compress documents to extract relevant parts."""
+from .base import BaseRAG
+from loguru import logger
+
+
+class ContextualCompressionRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def _compress_document(self, doc, query):
+        prompt = (
+            "提取以下文档中与问题相关的核心信息,去除不相关的内容。\n\n"
+            "问题:{}\n\n"
+            "文档内容:\n{}\n\n"
+            "提取的核心信息:".format(query, doc.page_content)
+        )
+        compressed_content = self._call_llm(prompt)
+        return compressed_content
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
+        
+        compressed_results = []
+        for doc, score in initial_results:
+            compressed_content = self._compress_document(doc, query)
+            compressed_doc = type(doc)(
+                page_content=compressed_content,
+                metadata=dict(doc.metadata, is_compressed=True)
+            )
+            compressed_results.append((compressed_doc, score))
+        
+        logger.info("ContextualCompression: compressed {} documents".format(len(compressed_results)))
+        
+        compressed_results.sort(key=lambda x: x[1], reverse=True)
+        return self._deduplicate_results(compressed_results, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过上下文压缩的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 53 - 0
bdirag/rag_methods/corrective_rag.py

@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+"""Corrective RAG - correct the answer by checking against retrieved documents."""
+from .base import BaseRAG
+from loguru import logger
+
+
+class CorrectiveRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def _verify_answer(self, query, answer, docs):
+        context = "\n\n".join([doc.page_content[:200] for doc, _ in docs[:3]])
+        prompt = (
+            "验证以下答案是否与参考文档一致。如果不一致,请指出需要修正的部分。\n\n"
+            "问题:{}\n\n"
+            "参考文档(前3个):\n{}\n\n"
+            "答案:\n{}\n\n"
+            "验证结果(一致/不一致)及需要修正的部分:".format(query, context, answer[:500])
+        )
+        return self._call_llm(prompt)
+
+    def _regenerate_answer(self, query, original_answer, verification_result, context):
+        prompt = (
+            "根据验证结果,修正以下答案。\n\n"
+            "原始答案:\n{}\n\n"
+            "验证结果:\n{}\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "修正后的答案:".format(original_answer[:500], verification_result, context, query)
+        )
+        return self._call_llm(prompt)
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
+
+    def generate(self, query, context):
+        answer = self._call_llm(
+            "根据以下参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        
+        docs = [(type('Doc', (), {'page_content': c.split('\n\n---\n\n')[0]})(), 0.0) for c in context.split('\n\n---\n\n')]
+        
+        verification = self._verify_answer(query, answer, docs[:3])
+        
+        if "不一致" in verification:
+            logger.info("CorrectiveRAG: answer verification failed, regenerating...")
+            answer = self._regenerate_answer(query, answer, verification, context)
+        
+        return answer

+ 46 - 0
bdirag/rag_methods/dedup.py

@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+"""Helpers for deduplicating ranked retrieval results by document content."""
+import re
+
+
+_WHITESPACE_RE = re.compile(r"\s+")
+
+
+def normalized_content(text):
+    """Normalize content for exact duplicate detection."""
+    if text is None:
+        return ""
+    return _WHITESPACE_RE.sub(" ", str(text).strip())
+
+
+def content_dedup_key(doc):
+    """Build a stable dedup key, preferring normalized page content."""
+    content = normalized_content(getattr(doc, "page_content", ""))
+    if content:
+        return ("content", content)
+    return ("object", id(doc))
+
+
+def deduplicate_ranked_results(results, k=None):
+    """
+    Deduplicate ``(doc, score)`` retrieval results by content.
+
+    The highest scoring duplicate wins. Ties keep the earlier result. Final
+    output is sorted by score descending, with original order as the tie-breaker.
+    """
+    if k is not None and k <= 0:
+        return []
+
+    best_by_key = {}
+    for order, (doc, score) in enumerate(results or []):
+        key = content_dedup_key(doc)
+        score = float(score)
+        current = best_by_key.get(key)
+        if current is None or score > current[2]:
+            best_by_key[key] = (order, doc, score)
+
+    ranked = sorted(best_by_key.values(), key=lambda item: (-item[2], item[0]))
+    deduped = [(doc, score) for _, doc, score in ranked]
+    if k is None:
+        return deduped
+    return deduped[:k]

+ 61 - 0
bdirag/rag_methods/ensemble_rag.py

@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+"""Ensemble RAG - combine multiple retrieval strategies for robust results."""
+from .base import BaseRAG
+from loguru import logger
+
+
+class EnsembleRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def retrieve(self, query, k=10):
+        strategies = [
+            ("semantic", self._semantic_retrieve),
+            ("keyword", self._keyword_retrieve),
+        ]
+        
+        all_results = {}
+        per_strategy_k = k
+        
+        for name, strategy in strategies:
+            try:
+                results = strategy(query, per_strategy_k)
+                for doc, score in results:
+                    key = self._dedup_key(doc)
+                    if key not in all_results:
+                        all_results[key] = (doc, 0.0)
+                    all_results[key] = (doc, all_results[key][1] + score)
+                logger.info("Ensemble strategy '{}' returned {} results".format(name, len(results)))
+            except Exception as e:
+                logger.warning("Ensemble strategy '{}' failed: {}".format(name, e))
+        
+        results = list(all_results.values())
+        results.sort(key=lambda x: x[1], reverse=True)
+        return self._deduplicate_results(results, k)
+
+    def _semantic_retrieve(self, query, k):
+        query_embedding = self.embedding_model.embed_query(query)
+        return self.vector_store.similarity_search(query_embedding, k)
+
+    def _keyword_retrieve(self, query, k):
+        query_embedding = self.embedding_model.embed_query(query)
+        results = self.vector_store.similarity_search(query_embedding, k * 5)
+        keyword_results = []
+        for doc, score in results:
+            query_words = set(query.split())
+            doc_words = set(doc.page_content.split())
+            overlap = len(query_words & doc_words) / len(query_words) if query_words else 0
+            if overlap > 0.3:
+                keyword_results.append((doc, score * overlap))
+        
+        keyword_results.sort(key=lambda x: x[1], reverse=True)
+        return keyword_results[:k]
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下多策略集成的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 49 - 0
bdirag/rag_methods/flare_rag.py

@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+"""FLARE RAG - Forward-Looking Active REtrieval augmented generation."""
+from .base import BaseRAG
+from loguru import logger
+
+
+class FLARERAG(BaseRAG):
+    def __init__(self, max_iterations=3, **kwargs):
+        super().__init__(**kwargs)
+        self.max_iterations = max_iterations
+
+    def _generate_with_retrieval(self, query):
+        prompt = "根据已知信息,逐步回答以下问题。如果某一步需要更多信息,请标记[需要检索]。\n\n问题:{}\n\n回答:".format(query)
+        
+        current_answer = ""
+        for i in range(self.max_iterations):
+            step_prompt = prompt + current_answer
+            partial_answer = self._call_llm(step_prompt)
+            
+            if "[需要检索]" not in partial_answer:
+                current_answer += partial_answer
+                break
+            
+            sentences = partial_answer.split("。")
+            new_sentences = []
+            for sent in sentences:
+                if "[需要检索]" in sent:
+                    retrieval_query = sent.replace("[需要检索]", "").strip()
+                    if retrieval_query:
+                        q_embedding = self.embedding_model.embed_query(retrieval_query)
+                        results = self.vector_store.similarity_search(q_embedding, 3)
+                        additional_context = "\n".join([doc.page_content[:200] for doc, _ in results])
+                        new_sentences.append(sent + "(补充信息:" + additional_context + ")")
+                    else:
+                        new_sentences.append(sent)
+                else:
+                    new_sentences.append(sent)
+            
+            current_answer += "。".join(new_sentences) + "。"
+            logger.info("FLARE iteration {}: retrieved additional context".format(i + 1))
+        
+        return current_answer
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
+
+    def generate(self, query, context):
+        return self._generate_with_retrieval(query)

+ 83 - 0
bdirag/rag_methods/graph_rag.py

@@ -0,0 +1,83 @@
+# -*- coding: utf-8 -*-
+"""Graph RAG - knowledge graph enhanced retrieval for bidding domain."""
+import re
+from .base import BaseRAG
+from loguru import logger
+
+
+class GraphRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self.graph = {}
+        self._build_graph_patterns = [
+            (r'预算金额[::]?\s*([0-9.]+)\s*万', 'budget_amount'),
+            (r'投标保证金[::]?\s*([0-9.]+)\s*万', 'bid_bond'),
+            (r'质保期[::]?\s*([^,,\n。]+)', 'warranty'),
+            (r'交货时间[::]?\s*([^,,\n。]+)', 'delivery_time'),
+            (r'资质要求[::]?\s*([^,,\n。]+)', 'qualification'),
+            (r'评标方法[::]?\s*([^,,\n。]+)', 'evaluation_method'),
+        ]
+
+    def _extract_entities(self, text):
+        entities = {}
+        for pattern, entity_type in self._build_graph_patterns:
+            match = re.search(pattern, text)
+            if match:
+                entities[entity_type] = match.group(1).strip()
+        return entities
+
+    def index_documents(self, documents):
+        self.graph = {}
+        for doc in documents:
+            entities = self._extract_entities(doc.page_content)
+            if entities:
+                self.graph[id(doc)] = {
+                    'doc': doc,
+                    'entities': entities
+                }
+        
+        texts = [doc.page_content for doc in documents]
+        embeddings = self.embedding_model.embed_documents(texts)
+        self.vector_store.add_documents(documents, embeddings)
+        
+        logger.info("GraphRAG built graph with {} nodes".format(len(self.graph)))
+
+    def _query_graph(self, query):
+        matching_nodes = []
+        for node_id, node_data in self.graph.items():
+            for entity_type, entity_value in node_data['entities'].items():
+                if entity_value in query:
+                    matching_nodes.append((node_data['doc'], 1.0))
+                    break
+        return matching_nodes
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        semantic_results = self.vector_store.similarity_search(query_embedding, k)
+        
+        graph_results = self._query_graph(query)
+        
+        combined = {}
+        for doc, score in semantic_results:
+            combined[self._dedup_key(doc)] = (doc, score)
+        
+        for doc, score in graph_results:
+            key = self._dedup_key(doc)
+            if key in combined:
+                combined[key] = (doc, combined[key][1] + 0.2)
+            else:
+                combined[key] = (doc, score)
+        
+        results = list(combined.values())
+        results.sort(key=lambda x: x[1], reverse=True)
+        
+        return self._deduplicate_results(results, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(知识图谱增强检索),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 59 - 0
bdirag/rag_methods/hybrid_search_rag.py

@@ -0,0 +1,59 @@
+# -*- coding: utf-8 -*-
+"""Hybrid Search RAG - combines semantic and keyword (BM25) retrieval."""
+import numpy as np
+from .base import BaseRAG
+
+
+class HybridSearchRAG(BaseRAG):
+    def __init__(self, bm25_weight=0.3, semantic_weight=0.7, **kwargs):
+        super().__init__(**kwargs)
+        self.bm25_weight = bm25_weight
+        self.semantic_weight = semantic_weight
+        self.bm25 = None
+        self._all_texts = []
+        self._all_documents = []
+
+    def index_documents(self, documents):
+        texts = [doc.page_content for doc in documents]
+        embeddings = self.embedding_model.embed_documents(texts)
+        self.vector_store.add_documents(documents, embeddings)
+        
+        self._all_documents = documents
+        self._all_texts = [text.split() for text in texts]
+        
+        from rank_bm25 import BM25Okapi
+        self.bm25 = BM25Okapi(self._all_texts)
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        semantic_results = self.vector_store.similarity_search(query_embedding, k * 2)
+        
+        bm25_scores = np.zeros(len(self._all_documents))
+        query_tokens = query.split()
+        bm25_scores = self.bm25.get_scores(query_tokens)
+        
+        semantic_scores = np.zeros(len(self._all_documents))
+        for doc, score in semantic_results:
+            for i, ref_doc in enumerate(self._all_documents):
+                if doc.page_content == ref_doc.page_content:
+                    semantic_scores[i] = score
+                    break
+        
+        hybrid_scores = self.semantic_weight * semantic_scores + self.bm25_weight * bm25_scores
+        
+        top_indices = np.argsort(hybrid_scores)[::-1][:k]
+        results = []
+        for idx in top_indices:
+            if hybrid_scores[idx] > 0:
+                results.append((self._all_documents[idx], float(hybrid_scores[idx])))
+        
+        return self._deduplicate_results(results, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下混合检索(语义+关键词)的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 30 - 0
bdirag/rag_methods/hyde_rag.py

@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+"""HyDE RAG - generate hypothetical document to improve embedding."""
+from .base import BaseRAG
+
+
+class HyDERAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def _generate_hypothetical_doc(self, query):
+        prompt = (
+            "请根据以下问题,生成一个假设的、包含相关信息的答案文档。\n\n"
+            "问题:{}\n\n"
+            "假设文档:".format(query)
+        )
+        return self._call_llm(prompt)
+
+    def retrieve(self, query, k=10):
+        hypothetical_doc = self._generate_hypothetical_doc(query)
+        hypo_embedding = self.embedding_model.embed_query(hypothetical_doc)
+        return self._deduplicate_results(self.vector_store.similarity_search(hypo_embedding, k), k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(基于假设文档检索),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 89 - 0
bdirag/rag_methods/keyword_rag.py

@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+"""Keyword RAG - unified interface for keyword-based retrieval (BM25 or TF-IDF)."""
+from .base import BaseRAG, RAGResult
+from loguru import logger
+
+
+class KeywordRAG(BaseRAG):
+    def __init__(self, search_method="bm25", llm_client=None, llm_model="gpt-4o", **kwargs):
+        self.embedding_model = None
+        self.vector_store = None
+        self.llm_client = llm_client
+        self.llm_model = llm_model
+        self.name = self.__class__.__name__
+        self.search_method = search_method
+        self._all_documents = []
+        self._all_texts = []
+        self.bm25 = None
+        self.tfidf_matrix = None
+        self.vectorizer = None
+
+    def index_documents(self, documents):
+        self._all_documents = documents
+        self._all_texts = [doc.page_content for doc in documents]
+
+        if self.search_method == "bm25":
+            from rank_bm25 import BM25Okapi
+            tokenized = [t.split() for t in self._all_texts]
+            self.bm25 = BM25Okapi(tokenized)
+        elif self.search_method == "tfidf":
+            from sklearn.feature_extraction.text import TfidfVectorizer
+            self.vectorizer = TfidfVectorizer()
+            self.tfidf_matrix = self.vectorizer.fit_transform(self._all_texts)
+
+        logger.info("KeywordRAG ({}) index built with {} documents".format(self.search_method, len(documents)))
+
+    def retrieve(self, query, k=10):
+        if self.search_method == "bm25":
+            query_tokens = query.split()
+            scores = self.bm25.get_scores(query_tokens)
+        elif self.search_method == "tfidf":
+            from sklearn.metrics.pairwise import cosine_similarity
+            query_vec = self.vectorizer.transform([query])
+            scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
+        else:
+            return []
+
+        import numpy as np
+        top_indices = np.argsort(scores)[::-1]
+
+        results = []
+        for idx in top_indices:
+            if scores[idx] > 0:
+                results.append((self._all_documents[idx], float(scores[idx])))
+
+        return self._deduplicate_results(results, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(关键词检索:{}),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(self.search_method, context, query)
+        )
+        return self._call_llm(prompt)
+
+    def query(self, query, k=10):
+        import time
+        start_total = time.time()
+        
+        t0 = time.time()
+        docs = self.retrieve(query, k)
+        retrieval_time = time.time() - t0
+
+        context = self._format_context(docs)
+
+        t1 = time.time()
+        answer = self.generate(query, context)
+        generation_time = time.time() - t1
+
+        total_time = time.time() - start_total
+
+        return RAGResult(
+            answer=answer,
+            retrieved_docs=docs,
+            latency_retrieval=retrieval_time,
+            latency_generation=generation_time,
+            latency_total=total_time,
+            metadata={"method": self.name, "num_context_tokens": len(context)},
+        )

+ 44 - 0
bdirag/rag_methods/llm_filter_rag.py

@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+"""LLM Filter RAG - filter retrieved documents by LLM relevance scoring."""
+from .base import BaseRAG
+
+
+class LLMFilterRAG(BaseRAG):
+    def __init__(self, filter_threshold=0.5, **kwargs):
+        super().__init__(**kwargs)
+        self.filter_threshold = filter_threshold
+
+    def _score_relevance(self, query, doc):
+        prompt = (
+            "评估以下文档与问题的相关性,给出0-1之间的分数。只返回分数数字。\n\n"
+            "问题:{}\n\n"
+            "文档:{}\n\n"
+            "相关性分数:".format(query, doc.page_content[:300])
+        )
+        try:
+            response = self._call_llm(prompt).strip()
+            return float(response)
+        except:
+            return 0.5
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        initial_results = self.vector_store.similarity_search(query_embedding, k * 3)
+        
+        filtered_results = []
+        for doc, score in initial_results:
+            relevance = self._score_relevance(query, doc)
+            if relevance >= self.filter_threshold:
+                filtered_results.append((doc, score * relevance))
+        
+        filtered_results.sort(key=lambda x: x[1], reverse=True)
+        return self._deduplicate_results(filtered_results, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过LLM精选的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 58 - 0
bdirag/rag_methods/metadata_filter_rag.py

@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+"""Metadata Filter RAG - filter by document type before retrieval."""
+from .base import BaseRAG
+from loguru import logger
+
+
+class MetadataFilterRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self.doc_type_keywords = {
+            "bid_announcement": ["招标公告", "采购公告", "公开招标公告"],
+            "tender_document": ["招标文件", "投标文件", "评标标准"],
+            "answer_to_queries": ["答疑", "澄清", "补充公告"],
+            "result": ["中标", "结果", "成交"],
+        }
+
+    def _classify_document(self, doc):
+        content = doc.page_content
+        best_type = "general"
+        max_score = 0
+        
+        for doc_type, keywords in self.doc_type_keywords.items():
+            score = sum(1 for kw in keywords if kw in content)
+            if score > max_score:
+                max_score = score
+                best_type = doc_type
+        
+        return best_type
+
+    def index_documents(self, documents):
+        for doc in documents:
+            doc.metadata["doc_type"] = self._classify_document(doc)
+        
+        texts = [doc.page_content for doc in documents]
+        embeddings = self.embedding_model.embed_documents(texts)
+        self.vector_store.add_documents(documents, embeddings)
+
+    def retrieve(self, query, k=10, doc_type=None):
+        query_embedding = self.embedding_model.embed_query(query)
+        initial_results = self.vector_store.similarity_search(query_embedding, k * 5)
+        
+        if doc_type:
+            filtered = [(doc, score) for doc, score in initial_results if doc.metadata.get("doc_type") == doc_type]
+            if filtered:
+                logger.info("MetadataFilter: filtered by type '{}' ({} docs)".format(doc_type, len(filtered)))
+                return self._deduplicate_results(filtered, k)
+        
+        initial_results.sort(key=lambda x: x[1], reverse=True)
+        return self._deduplicate_results(initial_results, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过元数据过滤的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 46 - 0
bdirag/rag_methods/multi_query_rag.py

@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+"""Multi-Query RAG - generate multiple query variants for retrieval."""
+from .base import BaseRAG
+
+
+class MultiQueryRAG(BaseRAG):
+    def __init__(self, num_queries=3, **kwargs):
+        super().__init__(**kwargs)
+        self.num_queries = num_queries
+
+    def _generate_queries(self, query):
+        prompt = (
+            "请为以下问题生成{}个不同的查询变体,用于提高检索效果。\n\n"
+            "原始问题:{}\n\n"
+            "查询变体(用换行分隔):".format(self.num_queries, query)
+        )
+        response = self._call_llm(prompt)
+        return [q.strip() for q in response.strip().split("\n") if q.strip()][:self.num_queries]
+
+    def retrieve(self, query, k=10):
+        queries = self._generate_queries(query)
+        all_queries = [query] + queries
+        
+        all_results = {}
+        per_query_k = max(k // len(all_queries), 3)
+        
+        for q in all_queries:
+            q_embedding = self.embedding_model.embed_query(q)
+            results = self.vector_store.similarity_search(q_embedding, per_query_k)
+            for doc, score in results:
+                key = self._dedup_key(doc)
+                if key not in all_results or score > all_results[key][1]:
+                    all_results[key] = (doc, score)
+        
+        results = list(all_results.values())
+        results.sort(key=lambda x: x[1], reverse=True)
+        return self._deduplicate_results(results, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下多角度查询检索的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 22 - 0
bdirag/rag_methods/naive_rag.py

@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+"""Naive RAG - basic semantic retrieval."""
+from .base import BaseRAG
+
+
+class NaiveRAG(BaseRAG):
+    def __init__(self, retrieval_prompt_template=None, **kwargs):
+        super().__init__(**kwargs)
+        self.retrieval_prompt_template = retrieval_prompt_template or (
+            "根据以下参考文档,回答问题。\n\n"
+            "参考文档:\n{context}\n\n"
+            "问题:{query}\n\n"
+            "请详细回答,如果参考文档中没有相关信息,请说明无法从文档中找到答案。"
+        )
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
+
+    def generate(self, query, context):
+        prompt = self.retrieval_prompt_template.format(context=context, query=query)
+        return self._call_llm(prompt)

+ 71 - 0
bdirag/rag_methods/parent_document_rag.py

@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+"""Parent Document RAG - retrieve child chunks, return parent documents."""
+from .base import BaseRAG, Document
+from loguru import logger
+
+
+class ParentDocumentRAG(BaseRAG):
+    def __init__(self, parent_chunk_size=1500, **kwargs):
+        super().__init__(**kwargs)
+        self.parent_chunk_size = parent_chunk_size
+        self.parent_docs = []
+        self.child_to_parent = {}
+
+    def index_documents(self, documents):
+        self.parent_docs = []
+        self.child_to_parent = {}
+        
+        for i, doc in enumerate(documents):
+            words = doc.page_content.split()
+            for j in range(0, len(words), self.parent_chunk_size):
+                parent_text = " ".join(words[j:j + self.parent_chunk_size])
+                parent_doc = Document(
+                    page_content=parent_text,
+                    metadata=dict(doc.metadata, chunk_index=j, is_parent=True)
+                )
+                self.parent_docs.append(parent_doc)
+        
+        child_texts = []
+        for doc in documents:
+            words = doc.page_content.split()
+            for j in range(0, len(words), self.parent_chunk_size):
+                chunk_words = words[j:j + self.parent_chunk_size]
+                for k in range(0, len(chunk_words), 512):
+                    child_text = " ".join(chunk_words[k:k + 512])
+                    child_id = len(child_texts)
+                    child_texts.append(child_text)
+                    self.child_to_parent[child_id] = child_id // 3
+        
+        if child_texts:
+            embeddings = self.embedding_model.embed_documents(child_texts)
+            child_docs = [
+                Document(page_content=text, metadata={"is_parent": False})
+                for text in child_texts
+            ]
+            self.vector_store.add_documents(child_docs, embeddings)
+        
+        logger.info("ParentDocumentRAG: {} parents, {} children".format(len(self.parent_docs), len(child_texts)))
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        child_results = self.vector_store.similarity_search(query_embedding, k * 2)
+        
+        parent_map = {}
+        for child_doc, score in child_results:
+            for i, parent_doc in enumerate(self.parent_docs):
+                if child_doc.page_content[:50] in parent_doc.page_content:
+                    if i not in parent_map or score > parent_map[i]:
+                        parent_map[i] = (parent_doc, score)
+                    break
+        
+        parent_results = sorted(parent_map.values(), key=lambda x: x[1], reverse=True)
+        return self._deduplicate_results(parent_results, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(父子文档检索,包含完整上下文),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 64 - 0
bdirag/rag_methods/query_routing_rag.py

@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+"""Query Routing RAG - classify query and boost retrieval by category keywords."""
+from .base import BaseRAG
+from loguru import logger
+
+
+class QueryRoutingRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def _classify_query(self, query):
+        categories = [
+            "budget", "deadline", "qualification", "evaluation",
+            "payment", "warranty", "delivery", "contact", "scope"
+        ]
+        categories_str = ", ".join(categories)
+        prompt = (
+            "将以下问题分类到以下类别之一:{}\n\n"
+            "问题:{}\n\n"
+            "类别:".format(categories_str, query)
+        )
+        return self._call_llm(prompt).strip().lower()
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        semantic_results = self.vector_store.similarity_search(query_embedding, k)
+        
+        category = self._classify_query(query)
+        logger.info("QueryRouting: category={}".format(category))
+        
+        category_boost = {}
+        for doc, score in semantic_results:
+            content_lower = doc.page_content.lower()
+            category_keywords = {
+                "budget": ["预算", "金额", "价格", "费用", "报价"],
+                "deadline": ["截止", "时间", "日期", "开标"],
+                "qualification": ["资格", "要求", "证书", "业绩"],
+                "evaluation": ["评标", "评价", "分数", "方法"],
+                "payment": ["付款", "结算", "进度", "保证金"],
+                "warranty": ["质保", "维修", "售后", "服务"],
+                "delivery": ["交货", "工期", "交付", "地点"],
+                "contact": ["联系人", "电话", "邮箱"],
+                "scope": ["范围", "内容", "清单", "设备"],
+            }
+            keywords = category_keywords.get(category, [])
+            boost = sum(1 for kw in keywords if kw in content_lower) * 0.1
+            category_boost[self._dedup_key(doc)] = boost
+        
+        enhanced_results = []
+        for doc, score in semantic_results:
+            boost = category_boost.get(self._dedup_key(doc), 0)
+            enhanced_results.append((doc, score + boost))
+        
+        enhanced_results.sort(key=lambda x: x[1], reverse=True)
+        return self._deduplicate_results(enhanced_results, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过查询路由的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 90 - 0
bdirag/rag_methods/raptor_rag.py

@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+"""RAPTOR RAG - Recursive Abstractive Processing for Tree-Organized Retrieval."""
+from .base import BaseRAG, Document
+from loguru import logger
+
+
+class RAPTORRAG(BaseRAG):
+    def __init__(self, max_tree_depth=3, cluster_size=5, **kwargs):
+        super().__init__(**kwargs)
+        self.max_tree_depth = max_tree_depth
+        self.cluster_size = cluster_size
+        self.tree_nodes = []
+
+    def _cluster_documents(self, docs, k):
+        if len(docs) <= k:
+            return [docs]
+        
+        embeddings = []
+        for doc in docs:
+            emb = self.embedding_model.embed_query(doc.page_content[:200])
+            embeddings.append(emb)
+        
+        import numpy as np
+        embeddings = np.array(embeddings)
+        
+        from sklearn.cluster import KMeans
+        k = min(k, len(docs))
+        kmeans = KMeans(n_clusters=k, random_state=42)
+        labels = kmeans.fit_predict(embeddings)
+        
+        clusters = {}
+        for i, label in enumerate(labels):
+            if label not in clusters:
+                clusters[label] = []
+            clusters[label].append(docs[i])
+        
+        return list(clusters.values())
+
+    def _summarize_cluster(self, docs):
+        texts = "\n".join([doc.page_content[:300] for doc in docs[:5]])
+        prompt = (
+            "总结以下文档的核心主题,用一段话概括。\n\n"
+            "文档内容:\n{}\n\n"
+            "核心主题:".format(texts)
+        )
+        summary = self._call_llm(prompt)
+        return summary
+
+    def build_tree(self, documents):
+        self.tree_nodes = []
+        current_level = documents
+        
+        for depth in range(self.max_tree_depth):
+            logger.info("RAPTOR building tree level {}".format(depth))
+            
+            clusters = self._cluster_documents(current_level, len(current_level) // self.cluster_size + 1)
+            next_level = []
+            
+            for cluster in clusters:
+                summary = self._summarize_cluster(cluster)
+                summary_doc = Document(
+                    page_content=summary,
+                    metadata={"level": depth, "num_children": len(cluster)}
+                )
+                next_level.append(summary_doc)
+            
+            self.tree_nodes.extend(current_level)
+            current_level = next_level
+            
+            if len(current_level) <= 1:
+                break
+        
+        self.tree_nodes.extend(current_level)
+        
+        texts = [doc.page_content for doc in self.tree_nodes]
+        embeddings = self.embedding_model.embed_documents(texts)
+        self.vector_store.add_documents(self.tree_nodes, embeddings)
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下树状组织的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 38 - 0
bdirag/rag_methods/rerank_rag.py

@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+"""Rerank RAG - semantic retrieval with LLM reranking."""
+from .base import BaseRAG
+
+
+class RerankRAG(BaseRAG):
+    def __init__(self, rerank_model=None, rerank_top_k=5, **kwargs):
+        super().__init__(**kwargs)
+        self.rerank_model = rerank_model
+        self.rerank_top_k = rerank_top_k
+        self.initial_k = 20
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        initial_docs = self.vector_store.similarity_search(query_embedding, self.initial_k)
+
+        if self.rerank_model and len(initial_docs) > 0:
+            texts = [doc.page_content for doc, _ in initial_docs]
+            pairs = [(query, text) for text in texts]
+            scores = self.rerank_model.compute_score(pairs)
+
+            if isinstance(scores, (int, float)):
+                scores = [scores]
+
+            reranked = list(zip(initial_docs, scores))
+            reranked.sort(key=lambda x: x[1], reverse=True)
+            return self._deduplicate_results([(doc, float(score)) for (doc, _), score in reranked], k)
+
+        return self._deduplicate_results(initial_docs, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下经过重排序的参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 80 - 0
bdirag/rag_methods/self_rag.py

@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+"""Self-RAG - self-reflection on retrieval and generation quality."""
+from .base import BaseRAG
+from loguru import logger
+
+
+class SelfRAG(BaseRAG):
+    def __init__(self, reflection_threshold=0.5, max_reflections=2, **kwargs):
+        super().__init__(**kwargs)
+        self.reflection_threshold = reflection_threshold
+        self.max_reflections = max_reflections
+
+    def _assess_retrieval(self, query, docs):
+        context = "\n\n".join([doc.page_content[:200] for doc, _ in docs[:3]])
+        prompt = (
+            "评估以下检索到的文档对于回答问题的相关性和完整性。给出0-1的分数。\n\n"
+            "问题:{}\n\n"
+            "检索到的文档(前3个):\n{}\n\n"
+            "相关性分数:".format(query, context)
+        )
+        try:
+            response = self._call_llm(prompt).strip()
+            return float(response)
+        except:
+            return 0.5
+
+    def _assess_generation(self, query, answer):
+        prompt = (
+            "评估以下答案对于问题的准确性和完整性。给出0-1的分数。\n\n"
+            "问题:{}\n\n"
+            "答案:\n{}\n\n"
+            "质量分数:".format(query, answer[:300])
+        )
+        try:
+            response = self._call_llm(prompt).strip()
+            return float(response)
+        except:
+            return 0.5
+
+    def _generate_critique(self, query, answer):
+        prompt = (
+            "以下答案有哪些不足之处?请指出需要改进的地方。\n\n"
+            "问题:{}\n\n"
+            "答案:\n{}\n\n"
+            "不足:".format(query, answer[:300])
+        )
+        return self._call_llm(prompt)
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        initial_results = self.vector_store.similarity_search(query_embedding, k)
+        
+        relevance_score = self._assess_retrieval(query, initial_results)
+        logger.info("SelfRAG retrieval relevance: {:.3f}".format(relevance_score))
+        
+        return self._deduplicate_results(initial_results, k)
+
+    def generate(self, query, context):
+        answer = self._call_llm(
+            "根据以下参考文档,回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        
+        quality_score = self._assess_generation(query, answer)
+        logger.info("SelfRAG generation quality: {:.3f}".format(quality_score))
+        
+        if quality_score < self.reflection_threshold:
+            critique = self._generate_critique(query, answer)
+            refined_prompt = (
+                "之前的答案有以下不足:{}\n\n"
+                "请重新回答以下问题,弥补这些不足。\n\n"
+                "参考文档:\n{}\n\n"
+                "问题:{}\n\n"
+                "改进后的答案:".format(critique, context, query)
+            )
+            answer = self._call_llm(refined_prompt)
+        
+        return answer

+ 37 - 0
bdirag/rag_methods/step_back_rag.py

@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+"""Step-Back RAG - ask a more general question for context, then specific."""
+from .base import BaseRAG
+
+
+class StepBackRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def _generate_step_back_query(self, query):
+        prompt = (
+            "针对以下具体问题,提出一个更宏观、更通用的背景问题。\n\n"
+            "具体问题:{}\n\n"
+            "宏观问题:".format(query)
+        )
+        return self._call_llm(prompt)
+
+    def retrieve(self, query, k=10):
+        step_back_query = self._generate_step_back_query(query)
+        
+        q1_embedding = self.embedding_model.embed_query(query)
+        specific_results = self.vector_store.similarity_search(q1_embedding, k // 2)
+        
+        q2_embedding = self.embedding_model.embed_query(step_back_query)
+        general_results = self.vector_store.similarity_search(q2_embedding, k // 2)
+        
+        combined = specific_results + general_results
+        return self._deduplicate_results(combined, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(宏观+具体检索),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 88 - 0
bdirag/rag_methods/table_aware_rag.py

@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+"""Table-Aware RAG - specialized retrieval for tabular data in bidding documents."""
+import re
+from .base import BaseRAG
+from loguru import logger
+
+
+class TableAwareRAG(BaseRAG):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self.table_pattern = re.compile(r'\|.*\|.*\|')
+
+    def _detect_tables(self, text):
+        lines = text.split('\n')
+        tables = []
+        current_table = []
+        
+        for line in lines:
+            if self.table_pattern.search(line):
+                current_table.append(line)
+            else:
+                if current_table:
+                    tables.append('\n'.join(current_table))
+                    current_table = []
+        
+        if current_table:
+            tables.append('\n'.join(current_table))
+        
+        return tables
+
+    def index_documents(self, documents):
+        table_docs = []
+        text_docs = []
+        
+        for doc in documents:
+            tables = self._detect_tables(doc.page_content)
+            if tables:
+                for i, table in enumerate(tables):
+                    table_doc = type(doc)(
+                        page_content=table,
+                        metadata=dict(doc.metadata, is_table=True, table_index=i)
+                    )
+                    table_docs.append(table_doc)
+                
+                table_text = '\n'.join(tables)
+                text_only = doc.page_content.replace(table_text, '')
+                if text_only.strip():
+                    text_doc = type(doc)(
+                        page_content=text_only,
+                        metadata=dict(doc.metadata, is_table=False)
+                    )
+                    text_docs.append(text_doc)
+            else:
+                text_docs.append(doc)
+        
+        all_docs = text_docs + table_docs
+        texts = [doc.page_content for doc in all_docs]
+        embeddings = self.embedding_model.embed_documents(texts)
+        self.vector_store.add_documents(all_docs, embeddings)
+        
+        logger.info("TableAwareRAG indexed {} text docs and {} tables".format(len(text_docs), len(table_docs)))
+
+    def retrieve(self, query, k=10):
+        query_embedding = self.embedding_model.embed_query(query)
+        results = self.vector_store.similarity_search(query_embedding, k * 2)
+        
+        table_results = []
+        text_results = []
+        
+        for doc, score in results:
+            if doc.metadata.get('is_table'):
+                table_results.append((doc, score * 1.1))
+            else:
+                text_results.append((doc, score))
+        
+        combined = table_results + text_results
+        combined.sort(key=lambda x: x[1], reverse=True)
+        
+        return self._deduplicate_results(combined, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(包含表格数据),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答,如果涉及表格数据,请准确提取相关数值。".format(context, query)
+        )
+        return self._call_llm(prompt)

+ 76 - 0
bdirag/rag_methods/tfidf_rag.py

@@ -0,0 +1,76 @@
+# -*- coding: utf-8 -*-
+"""TF-IDF RAG - term frequency-inverse document frequency retrieval."""
+from .base import BaseRAG, RAGResult
+from loguru import logger
+
+
+class TFIDFRAG(BaseRAG):
+    def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
+        self.embedding_model = None
+        self.vector_store = None
+        self.llm_client = llm_client
+        self.llm_model = llm_model
+        self.name = self.__class__.__name__
+        self.vectorizer = None
+        self.tfidf_matrix = None
+        self._all_documents = []
+
+    def index_documents(self, documents):
+        self._all_documents = documents
+        texts = [doc.page_content for doc in documents]
+        from sklearn.feature_extraction.text import TfidfVectorizer
+        self.vectorizer = TfidfVectorizer()
+        self.tfidf_matrix = self.vectorizer.fit_transform(texts)
+        logger.info("TF-IDF index built with {} documents, vocab size: {}".format(len(documents), len(self.vectorizer.vocabulary_)))
+
+    def retrieve(self, query, k=10):
+        if self.tfidf_matrix is None:
+            return []
+        
+        from sklearn.metrics.pairwise import cosine_similarity
+        query_vec = self.vectorizer.transform([query])
+        scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
+
+        import numpy as np
+        top_indices = np.argsort(scores)[::-1]
+
+        results = []
+        for idx in top_indices:
+            if scores[idx] > 0:
+                results.append((self._all_documents[idx], float(scores[idx])))
+
+        return self._deduplicate_results(results, k)
+
+    def generate(self, query, context):
+        prompt = (
+            "根据以下参考文档(TF-IDF关键词检索),回答问题。\n\n"
+            "参考文档:\n{}\n\n"
+            "问题:{}\n\n"
+            "请详细回答。".format(context, query)
+        )
+        return self._call_llm(prompt)
+
+    def query(self, query, k=10):
+        import time
+        start_total = time.time()
+        
+        t0 = time.time()
+        docs = self.retrieve(query, k)
+        retrieval_time = time.time() - t0
+
+        context = self._format_context(docs)
+
+        t1 = time.time()
+        answer = self.generate(query, context)
+        generation_time = time.time() - t1
+
+        total_time = time.time() - start_total
+
+        return RAGResult(
+            answer=answer,
+            retrieved_docs=docs,
+            latency_retrieval=retrieval_time,
+            latency_generation=generation_time,
+            latency_total=total_time,
+            metadata={"method": self.name, "num_context_tokens": len(context)},
+        )

+ 90 - 0
bdirag/rag_methods/tokenization.py

@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+"""Tokenization helpers for keyword-based retrieval."""
+import re
+import unicodedata
+
+try:
+    import jieba
+except ImportError:  # pragma: no cover - used only in minimal installs.
+    jieba = None
+
+
+_TOKEN_PATTERN = re.compile(
+    r"[a-z0-9]+(?:[-_./][a-z0-9]+)*|[\u4e00-\u9fff]+",
+    re.IGNORECASE,
+)
+
+
+_PHRASE_ALIASES = {
+    "预算金额": ["project", "budget", "amount"],
+    "项目预算": ["project", "budget"],
+    "采购预算": ["procurement", "budget"],
+    "最高限价": ["price", "ceiling", "budget"],
+    "预算": ["budget"],
+    "投标保证金": ["bid", "bond", "amount"],
+    "履约保证金": ["performance", "bond", "amount"],
+    "资格要求": ["qualification", "requirements"],
+    "资质要求": ["qualification", "requirements"],
+    "评标方法": ["evaluation", "method"],
+    "评审方法": ["evaluation", "method"],
+    "综合评分": ["comprehensive", "scoring", "method", "evaluation"],
+    "综合评估": ["comprehensive", "evaluation", "method"],
+    "质保期": ["warranty", "period"],
+    "保修期": ["warranty", "period"],
+    "三年": ["3", "years"],
+    "付款方式": ["payment", "terms"],
+    "支付方式": ["payment", "terms"],
+    "项目编号": ["project", "code"],
+    "采购编号": ["project", "code"],
+    "招标编号": ["project", "code"],
+    "交货时间": ["delivery", "time"],
+    "交付时间": ["delivery", "time"],
+    "投标截止": ["bid", "submission", "deadline"],
+    "开标时间": ["bid", "opening", "time"],
+    "采购人": ["purchaser"],
+    "招标代理": ["agency"],
+    "代理机构": ["agency"],
+    "联系方式": ["contact", "phone"],
+}
+
+
+def bm25_tokenize(text):
+    """
+    Tokenize mixed Chinese/English/number text for BM25.
+
+    Bidding documents commonly mix Chinese labels, English labels, project
+    codes and amounts.  The tokenizer keeps exact lexical tokens, splits
+    compound identifiers such as ``XX-ZB-2024-001``, adds Chinese bigrams, and
+    expands common bidding-domain phrases to English aliases so Chinese field
+    queries can recall English or bilingual source text.
+    """
+    if text is None:
+        return []
+
+    normalized = unicodedata.normalize("NFKC", str(text)).lower()
+    normalized = re.sub(r"(?<=\d),(?=\d)", "", normalized)
+    tokens = []
+
+    for match in _TOKEN_PATTERN.finditer(normalized):
+        raw = match.group(0)
+        if not raw:
+            continue
+
+        if re.search(r"[\u4e00-\u9fff]", raw):
+            if jieba is not None:
+                tokens.extend(t for t in jieba.cut(raw) if t.strip())
+            else:
+                tokens.append(raw)
+            if len(raw) > 1:
+                tokens.extend(raw[i:i + 2] for i in range(len(raw) - 1))
+            continue
+
+        tokens.append(raw)
+        if re.search(r"[-_./]", raw):
+            tokens.extend(part for part in re.split(r"[-_./]+", raw) if part)
+
+    for phrase, aliases in _PHRASE_ALIASES.items():
+        if phrase.lower() in normalized:
+            tokens.extend(aliases)
+
+    return tokens

+ 181 - 0
bdirag/vector_stores.py

@@ -0,0 +1,181 @@
+import os
+import json
+import pickle
+from abc import ABC, abstractmethod
+from typing import List, Dict, Tuple, Any, Optional
+import numpy as np
+from loguru import logger
+
+from .document_processor import Document
+
+
+class BaseVectorStore(ABC):
+    @abstractmethod
+    def add_documents(self, documents, embeddings):
+        pass
+
+    @abstractmethod
+    def similarity_search(self, query_embedding, k=10):
+        pass
+
+    @abstractmethod
+    def save(self, path):
+        pass
+
+    @abstractmethod
+    def load(self, path):
+        pass
+
+
+class FAISSStore(BaseVectorStore):
+    def __init__(self, embedding_model=None):
+        import faiss
+        self.faiss = faiss
+        self.index = None
+        self.documents = []
+        self.embedding_model = embedding_model
+        self.dimension = 0
+
+    def add_documents(self, documents, embeddings=None):
+        if embeddings is None and self.embedding_model:
+            texts = [doc.page_content for doc in documents]
+            embeddings = self.embedding_model.embed_documents(texts)
+
+        embeddings_np = np.array(embeddings, dtype=np.float32)
+        self.dimension = embeddings_np.shape[1]
+
+        if self.index is None:
+            self.index = self.faiss.IndexFlatIP(self.dimension)
+
+        self.faiss.normalize_L2(embeddings_np)
+        self.index.add(embeddings_np)
+        self.documents.extend(documents)
+        logger.info("Added {} documents to FAISS index, total: {}".format(len(documents), len(self.documents)))
+
+    def similarity_search(self, query_embedding, k=10):
+        if self.index is None:
+            return []
+        
+        k = min(k, len(self.documents))
+        query_np = np.array([query_embedding], dtype=np.float32)
+        self.faiss.normalize_L2(query_np)
+        scores, indices = self.index.search(query_np, k)
+        
+        results = []
+        for score, idx in zip(scores[0], indices[0]):
+            if idx != -1:
+                results.append((self.documents[idx], float(score)))
+        return results
+
+    def save(self, path):
+        os.makedirs(path, exist_ok=True)
+        self.faiss.write_index(self.index, os.path.join(path, "faiss.index"))
+        with open(os.path.join(path, "documents.pkl"), "wb") as f:
+            pickle.dump(self.documents, f)
+        logger.info("Saved FAISS index to {}".format(path))
+
+    def load(self, path):
+        self.index = self.faiss.read_index(os.path.join(path, "faiss.index"))
+        with open(os.path.join(path, "documents.pkl"), "rb") as f:
+            self.documents = pickle.load(f)
+        self.dimension = self.index.d
+        logger.info("Loaded FAISS index from {}, documents: {}".format(path, len(self.documents)))
+
+
+class ChromaStore(BaseVectorStore):
+    def __init__(self, embedding_model=None, collection_name="bidi_collection"):
+        import chromadb
+        self.chroma = chromadb
+        self.client = None
+        self.collection = None
+        self.collection_name = collection_name
+        self.embedding_model = embedding_model
+        self.documents = []
+
+    def _ensure_client(self):
+        if self.client is None:
+            self.client = self.chroma.Client()
+
+    def add_documents(self, documents, embeddings=None):
+        self._ensure_client()
+        
+        if embeddings is None and self.embedding_model:
+            texts = [doc.page_content for doc in documents]
+            embeddings = self.embedding_model.embed_documents(texts)
+
+        if self.collection is None:
+            if self.embedding_model:
+                self.collection = self.client.create_collection(
+                    name=self.collection_name,
+                    embedding_function=self._embedding_wrapper(),
+                )
+            else:
+                self.collection = self.client.create_collection(name=self.collection_name)
+
+        ids = ["doc_{}".format(i + len(self.documents)) for i in range(len(documents))]
+        texts = [doc.page_content for doc in documents]
+        metadatas = [doc.metadata for doc in documents]
+
+        self.collection.add(
+            ids=ids,
+            documents=texts,
+            embeddings=embeddings,
+            metadatas=metadatas,
+        )
+        self.documents.extend(documents)
+        logger.info("Added {} documents to Chroma, total: {}".format(len(documents), len(self.documents)))
+
+    def similarity_search(self, query_embedding, k=10):
+        if self.collection is None:
+            return []
+
+        k = min(k, len(self.documents))
+        results = self.collection.query(
+            query_embeddings=[query_embedding],
+            n_results=k,
+            include=["documents", "metadatas", "distances"],
+        )
+
+        docs_returned = []
+        for doc_text, metadata, distance in zip(
+            results["documents"][0], results["metadatas"][0], results["distances"][0]
+        ):
+            doc = Document(page_content=doc_text, metadata=metadata)
+            score = 1.0 - distance
+            docs_returned.append((doc, score))
+        return docs_returned
+
+    def _embedding_wrapper(self):
+        class EmbeddingFunc(object):
+            def __init__(self, model):
+                self.model = model
+
+            def __call__(self, input):
+                return self.model.embed_documents(input)
+        return EmbeddingFunc(self.embedding_model)
+
+    def save(self, path):
+        logger.info("Chroma auto-persists data, no explicit save needed for path: {}".format(path))
+
+    def load(self, path):
+        self._ensure_client()
+        try:
+            self.collection = self.client.get_collection(name=self.collection_name)
+            self.documents = [
+                Document(page_content="", metadata={}) 
+                for _ in range(len(self.collection.get()["ids"]))
+            ]
+            logger.info("Loaded Chroma collection: {}".format(self.collection_name))
+        except Exception as e:
+            logger.warning("Chroma collection not found: {}".format(e))
+
+
+def get_vector_store(store_type="faiss", **kwargs):
+    store_map = {
+        "faiss": FAISSStore,
+        "chroma": ChromaStore,
+    }
+    cls = store_map.get(store_type.lower())
+    if not cls:
+        raise ValueError("Unknown vector store type: {}".format(store_type))
+    return cls(**kwargs)

+ 46 - 0
convert_unicode_to_chinese.py

@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+"""Convert unicode escape sequences to Chinese characters in source files."""
+import re
+import os
+import codecs
+
+BASE_DIR = r'f:\Workspace2016\BidiRAG'
+
+def decode_unicode_in_file(filepath):
+    """Convert all \\uXXXX sequences in a file to actual Unicode characters."""
+    with open(filepath, 'r', encoding='utf-8') as f:
+        content = f.read()
+    
+    if not re.search(r'\\u[0-9a-fA-F]{4}', content):
+        return False
+    
+    def replace_unicode(match):
+        try:
+            return chr(int(match.group(0)[2:], 16))
+        except:
+            return match.group(0)
+    
+    new_content = re.sub(r'\\u[0-9a-fA-F]{4}', replace_unicode, content)
+    
+    with open(filepath, 'w', encoding='utf-8') as f:
+        f.write(new_content)
+    
+    return True
+
+files_to_convert = [
+    os.path.join(BASE_DIR, 'bdirag', 'rag_methods.py'),
+    os.path.join(BASE_DIR, 'bdirag', 'config.py'),
+    os.path.join(BASE_DIR, 'examples', 'test_bm25.py'),
+]
+
+for filepath in files_to_convert:
+    if os.path.exists(filepath):
+        changed = decode_unicode_in_file(filepath)
+        if changed:
+            print("Converted: {}".format(filepath))
+        else:
+            print("No changes needed: {}".format(filepath))
+    else:
+        print("File not found: {}".format(filepath))
+
+print("Done!")

二进制
doubao-page.png


+ 1 - 0
examples/__init__.py

@@ -0,0 +1 @@
+# Examples directory

+ 371 - 0
examples/benchmark_all_methods.py

@@ -0,0 +1,371 @@
+"""
+Complete benchmark script for comparing all RAG methods
+This script demonstrates various RAG methods and compares their performance
+"""
+
+import sys
+import os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from openai import OpenAI
+from FlagEmbedding import FlagReranker
+
+from bdirag.document_processor import Document, DocumentProcessor
+from bdirag.embedding_models import SentenceTransformerEmbedding
+from bdirag.vector_stores import FAISSStore
+from bdirag.rag_methods import (
+    NaiveRAG,
+    RerankRAG,
+    HybridSearchRAG,
+    MultiQueryRAG,
+    HyDERAG,
+    SelfRAG,
+    CorrectiveRAG,
+    FLARERAG,
+    RAPTORRAG,
+    BidFieldExtractionRAG,
+    TableAwareRAG,
+    EnsembleRAG,
+    GraphRAG,
+    StepBackRAG,
+    ContextualCompressionRAG,
+    BM25RAG,
+    TFIDFRAG,
+    KeywordRAG,
+)
+from bdirag.benchmark import RAGBenchmark
+from examples.sample_data import SAMPLE_BIDDING_DOCS
+
+
+def setup_rag_methods(embedding_model, vector_store, llm_client, llm_model="gpt-4o"):
+    """Initialize all RAG methods with shared components"""
+    rerank_model = FlagReranker("BAAI/bge-reranker-large", use_fp16=True)
+
+    methods = {}
+
+    # Basic RAG Methods
+    methods["NaiveRAG"] = NaiveRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+    )
+
+    methods["RerankRAG"] = RerankRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+        rerank_model=rerank_model,
+        rerank_top_k=5,
+    )
+
+    methods["HybridSearchRAG"] = HybridSearchRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+        semantic_weight=0.5,
+    )
+
+    # Advanced RAG Methods
+    methods["MultiQueryRAG"] = MultiQueryRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+        num_queries=3,
+    )
+
+    methods["HyDERAG"] = HyDERAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+        num_hypotheses=3,
+    )
+
+    methods["SelfRAG"] = SelfRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+        relevance_threshold=0.5,
+        support_threshold=0.5,
+    )
+
+    methods["CorrectiveRAG"] = CorrectiveRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+        correctness_threshold=0.6,
+    )
+
+    methods["FLARERAG"] = FLARERAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+        max_iterations=3,
+    )
+
+    methods["RAPTORRAG"] = RAPTORRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+        max_clusters=10,
+        summary_length=256,
+        num_tree_levels=2,
+    )
+
+    methods["StepBackRAG"] = StepBackRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+    )
+
+    methods["ContextualCompressionRAG"] = ContextualCompressionRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        compression_llm=llm_client,
+        llm_model=llm_model,
+    )
+
+    methods["EnsembleRAG"] = EnsembleRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+        methods=["naive", "hybrid", "multi_query"],
+    )
+
+    # Bidding-specific RAG Methods
+    methods["BidFieldExtractionRAG"] = BidFieldExtractionRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+    )
+
+    methods["TableAwareRAG"] = TableAwareRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+    )
+
+    methods["GraphRAG"] = GraphRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+    )
+
+    # Keyword-based RAG Methods (BM25 / TF-IDF)
+    methods["BM25RAG"] = BM25RAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+    )
+
+    methods["TFIDFRAG"] = TFIDFRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+    )
+
+    methods["KeywordRAG_BM25"] = KeywordRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+        search_method="bm25",
+    )
+
+    methods["KeywordRAG_TFIDF"] = KeywordRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=llm_model,
+        search_method="tfidf",
+    )
+
+    return methods
+
+
+def main():
+    print("=" * 60)
+    print("BidiRAG - RAG Methods Benchmark for Bidding Domain")
+    print("=" * 60)
+
+    # Configuration
+    LLM_API_KEY = os.getenv("OPENAI_API_KEY", "your-api-key-here")
+    LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
+    LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
+    EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL", "BAAI/bge-large-zh-v1.5")
+
+    # Step 1: Initialize embedding model
+    print("\n[1/5] Loading embedding model...")
+    embedding_model = SentenceTransformerEmbedding(
+        model_name=EMBEDDING_MODEL_NAME,
+        device="cpu"
+    )
+    print("  Embedding dimension: {0}.format(embedding_model.dimension)")
+
+    # Step 2: Initialize vector store
+    print("\n[2/5] Initializing vector store...")
+    vector_store = FAISSStore(embedding_model=embedding_model)
+
+    # Step 3: Process documents
+    print("\n[3/5] Processing sample bidding documents...")
+    documents = [
+        Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
+        for doc in SAMPLE_BIDDING_DOCS
+    ]
+    print("  Loaded {0} documents.format(len(documents))")
+
+    # Step 4: Index documents
+    print("\n[4/5] Indexing documents...")
+    for method_name, method in setup_rag_methods(embedding_model, vector_store, None, LLM_MODEL).items():
+        if method_name == "RAPTORRAG":
+            method.build_tree(documents)
+        elif method_name == "GraphRAG":
+            method.build_graph(documents)
+        else:
+            method.index_documents(documents)
+    print(f"  Indexing complete")
+
+    # Step 5: Initialize LLM client
+    print("\n[5/5] Initializing LLM client...")
+    llm_client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL)
+
+    # Setup RAG methods
+    print("\nSetting up RAG methods...")
+    methods = setup_rag_methods(embedding_model, vector_store, llm_client, LLM_MODEL)
+
+    # Define test queries
+    test_queries = [
+        "XX City Smart Transportation Project budget and deadline?",
+        "What are the qualification requirements for the hospital equipment procurement?",
+        "What is the warranty period for the university network project?",
+        "List all bid bond amounts in the announcements",
+        "What evaluation methods are used across different projects?",
+        "XX Road construction project payment terms?",
+        "Environmental monitoring system equipment list?",
+        "Which projects require Grade I qualification?",
+    ]
+
+    print("\nRunning benchmark with {0} queries across {1} methods....format(len(test_queries), len(methods))")
+    print("=" * 60)
+
+    # Run benchmark
+    benchmark = RAGBenchmark()
+
+    # Run methods one by one to avoid overwhelming the LLM API
+    results_summary = {}
+    for method_name, method in methods.items():
+        print("\n{0}.format('=' * 60)")
+        print("Testing: {0}.format(method_name)")
+        print("{0}.format('=' * 60)")
+
+        method_results = []
+        for i, query in enumerate(test_queries):
+            print("\n  Query {0}/{1}: {2}.format(i + 1, len(test_queries), query)")
+            try:
+                result = method.query(query, k=5)
+                method_results.append(result)
+                print("    Answer: {0}....format(result.answer[:100])")
+                print("    Total Latency: {0}s.format(result.latency_total:.3f)")
+                print("    Retrieval: {0}s | Generation: {1}s.format(result.latency_retrieval:.3f, result.latency_generation:.3f)")
+            except Exception as e:
+                print("    ERROR: {0}.format(e)")
+
+        results_summary[method_name] = method_results
+
+    # Generate comparison report
+    print("\n\n" + "=" * 60)
+    print("BENCHMARK RESULTS")
+    print("=" * 60)
+
+    # Create metrics manually from results
+    from bdirag.benchmark import BenchmarkMetrics, BenchmarkResult
+    import numpy as np
+    import datetime
+
+    metrics = []
+    for method_name, results in results_summary.items():
+        if not results:
+            continue
+
+        total_latencies = [r.latency_total for r in results]
+        retrieval_latencies = [r.latency_retrieval for r in results]
+        generation_latencies = [r.latency_generation for r in results]
+        num_docs = [len(r.retrieved_docs) for r in results]
+
+        metric = BenchmarkMetrics(
+            method_name=method_name,
+            avg_latency_total=np.mean(total_latencies),
+            avg_latency_retrieval=np.mean(retrieval_latencies),
+            avg_latency_generation=np.mean(generation_latencies),
+            avg_docs_retrieved=np.mean(num_docs),
+            total_queries=len(results),
+            latency_std=np.std(total_latencies),
+            retrieval_std=np.std(retrieval_latencies),
+            generation_std=np.std(generation_latencies),
+            min_latency=np.min(total_latencies),
+            max_latency=np.max(total_latencies),
+            p50_latency=np.percentile(total_latencies, 50),
+            p95_latency=np.percentile(total_latencies, 95),
+        )
+        metrics.append(metric)
+
+    # Print comparison table
+    print("\nLatency Comparison (sorted by average total latency):")
+    print("-" * 100)
+    print("{0} {1} {2} {3} {4} {5} {6} {7}.format('Method':<30, 'Avg Total':>10, 'Avg Retri':>10, 'Avg Gener':>10, 'P50':>8, 'P95':>8, 'Min':>8, 'Max':>8)")
+    print("-" * 100)
+
+    for m in sorted(metrics, key=lambda x: x.avg_latency_total):
+        print(
+            "{0} {1} {2} .format(m.method_name:<30, m.avg_latency_total:>10.3f, m.avg_latency_retrieval:>10.3f)"
+            "{0} {1} {2} .format(m.avg_latency_generation:>10.3f, m.p50_latency:>8.3f, m.p95_latency:>8.3f)"
+            "{0} {1}.format(m.min_latency:>8.3f, m.max_latency:>8.3f)"
+        )
+
+    print("-" * 100)
+    print("\nFastest Method: {0}.format(min(metrics, key=lambda x: x.avg_latency_total).method_name)")
+    print("Most Stable: {0}.format(min(metrics, key=lambda x: x.latency_std).method_name)")
+
+    # Save results
+    output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "output")
+    os.makedirs(output_dir, exist_ok=True)
+
+    benchmark_result = BenchmarkResult(
+        metrics=metrics,
+        detailed_results={},
+        timestamp=datetime.datetime.now().isoformat(),
+    )
+    benchmark_result.save(os.path.join(output_dir, "benchmark_results.json"))
+    benchmark.generate_report(benchmark_result, os.path.join(output_dir, "benchmark_report.md"))
+
+    try:
+        benchmark.plot_comparison(
+            metrics,
+            save_path=os.path.join(output_dir, "benchmark_comparison.png"),
+            show=False,
+        )
+    except Exception as e:
+        print("\nNote: Could not generate plot: {0}.format(e)")
+
+    print("\nResults saved to {0}.format(output_dir)")
+    print("\nBenchmark complete!")
+
+
+if __name__ == "__main__":
+    main()

+ 189 - 0
examples/benchmark_retrieval_speed.py

@@ -0,0 +1,189 @@
+"""
+Speed-focused benchmark script - compares retrieval speed without LLM generation
+Use this to quickly compare the performance of different retrieval methods
+"""
+
+import sys
+import os
+import time
+import numpy as np
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from rich.console import Console
+from rich.table import Table
+
+from bdirag.document_processor import Document
+from bdirag.embedding_models import SentenceTransformerEmbedding
+from bdirag.vector_stores import FAISSStore
+from bdirag.rag_methods import (
+    NaiveRAG,
+    HybridSearchRAG,
+    MultiQueryRAG,
+    HyDERAG,
+    SelfRAG,
+    StepBackRAG,
+    BidFieldExtractionRAG,
+    TableAwareRAG,
+    EnsembleRAG,
+    GraphRAG,
+    BM25RAG,
+    TFIDFRAG,
+    KeywordRAG,
+)
+from examples.sample_data import SAMPLE_BIDDING_DOCS
+
+
+class RetrievalBenchmark:
+    def __init__(self):
+        self.results = {}
+
+    def run(self, methods, queries, iterations=3):
+        for method_name, method in methods.items():
+            print("\nBenchmarking {0}....format(method_name)")
+            latencies = []
+
+            for query in queries:
+                query_latencies = []
+                for _ in range(iterations):
+                    start = time.time()
+                    docs = method.retrieve(query, k=5)
+                    elapsed = time.time() - start
+                    query_latencies.append(elapsed)
+
+                avg_latency = np.mean(query_latencies)
+                latencies.append(avg_latency)
+                print("  Query: {0}... -> {1}s.format(query[:50], avg_latency:.3f)")
+
+            self.results[method_name] = {
+                "latencies": latencies,
+                "avg": np.mean(latencies),
+                "std": np.std(latencies),
+                "min": np.min(latencies),
+                "max": np.max(latencies),
+                "p50": np.percentile(latencies, 50),
+                "p95": np.percentile(latencies, 95),
+            }
+
+        self.print_results()
+
+    def print_results(self):
+        console = Console()
+        table = Table(title="Retrieval Speed Comparison")
+
+        table.add_column("Method", style="cyan")
+        table.add_column("Avg (s)", justify="right", style="green")
+        table.add_column("Std (s)", justify="right", style="green")
+        table.add_column("Min (s)", justify="right", style="yellow")
+        table.add_column("Max (s)", justify="right", style="yellow")
+        table.add_column("P50 (s)", justify="right", style="magenta")
+        table.add_column("P95 (s)", justify="right", style="magenta")
+        table.add_column("QPS", justify="right", style="blue")
+
+        for name in sorted(self.results.keys(), key=lambda x: self.results[x]["avg"]):
+            r = self.results[name]
+            table.add_row(
+                name,
+                "{0}.format(r['avg']:.4f)",
+                "{0}.format(r['std']:.4f)",
+                "{0}.format(r['min']:.4f)",
+                "{0}.format(r['max']:.4f)",
+                "{0}.format(r['p50']:.4f)",
+                "{0}.format(r['p95']:.4f)",
+                "{0}.format(1/r['avg']:.1f)",
+            )
+
+        console.print(table)
+
+
+def main():
+    print("=" * 60)
+    print("BidiRAG - Retrieval Speed Benchmark")
+    print("=" * 60)
+
+    # Load embedding model
+    print("\nLoading embedding model...")
+    embedding_model = SentenceTransformerEmbedding(
+        model_name="BAAI/bge-large-zh-v1.5",
+        device="cpu"
+    )
+
+    # Create vector store
+    print("Creating vector store...")
+    vector_store = FAISSStore(embedding_model=embedding_model)
+
+    # Prepare documents
+    documents = [
+        Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
+        for doc in SAMPLE_BIDDING_DOCS
+    ]
+
+    # Initialize methods (no LLM needed for pure retrieval)
+    print("Initializing retrieval methods...")
+    methods = {
+        "BM25RAG": BM25RAG(
+            embedding_model=embedding_model,
+            vector_store=vector_store,
+        ),
+        "TFIDFRAG": TFIDFRAG(
+            embedding_model=embedding_model,
+            vector_store=vector_store,
+        ),
+        "KeywordRAG_BM25": KeywordRAG(
+            embedding_model=embedding_model,
+            vector_store=vector_store,
+            search_method="bm25",
+        ),
+        "KeywordRAG_TFIDF": KeywordRAG(
+            embedding_model=embedding_model,
+            vector_store=vector_store,
+            search_method="tfidf",
+        ),
+        "NaiveRAG": NaiveRAG(
+            embedding_model=embedding_model,
+            vector_store=vector_store,
+        ),
+        "HybridSearchRAG": HybridSearchRAG(
+            embedding_model=embedding_model,
+            vector_store=vector_store,
+        ),
+        "BidFieldExtractionRAG": BidFieldExtractionRAG(
+            embedding_model=embedding_model,
+            vector_store=vector_store,
+        ),
+        "TableAwareRAG": TableAwareRAG(
+            embedding_model=embedding_model,
+            vector_store=vector_store,
+        ),
+        "EnsembleRAG": EnsembleRAG(
+            embedding_model=embedding_model,
+            vector_store=vector_store,
+        ),
+    }
+
+    # Index documents
+    print("Indexing documents...")
+    for method in methods.values():
+        method.index_documents(documents)
+
+    # Test queries
+    test_queries = [
+        "What is the project budget?",
+        "What are the qualification requirements?",
+        "When is the bid deadline?",
+        "What is the warranty period?",
+        "What are the payment terms?",
+        "What is the evaluation method?",
+        "What equipment is needed?",
+        "What is the delivery time?",
+    ]
+
+    # Run benchmark
+    print("\nRunning benchmark with {0} queries (3 iterations each)....format(len(test_queries))")
+    benchmark = RetrievalBenchmark()
+    benchmark.run(methods, test_queries, iterations=3)
+
+    print("\nBenchmark complete!")
+
+
+if __name__ == "__main__":
+    main()

+ 82 - 0
examples/bid_field_extraction_demo.py

@@ -0,0 +1,82 @@
+"""
+Bid field extraction demo - demonstrates structured information extraction
+from bidding announcements using RAG
+"""
+
+import sys
+import os
+import json
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from openai import OpenAI
+
+from bdirag.document_processor import Document
+from bdirag.embedding_models import SentenceTransformerEmbedding
+from bdirag.vector_stores import FAISSStore
+from bdirag.rag_methods import BidFieldExtractionRAG
+from examples.sample_data import SAMPLE_BIDDING_DOCS
+
+
+def main():
+    print("=" * 60)
+    print("BidiRAG - Bid Field Extraction Demo")
+    print("=" * 60)
+
+    # Configuration
+    LLM_API_KEY = os.getenv("OPENAI_API_KEY", "your-api-key-here")
+    LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
+    LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
+
+    # Load embedding model
+    print("\n[1/3] Loading embedding model...")
+    embedding_model = SentenceTransformerEmbedding(
+        model_name="BAAI/bge-large-zh-v1.5",
+        device="cpu"
+    )
+
+    # Create vector store and index
+    print("\n[2/3] Indexing bidding documents...")
+    vector_store = FAISSStore(embedding_model=embedding_model)
+
+    documents = [
+        Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
+        for doc in SAMPLE_BIDDING_DOCS
+    ]
+
+    llm_client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL)
+
+    rag = BidFieldExtractionRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=LLM_MODEL,
+    )
+
+    rag.index_documents(documents)
+    print("  Indexed {0} documents.format(len(documents))")
+
+    # Extract fields for each bidding document
+    print("\n[3/3] Extracting fields from bidding documents...")
+
+    for doc in SAMPLE_BIDDING_DOCS:
+        print("\n{0}.format('=' * 60)")
+        print("Document: {0}.format(doc['title'])")
+        print("{0}.format('=' * 60)")
+
+        query = "Extract all information from {0}.format(doc['title'])"
+
+        try:
+            result = rag.query(query, k=10)
+            print(f"\nExtracted JSON:")
+            print(result.answer)
+            print("\nLatency: {0}s.format(result.latency_total:.3f)")
+            print("Retrieved {0} document chunks.format(len(result.retrieved_docs))")
+
+        except Exception as e:
+            print("ERROR: {0}.format(e)")
+
+    print("\n\nExtraction complete!")
+
+
+if __name__ == "__main__":
+    main()

+ 83 - 0
examples/debug_bm25.py

@@ -0,0 +1,83 @@
+# -*- coding: utf-8 -*-
+"""调试 BM25RAG 的 retrieve 方法"""
+
+import sys
+import os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.document_processor import Document
+from bdirag.rag_methods.bm25_rag import BM25RAG
+
+# 创建测试文档
+doc = Document(
+    page_content="""XX市第一人民医院医疗设备招标公告
+
+项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目
+项目编号:XX-ZB-2024-001
+预算金额:500万元
+采购内容:彩色多普勒超声诊断仪 1台
+
+投标人资格要求:
+1. 具有独立承担民事责任的能力
+2. 具有有效的医疗器械经营许可证
+3. 近三年内无不良经营记录
+4. 投标保证金:人民币5万元整
+
+技术需求:
+1. 彩色多普勒超声诊断仪技术参数
+   - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头
+   - 显示屏:≥19英寸高清液晶显示器
+   - 质保期:整机质保三年
+2. 交货时间:合同签订后60天内交货
+3. 交货地点:XX市第一人民医院设备科
+
+评标方法:采用综合评分法
+   - 技术部分:60分
+   - 商务部分:30分
+   - 价格部分:10分
+
+付款方式:合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%
+
+投标截止时间:2024年12月31日上午9:30
+开标时间:同投标截止时间
+投标文件递交地点:XX市公共资源交易中心""",
+    metadata={"title": "Test"}
+)
+
+# 初始化 BM25RAG
+print("初始化 BM25RAG...")
+bm25_rag = BM25RAG()
+bm25_rag.index_documents([doc])
+print(f"索引构建完成,文档数: {len(bm25_rag._all_documents)}")
+
+# 测试查询
+test_queries = [
+    "预算金额",
+    "投标保证金",
+    "质保期",
+    "评标方法",
+]
+
+for query in test_queries:
+    print(f"\n{'='*60}")
+    print(f"查询: {query}")
+    print(f"{'='*60}")
+    
+    # 查看 tokenize 结果
+    from bdirag.rag_methods.tokenization import bm25_tokenize
+    query_tokens = bm25_tokenize(query)
+    print(f"查询分词: {query_tokens}")
+    
+    # 检查 BM25 得分
+    if bm25_rag.bm25 is not None:
+        scores = bm25_rag.bm25.get_scores(query_tokens)
+        print(f"BM25 原始得分: {scores}")
+        print(f"最大得分: {max(scores) if len(scores) > 0 else 0:.4f}")
+    
+    # 执行检索
+    results = bm25_rag.retrieve(query, k=3)
+    print(f"召回 {len(results)} 个结果")
+    
+    for i, (doc, score) in enumerate(results, 1):
+        print(f"  [{i}] Score: {score:.4f}")
+        print(f"      Content: {doc.page_content[:100]}...")

+ 45 - 0
examples/debug_bm25_html.py

@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+"""调试 BM25 HTML Tree 检索"""
+
+import sys
+import os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.bm25_html_tree_rag import BM25HTMLTreeRAG
+
+# 创建测试 HTML
+html = """
+<html>
+<body>
+    <h1>政府采购中标公告</h1>
+    <table>
+        <tr><td>采购人</td><td>XX市财政局</td></tr>
+        <tr><td>中标人</td><td>XX科技有限公司</td></tr>
+        <tr><td>中标金额</td><td>50万元</td></tr>
+    </table>
+</body>
+</html>
+"""
+
+print("创建 BM25HTMLTreeRAG 实例...")
+rag = BM25HTMLTreeRAG()
+
+print("\n构建索引...")
+rag.build_index(html)
+
+print(f"\n索引了 {len(rag.all_nodes)} 个节点")
+
+print("\n测试查询: '采购人'")
+results = rag.query("采购人", k=3)
+print(f"返回 {len(results)} 个结果")
+for i, (doc, score) in enumerate(results, 1):
+    print(f"\n结果 {i} (score: {score:.3f}):")
+    print(doc.page_content[:200])
+
+print("\n" + "=" * 80)
+print("\n测试查询: '中标人'")
+results = rag.query("中标人", k=3)
+print(f"返回 {len(results)} 个结果")
+for i, (doc, score) in enumerate(results, 1):
+    print(f"\n结果 {i} (score: {score:.3f}):")
+    print(doc.page_content[:200])

+ 198 - 0
examples/demo_bid_announcement.py

@@ -0,0 +1,198 @@
+# -*- coding: utf-8 -*-
+"""
+检索招投标公告中的招标人和中标人信息示例
+
+这个示例展示了如何使用 BidiRag 从招投标公告中提取:
+- 招标人(采购人/采购单位)
+- 中标人
+- 相关关键词
+"""
+
+import sys
+import os
+
+# 添加项目根目录到路径
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.bidi_rag import BidiRag
+
+
+def demo_bid_announcement_search():
+    """演示招投标公告检索"""
+    
+    print("=" * 80)
+    print("招投标公告信息检索示例")
+    print("=" * 80)
+    
+    # 1. 初始化 BidiRag(使用 BM25 HTML Tree 方法)
+    print("\n[步骤 1] 初始化 BidiRag...")
+    rag = BidiRag(rag_method='bm25_html_tree')
+    
+    # 2. 模拟招投标公告数据
+    print("\n[步骤 2] 添加招投标公告...")
+    
+    announcements = [
+        """<html>
+        <body>
+            <h1>政府采购意向公告</h1>
+            <table>
+                <tr><td>采购意向单位</td><td>大连长兴岛经济技术开发区交流岛街道办事处</td></tr>
+                <tr><td>采购项目名称</td><td>交流岛滨海路夜间出行照明提升工程</td></tr>
+                <tr><td>预算金额</td><td>147.060000万元</td></tr>
+                <tr><td>采购品目</td><td>路灯照明工程</td></tr>
+                <tr><td>采购需求概况</td><td>在滨海路安装太阳能路灯200盏(单排)</td></tr>
+                <tr><td>预计采购时间</td><td>2026-05</td></tr>
+            </table>
+        </body>
+        </html>""",
+        
+        """<html>
+        <body>
+            <h1>中标公告</h1>
+            <table>
+                <tr><td>采购人</td><td>XX市财政局</td></tr>
+                <tr><td>项目名称</td><td>办公设备采购项目</td></tr>
+                <tr><td>中标人</td><td>XX办公设备有限公司</td></tr>
+                <tr><td>中标金额</td><td>50万元</td></tr>
+                <tr><td>采购内容</td><td>电脑、打印机、复印机等办公设备</td></tr>
+            </table>
+        </body>
+        </html>""",
+        
+        """<html>
+        <body>
+            <h1>招标公告</h1>
+            <table>
+                <tr><td>招标人</td><td>XX市交通运输局</td></tr>
+                <tr><td>项目名称</td><td>智慧交通系统建设项目</td></tr>
+                <tr><td>项目预算</td><td>5000万元</td></tr>
+                <tr><td>招标内容</td><td>交通信号控制系统、视频监控系统</td></tr>
+                <tr><td>投标截止时间</td><td>2024-06-15</td></tr>
+            </table>
+        </body>
+        </html>""",
+        
+        """<html>
+        <body>
+            <h1>中标结果公告</h1>
+            <table>
+                <tr><td>招标人</td><td>XX市城市建设投资集团</td></tr>
+                <tr><td>中标人</td><td>XX建设工程有限公司</td></tr>
+                <tr><td>项目名称</td><td>道路建设项目</td></tr>
+                <tr><td>中标金额</td><td>18500万元</td></tr>
+                <tr><td>建设内容</td><td>道路路基、路面、桥梁工程</td></tr>
+            </table>
+        </body>
+        </html>"""
+    ]
+    
+    rag.add_texts(announcements)
+    print(f"✓ 已添加 {rag.get_document_count()} 个公告")
+    
+    # 3. 检索招标人信息
+    print("\n" + "=" * 80)
+    print("[示例 3] 检索招标人/采购人信息")
+    print("=" * 80)
+    
+    print("\n使用关键词: ['招标人', '采购人', '采购意向单位']")
+    results = rag.retrieve(
+        query="招标人信息",
+        top_k=5,
+        keywords=["招标人", "采购人", "采购意向单位", "采购单位"]
+    )
+    
+    print(f"\n找到 {len(results)} 个相关结果:\n")
+    for i, (doc, score) in enumerate(results, 1):
+        print(f"--- 结果 {i} (相关性: {score:.3f}) ---")
+        print(doc.page_content)
+        print()
+    
+    # 4. 检索中标人信息
+    print("\n" + "=" * 80)
+    print("[示例 4] 检索中标人信息")
+    print("=" * 80)
+    
+    print("\n使用关键词: ['中标人']")
+    results = rag.retrieve(
+        query="中标人信息",
+        top_k=5,
+        keywords=["中标人", "中标单位"]
+    )
+    
+    print(f"\n找到 {len(results)} 个相关结果:\n")
+    for i, (doc, score) in enumerate(results, 1):
+        print(f"--- 结果 {i} (相关性: {score:.3f}) ---")
+        print(doc.page_content)
+        print()
+    
+    # 5. 同时检索招标人和中标人
+    print("\n" + "=" * 80)
+    print("[示例 5] 同时检索招标人和中标人")
+    print("=" * 80)
+    
+    print("\n使用关键词: ['招标人', '中标人']")
+    results = rag.retrieve(
+        query="项目参与方",
+        top_k=5,
+        keywords=["招标人", "中标人"]
+    )
+    
+    print(f"\n找到 {len(results)} 个相关结果:\n")
+    for i, (doc, score) in enumerate(results, 1):
+        print(f"--- 结果 {i} (相关性: {score:.3f}) ---")
+        print(doc.page_content)
+        print()
+    
+    # 6. 关键词精确搜索
+    print("\n" + "=" * 80)
+    print("[示例 6] 关键词精确搜索")
+    print("=" * 80)
+    
+    print("\n搜索同时包含'招标人'和'中标人'的公告:")
+    results = rag.search_keywords(["招标人", "中标人"])
+    print(f"\n找到 {len(results)} 个公告:\n")
+    for i, doc in enumerate(results, 1):
+        print(f"--- 公告 {i} ---")
+        # 提取关键信息
+        content = doc.page_content
+        if "招标人" in content:
+            import re
+            match = re.search(r'招标人[::].{0,50}', content)
+            if match:
+                print(f"  {match.group(0)}")
+        if "中标人" in content:
+            match = re.search(r'中标人[::].{0,50}', content)
+            if match:
+                print(f"  {match.group(0)}")
+        print()
+    
+    # 7. 检索特定项目
+    print("\n" + "=" * 80)
+    print("[示例 7] 检索特定项目信息")
+    print("=" * 80)
+    
+    print("\n搜索'太阳能路灯'项目:")
+    results = rag.retrieve(
+        query="太阳能路灯",
+        top_k=3,
+        keywords=["路灯", "照明", "太阳能"]
+    )
+    
+    print(f"\n找到 {len(results)} 个相关结果:\n")
+    for i, (doc, score) in enumerate(results, 1):
+        print(f"--- 结果 {i} (相关性: {score:.3f}) ---")
+        print(doc.page_content[:300])
+        print()
+    
+    print("\n" + "=" * 80)
+    print("示例完成!")
+    print("=" * 80)
+    print("\n总结:")
+    print("1. 使用 bm25_html_tree 方法可以有效检索 HTML 结构的招投标公告")
+    print("2. 通过 keywords 参数可以精化检索结果")
+    print("3. search_keywords 方法可以进行精确的关键词匹配")
+    print("4. 可以从公告中提取招标人、中标人、项目信息等关键字段")
+
+
+if __name__ == "__main__":
+    demo_bid_announcement_search()

+ 276 - 0
examples/demo_bidi_rag.py

@@ -0,0 +1,276 @@
+# -*- coding: utf-8 -*-
+"""
+BidiRag 使用示例 - 演示如何使用 BidiRag 类进行 HTML 文档检索
+
+功能:
+1. 加载 HTML 文档
+2. 使用不同的 RAG 方法进行检索
+3. 支持关键词检索
+4. 支持自然语言查询
+"""
+
+import os
+from loguru import logger
+
+from bdirag.bidi_rag import BidiRag
+
+
+def demo_basic_usage():
+    """基础使用示例"""
+    print("=" * 80)
+    print("示例 1: 基础使用 - BM25 HTML Tree 方法")
+    print("=" * 80)
+    
+    # 1. 初始化 BidiRag(使用 BM25 HTML Tree 方法,适合 HTML 文档)
+    rag = BidiRag(
+        rag_method='bm25_html_tree',  # 使用 BM25 HTML Tree 方法
+        chunk_size=512,
+        chunk_overlap=50
+    )
+    
+    # 2. 添加 HTML 文档
+    html_files = [
+        "examples/sample_data.html",  # 假设的 HTML 文件
+        # "data/documents/announcement1.html",
+        # "data/documents/announcement2.html",
+    ]
+    
+    # 如果示例文件不存在,使用文本替代
+    if not any(os.path.exists(f) for f in html_files):
+        print("未找到 HTML 文件,使用示例文本...")
+        sample_texts = [
+            """
+            <html>
+            <body>
+            <h1>招标公告</h1>
+            <p>项目名称:大连长兴岛经济技术开发区交流岛街道办事处2026年5月至6月政府采购意向</p>
+            <p>采购单位:大连长兴岛经济技术开发区交流岛街道办事处</p>
+            <p>采购项目名称:交流岛滨海路夜间出行照明提升工程</p>
+            <p>预算金额:147.060000万元(人民币)</p>
+            <p>采购需求:在滨海路安装太阳能路灯200盏(单排)</p>
+            <p>预计采购时间:2026-05</p>
+            </body>
+            </html>
+            """,
+            """
+            <html>
+            <body>
+            <h1>中标公告</h1>
+            <p>项目名称:XX市智慧交通系统建设项目</p>
+            <p>采购人:XX市交通运输局</p>
+            <p>中标人:XX科技有限公司</p>
+            <p>中标金额:5000万元</p>
+            <p>项目内容:交通信号控制系统、视频监控系统、交通流量监测系统</p>
+            </body>
+            </html>
+            """
+        ]
+        rag.add_texts(sample_texts)
+    else:
+        rag.add_documents(html_files)
+    
+    print(f"\n已加载 {rag.get_document_count()} 个文档\n")
+    
+    # 3. 检索文档 - 使用关键词
+    print("-" * 80)
+    print("检索 1: 查找'招标人'和'中标人'相关信息")
+    print("-" * 80)
+    
+    results = rag.retrieve(
+        query="招标人和中标人",
+        top_k=3,
+        keywords=["招标人", "中标人", "采购人", "中标"]
+    )
+    
+    for i, (doc, score) in enumerate(results, 1):
+        print(f"\n结果 {i} (相似度: {score:.3f}):")
+        print(doc.page_content[:200])
+        print("...")
+    
+    print("\n" + "=" * 80)
+
+
+def demo_keyword_search():
+    """关键词搜索示例"""
+    print("\n示例 2: 关键词搜索")
+    print("=" * 80)
+    
+    # 初始化
+    rag = BidiRag(rag_method='bm25')
+    
+    # 添加示例数据
+    sample_docs = [
+        "招标公告:本项目招标人为XX市财政局,项目预算100万元",
+        "中标公告:中标人为XX建设有限公司,中标金额98万元",
+        "采购公告:采购单位XX医院,采购医疗设备一批",
+        "中标公告:本项目中标人为XX科技公司,中标价格50万元",
+    ]
+    rag.add_texts(sample_docs)
+    
+    # 搜索特定关键词
+    print("\n搜索关键词:'招标人'")
+    results = rag.search_keywords(["招标人"], top_k=5)
+    for i, doc in enumerate(results, 1):
+        print(f"{i}. {doc.page_content}")
+    
+    print("\n搜索关键词:'中标人'")
+    results = rag.search_keywords(["中标人"], top_k=5)
+    for i, doc in enumerate(results, 1):
+        print(f"{i}. {doc.page_content}")
+    
+    print("\n搜索关键词:'招标人' AND '中标人'")
+    results = rag.search_keywords(["招标人", "中标人"], top_k=5)
+    for i, doc in enumerate(results, 1):
+        print(f"{i}. {doc.page_content}")
+    
+    print("\n" + "=" * 80)
+
+
+def demo_different_methods():
+    """不同 RAG 方法对比"""
+    print("\n示例 3: 不同 RAG 方法对比")
+    print("=" * 80)
+    
+    # 准备相同的测试数据
+    test_texts = [
+        "招标公告:XX市政府采购项目,招标人:XX局,预算200万",
+        "中标公告:中标人XX公司,中标金额180万,项目:信息化建设",
+        "采购意向:XX学校设备采购,采购单位:XX学校,预计采购时间2026年",
+    ]
+    
+    # 测试不同的方法
+    methods = ['bm25', 'tfidf', 'keyword', 'naive']
+    
+    for method in methods:
+        print(f"\n测试方法: {method}")
+        print("-" * 40)
+        
+        try:
+            rag = BidiRag(rag_method=method)
+            rag.add_texts(test_texts)
+            
+            # 检索
+            results = rag.retrieve(
+                query="招标人信息",
+                top_k=2,
+                keywords=["招标人", "采购人"]
+            )
+            
+            for i, (doc, score) in enumerate(results, 1):
+                print(f"  {i}. (score: {score:.3f}) {doc.page_content[:100]}...")
+            
+        except Exception as e:
+            print(f"  方法 {method} 失败: {e}")
+    
+    print("\n" + "=" * 80)
+
+
+def demo_query_with_answer():
+    """完整问答示例"""
+    print("\n示例 4: 完整问答(需要 LLM)")
+    print("=" * 80)
+    
+    # 注意:这个方法需要配置 LLM
+    print("注意:此示例需要配置 LLM 客户端")
+    print("如果您有 OpenAI API key,可以这样使用:\n")
+    
+    example_code = '''
+from openai import OpenAI
+
+# 初始化 LLM 客户端
+llm_client = OpenAI(
+    api_key="your-api-key",
+    base_url="https://api.openai.com/v1"
+)
+
+# 使用 BidiRag
+rag = BidiRag(
+    rag_method='naive',
+    llm_client=llm_client,
+    llm_model='gpt-4o'
+)
+
+rag.add_texts([
+    "招标公告:招标人XX局,项目预算100万",
+    "中标公告:中标人XX公司,金额98万"
+])
+
+# 完整问答
+result = rag.query(
+    query="谁是招标人?",
+    keywords=["招标人", "采购人"]
+)
+
+print("答案:", result.answer)
+print("检索到的文档数:", len(result.retrieved_docs))
+'''
+    
+    print(example_code)
+    print("=" * 80)
+
+
+def demo_html_structure_aware():
+    """HTML 结构感知检索示例"""
+    print("\n示例 5: HTML 结构感知检索(推荐用于招投标公告)")
+    print("=" * 80)
+    
+    # BM25 HTML Tree 方法会保留 HTML 结构信息
+    rag = BidiRag(rag_method='bm25_html_tree')
+    
+    html_content = """
+    <html>
+    <body>
+        <div class="announcement">
+            <h1>政府采购中标公告</h1>
+            <table>
+                <tr><td>项目名称</td><td>办公设备采购项目</td></tr>
+                <tr><td>采购人</td><td>XX市财政局</td></tr>
+                <tr><td>中标人</td><td>XX办公设备有限公司</td></tr>
+                <tr><td>中标金额</td><td>50万元</td></tr>
+                <tr><td>采购内容</td><td>电脑、打印机、复印机等办公设备</td></tr>
+            </table>
+        </div>
+    </body>
+    </html>
+    """
+    
+    rag.add_texts([html_content])
+    
+    # 检索表格中的结构化信息
+    results = rag.retrieve(
+        query="中标人信息",
+        top_k=3,
+        keywords=["中标人", "采购人", "中标金额"]
+    )
+    
+    print(f"\n找到 {len(results)} 个相关结果:")
+    for i, (doc, score) in enumerate(results, 1):
+        print(f"\n结果 {i} (score: {score:.3f}):")
+        print(doc.page_content)
+    
+    print("\n" + "=" * 80)
+
+
+def main():
+    """运行所有示例"""
+    print("\n" + "🚀 " * 20)
+    print("BidiRag 使用示例")
+    print(" " * 20 + "\n")
+    
+    # 运行示例
+    demo_basic_usage()
+    demo_keyword_search()
+    demo_different_methods()
+    demo_query_with_answer()
+    demo_html_structure_aware()
+    
+    print("\n✅ 所有示例运行完成!")
+    print("\n提示:")
+    print("1. 对于招投标公告,推荐使用 'bm25_html_tree' 方法")
+    print("2. 可以使用 keywords 参数精化检索结果")
+    print("3. 如果需要生成答案,需要配置 LLM 客户端")
+    print("4. 可用的 RAG 方法:", BidiRag(rag_method='bm25').list_available_methods())
+
+
+if __name__ == "__main__":
+    main()

+ 129 - 0
examples/demo_bm25_retrieval.py

@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+"""
+使用 BidiRag 的 BM25 方法召回内容片段示例
+
+对比 bm25_html_tree 和 bm25 两种方法的效果
+"""
+
+import sys
+import os
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.bidi_rag import BidiRag
+
+
+def demo_bm25_retrieval():
+    """演示使用 BM25 方法召回内容片段"""
+    
+    print("=" * 80)
+    print("BidiRag - BM25 纯文本召回示例")
+    print("=" * 80)
+    
+    # 1. 初始化 BidiRag(使用 bm25 方法)
+    print("\n[步骤 1] 初始化 BidiRag (使用 bm25 方法)...")
+    rag = BidiRag(rag_method='bm25')
+    
+    # 2. 准备纯文本文档(BM25 更适合纯文本)
+    print("\n[步骤 2] 添加招投标公告(纯文本格式)...")
+    
+    sample_docs = [
+        """XX市第一人民医院医疗设备招标公告
+
+项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目
+项目编号:XX-ZB-2024-001
+预算金额:500万元
+采购内容:彩色多普勒超声诊断仪 1台
+
+投标人资格要求:
+1. 具有独立承担民事责任的能力
+2. 具有有效的医疗器械经营许可证
+3. 近三年内无不良经营记录
+4. 投标保证金:人民币5万元整
+
+技术需求:
+1. 彩色多普勒超声诊断仪技术参数
+   - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头
+   - 显示屏:≥19英寸高清液晶显示器
+   - 质保期:整机质保三年
+2. 交货时间:合同签订后60天内交货
+3. 交货地点:XX市第一人民医院设备科
+
+评标方法:采用综合评分法
+   - 技术部分:60分
+   - 商务部分:30分
+   - 价格部分:10分
+
+付款方式:合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%
+
+投标截止时间:2024年12月31日上午9:30
+开标时间:同投标截止时间
+投标文件递交地点:XX市公共资源交易中心"""
+    ]
+    
+    rag.add_texts(sample_docs)
+    print("✓ 文档已添加")
+    
+    # 3. 测试不同查询,召回内容片段
+    test_queries = [
+        ("预算金额", ["预算", "Budget"]),
+        ("投标保证金", ["保证金", "Bond"]),
+        ("技术参数 探头", ["探头", "technical"]),
+        ("评标方法 综合评分", ["评标", "综合评分"]),
+        ("质保期", ["质保", "Warranty"]),
+        ("付款方式", ["付款", "Payment"]),
+        ("交货时间", ["交货", "Delivery"]),
+    ]
+    
+    print("\n" + "=" * 80)
+    print("开始召回测试")
+    print("=" * 80)
+    
+    for query, keywords in test_queries:
+        print(f"\n{'=' * 80}")
+        print(f"查询: {query}")
+        print(f"{'-' * 80}")
+        
+        # 使用 retrieve 方法召回内容片段
+        results = rag.retrieve(
+            query=query,
+            top_k=3,
+            keywords=keywords
+        )
+        
+        print(f"召回 {len(results)} 个内容片段:\n")
+        
+        # 评估相关性
+        relevant_count = 0
+        for i, (doc, score) in enumerate(results, 1):
+            # 检查相关性
+            is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
+            if is_relevant:
+                relevant_count += 1
+                marker = "[✓]"
+            else:
+                marker = "[ ]"
+            
+            print(f"  片段 {i} {marker} (分数: {score:.4f})")
+            print(f"  内容:")
+            # 显示内容片段(前200字符)
+            content_preview = doc.page_content[:200].replace("\n", " ")
+            print(f"    {content_preview}...")
+            print()
+        
+        # 计算精确度
+        precision = relevant_count / len(results) if results else 0
+        print(f"  精确度@3: {precision:.1%}")
+    
+    print("\n" + "=" * 80)
+    print("示例完成!")
+    print("=" * 80)
+    print("\n说明:")
+    print("1. 使用 bm25 方法召回的是纯文本片段(非树节点)")
+    print("2. 片段基于文档分块,不包含层级结构信息")
+    print("3. BM25 算法擅长关键词匹配")
+    print("4. 适合快速检索和简单场景")
+
+
+if __name__ == "__main__":
+    demo_bm25_retrieval()

+ 149 - 0
examples/demo_tree_node_retrieval.py

@@ -0,0 +1,149 @@
+# -*- coding: utf-8 -*-
+"""
+使用 BidiRag 召回 HTML 树节点内容片段示例
+
+参考 test_bm25.py 的 test_bm25_html_tree() 方法
+展示如何从 HTML 树结构中召回相关节点片段
+"""
+
+import sys
+import os
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.bidi_rag import BidiRag
+
+
+def demo_tree_node_retrieval():
+    """演示如何召回 HTML 树节点片段"""
+    
+    print("=" * 80)
+    print("BidiRag - HTML 树节点片段召回示例")
+    print("=" * 80)
+    
+    # 1. 初始化 BidiRag(使用 bm25_html_tree 方法)
+    print("\n[步骤 1] 初始化 BidiRag (使用 bm25_html_tree 方法)...")
+    rag = BidiRag(rag_method='bm25_html_tree')
+    
+    # 2. 准备 HTML 文档
+    print("\n[步骤 2] 添加 HTML 公告...")
+    
+    sample_html = """
+    <html>
+    <body>
+        <h1>XX市第一人民医院医疗设备招标公告</h1>
+        <div>
+            <h2>一、项目概况</h2>
+            <p>项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目</p>
+            <p>项目编号:XX-ZB-2024-001</p>
+            <p>预算金额:500万元</p>
+            <p>采购内容:彩色多普勒超声诊断仪 1台</p>
+        </div>
+        <div>
+            <h2>二、投标人资格要求</h2>
+            <p>1. 具有独立承担民事责任的能力</p>
+            <p>2. 具有有效的医疗器械经营许可证</p>
+            <p>3. 近三年内无不良经营记录</p>
+            <p>4. 投标保证金:人民币5万元整</p>
+        </div>
+        <div>
+            <h2>三、技术需求</h2>
+            <p>1. 彩色多普勒超声诊断仪技术参数</p>
+            <p>   - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头</p>
+            <p>   - 显示屏:≥19英寸高清液晶显示器</p>
+            <p>   - 质保期:整机质保三年</p>
+            <p>2. 交货时间:合同签订后60天内交货</p>
+            <p>3. 交货地点:XX市第一人民医院设备科</p>
+        </div>
+        <div>
+            <h2>四、评标方法</h2>
+            <p>采用综合评分法:</p>
+            <p>   - 技术部分:60分</p>
+            <p>   - 商务部分:30分</p>
+            <p>   - 价格部分:10分</p>
+        </div>
+        <div>
+            <h2>五、付款方式</h2>
+            <p>合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%</p>
+        </div>
+        <div>
+            <h2>六、投标截止时间</h2>
+            <p>投标截止时间:2024年12月31日上午9:30</p>
+            <p>开标时间:同投标截止时间</p>
+            <p>投标文件递交地点:XX市公共资源交易中心</p>
+        </div>
+    </body>
+    </html>
+    """
+    
+    rag.add_texts([sample_html])
+    print("✓ HTML 文档已添加")
+    
+    # 3. 测试不同查询,召回树节点片段
+    test_queries = [
+        ("预算金额", ["预算", "Budget"]),
+        ("投标保证金", ["保证金", "Bond"]),
+        ("技术参数 探头", ["探头", "technical"]),
+        ("评标方法 综合评分", ["评标", "综合评分"]),
+        ("质保期", ["质保", "Warranty"]),
+        ("付款方式", ["付款", "Payment"]),
+        ("交货时间", ["交货", "Delivery"]),
+    ]
+    
+    print("\n" + "=" * 80)
+    print("开始召回测试")
+    print("=" * 80)
+    
+    for query, keywords in test_queries:
+        print(f"\n{'=' * 80}")
+        print(f"查询: {query}")
+        print(f"{'-' * 80}")
+        
+        # 使用 retrieve 方法召回树节点片段
+        results = rag.retrieve(
+            query=query,
+            top_k=3,
+            keywords=keywords
+        )
+        
+        print(f"召回 {len(results)} 个树节点片段:\n")
+        
+        # 评估相关性
+        relevant_count = 0
+        for i, (doc, score) in enumerate(results, 1):
+            # 从 metadata 中获取路径信息
+            path = doc.metadata.get("path", "")
+            title = doc.metadata.get("title", "")
+            
+            # 检查相关性
+            is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
+            if is_relevant:
+                relevant_count += 1
+                marker = "[✓]"
+            else:
+                marker = "[ ]"
+            
+            print(f"  片段 {i} {marker} (分数: {score:.4f})")
+            print(f"  路径: {path}")
+            print(f"  内容:")
+            # 显示内容片段(前200字符)
+            content_preview = doc.page_content[:200].replace("\n", " ")
+            print(f"    {content_preview}...")
+            print()
+        
+        # 计算精确度
+        precision = relevant_count / len(results) if results else 0
+        print(f"  精确度@3: {precision:.1%}")
+    
+    print("\n" + "=" * 80)
+    print("示例完成!")
+    print("=" * 80)
+    print("\n说明:")
+    print("1. 使用 bm25_html_tree 方法可以召回 HTML 树结构的节点片段")
+    print("2. 每个片段包含完整的路径信息(parent path)")
+    print("3. 内容片段保留了层级结构信息")
+    print("4. 通过 keywords 参数可以过滤和精化召回结果")
+
+
+if __name__ == "__main__":
+    demo_tree_node_retrieval()

+ 282 - 0
examples/extract_bid_info.py

@@ -0,0 +1,282 @@
+# -*- coding: utf-8 -*-
+"""
+从招投标公告中提取招标人、中标人等关键信息
+
+这个示例展示了如何使用 BidiRag 召回 HTML 公告中的内容片段,
+并从中提取关键信息。
+"""
+
+import sys
+import os
+import re
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.bidi_rag import BidiRag
+
+
+def extract_key_info(html_content):
+    """从 HTML 内容中提取关键信息"""
+    from bs4 import BeautifulSoup
+    
+    info = {}
+    
+    try:
+        # 尝试使用 lxml,如果失败则使用内置的 html.parser
+        try:
+            soup = BeautifulSoup(html_content, 'lxml')
+        except:
+            soup = BeautifulSoup(html_content, 'html.parser')
+        
+        # 查找所有表格行
+        for tr in soup.find_all('tr'):
+            cells = tr.find_all(['td', 'th'])
+            if len(cells) >= 2:
+                key = cells[0].get_text(strip=True)
+                value = cells[1].get_text(strip=True)
+                
+                # 匹配关键字段
+                if '招标人' in key:
+                    info['招标人'] = value
+                elif '采购人' in key and '意向' not in key:
+                    info['采购人'] = value
+                elif '采购意向单位' in key:
+                    info['采购意向单位'] = value
+                elif '中标人' in key:
+                    info['中标人'] = value
+                elif '项目名称' in key or '采购项目名称' in key:
+                    info['项目名称'] = value
+                elif '中标金额' in key:
+                    info['中标金额'] = value
+                elif '预算金额' in key:
+                    info['预算金额'] = value
+    except Exception as e:
+        # 如果解析失败,使用正则表达式作为后备
+        patterns = {
+            '招标人': r'<td>招标人</td>\s*<td>([^<]+)</td>',
+            '采购人': r'<td>采购人</td>\s*<td>([^<]+)</td>',
+            '采购意向单位': r'<td>采购意向单位</td>\s*<td>([^<]+)</td>',
+            '中标人': r'<td>中标人</td>\s*<td>([^<]+)</td>',
+            '项目名称': r'<td>项目名称</td>\s*<td>([^<]+)</td>',
+            '中标金额': r'<td>中标金额</td>\s*<td>([^<]+)</td>',
+            '预算金额': r'<td>预算金额</td>\s*<td>([^<]+)</td>',
+        }
+        
+        for key, pattern in patterns.items():
+            match = re.search(pattern, html_content)
+            if match:
+                info[key] = match.group(1).strip()
+    
+    return info
+
+
+def demo_extract_bid_info():
+    """演示从招投标公告中提取信息"""
+    
+    print("=" * 80)
+    print("招投标公告信息提取示例")
+    print("=" * 80)
+    
+    # 1. 初始化 BidiRag
+    print("\n[步骤 1] 初始化 BidiRag (使用 bm25_html_tree 方法)...")
+    rag = BidiRag(rag_method='bm25_html_tree')
+    
+    # 2. 添加公告数据
+    print("\n[步骤 2] 添加招投标公告...")
+    
+    announcements = [
+        """<html>
+        <body>
+            <h1>政府采购意向公告</h1>
+            <table>
+                <tr><td>采购意向单位</td><td>大连长兴岛经济技术开发区交流岛街道办事处</td></tr>
+                <tr><td>采购项目名称</td><td>交流岛滨海路夜间出行照明提升工程</td></tr>
+                <tr><td>预算金额</td><td>147.060000万元</td></tr>
+                <tr><td>采购品目</td><td>路灯照明工程</td></tr>
+                <tr><td>采购需求概况</td><td>在滨海路安装太阳能路灯200盏(单排)</td></tr>
+                <tr><td>预计采购时间</td><td>2026-05</td></tr>
+            </table>
+        </body>
+        </html>""",
+        
+        """<html>
+        <body>
+            <h1>中标公告</h1>
+            <table>
+                <tr><td>采购人</td><td>XX市财政局</td></tr>
+                <tr><td>项目名称</td><td>办公设备采购项目</td></tr>
+                <tr><td>中标人</td><td>XX办公设备有限公司</td></tr>
+                <tr><td>中标金额</td><td>50万元</td></tr>
+                <tr><td>采购内容</td><td>电脑、打印机、复印机等办公设备</td></tr>
+            </table>
+        </body>
+        </html>""",
+        
+        """<html>
+        <body>
+            <h1>招标公告</h1>
+            <table>
+                <tr><td>招标人</td><td>XX市交通运输局</td></tr>
+                <tr><td>项目名称</td><td>智慧交通系统建设项目</td></tr>
+                <tr><td>项目预算</td><td>5000万元</td></tr>
+                <tr><td>招标内容</td><td>交通信号控制系统、视频监控系统</td></tr>
+                <tr><td>投标截止时间</td><td>2024-06-15</td></tr>
+            </table>
+        </body>
+        </html>""",
+        
+        """<html>
+        <body>
+            <h1>中标结果公告</h1>
+            <table>
+                <tr><td>招标人</td><td>XX市城市建设投资集团</td></tr>
+                <tr><td>中标人</td><td>XX建设工程有限公司</td></tr>
+                <tr><td>项目名称</td><td>道路建设项目</td></tr>
+                <tr><td>中标金额</td><td>18500万元</td></tr>
+                <tr><td>建设内容</td><td>道路路基、路面、桥梁工程</td></tr>
+            </table>
+        </body>
+        </html>"""
+    ]
+    
+    rag.add_texts(announcements)
+    print(f"✓ 已添加 {rag.get_document_count()} 个公告")
+    
+    # 3. 检索招标人信息
+    print("\n" + "=" * 80)
+    print("[示例 1] 检索所有包含'招标人'或'采购人'的公告")
+    print("=" * 80)
+    
+    # 使用 OR 逻辑:分别搜索
+    for keyword in ['招标人', '采购人', '采购意向单位']:
+        print(f"\n搜索关键词: '{keyword}'")
+        results = rag.search_keywords([keyword], top_k=10)
+        
+        if results:
+            print(f"找到 {len(results)} 个相关公告:\n")
+            for i, doc in enumerate(results, 1):
+                print(f"--- 公告 {i} ---")
+                info = extract_key_info(doc.page_content)
+                for key, value in info.items():
+                    print(f"  {key}: {value}")
+                print()
+        else:
+            print("未找到相关结果\n")
+    
+    # 4. 检索中标人信息
+    print("\n" + "=" * 80)
+    print("[示例 2] 检索所有包含'中标人'的公告")
+    print("=" * 80)
+    
+    results = rag.search_keywords(['中标人'], top_k=10)
+    
+    if results:
+        print(f"\n找到 {len(results)} 个中标公告:\n")
+        for i, doc in enumerate(results, 1):
+            print(f"--- 中标公告 {i} ---")
+            info = extract_key_info(doc.page_content)
+            for key, value in info.items():
+                print(f"  {key}: {value}")
+            print()
+    
+    # 5. 同时包含招标人和中标人的公告
+    print("\n" + "=" * 80)
+    print("[示例 3] 检索同时包含'招标人'和'中标人'的公告(完整项目信息)")
+    print("=" * 80)
+    
+    results = rag.search_keywords(['招标人', '中标人'], top_k=10)
+    
+    if results:
+        print(f"\n找到 {len(results)} 个完整项目公告:\n")
+        for i, doc in enumerate(results, 1):
+            print(f"--- 项目 {i} ---")
+            info = extract_key_info(doc.page_content)
+            
+            # 分类显示
+            purchaser = info.get('招标人') or info.get('采购人')
+            winner = info.get('中标人')
+            project = info.get('项目名称')
+            amount = info.get('中标金额') or info.get('预算金额')
+            
+            print(f"  招标人: {purchaser}")
+            print(f"  中标人: {winner}")
+            print(f"  项目: {project}")
+            print(f"  金额: {amount}")
+            print()
+    
+    # 6. 检索特定类型项目
+    print("\n" + "=" * 80)
+    print("[示例 4] 检索特定项目类型")
+    print("=" * 80)
+    
+    search_terms = [
+        (['路灯', '照明'], "照明工程"),
+        (['交通', '系统'], "交通系统"),
+        (['道路', '建设'], "道路建设"),
+        (['办公', '设备'], "办公设备"),
+    ]
+    
+    for keywords, category in search_terms:
+        print(f"\n搜索 '{category}' 相关项目:")
+        results = rag.search_keywords(keywords, top_k=5)
+        
+        if results:
+            for i, doc in enumerate(results, 1):
+                info = extract_key_info(doc.page_content)
+                purchaser = info.get('招标人') or info.get('采购人') or info.get('采购意向单位')
+                project = info.get('项目名称')
+                amount = info.get('中标金额') or info.get('预算金额')
+                
+                print(f"  {i}. {project}")
+                print(f"     招标方: {purchaser}")
+                print(f"     金额: {amount}")
+        else:
+            print("  未找到相关项目")
+    
+    # 7. 总结
+    print("\n" + "=" * 80)
+    print("提取结果总结")
+    print("=" * 80)
+    
+    # 统计所有公告
+    all_results = rag.search_keywords(['招标人', '采购人', '中标人'], top_k=100)
+    
+    purchasers = set()
+    winners = set()
+    total_amount = 0
+    
+    for doc in all_results:
+        info = extract_key_info(doc.page_content)
+        
+        purchaser = info.get('招标人') or info.get('采购人') or info.get('采购意向单位')
+        if purchaser:
+            purchasers.add(purchaser)
+        
+        winner = info.get('中标人')
+        if winner:
+            winners.add(winner)
+        
+        amount_str = info.get('中标金额') or info.get('预算金额') or ''
+        # 提取金额数字(简单处理)
+        amount_match = re.search(r'(\d+(?:\.\d+)?)万', amount_str)
+        if amount_match:
+            total_amount += float(amount_match.group(1))
+    
+    print(f"\n共处理 {len(all_results)} 个公告")
+    print(f"涉及招标方 {len(purchasers)} 个:")
+    for p in purchasers:
+        print(f"  - {p}")
+    
+    print(f"\n涉及中标方 {len(winners)} 个:")
+    for w in winners:
+        print(f"  - {w}")
+    
+    print(f"\n项目总金额: 约 {total_amount:.2f} 万元")
+    
+    print("\n" + "=" * 80)
+    print("示例完成!")
+    print("=" * 80)
+
+
+if __name__ == "__main__":
+    demo_extract_bid_info()

+ 115 - 0
examples/quick_demo.py

@@ -0,0 +1,115 @@
+"""
+Simple demo script to test basic RAG methods without full benchmark
+Good for quick validation and understanding
+"""
+
+import sys
+import os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from openai import OpenAI
+
+from bdirag.document_processor import Document
+from bdirag.embedding_models import SentenceTransformerEmbedding
+from bdirag.vector_stores import FAISSStore
+from bdirag.rag_methods import (
+    NaiveRAG,
+    BidFieldExtractionRAG,
+    HyDERAG,
+)
+from examples.sample_data import SAMPLE_BIDDING_DOCS
+
+
+def main():
+    print("=" * 60)
+    print("BidiRAG - Quick Demo")
+    print("=" * 60)
+
+    # Configuration - modify these as needed
+    LLM_API_KEY = os.getenv("OPENAI_API_KEY", "your-api-key-here")
+    LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
+    LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
+
+    # Step 1: Load embedding model
+    print("\n[1/4] Loading embedding model...")
+    embedding_model = SentenceTransformerEmbedding(
+        model_name="BAAI/bge-large-zh-v1.5",
+        device="cpu"
+    )
+    print("  Model loaded, dimension: {0}.format(embedding_model.dimension)")
+
+    # Step 2: Create vector store and index documents
+    print("\n[2/4] Creating vector store and indexing documents...")
+    vector_store = FAISSStore(embedding_model=embedding_model)
+
+    documents = [
+        Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
+        for doc in SAMPLE_BIDDING_DOCS
+    ]
+    print("  Prepared {0} documents.format(len(documents))")
+
+    # Step 3: Initialize RAG methods
+    print("\n[3/4] Initializing RAG methods...")
+    llm_client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL)
+
+    naive_rag = NaiveRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=LLM_MODEL,
+    )
+
+    bid_rag = BidFieldExtractionRAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=LLM_MODEL,
+    )
+
+    hyde_rag = HyDERAG(
+        embedding_model=embedding_model,
+        vector_store=vector_store,
+        llm_client=llm_client,
+        llm_model=LLM_MODEL,
+    )
+
+    naive_rag.index_documents(documents)
+    bid_rag.index_documents(documents)
+    hyde_rag.index_documents(documents)
+    print("  Indexing complete")
+
+    # Step 4: Test queries
+    print("\n[4/4] Running test queries...")
+
+    queries = [
+        "What is the budget for the smart transportation project?",
+        "List the qualification requirements for all projects",
+        "What are the payment terms for the road construction project?",
+    ]
+
+    methods = [
+        ("NaiveRAG", naive_rag),
+        ("BidFieldExtractionRAG", bid_rag),
+        ("HyDERAG", hyde_rag),
+    ]
+
+    for query in queries:
+        print("\n{0}.format('=' * 60)")
+        print("Query: {0}.format(query)")
+        print("{0}.format('=' * 60)")
+
+        for method_name, method in methods:
+            print("\n--- {0} ---.format(method_name)")
+            try:
+                result = method.query(query, k=5)
+                print("Answer: {0}.format(result.answer)")
+                print("Latency: {0}s (retrieval: {1}s, generation: {2}s).format(result.latency_total:.3f, result.latency_retrieval:.3f, result.latency_generation:.3f)")
+                print("Retrieved {0} documents.format(len(result.retrieved_docs))")
+            except Exception as e:
+                print("ERROR: {0}.format(e)")
+
+    print("\n\nDemo complete!")
+
+
+if __name__ == "__main__":
+    main()

+ 133 - 0
examples/quick_test_methods.py

@@ -0,0 +1,133 @@
+# -*- coding: utf-8 -*-
+"""
+简单测试 BidiRag 的几种无需 embedding 的方法
+"""
+
+import sys
+import os
+import time
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.bidi_rag import BidiRag
+
+
+# 测试文档
+TEST_DOCS = [
+    """XX市第一人民医院医疗设备招标公告
+
+项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目
+项目编号:XX-ZB-2024-001
+预算金额:500万元
+采购内容:彩色多普勒超声诊断仪 1台
+
+投标人资格要求:
+1. 具有独立承担民事责任的能力
+2. 具有有效的医疗器械经营许可证
+3. 近三年内无不良经营记录
+4. 投标保证金:人民币5万元整
+
+技术需求:
+1. 彩色多普勒超声诊断仪技术参数
+   - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头
+   - 显示屏:≥19英寸高清液晶显示器
+   - 质保期:整机质保三年
+2. 交货时间:合同签订后60天内交货
+3. 交货地点:XX市第一人民医院设备科
+
+评标方法:采用综合评分法
+   - 技术部分:60分
+   - 商务部分:30分
+   - 价格部分:10分
+
+付款方式:合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%
+
+投标截止时间:2024年12月31日上午9:30""",
+
+    """XX市智慧交通系统建设项目招标公告
+
+项目名称:XX市智慧交通系统建设项目
+项目编号:XX-ZB-2024-002
+招标人:XX市交通运输局
+预算金额:5000万元
+
+项目内容:
+1. 交通信号控制系统
+2. 视频监控系统
+3. 交通流量监测系统
+4. 数据分析平台
+
+资质要求:
+1. 电子与智能化工程专业承包二级以上资质
+2. 近三年至少完成2个类似项目业绩
+
+评标方法:综合评分法
+   - 技术部分:60分
+   - 商务部分:40分
+
+交货时间:合同签订后180天内
+质保期:3年"""
+]
+
+
+def test_method(method_name):
+    """测试单个方法"""
+    print("\n" + "=" * 80)
+    print("测试方法: " + method_name)
+    print("=" * 80)
+    
+    try:
+        # 初始化
+        t0 = time.time()
+        rag = BidiRag(rag_method=method_name)
+        init_time = time.time() - t0
+        
+        # 添加文档
+        t0 = time.time()
+        rag.add_texts(TEST_DOCS)
+        add_time = time.time() - t0
+        
+        # 检索
+        t0 = time.time()
+        results = rag.retrieve(query="预算金额", top_k=3, keywords=["预算"])
+        retrieve_time = time.time() - t0
+        
+        # 评估
+        relevant = sum(1 for doc, _ in results if "预算" in doc.page_content)
+        precision = relevant / len(results) if results else 0
+        
+        print("初始化时间: {:.2f}s".format(init_time))
+        print("索引构建时间: {:.2f}s".format(add_time))
+        print("检索时间: {:.4f}s".format(retrieve_time))
+        print("召回数量: {}".format(len(results)))
+        print("精确度: {:.1%}".format(precision))
+        
+        if results:
+            print("\n结果预览:")
+            for i, (doc, score) in enumerate(results[:2], 1):
+                preview = doc.page_content[:80].replace("\n", " ")
+                print("  [{}] Score={:.4f} | {}".format(i, score, preview))
+        
+        return True
+        
+    except Exception as e:
+        print("失败: " + str(e))
+        return False
+
+
+if __name__ == "__main__":
+    print("=" * 80)
+    print("BidiRag - RAG 方法快速测试")
+    print("=" * 80)
+    
+    methods = ['bm25', 'tfidf', 'keyword', 'bm25_html_tree']
+    
+    success_count = 0
+    for i, method in enumerate(methods, 1):
+        print("\n[{}/{}] ".format(i, len(methods)), end="")
+        if test_method(method):
+            success_count += 1
+    
+    print("\n\n" + "=" * 80)
+    print("测试完成! 成功: {}/{}".format(success_count, len(methods)))
+    print("=" * 80)

+ 317 - 0
examples/rag_test_utils.py

@@ -0,0 +1,317 @@
+# -*- coding: utf-8 -*-
+"""Shared helpers for standalone RAG example tests."""
+import math
+import os
+import re
+import sys
+import time
+import types
+
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+if ROOT_DIR not in sys.path:
+    sys.path.insert(0, ROOT_DIR)
+
+from bdirag.document_processor import Document
+from examples.sample_data import SAMPLE_BIDDING_DOCS
+
+
+TEST_QUERIES = [
+    "project budget amount",
+    "bid bond amount",
+    "qualification requirements",
+    "evaluation method",
+    "warranty period",
+    "payment terms",
+    "project code XX-ZB",
+    "delivery time",
+]
+
+
+def install_rank_bm25_fallback():
+    """Expose the repo BM25 fallback as rank_bm25 for methods that import it."""
+    if "rank_bm25" in sys.modules:
+        return
+    from bdirag.rag_methods.bm25_backend import SimpleBM25Okapi
+
+    module = types.ModuleType("rank_bm25")
+    module.BM25Okapi = SimpleBM25Okapi
+    sys.modules["rank_bm25"] = module
+
+
+class _ScoreList(list):
+    def flatten(self):
+        return self
+
+
+class _SimpleTfidfVectorizer(object):
+    def __init__(self):
+        self.vocabulary_ = {}
+
+    def fit_transform(self, texts):
+        for text in texts:
+            for token in _tokens(text):
+                if token not in self.vocabulary_:
+                    self.vocabulary_[token] = len(self.vocabulary_)
+        return self.transform(texts)
+
+    def transform(self, texts):
+        vectors = []
+        for text in texts:
+            vector = [0.0] * len(self.vocabulary_)
+            for token in _tokens(text):
+                index = self.vocabulary_.get(token)
+                if index is not None:
+                    vector[index] += 1.0
+            norm = math.sqrt(sum(v * v for v in vector)) or 1.0
+            vectors.append([v / norm for v in vector])
+        return vectors
+
+
+def _simple_cosine_similarity(query_vecs, matrix):
+    query = query_vecs[0] if query_vecs else []
+    scores = _ScoreList()
+    for vector in matrix:
+        scores.append(sum(a * b for a, b in zip(query, vector)))
+    return scores
+
+
+class _SimpleKMeans(object):
+    def __init__(self, n_clusters=2, random_state=None):
+        self.n_clusters = max(1, int(n_clusters))
+        self.random_state = random_state
+
+    def fit_predict(self, embeddings):
+        return [i % self.n_clusters for i in range(len(embeddings))]
+
+
+def install_sklearn_fallback():
+    """Install tiny sklearn-compatible modules used by the example tests."""
+    sklearn = types.ModuleType("sklearn")
+    feature_extraction = types.ModuleType("sklearn.feature_extraction")
+    text = types.ModuleType("sklearn.feature_extraction.text")
+    metrics = types.ModuleType("sklearn.metrics")
+    pairwise = types.ModuleType("sklearn.metrics.pairwise")
+    cluster = types.ModuleType("sklearn.cluster")
+
+    text.TfidfVectorizer = _SimpleTfidfVectorizer
+    pairwise.cosine_similarity = _simple_cosine_similarity
+    cluster.KMeans = _SimpleKMeans
+
+    sklearn.feature_extraction = feature_extraction
+    sklearn.metrics = metrics
+    sklearn.cluster = cluster
+    feature_extraction.text = text
+    metrics.pairwise = pairwise
+
+    sys.modules["sklearn"] = sklearn
+    sys.modules["sklearn.feature_extraction"] = feature_extraction
+    sys.modules["sklearn.feature_extraction.text"] = text
+    sys.modules["sklearn.metrics"] = metrics
+    sys.modules["sklearn.metrics.pairwise"] = pairwise
+    sys.modules["sklearn.cluster"] = cluster
+
+
+def make_documents():
+    return [
+        Document(
+            page_content=doc["content"],
+            metadata={"title": doc["title"], "source": doc["title"]},
+        )
+        for doc in SAMPLE_BIDDING_DOCS
+    ]
+
+
+def _tokens(text):
+    return re.findall(r"[A-Za-z0-9]+", (text or "").lower())
+
+
+class FakeEmbedding(object):
+    def __init__(self, dimension=64):
+        self.dimension = dimension
+
+    def _embed(self, text):
+        vector = [0.0] * self.dimension
+        for token in _tokens(text):
+            index = sum(ord(ch) for ch in token) % self.dimension
+            vector[index] += 1.0
+        norm = math.sqrt(sum(v * v for v in vector)) or 1.0
+        return [v / norm for v in vector]
+
+    def embed_documents(self, texts):
+        return [self._embed(text) for text in texts]
+
+    def embed_query(self, text):
+        return self._embed(text)
+
+
+class SimpleVectorStore(object):
+    def __init__(self):
+        self.documents = []
+        self.embeddings = []
+
+    def add_documents(self, documents, embeddings):
+        self.documents.extend(documents)
+        self.embeddings.extend(embeddings)
+
+    def similarity_search(self, query_embedding, k=10):
+        scored = []
+        for doc, embedding in zip(self.documents, self.embeddings):
+            score = sum(a * b for a, b in zip(query_embedding, embedding))
+            if score > 0:
+                scored.append((doc, float(score)))
+        scored.sort(key=lambda item: item[1], reverse=True)
+        return scored[:k]
+
+    def save(self, path):
+        return None
+
+    def load(self, path):
+        return None
+
+
+class _FakeMessage(object):
+    def __init__(self, content):
+        self.content = content
+
+
+class _FakeChoice(object):
+    def __init__(self, content):
+        self.message = _FakeMessage(content)
+
+
+class _FakeResponse(object):
+    def __init__(self, content):
+        self.choices = [_FakeChoice(content)]
+
+
+class _FakeCompletions(object):
+    def create(self, model=None, messages=None, temperature=None, max_tokens=None):
+        prompt = messages[-1]["content"] if messages else ""
+        return _FakeResponse(fake_llm_text(prompt))
+
+
+class _FakeChat(object):
+    def __init__(self):
+        self.completions = _FakeCompletions()
+
+
+class FakeLLMClient(object):
+    def __init__(self):
+        self.chat = _FakeChat()
+
+
+def fake_llm_text(prompt):
+    lower = (prompt or "").lower()
+    if "json" in lower:
+        return (
+            '{"project_name": "sample bidding project", '
+            '"budget_amount": "sample budget", '
+            '"evaluation_method": "sample evaluation method"}'
+        )
+    if "0-1" in prompt or "0-1" in lower or "score" in lower:
+        return "0.8"
+    if "category" in lower or "classify" in lower:
+        return "budget"
+    if "simple" in lower and "complex" in lower:
+        return "simple"
+    if "sub" in lower or "variant" in lower:
+        return "project budget\nqualification requirements\ndelivery time"
+    if "step" in lower or "general" in lower:
+        return "general bidding project information"
+    return "Sample offline LLM answer for the standalone RAG test."
+
+
+class FakeRerankModel(object):
+    def compute_score(self, pairs):
+        scores = []
+        for query, text in pairs:
+            query_words = set(_tokens(query))
+            text_words = set(_tokens(text))
+            scores.append(float(len(query_words & text_words)) + 0.1)
+        return scores
+
+
+def build_vector_rag(rag_cls, **kwargs):
+    options = {
+        "embedding_model": FakeEmbedding(),
+        "vector_store": SimpleVectorStore(),
+        "llm_client": FakeLLMClient(),
+    }
+    options.update(kwargs)
+    return rag_cls(**options)
+
+
+def print_results(method_name, query, results, elapsed):
+    print("\nQuery: {}".format(query))
+    print("-" * 60)
+    print("  Retrieved {} documents in {:.4f}s".format(len(results), elapsed))
+    for i, (doc, score) in enumerate(results, 1):
+        title = doc.metadata.get("title", doc.metadata.get("source", "Unknown"))
+        preview = doc.page_content[:100].replace("\n", " ")
+        print("  [{}] {} (Score: {:.4f})".format(i, title, float(score)))
+        print("      Preview: {}...".format(preview))
+
+
+def jinrun_retrieval_test(method_name, rag, index_func=None, queries=None, k=3):
+    print("=" * 60)
+    print("{} - Standalone Retrieval Test".format(method_name))
+    print("=" * 60)
+
+    documents = make_documents()
+    print("\n[1/2] Preparing documents...")
+    print("  Prepared {} documents".format(len(documents)))
+
+    print("\n[2/2] Building index...")
+    if index_func is None:
+        rag.index_documents(documents)
+    else:
+        index_func(rag, documents)
+    print("  Index built successfully")
+
+    print("\n" + "=" * 60)
+    print("{} Retrieval Test Results".format(method_name))
+    print("=" * 60)
+
+    for query in queries or TEST_QUERIES:
+        start = time.time()
+        results = rag.retrieve(query, k=k)
+        elapsed = time.time() - start
+        print_results(method_name, query, results, elapsed)
+
+    print("\n{} test complete!".format(method_name))
+
+
+run_retrieval_test = jinrun_retrieval_test
+
+
+def run_html_tree_test(rag_cls):
+    sample_html = """
+    <html><body>
+      <h1>Sample Medical Equipment Procurement Bidding Announcement</h1>
+      <h2>Project Overview</h2>
+      <p>Project budget: 28,000,000 RMB</p>
+      <p>Project code: XX-HOSP-2024-015</p>
+      <h2>Qualification Requirements</h2>
+      <p>Must have Medical Device Operation License and ISO 13485 certification.</p>
+      <h2>Delivery and Warranty</h2>
+      <p>Delivery time: within 90 calendar days. Warranty period: minimum 5 years.</p>
+      <h2>Payment Terms</h2>
+      <p>100% payment after installation, debugging and acceptance.</p>
+    </body></html>
+    """
+
+    print("=" * 60)
+    print("{} - Standalone HTML Tree Test".format(rag_cls.__name__))
+    print("=" * 60)
+
+    rag = rag_cls()
+    rag.build_index(sample_html)
+    print("  HTML tree index built successfully")
+
+    for query in ["project budget", "qualification requirements", "delivery time", "payment terms"]:
+        start = time.time()
+        results = rag.query(query, k=3)
+        elapsed = time.time() - start
+        print_results(rag_cls.__name__, query, results, elapsed)
+
+    print("\n{} test complete!".format(rag_cls.__name__))

+ 242 - 0
examples/sample_data.py

@@ -0,0 +1,242 @@
+"""
+Sample bidding announcement documents for testing RAG methods
+"""
+
+SAMPLE_BIDDING_DOCS = [
+    {
+        "title": "XX City Smart Transportation System Project Bidding Announcement",
+        "content": """XX City Smart Transportation System Project Bidding Announcement
+
+Project Name: XX City Smart Transportation System Construction Project
+Project Code: XX-ZB-2024-001
+Purchaser: XX City Transportation Bureau
+Purchaser Contact: Zhang San
+Purchaser Phone: 010-12345678
+
+Bidding Agency: XX Bidding Agency Co., Ltd.
+Agency Contact: Li Si
+Agency Phone: 010-87654321
+
+Project Budget: 50,000,000 RMB
+Bid Bond Amount: 500,000 RMB
+Performance Bond Amount: 5,000,000 RMB (10% of contract amount)
+
+Bid Submission Deadline: 2024-06-15 09:30 (Beijing Time)
+Bid Opening Time: 2024-06-15 09:30 (Beijing Time)
+Bid Submission Location: XX City Public Resource Trading Center, 3rd Floor, Room 301
+
+Scope of Work: Including the design, development, implementation and maintenance of the smart transportation system, including traffic signal control system, video surveillance system, traffic flow monitoring system and data analysis platform.
+
+Qualification Requirements:
+1. Bidders must have independent legal person qualification
+2. Must have electronic and intelligent engineering professional contracting Grade II or above qualification
+3. Must have completed at least 2 similar project performance in the past 3 years
+4. Project manager must have electromechanical engineering professional first-class registered constructor certificate
+
+Evaluation Method: Comprehensive Evaluation Method
+Technical Score Weight: 60%
+Commercial Score Weight: 40%
+
+Delivery Time: Within 180 calendar days after contract signing
+Delivery Location: XX City Transportation Bureau designated location
+Warranty Period: 3 years
+Payment Terms: 30% advance payment, 40% upon delivery and acceptance, 25% upon final acceptance, 5% quality guarantee deposit"""
+    },
+    {
+        "title": "XX Hospital Medical Equipment Procurement Bidding Announcement",
+        "content": """XX Hospital Medical Equipment Procurement Bidding Announcement
+
+Project Name: XX Hospital Medical Equipment Procurement Project
+Project Code: XX-HOSP-2024-015
+Purchaser: XX Hospital
+Purchaser Contact: Wang Wu
+Purchaser Phone: 021-55667788
+
+Bidding Agency: XX Medical Equipment Bidding Center
+Agency Contact: Zhao Liu
+Agency Phone: 021-88776655
+
+Project Budget: 28,000,000 RMB
+Currency: RMB
+Bid Bond Amount: 280,000 RMB
+Performance Bond Amount: 2,800,000 RMB
+
+Bid Submission Deadline: 2024-07-20 14:00 (Beijing Time)
+Bid Opening Time: 2024-07-20 14:00 (Beijing Time)
+Bid Location: XX City Public Resource Trading Center Medical Equipment Branch
+
+Procurement List:
+1. MRI System - 1 set
+2. CT Scanner - 2 sets
+3. Digital X-ray Machine - 3 sets
+4. Ultrasound Diagnostic Equipment - 5 sets
+5. Blood Analyzer - 2 sets
+
+Qualification Requirements:
+1. Must have Medical Device Operation License
+2. Must be authorized agent or manufacturer of the equipment
+3. Must have ISO 13485 medical device quality management system certification
+4. Similar project performance in the past 3 years
+
+Evaluation Method: Lowest Bid Price Method with Technical Threshold
+Technical Threshold: All parameters must meet or exceed the requirements in the bidding documents
+
+Delivery Time: Within 90 calendar days after contract signing
+Delivery Location: XX Hospital Equipment Department
+Warranty Period: Minimum 5 years for main equipment
+Payment Terms: 100% payment after installation, debugging and acceptance
+
+Special Notes:
+1. Imported equipment must provide customs declaration certificates
+2. Domestic equipment must provide medical device registration certificates
+3. All equipment must provide after-sales service commitment letter"""
+    },
+    {
+        "title": "XX University Campus Network Upgrade Project Bidding Announcement",
+        "content": """XX University Campus Network Upgrade Project Bidding Announcement
+
+Project Name: XX University Campus Network Upgrade and Expansion Project
+Project Code: XX-EDU-2024-008
+Purchaser: XX University Information Technology Center
+Purchaser Contact: Chen Qi
+Purchaser Phone: 025-11223344
+
+Bidding Agency: XX Government Procurement Center
+Agency Contact: Sun Ba
+Agency Phone: 025-44332211
+
+Project Budget: 15,800,000 RMB
+Bid Bond Amount: 158,000 RMB
+Performance Bond Amount: 1,580,000 RMB
+
+Bid Deadline: 2024-08-10 10:00 (Beijing Time)
+Bid Opening Time: 2024-08-10 10:00 (Beijing Time)
+Bid Location: XX City Government Procurement Center Bid Room 2
+
+Scope of Work:
+1. Core network equipment upgrade (core switches, routers, firewalls)
+2. Wireless network coverage (all teaching buildings, dormitories, libraries)
+3. Network security system construction (intrusion detection, log audit, behavior management)
+4. Server room renovation (UPS, precision air conditioning, environment monitoring)
+5. Network management platform development
+
+Technical Requirements:
+1. Core switch switching capacity not less than 10Tbps
+2. Wireless AP supports Wi-Fi 6 standard
+3. Firewall throughput not less than 20Gbps
+4. UPS backup time not less than 2 hours
+
+Qualification Requirements:
+1. Electronic and intelligent engineering professional contracting Grade I qualification
+2. Information system integration enterprise qualification Grade II or above
+3. Information security service qualification
+4. Project manager must have information system project management engineer certificate
+
+Evaluation Method: Comprehensive Scoring Method
+Price Score: 30 points
+Technical Score: 50 points
+Business Score: 20 points
+
+Delivery Time: Within 120 calendar days
+Warranty Period: 5 years free warranty, lifetime maintenance
+Payment Terms: 3-4-2-1 payment method (30% advance, 40% mid-term, 20% acceptance, 10% warranty)"""
+    },
+    {
+        "title": "XX City Environmental Monitoring System Project Bidding Announcement",
+        "content": """XX City Environmental Monitoring System Project Bidding Announcement
+
+Project Name: XX City Air and Water Quality Monitoring System Construction Project
+Project Code: XX-ENV-2024-022
+Purchaser: XX City Ecology and Environment Bureau
+Purchaser Contact: Zhou Jiu
+Purchaser Phone: 0512-99887766
+
+Agency Name: XX Environmental Consulting Co., Ltd.
+Agency Contact: Wu Shi
+Agency Phone: 0512-66778899
+
+Budget Amount: 35,600,000 RMB
+Bid Bond: 356,000 RMB
+Performance Bond: 3,560,000 RMB (10%)
+
+Bid Submission Deadline: 2024-09-05 09:00 (Beijing Time)
+Bid Opening: 2024-09-05 09:00
+Venue: XX City Ecology and Environment Bureau Conference Room
+
+Scope: Construction of 15 air quality monitoring stations and 8 water quality monitoring points in the city, including equipment procurement, installation, commissioning and networking operation.
+
+Equipment List:
+| Equipment Name | Quantity | Technical Requirements |
+| Air Quality Monitor | 15 sets | Monitor PM2.5, PM10, SO2, NO2, CO, O3 |
+| Water Quality Monitor | 8 sets | Monitor COD, ammonia nitrogen, total phosphorus, total nitrogen, pH |
+| Data Transmission Equipment | 23 sets | 4G/5G dual mode, support breakpoint transmission |
+| Solar Power System | 23 sets | Meet 7 consecutive rainy days power requirements |
+
+Qualification Requirements:
+1. Environmental engineering professional contracting qualification
+2. CMA certification (China Metrology Accreditation)
+3. Environmental monitoring instrument manufacturer authorization
+4. Similar project performance: at least 1 contract amount over 20 million in the past 3 years
+
+Evaluation Method: Comprehensive Evaluation Method
+Technical Part: 55 points
+Commercial Part: 35 points
+Price Part: 10 points
+
+Delivery Time: 150 calendar days
+Installation Location: Designated monitoring points throughout the city
+Warranty: 3 years free warranty, including quarterly calibration service
+Payment: 30% advance, 40% upon installation, 25% upon acceptance, 5% after warranty period"""
+    },
+    {
+        "title": "XX City Road Construction Project Bidding Announcement",
+        "content": """XX City Key Road Construction Project Bidding Announcement
+
+Project Name: XX City East-West Second Expressway Section III Construction Project
+Project Code: XX-ROAD-2024-035
+Purchaser: XX City Urban Construction Investment Group Co., Ltd.
+Contact Person: Zheng Shiyi
+Contact Phone: 0571-12340987
+
+Bidding Agency: XX Construction Project Bidding Co., Ltd.
+Agency Contact: Feng Shier
+Agency Phone: 0571-78901234
+
+Budget: 185,000,000 RMB
+Bid Bond Amount: 2,000,000 RMB
+Performance Bond Amount: 18,500,000 RMB
+
+Bid Deadline: 2024-10-20 09:30 (Beijing Time)
+Bid Opening Time: 2024-10-20 09:30
+Bid Location: XX City Construction Project Trading Center
+
+Scope of Work:
+The section III of East-West Second Expressway is about 5.8 kilometers long, including roadbed engineering, pavement engineering, bridge engineering (2 medium bridges), culvert engineering, traffic engineering, greening engineering and drainage engineering.
+
+Main Engineering Quantities:
+- Earthwork excavation: 350,000 cubic meters
+- Earthwork filling: 280,000 cubic meters
+- Cement stabilized碎石 base: 125,000 square meters
+- Asphalt concrete surface: 110,000 square meters
+- Medium bridge: 2 bridges, total length 260 meters
+- Pipe culvert: 35
+
+Qualification Requirements:
+1. Highway engineering construction general contracting Grade I or above qualification
+2. Highway engineering construction enterprise安全生产许可证
+3. Project manager must have highway engineering professional first-class registered constructor certificate and安全生产考核合格证书 (B certificate)
+4. Technical director must have highway engineering related professional senior technical title
+5. Similar project performance: completed at least 1 highway engineering construction project with contract amount over 100 million in the past 5 years
+
+Evaluation Method: Comprehensive Evaluation Method
+Technical Proposal: 45 points
+Project Management Institution: 15 points
+Commercial Reputation and Performance: 10 points
+Bid Price: 30 points
+
+Construction Period: 540 calendar days
+Quality Target: Qualified, striving for provincial quality engineering award
+Warranty Period: As per "Construction Project Quality Management Regulations", infrastructure project warranty period is reasonable service life
+Payment Terms: Monthly progress payment 80%, completion settlement payment to 97%, 3% quality guarantee deposit"""
+    },
+]

+ 12 - 0
examples/test_adaptive_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""AdaptiveRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.adaptive_rag import AdaptiveRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("AdaptiveRAG", build_vector_rag(AdaptiveRAG))

+ 226 - 0
examples/test_all_rag_methods.py

@@ -0,0 +1,226 @@
+# -*- coding: utf-8 -*-
+"""
+测试 BidiRag 支持的所有 RAG 方法
+
+这个脚本会逐一测试每种 RAG 方法的召回效果
+"""
+
+import sys
+import os
+import time
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.bidi_rag import BidiRag
+
+
+# 测试文档(纯文本格式)
+TEST_DOCS = [
+    """XX市第一人民医院医疗设备招标公告
+
+项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目
+项目编号:XX-ZB-2024-001
+预算金额:500万元
+采购内容:彩色多普勒超声诊断仪 1台
+
+投标人资格要求:
+1. 具有独立承担民事责任的能力
+2. 具有有效的医疗器械经营许可证
+3. 近三年内无不良经营记录
+4. 投标保证金:人民币5万元整
+
+技术需求:
+1. 彩色多普勒超声诊断仪技术参数
+   - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头
+   - 显示屏:≥19英寸高清液晶显示器
+   - 质保期:整机质保三年
+2. 交货时间:合同签订后60天内交货
+3. 交货地点:XX市第一人民医院设备科
+
+评标方法:采用综合评分法
+   - 技术部分:60分
+   - 商务部分:30分
+   - 价格部分:10分
+
+付款方式:合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%
+
+投标截止时间:2024年12月31日上午9:30
+开标时间:同投标截止时间
+投标文件递交地点:XX市公共资源交易中心""",
+
+    """XX市智慧交通系统建设项目招标公告
+
+项目名称:XX市智慧交通系统建设项目
+项目编号:XX-ZB-2024-002
+招标人:XX市交通运输局
+预算金额:5000万元
+
+项目内容:
+1. 交通信号控制系统
+2. 视频监控系统
+3. 交通流量监测系统
+4. 数据分析平台
+
+资质要求:
+1. 电子与智能化工程专业承包二级以上资质
+2. 近三年至少完成2个类似项目业绩
+3. 项目经理须具备机电工程专业一级建造师证书
+
+评标方法:综合评分法
+   - 技术部分:60分
+   - 商务部分:40分
+
+交货时间:合同签订后180天内
+质保期:3年
+付款方式:30%预付款,40%到货验收,25%最终验收,5%质保金"""
+]
+
+
+def test_rag_method(method_name, query="预算金额", keywords=None):
+    """测试单个 RAG 方法"""
+    
+    if keywords is None:
+        keywords = ["预算"]
+    
+    print("\n" + "=" * 80)
+    print("测试方法: {}".format(method_name))
+    print("=" * 80)
+    
+    try:
+        # 初始化 RAG
+        start_time = time.time()
+        rag = BidiRag(rag_method=method_name)
+        init_time = time.time() - start_time
+        
+        # 添加文档
+        start_time = time.time()
+        rag.add_texts(TEST_DOCS)
+        add_time = time.time() - start_time
+        
+        # 执行检索
+        start_time = time.time()
+        results = rag.retrieve(query=query, top_k=3, keywords=keywords)
+        retrieve_time = time.time() - start_time
+        
+        # 评估结果
+        relevant_count = 0
+        for doc, score in results:
+            is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
+            if is_relevant:
+                relevant_count += 1
+        
+        precision = relevant_count / len(results) if results else 0
+        
+        # 输出结果
+        print("✓ 初始化时间: {:.2f}s".format(init_time))
+        print("✓ 索引构建时间: {:.2f}s".format(add_time))
+        print("✓ 检索时间: {:.4f}s".format(retrieve_time))
+        print("✓ 召回数量: {}".format(len(results)))
+        print("✓ 精确度@3: {:.1%}".format(precision))
+        
+        if results:
+            print("\n前3个结果:")
+            for i, (doc, score) in enumerate(results[:3], 1):
+                preview = doc.page_content[:100].replace("\n", " ")
+                print("  [{}] Score: {:.4f}".format(i, score))
+                print("      Preview: {}...".format(preview))
+        
+        return {
+            'method': method_name,
+            'success': True,
+            'results_count': len(results),
+            'precision': precision,
+            'init_time': init_time,
+            'add_time': add_time,
+            'retrieve_time': retrieve_time,
+        }
+    
+    except Exception as e:
+        print("✗ 测试失败: {}".format(str(e)))
+        import traceback
+        traceback.print_exc()
+        return {
+            'method': method_name,
+            'success': False,
+            'error': str(e),
+        }
+
+
+def main():
+    """主测试函数"""
+    
+    print("=" * 80)
+    print("BidiRag - 所有 RAG 方法测试")
+    print("=" * 80)
+    
+    # 只测试不需要 embedding 模型的方法
+    methods_to_test = ['bm25', 'tfidf', 'keyword', 'bm25_html_tree']
+    
+    print("\n测试方法数量: {}".format(len(methods_to_test)))
+    print("方法列表: {}".format(', '.join(methods_to_test)))
+    print("\n注意: 其他方法需要 embedding 模型,由于网络限制暂时跳过")
+    
+    # 测试结果汇总
+    results_summary = []
+    
+    # 测试每个方法
+    for i, method in enumerate(methods_to_test, 1):
+        print("\n[{}/{}] ".format(i, len(methods_to_test)), end="")
+        result = test_rag_method(method)
+        results_summary.append(result)
+    
+    # 打印汇总报告
+    print("\n\n" + "=" * 80)
+    print("测试汇总报告")
+    print("=" * 80)
+    
+    successful = [r for r in results_summary if r['success']]
+    failed = [r for r in results_summary if not r['success']]
+    
+    print("\n总测试数: {}".format(len(results_summary)))
+    print("成功: {}".format(len(successful)))
+    print("失败: {}".format(len(failed)))
+    
+    if successful:
+        print("\n{:<30} {:>8} {:>10} {:>12}".format("方法名称", "召回数", "精确度", "检索时间"))
+        print("-" * 80)
+        
+        # 按精确度排序
+        successful.sort(key=lambda x: x['precision'], reverse=True)
+        
+        for r in successful:
+            print("{:<30} {:>8} {:>9.1%} {:>11.4f}s".format(
+                r['method'], r['results_count'], r['precision'], r['retrieve_time']))
+    
+    if failed:
+        print("\n失败的方法:")
+        for r in failed:
+            print("  - {}: {}".format(r['method'], r.get('error', 'Unknown error')))
+    
+    # 推荐方法
+    print("\n" + "=" * 80)
+    print("推荐使用方法")
+    print("=" * 80)
+    
+    # 找出精确度最高的方法
+    best_methods = [r for r in successful if r['precision'] == 1.0]
+    if best_methods:
+        print("\n精确度 100% 的方法:")
+        for r in best_methods:
+            print("  ✓ {} (召回 {} 个结果, 耗时 {:.4f}s)".format(
+                r['method'], r['results_count'], r['retrieve_time']))
+    
+    # 快速方法(检索时间 < 0.01s)
+    fast_methods = [r for r in successful if r['retrieve_time'] < 0.01]
+    if fast_methods:
+        print("\n快速检索方法 (< 10ms):")
+        for r in fast_methods:
+            print("  ✓ {} (精确度 {:.1%})".format(r['method'], r['precision']))
+    
+    print("\n" + "=" * 80)
+    print("测试完成!")
+    print("=" * 80)
+
+
+if __name__ == "__main__":
+    main()

+ 21 - 0
examples/test_bid_field_extraction_rag.py

@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+"""BidFieldExtractionRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.bid_field_extraction_rag import BidFieldExtractionRAG
+from examples.rag_test_utils import build_vector_rag, make_documents, run_retrieval_test
+
+
+if __name__ == "__main__":
+    rag = build_vector_rag(BidFieldExtractionRAG)
+    run_retrieval_test("BidFieldExtractionRAG", rag)
+
+    print("\n" + "=" * 60)
+    print("BidFieldExtractionRAG Field Extraction Test")
+    print("=" * 60)
+    rag = build_vector_rag(BidFieldExtractionRAG)
+    rag.index_documents(make_documents())
+    result = rag.extract_fields("extract project budget and evaluation method", k=3)
+    print(result.answer)

+ 292 - 0
examples/test_bidi_rag.py

@@ -0,0 +1,292 @@
+# -*- coding: utf-8 -*-
+"""BidiRag 测试脚本"""
+
+import sys
+import os
+
+# 添加项目根目录到路径
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.bidi_rag import BidiRag
+from loguru import logger
+
+
+def test_basic_retrieval():
+    """测试基础检索功能"""
+    print("\n" + "=" * 80)
+    print("测试 1: 基础检索功能")
+    print("=" * 80)
+    
+    try:
+        # 初始化
+        rag = BidiRag(rag_method='bm25')
+        
+        # 添加测试文档
+        docs = [
+            "招标公告:招标人为XX市财政局,项目预算100万元",
+            "中标公告:中标人为XX建设有限公司,中标金额98万元",
+            "采购公告:采购单位XX医院,采购医疗设备",
+        ]
+        rag.add_texts(docs)
+        
+        assert rag.get_document_count() == 3, "文档数量不正确"
+        print("✓ 文档添加成功")
+        
+        # 检索
+        results = rag.retrieve("招标人", top_k=2)
+        assert len(results) > 0, "检索结果为空"
+        print(f"✓ 检索成功,返回 {len(results)} 个结果")
+        
+        # 验证结果包含关键词
+        found = False
+        for doc, score in results:
+            if "招标人" in doc.page_content:
+                found = True
+                break
+        assert found, "结果中未找到关键词"
+        print("✓ 检索结果正确")
+        
+        print("✅ 测试 1 通过\n")
+        return True
+        
+    except Exception as e:
+        print(f"❌ 测试 1 失败: {e}\n")
+        import traceback
+        traceback.print_exc()
+        return False
+
+
+def test_keyword_search():
+    """测试关键词搜索"""
+    print("\n" + "=" * 80)
+    print("测试 2: 关键词搜索")
+    print("=" * 80)
+    
+    try:
+        rag = BidiRag(rag_method='bm25')
+        
+        docs = [
+            "招标人:XX局,项目A",
+            "中标人:XX公司,项目B",
+            "招标人:XX委,中标人:XX企业,项目C",
+        ]
+        rag.add_texts(docs)
+        
+        # 搜索单个关键词
+        results = rag.search_keywords(["招标人"])
+        assert len(results) == 2, f"应找到2个文档,实际找到 {len(results)} 个"
+        print("✓ 单关键词搜索成功")
+        
+        # 搜索多个关键词(AND 逻辑)
+        results = rag.search_keywords(["招标人", "中标人"])
+        assert len(results) == 1, f"应找到1个文档,实际找到 {len(results)} 个"
+        assert "项目C" in results[0].page_content
+        print("✓ 多关键词搜索成功")
+        
+        print("✅ 测试 2 通过\n")
+        return True
+        
+    except Exception as e:
+        print(f"❌ 测试 2 失败: {e}\n")
+        import traceback
+        traceback.print_exc()
+        return False
+
+
+def test_html_retrieval():
+    """测试 HTML 文档检索"""
+    print("\n" + "=" * 80)
+    print("测试 3: HTML 文档检索")
+    print("=" * 80)
+    
+    try:
+        # 使用 HTML Tree 方法
+        rag = BidiRag(rag_method='bm25_html_tree')
+        
+        html_docs = [
+            """<html><body>
+                <h1>招标公告</h1>
+                <p>招标人:XX市政府采购中心</p>
+                <p>项目预算:200万元</p>
+            </body></html>""",
+            """<html><body>
+                <h1>中标公告</h1>
+                <p>中标人:XX科技有限公司</p>
+                <p>中标金额:180万元</p>
+            </body></html>"""
+        ]
+        
+        rag.add_texts(html_docs)
+        print("✓ HTML 文档添加成功")
+        
+        # 检索
+        results = rag.retrieve(
+            "招标人信息",
+            top_k=2,
+            keywords=["招标人", "中标人"]
+        )
+        
+        assert len(results) > 0, "检索结果为空"
+        print(f"✓ HTML 检索成功,返回 {len(results)} 个结果")
+        
+        print("✅ 测试 3 通过\n")
+        return True
+        
+    except Exception as e:
+        print(f"❌ 测试 3 失败: {e}\n")
+        import traceback
+        traceback.print_exc()
+        return False
+
+
+def test_different_methods():
+    """测试不同的 RAG 方法"""
+    print("\n" + "=" * 80)
+    print("测试 4: 不同 RAG 方法")
+    print("=" * 80)
+    
+    test_docs = [
+        "招标人A,项目1",
+        "中标人B,项目2",
+    ]
+    
+    methods = ['bm25', 'tfidf', 'keyword']
+    success_count = 0
+    
+    for method in methods:
+        try:
+            print(f"\n测试方法: {method}")
+            rag = BidiRag(rag_method=method)
+            rag.add_texts(test_docs)
+            
+            results = rag.retrieve("招标人", top_k=1)
+            if len(results) > 0:
+                print(f"  ✓ {method} 方法正常工作")
+                success_count += 1
+            else:
+                print(f"  ⚠ {method} 方法返回空结果")
+                
+        except Exception as e:
+            print(f"  ❌ {method} 方法失败: {e}")
+    
+    print(f"\n✅ {success_count}/{len(methods)} 个方法测试通过\n")
+    return success_count == len(methods)
+
+
+def test_filter_by_keywords():
+    """测试关键词过滤"""
+    print("\n" + "=" * 80)
+    print("测试 5: 关键词过滤")
+    print("=" * 80)
+    
+    try:
+        rag = BidiRag(rag_method='bm25')
+        
+        docs = [
+            "招标人:A公司,项目1",
+            "中标人:B公司,项目2",
+            "招标人:C公司,中标人:D公司,项目3",
+            "采购人:E单位,项目4",
+        ]
+        rag.add_texts(docs)
+        
+        # 使用关键词过滤
+        results = rag.retrieve(
+            "公司信息",
+            top_k=5,
+            keywords=["招标人"]
+        )
+        
+        # 所有结果都应包含"招标人"
+        for doc, score in results:
+            assert "招标人" in doc.page_content, f"结果中应包含'招标人': {doc.page_content}"
+        
+        print(f"✓ 关键词过滤成功,返回 {len(results)} 个结果")
+        print("✅ 测试 5 通过\n")
+        return True
+        
+    except Exception as e:
+        print(f"❌ 测试 5 失败: {e}\n")
+        import traceback
+        traceback.print_exc()
+        return False
+
+
+def test_available_methods():
+    """测试可用方法列表"""
+    print("\n" + "=" * 80)
+    print("测试 6: 可用方法列表")
+    print("=" * 80)
+    
+    try:
+        rag = BidiRag(rag_method='bm25')
+        methods = rag.list_available_methods()
+        
+        assert len(methods) > 0, "方法列表为空"
+        print(f"✓ 可用方法数量: {len(methods)}")
+        print(f"✓ 方法列表: {', '.join(methods[:10])}...")
+        
+        # 测试方法信息
+        info = rag.get_method_info()
+        assert 'method_name' in info
+        assert 'document_count' in info
+        print(f"✓ 方法信息获取成功: {info['method_name']}")
+        
+        print("✅ 测试 6 通过\n")
+        return True
+        
+    except Exception as e:
+        print(f"❌ 测试 6 失败: {e}\n")
+        import traceback
+        traceback.print_exc()
+        return False
+
+
+def main():
+    """运行所有测试"""
+    print("\n" + " " * 20)
+    print("BidiRag 测试套件")
+    print(" " * 20 + "\n")
+    
+    tests = [
+        ("基础检索", test_basic_retrieval),
+        ("关键词搜索", test_keyword_search),
+        ("HTML 检索", test_html_retrieval),
+        ("不同方法", test_different_methods),
+        ("关键词过滤", test_filter_by_keywords),
+        ("可用方法", test_available_methods),
+    ]
+    
+    results = []
+    for name, test_func in tests:
+        try:
+            result = test_func()
+            results.append((name, result))
+        except Exception as e:
+            print(f"❌ {name} 测试异常: {e}")
+            results.append((name, False))
+    
+    # 汇总
+    print("\n" + "=" * 80)
+    print("测试汇总")
+    print("=" * 80)
+    
+    passed = sum(1 for _, r in results if r)
+    total = len(results)
+    
+    for name, result in results:
+        status = "✅ 通过" if result else "❌ 失败"
+        print(f"{status} - {name}")
+    
+    print(f"\n总计: {passed}/{total} 测试通过")
+    
+    if passed == total:
+        print("\n🎉 所有测试通过!")
+        return 0
+    else:
+        print(f"\n⚠️  {total - passed} 个测试失败")
+        return 1
+
+
+if __name__ == "__main__":
+    exit(main())

+ 267 - 0
examples/test_bm25.py

@@ -0,0 +1,267 @@
+"""
+BM25 RAG method standalone test
+Tests BM25 keyword retrieval performance on bidding documents
+"""
+
+import sys
+import os
+import time
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.document_processor import Document
+from bdirag.rag_methods.bm25_rag import BM25RAG
+from bdirag.rag_methods.naive_rag import NaiveRAG
+from bdirag.rag_methods.bm25_html_tree_rag import BM25HTMLTreeRAG
+from examples.sample_data import SAMPLE_BIDDING_DOCS
+
+SAMPLE_HTML = """
+<html>
+<body>
+    <h1>XX市第一人民医院医疗设备招标公告</h1>
+    <div>
+        <h2>一、项目概况</h2>
+        <p>项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目</p>
+        <p>项目编号:XX-ZB-2024-001</p>
+        <p>预算金额:500万元</p>
+        <p>采购内容:彩色多普勒超声诊断仪 1台</p>
+    </div>
+    <div>
+        <h2>二、投标人资格要求</h2>
+        <p>1. 具有独立承担民事责任的能力</p>
+        <p>2. 具有有效的医疗器械经营许可证</p>
+        <p>3. 近三年内无不良经营记录</p>
+        <p>4. 投标保证金:人民币5万元整</p>
+    </div>
+    <div>
+        <h2>三、技术需求</h2>
+        <p>1. 彩色多普勒超声诊断仪技术参数</p>
+        <p>   - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头</p>
+        <p>   - 显示屏:≥19英寸高清液晶显示器</p>
+        <p>   - 质保期:整机质保三年</p>
+        <p>2. 交货时间:合同签订后60天内交货</p>
+        <p>3. 交货地点:XX市第一人民医院设备科</p>
+    </div>
+    <div>
+        <h2>四、评标方法</h2>
+        <p>采用综合评分法:</p>
+        <p>   - 技术部分:60分</p>
+        <p>   - 商务部分:30分</p>
+        <p>   - 价格部分:10分</p>
+    </div>
+    <div>
+        <h2>五、付款方式</h2>
+        <p>合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%</p>
+    </div>
+    <div>
+        <h2>六、投标截止时间</h2>
+        <p>投标截止时间:2024年12月31日上午9:30</p>
+        <p>开标时间:同投标截止时间</p>
+        <p>投标文件递交地点:XX市公共资源交易中心</p>
+    </div>
+</body>
+</html>
+"""
+
+
+def test_bm25_retrieval():
+    print("=" * 60)
+    print("BM25 RAG - Standalone Test (Plain Text)")
+    print("=" * 60)
+
+    print("\n[1/2] Preparing documents...")
+    documents = [
+        Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
+        for doc in SAMPLE_BIDDING_DOCS
+    ]
+    print("  Prepared {} documents".format(len(documents)))
+
+    print("\n[2/2] Initializing BM25RAG...")
+    bm25_rag = BM25RAG()
+    bm25_rag.index_documents(documents)
+    print("  BM25 index built successfully")
+
+    test_queries = [
+        ("预算金额", ["budget", "Budget", "预算"]),
+        ("投标保证金", ["bid bond", "Bid Bond", "保证金"]),
+        ("资质要求", ["qualification", "Qualification", "资质"]),
+        ("评标方法", ["evaluation", "Evaluation", "评标"]),
+        ("质保期", ["warranty", "Warranty", "质保"]),
+        ("付款方式", ["payment", "Payment", "付款"]),
+        ("项目编号 XX-ZB", ["XX-ZB", "Project Code"]),
+        ("交货时间", ["delivery", "Delivery", "交货"]),
+    ]
+
+    print("\n" + "=" * 60)
+    print("BM25 Retrieval Test Results")
+    print("=" * 60)
+
+    for query, keywords in test_queries:
+        print("\nQuery: {}".format(query))
+        print("-" * 60)
+        
+        start = time.time()
+        results = bm25_rag.retrieve(query, k=3)
+        elapsed = time.time() - start
+        
+        print("  Retrieved {} documents in {:.4f}s".format(len(results), elapsed))
+        
+        # 评估相关性
+        relevant_count = 0
+        for i, (doc, score) in enumerate(results, 1):
+            title = doc.metadata.get("title", "Unknown")
+            preview = doc.page_content[:80].replace("\n", " ")
+            
+            # 检查是否包含关键词
+            is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
+            if is_relevant:
+                relevant_count += 1
+                marker = "[OK]"
+            else:
+                marker = "[ --]"
+            
+            print("  [{}] {} {} (Score: {:.4f})".format(i, marker, title, score))
+            print("      Preview: {}...".format(preview))
+        
+        precision = relevant_count / len(results) if results else 0
+        print("  Precision@3: {:.1%}".format(precision))
+    
+    print("\n\nBM25 plain text test complete!")
+
+
+def test_bm25_html_tree():
+    print("\n\n" + "=" * 60)
+    print("BM25 HTML Tree RAG - Standalone Test")
+    print("=" * 60)
+
+    print("\n[1/2] Parsing HTML and building tree...")
+    tree_rag = BM25HTMLTreeRAG()
+    tree_rag.build_index(SAMPLE_HTML)
+    print("  HTML tree index built successfully")
+
+    test_queries = [
+        ("预算金额", ["预算", "Budget"]),
+        ("投标保证金", ["保证金", "Bond"]),
+        ("技术参数 探头", ["探头", "technical"]),
+        ("评标方法 综合评分", ["评标", "综合评分"]),
+        ("质保期", ["质保", "Warranty"]),
+        ("付款方式", ["付款", "Payment"]),
+        ("交货时间", ["交货", "Delivery"]),
+    ]
+
+    print("\n[2/2] Testing BM25 HTML Tree Retrieval...")
+    print("=" * 60)
+
+    for query, keywords in test_queries:
+        print("\nQuery: {}".format(query))
+        print("-" * 60)
+        
+        start = time.time()
+        results = tree_rag.query(query, k=3)
+        elapsed = time.time() - start
+        
+        print("  Retrieved {} subtrees in {:.4f}s".format(len(results), elapsed))
+        
+        relevant_count = 0
+        for i, (doc, score) in enumerate(results, 1):
+            path = doc.metadata.get("path", "")
+            title = doc.metadata.get("title", "")
+            
+            # 检查相关性
+            is_relevant = any(kw.lower() in doc.page_content.lower() for kw in keywords)
+            if is_relevant:
+                relevant_count += 1
+                marker = "[OK]"
+            else:
+                marker = "[ --]"
+            
+            print("  [{}] {} Score: {:.4f}".format(i, marker, score))
+            print("      Path: {}".format(path))
+            print("      Content: {}...".format(doc.page_content[:120].replace("\n", " ")))
+        
+        precision = relevant_count / len(results) if results else 0
+        print("  Precision@3: {:.1%}".format(precision))
+    
+    print("\n\nBM25 HTML Tree test complete!")
+
+
+def compare_bm25_vs_html_tree():
+    print("\n\n" + "=" * 60)
+    print("BM25 Plain Text vs BM25 HTML Tree Comparison")
+    print("=" * 60)
+
+    documents = [
+        Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
+        for doc in SAMPLE_BIDDING_DOCS
+    ]
+
+    bm25_rag = BM25RAG()
+    bm25_rag.index_documents(documents)
+
+    tree_rag = BM25HTMLTreeRAG()
+    tree_rag.build_index(SAMPLE_HTML)
+
+    test_queries = [
+        ("预算金额", ["预算", "Budget"]),
+        ("质保期", ["质保", "Warranty"]),
+        ("评标方法", ["评标", "Evaluation"]),
+    ]
+
+    print("\n{:<15} | {:>15} | {:>15}".format("Query", "BM25 Docs", "HTML Tree Docs"))
+    print("-" * 60)
+
+    bm25_total_precision = 0
+    tree_total_precision = 0
+    num_queries = len(test_queries)
+
+    for query, keywords in test_queries:
+        t0 = time.time()
+        bm25_results = bm25_rag.retrieve(query, k=3)
+        bm25_time = time.time() - t0
+
+        # 计算BM25的precision
+        bm25_relevant = sum(1 for doc, _ in bm25_results if any(kw.lower() in doc.page_content.lower() for kw in keywords))
+        bm25_precision = bm25_relevant / len(bm25_results) if bm25_results else 0
+
+        t1 = time.time()
+        tree_results = tree_rag.query(query, k=3)
+        tree_time = time.time() - t1
+
+        # 计算HTML Tree的precision - 考虑父节点上下文
+        def is_relevant_with_context(doc, keywords):
+            """Check relevance considering parent context from path."""
+            content = doc.page_content.lower()
+            path = doc.metadata.get("path", "").lower()
+            
+            # Check content
+            if any(kw.lower() in content for kw in keywords):
+                return True
+            
+            # Check path (which includes parent nodes)
+            if any(kw.lower() in path for kw in keywords):
+                return True
+            
+            return False
+        
+        tree_relevant = sum(1 for doc, _ in tree_results if is_relevant_with_context(doc, keywords))
+        tree_precision = tree_relevant / len(tree_results) if tree_results else 0
+
+        print("{:<15} | {:>10} ({:.4f}s, P@3:{:.1%}) | {:>10} ({:.4f}s, P@3:{:.1%})".format(
+            query, len(bm25_results), bm25_time, bm25_precision, 
+            len(tree_results), tree_time, tree_precision))
+        
+        bm25_total_precision += bm25_precision
+        tree_total_precision += tree_precision
+
+    print("-" * 60)
+    print("{:<15} | {:>15} | {:>15}".format(
+        "Average", 
+        "P@3: {:.1%}".format(bm25_total_precision / num_queries),
+        "P@3: {:.1%}".format(tree_total_precision / num_queries)))
+    print("\nComparison complete!")
+
+
+if __name__ == "__main__":
+    test_bm25_retrieval()
+    test_bm25_html_tree()
+    # 如果需要对比,取消下面的注释
+    compare_bm25_vs_html_tree()

+ 12 - 0
examples/test_bm25_html_tree_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""BM25HTMLTreeRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.bm25_html_tree_rag import BM25HTMLTreeRAG
+from examples.rag_test_utils import run_html_tree_test
+
+
+if __name__ == "__main__":
+    run_html_tree_test(BM25HTMLTreeRAG)

+ 12 - 0
examples/test_contextual_compression_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""ContextualCompressionRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.contextual_compression_rag import ContextualCompressionRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("ContextualCompressionRAG", build_vector_rag(ContextualCompressionRAG))

+ 12 - 0
examples/test_corrective_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""CorrectiveRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.corrective_rag import CorrectiveRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("CorrectiveRAG", build_vector_rag(CorrectiveRAG))

+ 12 - 0
examples/test_ensemble_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""EnsembleRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.ensemble_rag import EnsembleRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("EnsembleRAG", build_vector_rag(EnsembleRAG))

+ 12 - 0
examples/test_flare_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""FLARERAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.flare_rag import FLARERAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("FLARERAG", build_vector_rag(FLARERAG, max_iterations=2))

+ 12 - 0
examples/test_graph_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""GraphRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.graph_rag import GraphRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("GraphRAG", build_vector_rag(GraphRAG))

+ 15 - 0
examples/test_hybrid_search_rag.py

@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+"""HybridSearchRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from examples.rag_test_utils import install_rank_bm25_fallback
+install_rank_bm25_fallback()
+
+from bdirag.rag_methods.hybrid_search_rag import HybridSearchRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("HybridSearchRAG", build_vector_rag(HybridSearchRAG))

+ 12 - 0
examples/test_hyde_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""HyDERAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.hyde_rag import HyDERAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("HyDERAG", build_vector_rag(HyDERAG))

+ 22 - 0
examples/test_keyword_rag.py

@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+"""KeywordRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from examples.rag_test_utils import (
+    FakeLLMClient,
+    install_rank_bm25_fallback,
+    install_sklearn_fallback,
+    run_retrieval_test,
+)
+install_rank_bm25_fallback()
+install_sklearn_fallback()
+
+from bdirag.rag_methods.keyword_rag import KeywordRAG
+
+
+if __name__ == "__main__":
+    run_retrieval_test("KeywordRAG-BM25", KeywordRAG(search_method="bm25", llm_client=FakeLLMClient()))
+    print("\n")
+    run_retrieval_test("KeywordRAG-TFIDF", KeywordRAG(search_method="tfidf", llm_client=FakeLLMClient()))

+ 12 - 0
examples/test_llm_filter_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""LLMFilterRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.llm_filter_rag import LLMFilterRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("LLMFilterRAG", build_vector_rag(LLMFilterRAG, filter_threshold=0.3))

+ 12 - 0
examples/test_metadata_filter_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""MetadataFilterRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.metadata_filter_rag import MetadataFilterRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("MetadataFilterRAG", build_vector_rag(MetadataFilterRAG))

+ 213 - 0
examples/test_methods_direct.py

@@ -0,0 +1,213 @@
+# -*- coding: utf-8 -*-
+"""
+直接测试各种 RAG 方法(不通过 BidiRag 封装)
+"""
+
+import sys
+import os
+import time
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.document_processor import Document
+from bdirag.rag_methods.bm25_rag import BM25RAG
+from bdirag.rag_methods.tfidf_rag import TFIDFRAG
+from bdirag.rag_methods.keyword_rag import KeywordRAG
+from bdirag.rag_methods.bm25_html_tree_rag import BM25HTMLTreeRAG
+
+
+TEST_DOCS = [
+    Document(page_content="""XX市第一人民医院医疗设备招标公告
+
+项目名称:XX市第一人民医院彩色多普勒超声诊断仪采购项目
+项目编号:XX-ZB-2024-001
+预算金额:500万元
+采购内容:彩色多普勒超声诊断仪 1台
+
+投标人资格要求:
+1. 具有独立承担民事责任的能力
+2. 具有有效的医疗器械经营许可证
+3. 近三年内无不良经营记录
+4. 投标保证金:人民币5万元整
+
+技术需求:
+1. 彩色多普勒超声诊断仪技术参数
+   - 探头配置:腹部凸阵探头、高频线阵探头、心脏相控阵探头
+   - 显示屏:≥19英寸高清液晶显示器
+   - 质保期:整机质保三年
+2. 交货时间:合同签订后60天内交货
+3. 交货地点:XX市第一人民医院设备科
+
+评标方法:采用综合评分法
+   - 技术部分:60分
+   - 商务部分:30分
+   - 价格部分:10分
+
+付款方式:合同签订后支付30%,交货验收合格后支付65%,质保期满后支付5%"""),
+
+    Document(page_content="""XX市智慧交通系统建设项目招标公告
+
+项目名称:XX市智慧交通系统建设项目
+项目编号:XX-ZB-2024-002
+招标人:XX市交通运输局
+预算金额:5000万元
+
+项目内容:
+1. 交通信号控制系统
+2. 视频监控系统
+3. 交通流量监测系统
+4. 数据分析平台
+
+资质要求:
+1. 电子与智能化工程专业承包二级以上资质
+2. 近三年至少完成2个类似项目业绩
+
+评标方法:综合评分法
+   - 技术部分:60分
+   - 商务部分:40分
+
+交货时间:合同签订后180天内
+质保期:3年""")
+]
+
+
+def test_bm25():
+    """测试 BM25"""
+    print("\n" + "=" * 80)
+    print("测试 BM25")
+    print("=" * 80)
+    
+    rag = BM25RAG()
+    rag.index_documents(TEST_DOCS)
+    
+    t0 = time.time()
+    results = rag.retrieve("预算金额", k=3)
+    retrieve_time = time.time() - t0
+    
+    relevant = sum(1 for doc, _ in results if "预算" in doc.page_content)
+    precision = relevant / len(results) if results else 0
+    
+    print("召回数量: {}".format(len(results)))
+    print("检索时间: {:.4f}s".format(retrieve_time))
+    print("精确度: {:.1%}".format(precision))
+    
+    if results:
+        print("\n结果预览:")
+        for i, (doc, score) in enumerate(results[:2], 1):
+            preview = doc.page_content[:80].replace("\n", " ")
+            print("  [{}] Score={:.4f} | {}".format(i, score, preview))
+
+
+def test_tfidf():
+    """测试 TF-IDF"""
+    print("\n" + "=" * 80)
+    print("测试 TF-IDF")
+    print("=" * 80)
+    
+    rag = TFIDFRAG()
+    rag.index_documents(TEST_DOCS)
+    
+    t0 = time.time()
+    results = rag.retrieve("预算金额", k=3)
+    retrieve_time = time.time() - t0
+    
+    relevant = sum(1 for doc, _ in results if "预算" in doc.page_content)
+    precision = relevant / len(results) if results else 0
+    
+    print("召回数量: {}".format(len(results)))
+    print("检索时间: {:.4f}s".format(retrieve_time))
+    print("精确度: {:.1%}".format(precision))
+    
+    if results:
+        print("\n结果预览:")
+        for i, (doc, score) in enumerate(results[:2], 1):
+            preview = doc.page_content[:80].replace("\n", " ")
+            print("  [{}] Score={:.4f} | {}".format(i, score, preview))
+
+
+def test_keyword():
+    """测试 Keyword"""
+    print("\n" + "=" * 80)
+    print("测试 Keyword")
+    print("=" * 80)
+    
+    rag = KeywordRAG()
+    rag.index_documents(TEST_DOCS)
+    
+    t0 = time.time()
+    results = rag.retrieve("预算", k=3)
+    retrieve_time = time.time() - t0
+    
+    relevant = sum(1 for doc, _ in results if "预算" in doc.page_content)
+    precision = relevant / len(results) if results else 0
+    
+    print("召回数量: {}".format(len(results)))
+    print("检索时间: {:.4f}s".format(retrieve_time))
+    print("精确度: {:.1%}".format(precision))
+    
+    if results:
+        print("\n结果预览:")
+        for i, (doc, score) in enumerate(results[:2], 1):
+            preview = doc.page_content[:80].replace("\n", " ")
+            print("  [{}] Score={:.4f} | {}".format(i, score, preview))
+
+
+def test_bm25_html_tree():
+    """测试 BM25 HTML Tree"""
+    print("\n" + "=" * 80)
+    print("测试 BM25 HTML Tree")
+    print("=" * 80)
+    
+    html_content = "\n".join([doc.page_content for doc in TEST_DOCS])
+    
+    rag = BM25HTMLTreeRAG()
+    rag.build_index(html_content)
+    
+    t0 = time.time()
+    results = rag.query("预算金额", k=3)
+    retrieve_time = time.time() - t0
+    
+    relevant = sum(1 for doc, _ in results if "预算" in doc.page_content)
+    precision = relevant / len(results) if results else 0
+    
+    print("召回数量: {}".format(len(results)))
+    print("检索时间: {:.4f}s".format(retrieve_time))
+    print("精确度: {:.1%}".format(precision))
+    
+    if results:
+        print("\n结果预览:")
+        for i, (doc, score) in enumerate(results[:2], 1):
+            path = doc.metadata.get("path", "")
+            preview = doc.page_content[:80].replace("\n", " ")
+            print("  [{}] Score={:.4f} | Path: {}".format(i, score, path[:60]))
+            print("      Content: {}...".format(preview))
+
+
+if __name__ == "__main__":
+    print("=" * 80)
+    print("RAG 方法直接测试")
+    print("=" * 80)
+    
+    try:
+        test_bm25()
+    except Exception as e:
+        print("BM25 失败: " + str(e))
+    
+    try:
+        test_tfidf()
+    except Exception as e:
+        print("TF-IDF 失败: " + str(e))
+    
+    try:
+        test_keyword()
+    except Exception as e:
+        print("Keyword 失败: " + str(e))
+    
+    try:
+        test_bm25_html_tree()
+    except Exception as e:
+        print("BM25 HTML Tree 失败: " + str(e))
+    
+    print("\n\n" + "=" * 80)
+    print("测试完成!")
+    print("=" * 80)

+ 12 - 0
examples/test_multi_query_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""MultiQueryRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.multi_query_rag import MultiQueryRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("MultiQueryRAG", build_vector_rag(MultiQueryRAG, num_queries=3))

+ 12 - 0
examples/test_naive_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""NaiveRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.naive_rag import NaiveRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("NaiveRAG", build_vector_rag(NaiveRAG))

+ 12 - 0
examples/test_parent_document_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""ParentDocumentRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.parent_document_rag import ParentDocumentRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("ParentDocumentRAG", build_vector_rag(ParentDocumentRAG, parent_chunk_size=120))

+ 12 - 0
examples/test_query_routing_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""QueryRoutingRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.query_routing_rag import QueryRoutingRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("QueryRoutingRAG", build_vector_rag(QueryRoutingRAG))

+ 19 - 0
examples/test_raptor_rag.py

@@ -0,0 +1,19 @@
+# -*- coding: utf-8 -*-
+"""RAPTORRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.raptor_rag import RAPTORRAG
+from examples.rag_test_utils import build_vector_rag, install_sklearn_fallback, run_retrieval_test
+
+install_sklearn_fallback()
+
+
+def build_tree(rag, documents):
+    rag.build_tree(documents)
+
+
+if __name__ == "__main__":
+    rag = build_vector_rag(RAPTORRAG, max_tree_depth=2, cluster_size=2)
+    run_retrieval_test("RAPTORRAG", rag, index_func=build_tree)

+ 13 - 0
examples/test_rerank_rag.py

@@ -0,0 +1,13 @@
+# -*- coding: utf-8 -*-
+"""RerankRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.rerank_rag import RerankRAG
+from examples.rag_test_utils import FakeRerankModel, build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    rag = build_vector_rag(RerankRAG, rerank_model=FakeRerankModel(), rerank_top_k=5)
+    run_retrieval_test("RerankRAG", rag)

+ 90 - 0
examples/test_retrieval_dedup.py

@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+"""Focused tests for content-level retrieval deduplication."""
+import os
+import sys
+
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+if ROOT_DIR not in sys.path:
+    sys.path.insert(0, ROOT_DIR)
+
+from bdirag.document_processor import Document
+from bdirag.rag_methods.bm25_html_tree_rag import BM25HTMLTreeRAG
+from bdirag.rag_methods.bm25_rag import BM25RAG
+from bdirag.rag_methods.dedup import deduplicate_ranked_results
+from bdirag.rag_methods.ensemble_rag import EnsembleRAG
+from examples.rag_test_utils import FakeEmbedding, SimpleVectorStore, install_rank_bm25_fallback
+
+
+def test_deduplicate_ranked_results_keeps_highest_score():
+    low = Document(page_content=" Duplicate   content ", metadata={"source": "low"})
+    high = Document(page_content="Duplicate content", metadata={"source": "high"})
+    other = Document(page_content="Other content", metadata={"source": "other"})
+
+    results = deduplicate_ranked_results([(low, 0.1), (other, 0.2), (high, 0.9)], k=10)
+
+    assert len(results) == 2
+    assert results[0][0].metadata["source"] == "high"
+    assert results[0][1] == 0.9
+
+
+def test_deduplicate_ranked_results_keeps_first_on_score_tie():
+    first = Document(page_content="Same content", metadata={"source": "first"})
+    second = Document(page_content=" Same   content ", metadata={"source": "second"})
+
+    results = deduplicate_ranked_results([(first, 0.5), (second, 0.5)], k=10)
+
+    assert len(results) == 1
+    assert results[0][0].metadata["source"] == "first"
+
+
+def test_bm25_retrieve_deduplicates_equal_content_documents():
+    install_rank_bm25_fallback()
+    rag = BM25RAG()
+    docs = [
+        Document(page_content="alpha beta project budget", metadata={"source": "a"}),
+        Document(page_content="alpha beta project budget", metadata={"source": "b"}),
+        Document(page_content="alpha delivery schedule", metadata={"source": "c"}),
+        Document(page_content="gamma warranty terms", metadata={"source": "d"}),
+        Document(page_content="delta payment terms", metadata={"source": "e"}),
+    ]
+    rag.index_documents(docs)
+
+    results = rag.retrieve("alpha beta", k=3)
+
+    contents = [doc.page_content for doc, _ in results]
+    assert contents.count("alpha beta project budget") == 1
+    assert len(contents) == len(set(contents))
+
+
+def test_ensemble_retrieve_merges_duplicate_content_from_distinct_objects():
+    docs = [
+        Document(page_content="alpha beta project budget", metadata={"source": "a"}),
+        Document(page_content="alpha beta project budget", metadata={"source": "b"}),
+        Document(page_content="alpha delivery schedule", metadata={"source": "c"}),
+    ]
+    rag = EnsembleRAG(embedding_model=FakeEmbedding(), vector_store=SimpleVectorStore())
+    embeddings = rag.embedding_model.embed_documents([doc.page_content for doc in docs])
+    rag.vector_store.add_documents(docs, embeddings)
+
+    results = rag.retrieve("alpha beta", k=3)
+
+    contents = [doc.page_content for doc, _ in results]
+    assert contents.count("alpha beta project budget") == 1
+    assert len(contents) == len(set(contents))
+
+
+def test_html_tree_query_deduplicates_formatted_documents():
+    rag = BM25HTMLTreeRAG()
+    node_a = {"type": "p", "sentence_title_text": "A"}
+    node_b = {"type": "p", "sentence_title_text": "B"}
+    rag.retrieve_subtrees = lambda query, k: [
+        (node_a, 0.7, "Repeated subtree text"),
+        (node_b, 0.9, " Repeated   subtree   text "),
+    ]
+    rag.get_node_path = lambda node: node["sentence_title_text"]
+
+    results = rag.query("repeated", k=5)
+
+    assert len(results) == 1
+    assert results[0][0].metadata["title"] == "B"
+    assert results[0][1] == 0.9

+ 12 - 0
examples/test_self_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""SelfRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.self_rag import SelfRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("SelfRAG", build_vector_rag(SelfRAG, reflection_threshold=0.3))

+ 12 - 0
examples/test_step_back_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""StepBackRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.step_back_rag import StepBackRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("StepBackRAG", build_vector_rag(StepBackRAG))

+ 12 - 0
examples/test_table_aware_rag.py

@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""TableAwareRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.table_aware_rag import TableAwareRAG
+from examples.rag_test_utils import build_vector_rag, run_retrieval_test
+
+
+if __name__ == "__main__":
+    run_retrieval_test("TableAwareRAG", build_vector_rag(TableAwareRAG))

+ 14 - 0
examples/test_tfidf_rag.py

@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+"""TFIDFRAG standalone test."""
+import os
+import sys
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from bdirag.rag_methods.tfidf_rag import TFIDFRAG
+from examples.rag_test_utils import FakeLLMClient, install_sklearn_fallback, run_retrieval_test
+
+install_sklearn_fallback()
+
+
+if __name__ == "__main__":
+    run_retrieval_test("TFIDFRAG", TFIDFRAG(llm_client=FakeLLMClient()))

+ 1 - 0
examples/untitled-1.py

@@ -0,0 +1 @@
+import lxml

+ 90 - 0
fix_fstrings.py

@@ -0,0 +1,90 @@
+"""
+Script to convert f-strings to .format() for Python 3.5 compatibility
+"""
+import re
+import os
+
+files_to_fix = [
+    "bdirag/document_processor.py",
+    "bdirag/embedding_models.py",
+    "bdirag/vector_stores.py",
+    "bdirag/rag_methods.py",
+    "bdirag/benchmark.py",
+    "bdirag/config.py",
+    "examples/test_bm25.py",
+    "examples/benchmark_all_methods.py",
+    "examples/quick_demo.py",
+    "examples/benchmark_retrieval_speed.py",
+    "examples/bid_field_extraction_demo.py",
+]
+
+
+def convert_fstring_in_line(line):
+    # Find f"..." or f'...' patterns on a single line
+    # Handle double quotes
+    pattern_dq = r'f"([^"]*?\{[^}]*\}[^"]*?)"'
+    pattern_sq = r"f'([^']*?\{[^}]*\}[^']*?)'"
+
+    for pattern, quote in [(pattern_dq, '"'), (pattern_sq, "'")]:
+        while True:
+            match = re.search(pattern, line)
+            if not match:
+                break
+
+            inner = match.group(1)
+            values = []
+            new_str = ""
+            i = 0
+            while i < len(inner):
+                if inner[i] == "{" and i + 1 < len(inner) and inner[i + 1] != "{":
+                    j = inner.index("}", i)
+                    expr = inner[i + 1:j].strip()
+                    values.append(expr)
+                    new_str += "{" + str(len(values) - 1) + "}"
+                    i = j + 1
+                elif inner[i:i + 2] == "{{":
+                    new_str += "{{"
+                    i += 2
+                elif inner[i:i + 2] == "}}":
+                    new_str += "}}"
+                    i += 2
+                else:
+                    new_str += inner[i]
+                    i += 1
+
+            format_call = ".format(" + ", ".join(values) + ")"
+            replacement = new_str + format_call
+
+            line = line[:match.start()] + quote + replacement + quote + line[match.end():]
+
+    return line
+
+
+for filepath in files_to_fix:
+    if not os.path.exists(filepath):
+        print("Not found: " + filepath)
+        continue
+
+    with open(filepath, "r", encoding="utf-8") as f:
+        content = f.read()
+
+    lines = content.split("\n")
+    new_lines = []
+    changed = False
+
+    for line in lines:
+        if "f'" in line or 'f"' in line:
+            new_line = convert_fstring_in_line(line)
+            if new_line != line:
+                changed = True
+            new_lines.append(new_line)
+        else:
+            new_lines.append(line)
+
+    if changed:
+        new_content = "\n".join(new_lines)
+        with open(filepath, "w", encoding="utf-8") as f:
+            f.write(new_content)
+        print("Fixed: " + filepath)
+    else:
+        print("No changes: " + filepath)

+ 0 - 0
parser/__init__.py


+ 1239 - 0
parser/htmlparser.py

@@ -0,0 +1,1239 @@
+#coding:utf8
+
+import re
+
+import logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+import Levenshtein
+
+from bs4 import BeautifulSoup
+import copy
+
+end_pattern = "商务要求|评分标准|商务条件|商务条件"
+_param_pattern = "(产品|技术|清单|配置|参数|具体|明细|项目|招标|货物|服务|规格|工作|具体)[及和与]?(指标|配置|条件|要求|参数|需求|规格|条款|名称及要求)|配置清单|(质量|技术).{,10}要求|验收标准|^(参数|功能)$"
+meter_pattern = "[><≤≥±]\d+|\d+(?:[μucmkK微毫千]?[米升LlgGmMΩ]|摄氏度|英寸|度|天|VA|dB|bpm|rpm|kPa|mol|cmH20|%|°|Mpa|Hz|K?HZ|℃|W|min|[*×xX])|[*×xX]\d+|/min|\ds[^a-zA-Z]|GB.{,20}标准|PVC|PP|角度|容积|色彩|自动|流量|外径|轴位|折射率|帧率|柱镜|振幅|磁场|镜片|防漏|强度|允差|心率|倍数|瞳距|底座|色泽|噪音|间距|材质|材料|表面|频率|阻抗|浓度|兼容|防尘|防水|内径|实时|一次性|误差|性能|距离|精确|温度|超温|范围|跟踪|对比度|亮度|[横纵]向|均压|负压|正压|可调|设定值|功能|检测|高度|厚度|宽度|深度|[单双多]通道|效果|指数|模式|尺寸|重量|峰值|谷值|容量|寿命|稳定性|高温|信号|电源|电流|转换率|效率|释放量|转速|离心力|向心力|弯曲|电压|功率|气量|国标|标准协议|灵敏度|最大值|最小值|耐磨|波形|高压|性强|工艺|光源|低压|压力|压强|速度|湿度|重量|毛重|[MLX大中小]+码|净重|颜色|[红橙黄绿青蓝紫]色|不锈钢|输入|输出|噪声|认证|配置"
+not_meter_pattern = "投标报价|中标金额|商务部分|公章|分值构成|业绩|详见|联系人|联系电话|合同价|金额|采购预算|资金来源|费用|质疑|评审因素|评审标准|商务资信|商务评分|专家论证意见|评标方法|代理服务费|售后服务|评分类型|评分项目|预算金额|得\d+分|项目金额|详见招标文件|乙方"
+
+def judge_pur_chinese(keyword):
+    """
+    中文字符的编码范围为: u'\u4e00' -- u'\u9fff:只要在此范围内就可以判断为中文字符串
+    @param keyword:
+    @return:
+    """
+    # 定义一个需要删除的标点符号字符串列表
+    remove_chars = '[·’!"\#$%&\'()#!()*+,-./:;<=>?\@,:?¥★、….>【】[]《》?“”‘’\[\\]^_`{|}~]+'
+    # 利用re.sub来删除中文字符串中的标点符号
+    strings = re.sub(remove_chars, "", keyword)  # 将keyword中文字符串中remove_chars中包含的标点符号替换为空字符串
+    for ch in strings:
+        if u'\u4e00' <= ch <= u'\u9fff':
+            pass
+        else:
+            return False
+    return True
+
+
+def jaccard_score(source,target):
+    source_set = set([s for s in source])
+    target_set = set([s for s in target])
+    if len(source_set)==0 or len(target_set)==0:
+        return 0
+    return max(len(source_set&target_set)/len(source_set),len(source_set&target_set)/len(target_set))
+
+
+def is_similar(source,target,_radio=None):
+    source = str(source).lower()
+    target = str(target).lower()
+    max_len = max(len(source),len(target))
+    min_len = min(len(source),len(target))
+
+    min_ratio = 90
+    if min_len>=3:
+        min_ratio = 87
+    if min_len>=5:
+        min_ratio = 85
+    if _radio is not None:
+        min_ratio = _radio
+    # dis_len = abs(len(source)-len(target))
+    # min_dis = min(max_len*0.2,4)
+    if min_len==0 and max_len>0:
+        return False
+    if max_len<=2:
+        if source==target:
+            return True
+    if min_len<2:
+        return False
+    #判断相似度
+    similar = Levenshtein.ratio(source,target)*100
+    if similar>=min_ratio:
+        # log("%s and %s similar_jaro %d"%(source,target,similar))
+        return True
+    similar_jaro = Levenshtein.jaro(source,target)
+    if similar_jaro*100>=min_ratio:
+        # log("%s and %s similar_jaro %d"%(source,target,similar_jaro*100))
+        return True
+    similar_jarow = Levenshtein.jaro_winkler(source,target)
+    if similar_jarow*100>=min_ratio:
+        # log("%s and %s similar_jaro %d"%(source,target,similar_jarow*100))
+        return True
+
+    if min_len>=5:
+        if len(source)==max_len and str(source).find(target)>=0:
+            return True
+        elif len(target)==max_len and target.find(source)>=0:
+            return True
+        elif jaccard_score(source, target)==1 and judge_pur_chinese(source) and judge_pur_chinese(target):
+            return True
+    return False
+
+def getTrs(tbody):
+    #获取所有的tr
+    trs = []
+    if tbody.name=="table":
+        body = tbody.find("tbody",recursive=False)
+        if body is not None:
+            tbody = body
+    objs = tbody.find_all(recursive=False)
+    for obj in objs:
+        if obj.name=="tr":
+            trs.append(obj)
+        if obj.name=="tbody" or obj.name=="table":
+            for tr in obj.find_all("tr",recursive=False):
+                trs.append(tr)
+    return trs
+
+def fixSpan(tbody):
+    # 处理colspan, rowspan信息补全问题
+    #trs = tbody.findChildren('tr', recursive=False)
+
+    trs = getTrs(tbody)
+    ths_len = 0
+    ths = list()
+    trs_set = set()
+    #修改为先进行列补全再进行行补全,否则可能会出现表格解析混乱
+    # 遍历每一个tr
+
+    for indtr, tr in enumerate(trs):
+        ths_tmp = tr.findChildren('th', recursive=False)
+        #不补全含有表格的tr
+        if len(tr.findChildren('table'))>0:
+            continue
+        if len(ths_tmp) > 0:
+            ths_len = ths_len + len(ths_tmp)
+            for th in ths_tmp:
+                ths.append(th)
+            trs_set.add(tr)
+        # 遍历每行中的element
+        tds = tr.findChildren(recursive=False)
+        for indtd, td in enumerate(tds):
+            # 若有colspan 则补全同一行下一个位置
+            if 'colspan' in td.attrs:
+                if str(re.sub("[^0-9]","",str(td['colspan'])))!="":
+                    col = int(re.sub("[^0-9]","",str(td['colspan'])))
+                    if col<100 and len(td.get_text())<1000:
+                        td['colspan'] = 1
+                        for i in range(1, col, 1):
+                            td.insert_after(copy.copy(td))
+
+    for indtr, tr in enumerate(trs):
+        ths_tmp = tr.findChildren('th', recursive=False)
+        #不补全含有表格的tr
+        if len(tr.findChildren('table'))>0:
+            continue
+        if len(ths_tmp) > 0:
+            ths_len = ths_len + len(ths_tmp)
+            for th in ths_tmp:
+                ths.append(th)
+            trs_set.add(tr)
+        # 遍历每行中的element
+        tds = tr.findChildren(recursive=False)
+        for indtd, td in enumerate(tds):
+            # 若有rowspan 则补全下一行同样位置
+            if 'rowspan' in td.attrs:
+                if str(re.sub("[^0-9]","",str(td['rowspan'])))!="":
+                    row = int(re.sub("[^0-9]","",str(td['rowspan'])))
+                    td['rowspan'] = 1
+                    for i in range(1, row, 1):
+                        # 获取下一行的所有td, 在对应的位置插入
+                        if indtr+i<len(trs):
+                            tds1 = trs[indtr + i].findChildren(['td','th'], recursive=False)
+                            if len(tds1) >= (indtd) and len(tds1)>0:
+                                if indtd > 0:
+                                    tds1[indtd - 1].insert_after(copy.copy(td))
+                                else:
+                                    tds1[0].insert_before(copy.copy(td))
+                            elif indtd-2>0 and len(tds1) > 0 and len(tds1) == indtd - 1:  # 修正某些表格最后一列没补全
+                                tds1[indtd-2].insert_after(copy.copy(td))
+def getTable(tbody):
+    #trs = tbody.findChildren('tr', recursive=False)
+    fixSpan(tbody)
+    trs = getTrs(tbody)
+    inner_table = []
+    for tr in trs:
+        tr_line = []
+        tds = tr.findChildren(['td','th'], recursive=False)
+        if len(tds)==0:
+            tr_line.append([re.sub('\xa0','',tr.get_text()),0]) # 2021/12/21 修复部分表格没有td 造成数据丢失
+        for td in tds:
+            tr_line.append([re.sub('\xa0','',td.get_text()),0])
+            #tr_line.append([td.get_text(),0])
+        inner_table.append(tr_line)
+    return inner_table
+
+class ParseDocument():
+
+    def __init__(self,_html,auto_merge_table=True):
+        if _html is None:
+            _html = ""
+        self.html = _html
+
+
+        # self.soup = BeautifulSoup(self.html,"html.parser")
+        self.auto_merge_table = auto_merge_table
+
+        self.soup = BeautifulSoup(self.html,"lxml")
+        # self.soup = BeautifulSoup(self.html,"html5lib")
+        _body = self.soup.find("body")
+        if _body is not None:
+            self.soup = _body
+        self.list_obj = self.get_soup_objs(self.soup)
+
+        # for obj in self.list_obj:
+        #     print("obj",obj.get_text()[:20])
+
+        self.tree = self.buildParsetree(self.list_obj,[],auto_merge_table)
+
+
+        # #识别目录树
+        # if self.parseTree:
+        #     self.parseTree.printParseTree()
+        # self.print_tree(self.tree,"-|")
+
+    def get_soup_objs(self,soup,list_obj=None):
+        if list_obj is None:
+            list_obj = []
+        childs = soup.find_all(recursive=False)
+        for _obj in childs:
+            childs1 = _obj.find_all(recursive=False)
+            # Optimization: Keep more granular nodes for better retrieval
+            # Changed threshold from 40 to 30 to capture more meaningful chunks
+            if len(childs1)==0 or len(_obj.get_text())<30 or _obj.name=="table":
+                list_obj.append(_obj)
+            else:
+                self.get_soup_objs(_obj,list_obj)
+        return list_obj
+
+    def fix_tree(self,_product):
+        products = extract_products(self.tree,_product)
+        if len(products)>0:
+            self.tree = self.buildParsetree(self.list_obj,products,self.auto_merge_table)
+
+    def print_tree(self,tree,append=""):
+        self.set_tree_id = set()
+        if append=="":
+            for t in tree:
+                logger.debug("%s text:%s title:%s title_text:%s before:%s after%s product:%s"%("==>",t["text"][:50],t["sentence_title"],t["sentence_title_text"],t["title_before"],t["title_after"],t["has_product"]))
+
+        for t in tree:
+            _id = id(t)
+            if _id in self.set_tree_id:
+                continue
+            self.set_tree_id.add(_id)
+            logger.info("%s text:%s title:%s title_text:%s before:%s after%s product:%s"%(append,t["text"][:50],t["sentence_title"],t["sentence_title_text"],t["title_before"],t["title_after"],t["has_product"]))
+            childs = t["child_title"]
+            self.print_tree(childs,append=append+"-|")
+
+    def is_title_first(self,title):
+        if title in ("一","1","Ⅰ","a","A"):
+            return True
+        return False
+
+    def find_title_by_pattern(self,_text,_pattern="(^|★|▲|:|:|\s+)(?P<title_1>(?P<title_1_index_0_0>第?)(?P<title_1_index_1_1>[一二三四五六七八九十ⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]+)(?P<title_1_index_2_0>[、章册包标部.::、、]+))|" \
+                                             "([\s★▲\*]*)(?P<title_3>(?P<title_3_index_0_0>[^一二三四五六七八九十\dⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]{,3}?)(?P<title_3_index_0_1>[ⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]+)(?P<title_3_index_0_2>[、章册包标部.::、、]+))|" \
+                                             "([\s★▲\*]*)(?P<title_4>(?P<title_4_index_0_0>[^一二三四五六七八九十\dⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]{,3}?第?)(?P<title_4_index_1_1>[一二三四五六七八九十]+)(?P<title_4_index_2_0>[节章册部\.::、、]+))|" \
+                                             "([\s★▲\*]*)(?P<title_5>(?P<title_5_index_0_0>^)(?P<title_5_index_1_1>[一二三四五六七八九十]+)(?P<title_5_index_2_0>)[^一二三四五六七八九十节章册部\.::、、])|" \
+                                             "([\s★▲\*]*)(?P<title_12>(?P<title_12_index_0_0>[^一二三四五六七八九十\dⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]{,3}?\d{1,2}[\..、\s\-]\d{1,2}[\..、\s\-]\d{1,2}[\..、\s\-]\d{1,2}[\..、\s\-])(?P<title_12_index_1_1>\d{1,2})(?P<title_12_index_2_0>[\..、\s\-]?))|"\
+                                             "([\s★▲\*]*)(?P<title_11>(?P<title_11_index_0_0>[^一二三四五六七八九十\dⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]{,3}?\d{1,2}[\..、\s\-]\d{1,2}[\..、\s\-]\d{1,2}[\..、\s\-])(?P<title_11_index_1_1>\d{1,2})(?P<title_11_index_2_0>[\..、\s\-]?))|" \
+                                             "([\s★▲\*]*)(?P<title_10>(?P<title_10_index_0_0>[^一二三四五六七八九十\dⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]{,3}?\d{1,2}[\..、\s\-]\d{1,2}[\..、\s\-])(?P<title_10_index_1_1>\d{1,2})(?P<title_10_index_2_0>[\..、\s\-]?))|" \
+                                             "([\s★▲\*]*)(?P<title_7>(?P<title_7_index_0_0>[^一二三四五六七八九十\dⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]{,3}?\d{1,2}[\..\s\-])(?P<title_7_index_1_1>\d{1,2})(?P<title_7_index_2_0>[\..包标::、\s\-]*))|" \
+                                             "(^ [\s★▲\*]*)(?P<title_6>(?P<title_6_index_0_0>[^一二三四五六七八九十\dⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]{,3}?包?)(?P<title_6_index_0_1>\d{1,2})(?P<title_6_index_2_0>[\..、\s\-包标]*))|" \
+                                             "([\s★▲\*]*)(?P<title_15>(?P<title_15_index_0_0>[^一二三四五六七八九十\dⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]{,3}?[((]?)(?P<title_15_index_1_1>\d{1,2})(?P<title_15_index_2_0>[))包标\..::、]+))|" \
+                                             "([\s★▲\*]*)(?P<title_17>(?P<title_17_index_0_0>[^一二三四五六七八九十\dⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]{,3}?[((]?)(?P<title_17_index_1_1>[a-zA-Z]+)(?P<title_17_index_2_0>[))包标\..::、]+))|" \
+                                             "([\s★▲\*]*)(?P<title_19>(?P<title_19_index_0_0>[^一二三四五六七八九十\dⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]{,3}?[((]?)(?P<title_19_index_1_1>[一二三四五六七八九十ⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]+)(?P<title_19_index_2_0>[))]))"
+                              ):
+        _se = re.search(_pattern,_text)
+        groups = []
+        if _se is not None:
+            _gd = _se.groupdict()
+            for k,v in _gd.items():
+                if v is not None:
+                    groups.append((k,v))
+        if len(groups):
+            groups.sort(key=lambda x:x[0])
+            # Optimization: Filter out false positives - title should be at the beginning or after clear delimiter
+            if groups:
+                first_group = groups[0]
+                title_text = first_group[1]
+                # Check if title appears at reasonable position (beginning or after delimiter)
+                title_pos = _text.find(title_text)
+                if title_pos > 0 and title_pos < 5:
+                    # Title is near the beginning, likely valid
+                    pass
+                elif title_pos == 0:
+                    # Title at the very beginning, definitely valid
+                    pass
+                elif title_pos > 5:
+                    # Title too far from beginning, might be false positive
+                    # Check if there's a clear delimiter before it
+                    before_text = _text[:title_pos]
+                    if not re.search(r'[::;;,,.。\s★▲\*]', before_text[-1:]):
+                        return None
+            return groups
+        return None
+
+    def make_increase(self,_sort,_title,_add=1):
+        if len(_title)==0 and _add==0:
+            return ""
+        if len(_title)==0 and _add==1:
+            return _sort[0]
+        _index = _sort.index(_title[-1])
+        next_index = (_index+_add)%len(_sort)
+        next_chr = _sort[next_index]
+        if _index==len(_sort)-1:
+            _add = 1
+        else:
+            _add = 0
+        return next_chr+self.make_increase(_sort,_title[:-1],_add)
+
+
+    def get_next_title(self,_title):
+        if re.search("^\d+$",_title) is not None:
+            return str(int(_title)+1)
+        if re.search("^[一二三四五六七八九十百]+$",_title) is not None:
+            if _title[-1]=="十":
+                return _title+"一"
+            if _title[-1]=="百":
+                return _title+"零一"
+
+            if _title[-1]=="九":
+                if len(_title)==1:
+                    return "十"
+                if len(_title)==2:
+                    if _title[0]=="十":
+                        return "二十"
+                if len(_title)==3:
+                    if _title[0]=="九":
+                        return "一百"
+                    else:
+                        _next_title = self.make_increase(['一','二','三','四','五','六','七','八','九','十'],re.sub("[十百]",'',_title[0]))
+                        return _next_title+"十"
+
+            _next_title = self.make_increase(['一','二','三','四','五','六','七','八','九','十'],re.sub("[十百]",'',_title))
+            _next_title = list(_next_title)
+            _next_title.reverse()
+            if _next_title[-1]!="十":
+                if len(_next_title)>=2:
+                    _next_title.insert(-1,'十')
+            if len(_next_title)>=4:
+                _next_title.insert(-3,'百')
+            if _title[0]=="十":
+                if _next_title=="十":
+                    _next_title = ["二","十"]
+                _next_title.insert(0,"十")
+            _next_title = "".join(_next_title)
+            return _next_title
+        if re.search("^[a-z]+$",_title) is not None:
+            _next_title = self.make_increase([chr(i+ord('a')) for i in range(26)],_title)
+            _next_title = list(_next_title)
+            _next_title.reverse()
+            return "".join(_next_title)
+        if re.search("^[A-Z]+$",_title) is not None:
+            _next_title = self.make_increase([chr(i+ord('A')) for i in range(26)],_title)
+            _next_title = list(_next_title)
+            _next_title.reverse()
+            return "".join(_next_title)
+        if re.search("^[ⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩⅪⅫ]$",_title) is not None:
+            _sort = ["Ⅰ","Ⅱ","Ⅲ","Ⅳ","Ⅴ","Ⅵ","Ⅶ","Ⅷ","Ⅸ","Ⅹ","Ⅺ","Ⅻ"]
+            _index = _sort.index(_title)
+            if _index<len(_sort)-1:
+                return _sort[_index+1]
+            return None
+
+    def count_title_before(self,list_obj):
+        dict_before = {}
+        dict_sentence_count = {}
+        illegal_sentence = set()
+        for obj_i in range(len(list_obj)):
+            obj = list_obj[obj_i]
+            _type = "sentence"
+            _text = obj.text.strip()
+            if obj.name=="table":
+                _type = "table"
+                _text = str(obj)
+            _append = False
+
+
+            if _type=="sentence":
+                if len(_text)>10 and len(_text)<100:
+                    if _text not in dict_sentence_count:
+                        dict_sentence_count[_text] = 0
+                    dict_sentence_count[_text] += 1
+                    if re.search("\d+页",_text) is not None:
+                        illegal_sentence.add(_text)
+                elif len(_text)<10:
+                    if re.search("第\d+页",_text) is not None:
+                        illegal_sentence.add(_text)
+
+                sentence_groups = self.find_title_by_pattern(_text[:10])
+                if sentence_groups:
+                    # c062f53cf83401e671822003d63c1828print("sentence_groups",sentence_groups)
+                    sentence_title = sentence_groups[0][0]
+                    sentence_title_text = sentence_groups[0][1]
+                    title_index = sentence_groups[-2][1]
+                    title_before = sentence_groups[1][1].replace("(","(").replace(":",":").replace(":",";").replace(",",".").replace(",",".").replace("、",".")
+                    title_after = sentence_groups[-1][1].replace(")",")").replace(":",":").replace(":",";").replace(",",".").replace(",",".").replace("、",".")
+                    next_index = self.get_next_title(title_index)
+                    if title_before not in dict_before:
+                        dict_before[title_before] = 0
+                    dict_before[title_before] += 1
+
+        for k,v in dict_sentence_count.items():
+            if v>10:
+                illegal_sentence.add(k)
+        return dict_before,illegal_sentence
+
+    def is_page_no(self,sentence):
+        if len(sentence)<10:
+            if re.search("\d+页|^\-\d+\-$",sentence) is not None:
+                return True
+
+    def block_tree(self,childs):
+        for child in childs:
+
+            if not child["block"]:
+                child["block"] = True
+                childs2 = child["child_title"]
+                self.block_tree(childs2)
+
+
+    def buildParsetree(self,list_obj,products=[],auto_merge_table=True):
+
+        self.parseTree = None
+        trees = []
+        list_length = []
+        for obj in list_obj[:200]:
+            if obj.name!="table":
+                list_length.append(len(obj.get_text()))
+        if len(list_length)>0:
+            max_length = max(list_length)
+        else:
+            max_length = 40
+        max_length = min(max_length,40)
+
+        # logger.debug("%s:%d"%("max_length",max_length))
+
+
+        list_data = []
+        last_table_index = None
+        last_table_columns = None
+        last_table = None
+        dict_before,illegal_sentence = self.count_title_before(list_obj)
+        for obj_i in range(len(list_obj)):
+            obj = list_obj[obj_i]
+
+            # logger.debug("==obj %s"%obj.text[:20])
+
+            _type = "sentence"
+            _text = standard_product(obj.text)
+            if obj.name=="table":
+                _type = "table"
+                _text = standard_product(str(obj))
+            _append = False
+            sentence_title = None
+            sentence_title_text = None
+            sentence_groups = None
+            title_index = None
+            next_index = None
+            parent_title = None
+            title_before = None
+            title_after = None
+            title_next = None
+            childs = []
+
+            list_table = None
+            block = False
+
+            has_product = False
+
+            if _type=="sentence":
+                if _text in illegal_sentence:
+                    continue
+
+
+                sentence_groups = self.find_title_by_pattern(_text[:10])
+                if sentence_groups:
+                    title_before = standard_title_context(sentence_groups[1][1])
+                    title_after = sentence_groups[-1][1]
+                    sentence_title_text = sentence_groups[0][1]
+                    other_text = _text.replace(sentence_title_text,"")
+                    if (title_before in dict_before and dict_before[title_before]>1) or title_after!="":
+                        sentence_title = sentence_groups[0][0]
+
+                        title_index = sentence_groups[-2][1]
+                        next_index = self.get_next_title(title_index)
+
+                        other_text = _text.replace(sentence_title_text,"")
+
+                        for p in products:
+                            if other_text.strip()==p.strip():
+                                has_product = True
+
+                    else:
+                        _fix = False
+
+                        for p in products:
+                            if other_text.strip()==p.strip():
+                                title_before = "=产品"
+                                sentence_title = "title_0"
+                                sentence_title_text = p
+                                title_index = "0"
+                                title_after = "产品="
+                                next_index = "0"
+                                _fix = True
+                                has_product = True
+                                break
+                        if not _fix:
+                            title_before = None
+                            title_after = None
+                            sentence_title_text = None
+                else:
+                    if len(_text)<40 and re.search(_param_pattern,_text) is not None:
+                        for p in products:
+                            if _text.find(p)>=0:
+                                title_before = "=产品"
+                                sentence_title = "title_0"
+                                sentence_title_text = p
+                                title_index = "0"
+                                title_after = "产品="
+                                next_index = "0"
+                                _fix = True
+                                has_product = True
+                                break
+
+            if _type=="sentence":
+                # Optimization: Better text merging strategy
+                # Only merge if the previous node doesn't have a title (to preserve structure)
+                if sentence_title is None and len(list_data)>0 and list_data[-1]["sentence_title"] is None:
+                    # Merge short consecutive sentences without titles
+                    if list_data[-1]["line_width"]>=max_length*0.4 and len(_text)<max_length*0.6:
+                        list_data[-1]["text"] += " " + _text  # Add space separator
+                        list_data[-1]["line_width"] = len(list_data[-1]["text"])
+                        _append = True
+                elif sentence_title is None and len(list_data)>0 and _type==list_data[-1]["type"]:
+                    # Merge very short fragments
+                    if list_data[-1]["line_width"]>=max_length*0.5 and len(_text)<max_length*0.3:
+                        list_data[-1]["text"] += " " + _text
+                        list_data[-1]["line_width"] = len(list_data[-1]["text"])
+                        _append = True
+
+            if _type=="table":
+                _soup = BeautifulSoup(_text,"lxml")
+                _table = _soup.find("table")
+                if _table is not None:
+                    list_table = getTable(_table)
+                    if len(list_table)==0:
+                        continue
+                    table_columns = len(list_table[0])
+
+                    if auto_merge_table:
+                        if last_table_index is not None and abs(obj_i-last_table_index)<=2 and last_table_columns is not None and last_table_columns==table_columns:
+                            if last_table is not None:
+                                trs = getTrs(_table)
+                                last_tbody = BeautifulSoup(last_table["text"],"lxml")
+                                _table = last_tbody.find("table")
+                                last_trs = getTrs(_table)
+                                _append = True
+
+                                for _line in list_table:
+                                    last_table["list_table"].append(_line)
+                                if len(last_trs)>0:
+                                    for _tr in trs:
+                                        last_trs[-1].insert_after(copy.copy(_tr))
+                                    last_table["text"] = re.sub("</?html>|</?body>","",str(last_tbody))
+
+                                last_table_index = obj_i
+                                last_table_columns = len(list_table[-1])
+
+
+            if not _append:
+                _data = {"type":_type, "text":_text,"list_table":list_table,"line_width":len(_text),"sentence_title":sentence_title,"title_index":title_index,
+                         "sentence_title_text":sentence_title_text,"sentence_groups":sentence_groups,"parent_title":parent_title,
+                         "child_title":childs,"title_before":title_before,"title_after":title_after,"title_next":title_next,"next_index":next_index,
+                         "block":block,"has_product":has_product}
+
+                if _type=="table":
+                    last_table = _data
+                    last_table_index = obj_i
+                    if list_table:
+                        last_table_columns = last_table_columns = len(list_table[-1])
+
+                if sentence_title is not None:
+                    if len(list_data)>0:
+                        if self.is_title_first(title_index):
+                            for i in range(1,len(list_data)+1):
+                                _d = list_data[-i]
+                                if _d["sentence_title"] is not None:
+                                    _data["parent_title"] = _d
+                                    _d["child_title"].append(_data)
+                                    break
+                        else:
+                            _find = False
+                            for i in range(1,len(list_data)+1):
+                                if _find:
+                                    break
+                                _d = list_data[-i]
+                                if _d.get("sentence_title")==sentence_title and title_before==_d["title_before"] and title_after==_d["title_after"]:
+                                    if _d["next_index"]==title_index and _d["title_next"] is None and not _d["block"]:
+                                        _data["parent_title"] = _d["parent_title"]
+                                        _d["title_next"] = _data
+                                        if len(_d["child_title"])>0:
+                                            _d["child_title"][-1]["title_next"] = ""
+                                            self.block_tree(_d["child_title"])
+                                        if _d["parent_title"] is not None:
+                                            _d["parent_title"]["child_title"].append(_data)
+                                        _find = True
+                                        break
+                            for i in range(1,len(list_data)+1):
+                                if _find:
+                                    break
+                                _d = list_data[-i]
+                                if i==1 and not _d["block"] and _d.get("sentence_title")==sentence_title and title_before==_d["title_before"] and title_after==_d["title_after"]:
+                                    _data["parent_title"] = _d["parent_title"]
+                                    _d["title_next"] = _data
+                                    if len(_d["child_title"])>0:
+                                        _d["child_title"][-1]["title_next"] = ""
+                                        self.block_tree(_d["child_title"])
+                                    if _d["parent_title"] is not None:
+                                        _d["parent_title"]["child_title"].append(_data)
+                                    _find = True
+                                    break
+                            title_before = standard_title_context(title_before)
+                            title_after = standard_title_context(title_after)
+                            for i in range(1,len(list_data)+1):
+                                if _find:
+                                    break
+                                _d = list_data[-i]
+                                if _d.get("sentence_title")==sentence_title and title_before==standard_title_context(_d["title_before"]) and title_after==standard_title_context(_d["title_after"]):
+                                    if _d["next_index"]==title_index and _d["title_next"] is None and not _d["block"]:
+                                        _data["parent_title"] = _d["parent_title"]
+                                        _d["title_next"] = _data
+                                        if len(_d["child_title"])>0:
+                                            _d["child_title"][-1]["title_next"] = ""
+                                            self.block_tree(_d["child_title"])
+                                        if _d["parent_title"] is not None:
+                                            _d["parent_title"]["child_title"].append(_data)
+                                        _find = True
+                                        break
+                            for i in range(1,len(list_data)+1):
+                                if _find:
+                                    break
+                                _d = list_data[-i]
+                                if not _d["block"] and _d.get("sentence_title")==sentence_title and title_before==standard_title_context(_d["title_before"]) and title_after==standard_title_context(_d["title_after"]):
+                                    _data["parent_title"] = _d["parent_title"]
+                                    _d["title_next"] = _data
+                                    if len(_d["child_title"])>0:
+                                        _d["child_title"][-1]["title_next"] = ""
+                                        # self.block_tree(_d["child_title"])
+                                    if _d["parent_title"] is not None:
+                                        _d["parent_title"]["child_title"].append(_data)
+                                    _find = True
+                                    break
+                            for i in range(1,min(len(list_data)+1,20)):
+                                if _find:
+                                    break
+                                _d = list_data[-i]
+                                if not _d["block"] and _d.get("sentence_title")==sentence_title and title_before==standard_title_context(_d["title_before"]):
+                                    _data["parent_title"] = _d["parent_title"]
+                                    _d["title_next"] = _data
+                                    if len(_d["child_title"])>0:
+                                        _d["child_title"][-1]["title_next"] = ""
+                                        # self.block_tree(_d["child_title"])
+                                    if _d["parent_title"] is not None:
+                                        _d["parent_title"]["child_title"].append(_data)
+                                    _find = True
+                                    break
+
+                            if not _find:
+                                if len(list_data)>0:
+                                    for i in range(1,len(list_data)+1):
+                                        _d = list_data[-i]
+                                        if _d.get("sentence_title") is not None:
+                                            _data["parent_title"] = _d
+                                            _d["child_title"].append(_data)
+                                            break
+
+
+                else:
+                    if len(list_data)>0:
+                        for i in range(1,len(list_data)+1):
+                            _d = list_data[-i]
+                            if _d.get("sentence_title") is not None:
+                                _data["parent_title"] = _d
+                                _d["child_title"].append(_data)
+                                break
+
+                list_data.append(_data)
+
+        for _data in list_data:
+
+            childs = _data["child_title"]
+
+            for c_i in range(len(childs)):
+                cdata = childs[c_i]
+                if cdata["has_product"]:
+                    continue
+                else:
+                    if c_i>0:
+                        last_cdata = childs[c_i-1]
+                        if cdata["sentence_title"] is not None and last_cdata["sentence_title"] is not None and last_cdata["title_before"]==cdata["title_before"] and last_cdata["title_after"]==cdata["title_after"] and last_cdata["has_product"]:
+                            cdata["has_product"] = True
+                    if c_i<len(childs)-1:
+                        last_cdata = childs[c_i+1]
+                        if cdata["sentence_title"] is not None and last_cdata["sentence_title"] is not None and last_cdata["title_before"]==cdata["title_before"] and last_cdata["title_after"]==cdata["title_after"] and last_cdata["has_product"]:
+                            cdata["has_product"] = True
+            for c_i in range(len(childs)):
+                cdata = childs[len(childs)-1-c_i]
+                if cdata["has_product"]:
+                    continue
+                else:
+                    if c_i>0:
+                        last_cdata = childs[c_i-1]
+                        if cdata["sentence_title"] is not None and last_cdata["sentence_title"] is not None and last_cdata["title_before"]==cdata["title_before"] and last_cdata["title_after"]==cdata["title_after"] and last_cdata["has_product"]:
+                            cdata["has_product"] = True
+                    if c_i<len(childs)-1:
+                        last_cdata = childs[c_i+1]
+                        if cdata["sentence_title"] is not None and last_cdata["sentence_title"] is not None and last_cdata["title_before"]==cdata["title_before"] and last_cdata["title_after"]==cdata["title_after"] and last_cdata["has_product"]:
+                            cdata["has_product"] = True
+
+
+        return list_data
+
+
+def standard_title_context(_title_context):
+    return _title_context.replace("(","(").replace(")",")").replace(":",":").replace(":",";").replace(",",".").replace(",",".").replace("、",".").replace(".",".")
+
+def standard_product(sentence):
+    return sentence.replace("(","(").replace(")",")")
+
+def extract_products(list_data,_product,_param_pattern = "产品名称|设备材料|采购内存|标的名称|采购内容|(标的|维修|系统|报价构成|商品|产品|物料|物资|货物|设备|采购品|采购条目|物品|材料|印刷品?|采购|物装|配件|资产|耗材|清单|器材|仪器|器械|备件|拍卖物|标的物|物件|药品|药材|药械|货品|食品|食材|品目|^品名|气体|标项|分项|项目|计划|包组|标段|[分子]?包|子目|服务|招标|中标|成交|工程|招标内容)[\))的]?([、\w]{,4}名称|内容|描述)|标的|标项|项目$|商品|产品|物料|物资|货物|设备|采购品|采购条目|物品|材料|印刷品|物装|配件|资产|招标内容|耗材|清单|器材|仪器|器械|备件|拍卖物|标的物|物件|药品|药材|药械|货品|食品|食材|菜名|^品目$|^品名$|^名称|^内容$"):
+    _product = standard_product(_product)
+    list_result = []
+    list_table_products = []
+    for _data_i in range(len(list_data)):
+        _data = list_data[_data_i]
+        _type = _data["type"]
+        _text = _data["text"]
+
+        if _type=="table":
+            list_table = _data["list_table"]
+            if list_table is None:
+                continue
+            _check = True
+            max_length = max([len(a) for a in list_table])
+            min_length = min([len(a) for a in list_table])
+            if min_length<max_length/2:
+                continue
+            list_head_index = []
+            _begin_index = 0
+            head_cell_text = ""
+            for line_i in range(len(list_table[:2])):
+                line = list_table[line_i]
+                line_text = ",".join([cell[0] for cell in line])
+                for cell_i in range(len(line)):
+                    cell = line[cell_i]
+                    cell_text = cell[0]
+                    if len(cell_text)<10 and re.search(_param_pattern,cell_text) is not None and re.search("单价|数量|预算|限价|总价|品牌|规格|型号|用途|要求|采购量",line_text) is not None:
+                        _begin_index = line_i+1
+                        list_head_index.append(cell_i)
+
+            for line_i in range(len(list_table)):
+                line = list_table[line_i]
+                for cell_i in list_head_index:
+                    if cell_i>=len(line):
+                        continue
+                    cell = line[cell_i]
+                    cell_text = cell[0]
+                    head_cell_text += cell_text
+
+            # print("===head_cell_text",head_cell_text)
+            if re.search("招标人|采购人|项目编号|项目名称|金额|^\d+$",head_cell_text) is not None:
+                list_head_index = []
+
+            for line in list_table:
+                line_text = ",".join([cell[0] for cell in line])
+                for cell_i in range(len(line)):
+                    cell = line[cell_i]
+                    cell_text = cell[0]
+                    if cell_text is not None and _product is not None and len(cell_text)<len(_product)*10 and cell_text.find(_product)>=0 and re.search("单价|数量|总价|规格|品牌|型号|用途|要求|采购量",line_text) is not None:
+                        list_head_index.append(cell_i)
+
+            list_head_index = list(set(list_head_index))
+            if len(list_head_index)>0:
+                has_number = False
+                for cell_i in list_head_index:
+                    table_products = []
+
+                    for line_i in range(_begin_index,len(list_table)):
+                        line = list_table[line_i]
+
+                        for _i in range(len(line)):
+                            cell = line[_i]
+                            cell_text = cell[0]
+                            if re.search("^\d+$",cell_text) is not None:
+                                has_number = True
+
+                        if cell_i>=len(line):
+                            continue
+                        cell = line[cell_i]
+                        cell_text = cell[0]
+                        if re.search(_param_pattern,cell_text) is None or has_number:
+                            if re.search("^[\da-zA-Z]+$",cell_text) is None:
+                                table_products.append(cell_text)
+
+                    if len(table_products)>0:
+                        logger.debug("table products %s"%(str(table_products)))
+                        if min([len(x) for x in table_products])>0 and max([len(x) for x in table_products])<=30:
+                            if re.search("招标人|代理人|预算|数量|交货期|品牌|产地","".join(table_products)) is None:
+                                list_table_products.append(table_products)
+    _find = False
+    for table_products in list_table_products:
+        for _p in table_products:
+            if is_similar(_product,_p,90):
+                _find = True
+                logger.debug("similar table_products %s"%(str(table_products)))
+                list_result = list(set([a for a in table_products if len(a)>1 and len(a)<20 and re.search("费用|预算|合计|金额|万元|运费|^其他$",a) is None]))
+                break
+    if not _find:
+        for table_products in list_table_products:
+            list_result.extend(table_products)
+        list_result = list(set([a for a in list_result if len(a)>1 and len(a)<30 and re.search("费用|预算|合计|金额|万元|运费",a) is None]))
+    return list_result
+
+
+def get_childs(childs):
+    list_data = []
+    for _child in childs:
+        list_data.append(_child)
+        childs2 = _child.get("child_title",[])
+
+        if len(childs2)>0:
+            for _child2 in childs2:
+                list_data.extend(get_childs([_child2]))
+    return list_data
+
+def get_range_data_by_childs(list_data,childs):
+    range_data = []
+    list_child = get_childs(childs)
+    list_index = []
+    set_child = set([id(x) for x in list_child])
+    for _data_i in range(len(list_data)):
+        _data = list_data[_data_i]
+        _id = id(_data)
+        if _id in set_child:
+            list_index.append(_data_i)
+    if len(list_index)>0:
+        range_data = list_data[min(list_index):max(list_index)+1]
+    return range_data
+
+def get_correct_product(product,products):
+    list_data = []
+    for p in products:
+        is_sim = is_similar(product,p)
+        _d = {"product":p,"distance":abs(len(product)-len(p)),"is_sim":is_sim}
+        list_data.append(_d)
+    list_data.sort(key=lambda x:x["distance"])
+    for _d in list_data:
+        is_sim = _d["is_sim"]
+        if is_sim:
+            if len(_d["product"])>len(product) and _d["product"].find(product)>=0:
+                return product
+            return _d["product"]
+    return product
+
+def get_childs_text(childs,_product,products,is_begin=False,is_end=False):
+    _text = ""
+
+    end_next = False
+    for _child in childs:
+
+        child_text = _child.get("text")
+
+
+        if child_text.find(_product)>=0:
+            if not is_begin:
+                is_begin = True
+                if not end_next:
+                    if _child["sentence_title"] is not None and isinstance(_child["title_next"],dict) and _child["title_next"]["sentence_title"] is not None:
+                        end_next = True
+                        end_title = _child["title_next"]
+                        # logger.debug("end_title %s "%end_title["text"])
+
+        # logger.debug("%s-%s-%s"%("get_childs_text",child_text[:10],str(is_begin)))
+
+        for p in products:
+            if child_text.find(p)>=0 and is_similar(_product,p,90):
+                is_begin = True
+
+            if child_text.find(_product)<0  and not is_similar(_product,p,80) and  (child_text.find(p)>=0 or _child["has_product"]):
+                if is_begin:
+                    is_end = True
+                    # logger.debug("%s-%s-%s"%("get_childs_text end",child_text[:10],p))
+                break
+        if re.search(end_pattern,child_text) is not None:
+            if is_begin:
+                is_end = True
+                # logger.debug("%s-%s-%s"%("get_childs_text end",child_text[:10],str(is_end)))
+
+        if is_begin and is_end:
+            break
+
+        if is_begin:
+            _text += _child.get("text")+"\r\n"
+        childs2 = _child.get("child_title",[])
+
+
+        if len(childs2)>0:
+            for _child2 in childs2:
+                child_text,is_begin,is_end = get_childs_text([_child2],_product,products,is_begin)
+                if is_begin:
+                    _text += child_text
+                    if is_end:
+                        break
+
+        if end_next:
+            is_end = True
+
+    #     logger.debug("%s-%s-%s"%("get_childs_text1",_text,str(is_begin)))
+    # logger.debug("%s-%s-%s"%("get_childs_text2",_text,str(is_begin)))
+    return _text,is_begin,is_end
+
+def extract_parameters_by_tree(_product,products,list_data,_data_i,parent_title,list_result,):
+    _data = list_data[_data_i]
+    childs = _data.get("child_title",[])
+    if len(childs)>0:
+        child_text,_,_ = get_childs_text([_data],_product,products)
+        if len(child_text)>0:
+            logger.info("extract_type by_tree child_text:%s"%child_text)
+            list_result.append(child_text)
+    if parent_title is not None:
+        child_text,_,_ = get_childs_text([parent_title],_product,products)
+        if len(child_text)>0:
+            logger.info("extract_type by_tree child_text:%s"%child_text)
+            list_result.append(child_text)
+
+        childs = parent_title.get("child_title",[])
+        if len(childs)>0:
+
+            range_data = get_range_data_by_childs(list_data[_data_i:],childs)
+            p_text = ""
+            _find = False
+            end_id = id(_data["title_next"]) if isinstance(_data["sentence_title"],dict) and _data["title_next"] is not None and _data["title_next"]["sentence_title"] is not None else None
+            for pdata in range_data:
+                ptext = pdata["text"]
+                for p in products:
+                    if ptext.find(_product)<0 and  (ptext.find(p)>=0 or pdata["has_product"]):
+                        _find = True
+                        break
+                if re.search(end_pattern,ptext) is not None:
+                    _find = True
+                if _find:
+                    break
+                if id(pdata)==end_id:
+                    break
+                p_text += ptext+"\r\n"
+            if len(p_text)>0:
+                logger.debug("extract_type by parent range_text:%s"%p_text)
+                list_result.append(p_text)
+                return True
+    return False
+
+
+def get_table_pieces(_text,_product,products,list_result,_find):
+    _soup = BeautifulSoup(_text,"html5lib")
+    _table = _soup.find("table")
+    if _table is not None:
+        trs = getTrs(_table)
+        list_trs = []
+        for tr in trs:
+            tr_text = tr.get_text()
+            if tr_text.find(_product)>=0:
+                _find = True
+
+            # logger.debug("%s-%s"%("table_html_tr",tr_text))
+            for p in products:
+                if _find and p!=_product and tr_text.find(p)>=0:
+                    _find = False
+                    break
+            if re.search(end_pattern,tr_text) is not None:
+                _find = False
+                break
+            if _find:
+                list_trs.append(tr)
+        if len(list_trs)>0:
+            table_html = "<table>%s</table>"%("\r\n".join([str(a) for a in list_trs]))
+            # logger.debug("extract_type table slices %s"%(table_html))
+            list_result.append(table_html)
+
+def extract_parameters_by_table(_product,products,_param_pattern,list_data,_data_i,list_result):
+    _data = list_data[_data_i]
+    _text = _data["text"]
+    list_table = _data["list_table"]
+    parent_title = _data["parent_title"]
+    if list_table is not None:
+        _check = True
+        max_length = max([len(a) for a in list_table])
+        min_length = min([len(a) for a in list_table])
+        text_line_first = ",".join(a[0] for a in list_table[0])
+        if max_length>10:
+            if min_length<max_length/2:
+                return
+        last_data = list_data[_data_i-1]
+        _flag = False
+        if last_data["type"]=="sentence" and last_data["text"].find(_product)>=0:
+            # logger.debug("last sentence find product %s-%s"%(_product,last_data["text"]))
+            _flag = True
+        # print(text_line_first,"text_line_first",re.search(_param_pattern,text_line_first) is not None and text_line_first.find(_product)>=0)
+        if re.search(_param_pattern,text_line_first) is not None and text_line_first.find(_product)>=0:
+            _flag = True
+        if _flag:
+            if len(products)==0:
+                # logger.debug("extract_type whole table by param and product %s"%(_text))
+                list_result.append(_text)
+            else:
+                for p in products:
+                    if p!=_product and _text.find(p)>=0:
+                        # logger.debug("extract_type add all table failed %s-%s"%(_product,p))
+                        _flag = False
+                        break
+                if _flag:
+                    # logger.debug("extract_type add all table succeed")
+                    get_table_pieces(_text,_product,products,list_result,True)
+        else:
+            list_head_index = []
+            for line in list_table[:2]:
+                for cell_i in range(len(line)):
+                    cell = line[cell_i]
+                    cell_text = cell[0]
+                    if len(cell_text)<20 and re.search(_param_pattern,cell_text) is not None:
+                        list_head_index.append(cell_i)
+            list_head_index = list(set(list_head_index))
+            for line in list_table:
+                for cell in line:
+                    cell_text = cell[0]
+                    if len(cell_text)>50 and len(re.findall(meter_pattern,cell_text))>5 and cell_text.find(_product)>=0:
+                        _f = True
+                        for cell in line:
+                            if not _f:
+                                break
+                            cell_text = cell[0]
+                            for p in products:
+                                if cell_text.find(p)>=0 and p!=_product:
+                                    _f = False
+                                    break
+                        if _f:
+                            logger.debug("extract_type param column %s"%(cell_text))
+                            list_result.append(cell_text)
+                    if len(cell_text)<len(_product)*10 and str(cell_text).find(_product)>=0:
+                        for _index in list_head_index:
+                            if _index>=len(line):
+                                continue
+                            _cell = line[_index]
+                            if len(cell[0])>0:
+                                # logger.info("%s-%s"%("extract_type add on table text:",_cell[0]))
+                                list_result.append(_cell[0])
+        if not _flag and (re.search(_param_pattern,_text) is not None or (parent_title is not None and re.search(_param_pattern,parent_title["text"]) is not None)) and _text.find(_product)>=0:
+            get_table_pieces(_text,_product,products,list_result,False)
+
+
+def extract_parameters_by_sentence(list_data,_data,_data_i,_product,products,list_result,is_project):
+    _text = _data["text"]
+    if _text.find(_product)>=0:
+        parent_title = _data.get("parent_title")
+        parent_text = ""
+        parent_parent_title = None
+        parent_parent_text = ""
+        parent_title_index = None
+        parent_parent_title_index = None
+        childs = get_childs([_data])
+
+        child_find = False
+        for c in childs:
+            if re.search(_param_pattern,c["text"]) is not None and len(c["text"])<30:
+                # logger.debug("child text %s"%(c["text"]))
+                child_find = True
+                break
+
+        extract_text,_,_ = get_childs_text([_data],_product,products)
+        # logger.debug("childs found extract_text %s %s"%(str(child_find),extract_text))
+        if child_find:
+            if len(extract_text)>0:
+                list_result.append(extract_text)
+        else:
+            limit_nums = len(_product)*2+5
+            if len(_product)<=3:
+                limit_nums += 6
+            if _text.find("数量")>=0:
+                limit_nums += 6
+            if len(_text)<=limit_nums and _data["sentence_title"] is not None:
+                if re.search(meter_pattern,extract_text) is not None:
+                    list_result.append(extract_text)
+            elif len(re.findall(meter_pattern,extract_text))>2:
+                list_result.append(extract_text)
+
+        if parent_title is not None:
+            parent_text = parent_title.get("text","")
+            parent_parent_title = parent_title.get("parent_title")
+            parent_title_index = parent_title["title_index"]
+            if parent_parent_title is not None:
+                parent_parent_text = parent_parent_title.get("text","")
+                parent_parent_title_index = parent_parent_title["title_index"]
+
+        _suit = False
+        if re.search(_param_pattern,_text) is not None and len(_text)<50:
+            _suit = True
+        if re.search(_param_pattern,parent_text) is not None and len(parent_text)<50:
+            _suit = True
+        if re.search(_param_pattern,parent_parent_text) is not None and len(parent_parent_text)<50:
+            _suit = True
+        if _suit:
+            # logger.debug("extract_type sentence %s"%("extract_parameters_by_tree"))
+            if not extract_parameters_by_tree(_product,products,list_data,_data_i,parent_title,list_result):
+                logger.debug("extract_type sentence %s"%("extract_parameters_by_tree"))
+                extract_parameters_by_tree(_product,products,list_data,_data_i,parent_parent_title,list_result)
+
+    if re.search(_param_pattern,_text) is not None and len(_text)<50:
+        childs = _data["child_title"]
+        if len(childs)>0:
+            extract_text,_,_ = get_childs_text([_data],_product,products)
+            if len(extract_text)>0:
+                # logger.debug("extract_type param-product %s"%(extract_text))
+                list_result.append(extract_text)
+            elif is_project:
+                extract_text,_,_ = get_childs_text([_data],_product,products,is_begin=True)
+                if len(extract_text)>0 and re.search(meter_pattern,extract_text) is not None:
+                    # logger.debug("extract_type sentence is_project param-product is product %s"%(extract_text))
+                    list_result.append(extract_text)
+
+def getBestProductText(list_result,_product,products):
+    list_result.sort(key=lambda x:len(re.findall(meter_pattern+"|"+'[::;;]|\d+[%A-Za-z]+',BeautifulSoup(x,"html5lib").get_text())), reverse=True)
+
+    # logger.debug("+++++++++++++++++++++")
+    # for i in range(len(list_result)):
+    #     logger.debug("result%d %s"%(i,list_result[i]))
+    # logger.debug("+++++++++++++++++++++")
+
+    for i in range(len(list_result)):
+        _result = list_result[i]
+        _check = True
+        _result_text = BeautifulSoup(_result,"html5lib").get_text()
+        _search = re.search("项目编号[::]|项目名称[::]|联合体投标|开户银行",_result)
+        if _search is not None:
+            # logger.debug("result%d error illegal text %s"%(i,str(_search)))
+            _check = False
+        if not (len(_result_text)<1000 and _result[:6]!="<table"):
+            for p in products:
+                if _result_text.find(p)>0 and not (is_similar(_product,p,80) or p.find(_product)>=0 or _product.find(p)>=0):
+                    # logger.debug("result%d error product scoss %s"%(i,p))
+                    _check = False
+        if len(_result_text)<100:
+            if re.search(meter_pattern,_result_text) is None:
+                # logger.debug("result%d error text min count"%(i))
+                _check = False
+        if len(_result_text)>5000:
+            if len(_result_text)>10000:
+                # logger.debug("result%d error text max count"%(i))
+                _check = False
+            elif len(re.findall(meter_pattern,_result_text))<10:
+                # logger.debug("result%d error text max count less meter"%(i))
+                _check = False
+
+        list_find = list(set(re.findall(meter_pattern,_result_text)))
+
+        not_list_find = list(set(re.findall(not_meter_pattern,_result_text)))
+        _count = len(list_find)-len(not_list_find)
+        has_num = False
+        for _find in list_find:
+            if re.search('[0-9a-zA-Z]',_find) is not None:
+                has_num = True
+                break
+        if not(_count>=2 and has_num or _count>=5):
+            # logger.debug("result%d error match not enough"%(i))
+            _check = False
+
+        if _check:
+            return _result
+
+def format_text(_result):
+    list_result = re.split("\r|\n",_result)
+    _result = ""
+    for _r in list_result:
+        if len(_r)>0:
+            _result+="%s\n"%(_r)
+    _result = '<div style="white-space:pre">%s</div>'%(_result)
+    return _result
+
+def extract_product_parameters(list_data,_product):
+
+    list_result = []
+    _product = standard_product(_product.strip())
+    products = extract_products(list_data,_product)
+
+    _product = get_correct_product(_product,products)
+    logger.debug("all products %s-%s"%(_product,str(products)))
+    is_project = False
+    if re.search("项目名称|采购项目",_product) is not None:
+        is_project = True
+        
+    if len(products)==1 and is_similar(products[0],_product,90):
+        is_project = True
+    _find_count = 0
+    for _data_i in range(len(list_data)):
+        _data = list_data[_data_i]
+        _type = _data["type"]
+        _text = _data["text"]
+        if _type=="sentence":
+            if _text.find(_product)>=0:
+                _find_count += 1
+                if re.search("项目名称|采购项目",_text) is not None and re.search("等",_text) is not None:
+                    is_project = True
+            extract_parameters_by_sentence(list_data,_data,_data_i,_product,products,list_result,is_project)
+
+        elif _type=="table":
+            if _text.find(_product)>=0:
+                _find_count += 1
+            extract_parameters_by_table(_product,products,_param_pattern,list_data,_data_i,list_result)
+
+    _text = getBestProductText(list_result,_product,products)
+    return _text,_find_count
+
+
+if __name__ == '__main__':
+
+    filepath = "download/4597dcc128bfabc7584d10590ae50656.html"
+    _product = "彩色多普勒超声诊断仪"
+
+    _html = open(filepath, "r", encoding="utf8").read()
+
+    pd = ParseDocument(_html,False)
+
+    pd.fix_tree(_product)
+    list_data = pd.tree
+    pd.print_tree(list_data)
+
+    _text,_count = extract_product_parameters(list_data,_product)
+    logger.info("find count:%d"%(_count))
+    logger.info("extract_parameter_text::%s"%(_text))
+
+

+ 51 - 0
requirements.txt

@@ -0,0 +1,51 @@
+# Core dependencies
+langchain>=0.2.0
+langchain-community>=0.2.0
+langchain-openai>=0.1.0
+langchain-text-splitters>=0.2.0
+
+# Embedding models
+sentence-transformers>=2.2.0
+openai>=1.0.0
+
+# Vector stores
+faiss-cpu>=1.7.4
+chromadb>=0.4.0
+
+# Document processing
+beautifulsoup4>=4.12.0
+lxml>=4.9.0
+pypdf>=3.0.0
+python-docx>=0.8.11
+pymupdf>=1.23.0
+pandas>=2.0.0
+openpyxl>=3.1.0
+unstructured>=0.10.0
+
+# Reranking
+FlagEmbedding>=1.2.0
+
+# Advanced RAG
+jieba>=0.42.1
+python-Levenshtein>=0.21.0
+rank-bm25>=0.2.2
+numpy>=1.24.0
+scikit-learn>=1.3.0
+
+# Async support
+aiohttp>=3.8.0
+asyncio
+
+# Utilities
+pydantic>=2.0.0
+pyyaml>=6.0
+tqdm>=4.65.0
+rich>=13.0.0
+loguru>=0.7.0
+tiktoken>=0.5.0
+
+# Benchmarking
+matplotlib>=3.7.0
+seaborn>=0.12.0
+timeit
+