bm25_html_tree_rag.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # -*- coding: utf-8 -*-
  2. """BM25 HTML Tree RAG - BM25 retrieval on HTML hierarchical tree structure."""
  3. import importlib.util
  4. import os
  5. import numpy as np
  6. from bs4 import BeautifulSoup
  7. from loguru import logger
  8. from ..document_processor import Document
  9. from .bm25_backend import get_bm25_okapi
  10. from .dedup import deduplicate_ranked_results
  11. from .tokenization import bm25_tokenize
  12. class BM25HTMLTreeRAG:
  13. """
  14. BM25-based retrieval on HTML hierarchical tree.
  15. This class:
  16. 1. Parses HTML into a hierarchical tree using ParseDocument
  17. 2. Extracts all text nodes from the tree
  18. 3. Builds BM25 index on tree node texts
  19. 4. Retrieves relevant subtrees based on query
  20. """
  21. def __init__(self):
  22. self.tree = []
  23. self.all_nodes = []
  24. self.all_texts = []
  25. self.bm25 = None
  26. self.html_content = ""
  27. def _tokenize(self, text):
  28. """Tokenize mixed Chinese/English text for BM25."""
  29. return bm25_tokenize(text)
  30. def _get_node_depth(self, node, visited=None):
  31. """Calculate the depth of a node in the tree."""
  32. if visited is None:
  33. visited = set()
  34. depth = 0
  35. current = node
  36. while current is not None:
  37. current_id = id(current)
  38. if current_id in visited:
  39. break
  40. visited.add(current_id)
  41. parent = current.get("parent_title")
  42. if parent is None:
  43. break
  44. current = parent
  45. depth += 1
  46. return depth
  47. def _extract_node_text(self, node):
  48. """Extract clean text from a tree node."""
  49. text = node.get("text", "")
  50. if not text:
  51. return ""
  52. # If it's HTML content, extract text
  53. if text.startswith("<") and text.endswith(">"):
  54. try:
  55. soup = BeautifulSoup(text, "lxml")
  56. text = soup.get_text(strip=True)
  57. except:
  58. pass
  59. return text.strip()
  60. def _get_node_full_text(self, node):
  61. """Get full text including children for a node."""
  62. texts = [self._extract_node_text(node)]
  63. childs = node.get("child_title", [])
  64. for child in childs:
  65. child_text = self._get_node_full_text(child)
  66. if child_text:
  67. texts.append(child_text)
  68. return " ".join(texts)
  69. def _collect_all_nodes(self, tree_nodes, visited=None):
  70. """Recursively collect all nodes from the tree."""
  71. if visited is None:
  72. visited = set()
  73. nodes = []
  74. for node in tree_nodes:
  75. node_id = id(node)
  76. if node_id in visited:
  77. continue
  78. visited.add(node_id)
  79. nodes.append(node)
  80. childs = node.get("child_title", [])
  81. if childs:
  82. nodes.extend(self._collect_all_nodes(childs, visited))
  83. return nodes
  84. def build_index(self, html_content, auto_merge_table=True):
  85. """
  86. Build BM25 index from HTML content.
  87. Args:
  88. html_content: HTML string to parse
  89. auto_merge_table: Whether to auto-merge tables
  90. """
  91. self.html_content = html_content
  92. # Import the local parser/htmlparser.py explicitly. The package name
  93. # collides with Python's historical stdlib parser module on older
  94. # interpreters, so a normal ``from parser...`` import is unreliable.
  95. parser_path = os.path.join(
  96. os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
  97. "parser",
  98. "htmlparser.py",
  99. )
  100. spec = importlib.util.spec_from_file_location("bidirag_htmlparser", parser_path)
  101. htmlparser = importlib.util.module_from_spec(spec)
  102. spec.loader.exec_module(htmlparser)
  103. ParseDocument = htmlparser.ParseDocument
  104. # Parse HTML into tree structure
  105. pd = ParseDocument(html_content, auto_merge_table=auto_merge_table)
  106. self.tree = pd.tree
  107. # Collect all nodes
  108. self.all_nodes = self._collect_all_nodes(self.tree)
  109. # Extract texts for BM25
  110. self.all_texts = []
  111. for node in self.all_nodes:
  112. text = self._get_node_full_text(node)
  113. self.all_texts.append(self._tokenize(text) if text else [])
  114. # Build BM25 index
  115. BM25Okapi = get_bm25_okapi()
  116. self.bm25 = BM25Okapi(self.all_texts) if self.all_nodes else None
  117. logger.info("BM25HTMLTreeRAG: indexed {} nodes from HTML tree".format(len(self.all_nodes)))
  118. def retrieve_subtrees(self, query, k=5, min_score=0.0):
  119. """
  120. Retrieve relevant subtrees using BM25.
  121. Args:
  122. query: Search query
  123. k: Number of top results to return
  124. min_score: Minimum score threshold
  125. Returns:
  126. List of (node, score, subtree_text) tuples
  127. """
  128. if self.bm25 is None or k <= 0:
  129. return []
  130. query_tokens = self._tokenize(query)
  131. if not query_tokens:
  132. return []
  133. scores = self.bm25.get_scores(query_tokens)
  134. # Build id-to-index mapping for efficient lookup
  135. id_to_index = {}
  136. for j, node in enumerate(self.all_nodes):
  137. id_to_index[id(node)] = j
  138. # Compute enhanced scores for each node with multi-level boosting
  139. enhanced_scores = []
  140. max_score = float(np.max(scores)) if len(scores) > 0 and np.max(scores) > 0 else 1.0
  141. for i, node in enumerate(self.all_nodes):
  142. base_score = scores[i]
  143. normalized_score = base_score / max_score # Normalize to [0, 1]
  144. # Optimization 1: Boost from children (max child score)
  145. childs = node.get("child_title", [])
  146. child_boost = 0.0
  147. child_match_count = 0 # Count how many children have matches
  148. for child in childs:
  149. child_id = id(child)
  150. if child_id in id_to_index:
  151. child_idx = id_to_index[child_id]
  152. child_norm_score = scores[child_idx] / max_score
  153. if child_norm_score > 0.1: # Count significant matches
  154. child_match_count += 1
  155. child_boost = max(child_boost, child_norm_score * 0.4)
  156. # Optimization 2: Boost from parent (if parent matches, boost children)
  157. parent = node.get("parent_title")
  158. parent_boost = 0.0
  159. if parent is not None:
  160. parent_id = id(parent)
  161. if parent_id in id_to_index:
  162. parent_idx = id_to_index[parent_id]
  163. parent_norm_score = scores[parent_idx] / max_score
  164. if parent_norm_score > 0.3: # Only boost if parent has significant match
  165. parent_boost = parent_norm_score * 0.3
  166. # Optimization 3: Title boost (headers are more important)
  167. has_match = normalized_score > 0 or child_boost > 0
  168. title_boost = 0.0
  169. if has_match:
  170. # Check if this is a title/header node
  171. if node.get("sentence_title") is not None:
  172. title_boost = 0.15
  173. # Also check node type
  174. node_type = node.get("type", "").lower()
  175. if node_type in ["h1", "h2", "h3", "h4", "h5", "h6", "title", "header"]:
  176. title_boost = max(title_boost, 0.2)
  177. # Optimization 4: Depth penalty (prefer mid-level nodes over very deep or root)
  178. depth = self._get_node_depth(node)
  179. depth_factor = 1.0
  180. if depth == 0: # Root level
  181. depth_factor = 0.7
  182. elif depth > 5: # Too deep
  183. depth_factor = 0.8
  184. elif 1 <= depth <= 3: # Sweet spot
  185. depth_factor = 1.1
  186. # Optimization 5: Child diversity bonus - reward nodes with multiple matching children
  187. child_diversity_bonus = 0.0
  188. if child_match_count >= 2:
  189. child_diversity_bonus = 0.05 * min(child_match_count, 4) # Cap at 4 children
  190. # Optimization 6: Structural coherence bonus
  191. # If this node is a header/title and has matching content children, boost it
  192. structural_bonus = 0.0
  193. if node.get("sentence_title") is not None and child_match_count > 0:
  194. # Header with matching content is very valuable
  195. structural_bonus = 0.08 * min(child_match_count, 3)
  196. # Combined score with better weights
  197. combined_before_depth = (
  198. 0.45 * normalized_score + # Own content (reduced to give more weight to structure)
  199. 0.20 * child_boost + # Children influence
  200. 0.15 * parent_boost + # Parent context
  201. 0.10 * title_boost + # Title importance
  202. 0.05 * child_diversity_bonus + # Child diversity
  203. 0.05 * structural_bonus # NEW: Structural coherence
  204. )
  205. enhanced_score = combined_before_depth * depth_factor
  206. enhanced_scores.append((node, enhanced_score, base_score))
  207. # Sort by enhanced score
  208. enhanced_scores.sort(key=lambda x: x[1], reverse=True)
  209. # Filter by min_score and take top k with smart deduplication
  210. results = []
  211. section_counts = {} # Track how many results per section
  212. seen_paths = set() # Track exact paths to avoid duplicates
  213. for node, enhanced_score, base_score in enhanced_scores:
  214. if len(results) >= k:
  215. break
  216. if enhanced_score <= 0:
  217. break
  218. if enhanced_score < min_score:
  219. continue
  220. # Get path for smarter deduplication
  221. path = self.get_node_path(node)
  222. # Avoid exact duplicate paths
  223. if path in seen_paths:
  224. continue
  225. seen_paths.add(path)
  226. # Extract the main section (first level)
  227. section_parts = path.split(" > ")
  228. main_section = section_parts[0] if section_parts else path
  229. # Dynamic max results per section based on node characteristics
  230. # If this is a high-scoring parent node with multiple matching children, allow more results
  231. childs = node.get("child_title", [])
  232. matching_children = sum(1 for c in childs if id(c) in id_to_index and scores[id_to_index[id(c)]] / max_score > 0.1)
  233. # Check if this is a header node with strong structural bonus
  234. is_structural_header = (
  235. node.get("sentence_title") is not None and
  236. matching_children >= 2 and
  237. enhanced_score > 0.3
  238. )
  239. # Allow up to 3-4 results for rich sections, otherwise 2
  240. if is_structural_header:
  241. max_per_section = 4 # High-value structural nodes get more slots
  242. elif matching_children >= 2:
  243. max_per_section = 3 # Rich sub-structure
  244. else:
  245. max_per_section = 2 # Standard diversity
  246. # Count how many results we already have from this section
  247. section_count = section_counts.get(main_section, 0)
  248. if section_count >= max_per_section:
  249. continue
  250. # Get subtree text
  251. subtree_text = self._get_subtree_text(node)
  252. results.append((node, enhanced_score, subtree_text))
  253. # Update section count
  254. section_counts[main_section] = section_count + 1
  255. return results
  256. def _get_subtree_text(self, node, indent=0):
  257. """Get formatted text of a subtree with hierarchy."""
  258. lines = []
  259. node_text = self._extract_node_text(node)
  260. if node_text:
  261. # Add title info if available
  262. title_info = ""
  263. if node.get("sentence_title") is not None:
  264. title_text = node.get("sentence_title_text", "")
  265. if title_text:
  266. title_info = " [标题: {}]".format(title_text)
  267. prefix = " " * indent
  268. lines.append("{}{}{}".format(prefix, node_text, title_info))
  269. # Recursively get children
  270. childs = node.get("child_title", [])
  271. for child in childs:
  272. child_text = self._get_subtree_text(child, indent + 1)
  273. if child_text:
  274. lines.append(child_text)
  275. return "\n".join(lines)
  276. def _is_relevant_to_query(self, node, query_tokens, id_to_index, scores, max_score, threshold=0.1):
  277. """
  278. Check if a node or its subtree is relevant to the query.
  279. This includes checking the node itself and all its children.
  280. Also considers parent context for better relevance judgment.
  281. """
  282. # Check the node itself
  283. node_id = id(node)
  284. if node_id in id_to_index:
  285. node_score = scores[id_to_index[node_id]] / max_score
  286. if node_score > threshold:
  287. return True
  288. # Check all children recursively
  289. childs = node.get("child_title", [])
  290. for child in childs:
  291. if self._is_relevant_to_query(child, query_tokens, id_to_index, scores, max_score, threshold):
  292. return True
  293. # Check parent context - if parent has strong match, child might be relevant too
  294. parent = node.get("parent_title")
  295. if parent is not None:
  296. parent_id = id(parent)
  297. if parent_id in id_to_index:
  298. parent_score = scores[id_to_index[parent_id]] / max_score
  299. if parent_score > 0.3: # Parent has significant match
  300. return True
  301. return False
  302. def get_node_path(self, node):
  303. """Get the path from root to this node."""
  304. path = []
  305. current = node
  306. while current is not None:
  307. text = self._extract_node_text(current)
  308. if text:
  309. path.insert(0, text[:50])
  310. current = current.get("parent_title")
  311. return " > ".join(path)
  312. def query(self, query, k=5):
  313. """
  314. Full query pipeline: retrieve subtrees and format results.
  315. Args:
  316. query: Search query
  317. k: Number of results
  318. Returns:
  319. List of result dicts with node, score, path, and content
  320. """
  321. results = self.retrieve_subtrees(query, k)
  322. formatted_results = []
  323. for node, score, subtree_text in results:
  324. path = self.get_node_path(node)
  325. doc = Document(
  326. page_content=subtree_text,
  327. metadata={
  328. "node_type": node.get("type", "unknown"),
  329. "path": path,
  330. "score": score,
  331. "title": node.get("sentence_title_text", ""),
  332. }
  333. )
  334. formatted_results.append((doc, score))
  335. return deduplicate_ranked_results(formatted_results, k)
  336. # Example usage
  337. if __name__ == "__main__":
  338. sample_html = """
  339. <html>
  340. <body>
  341. <h1>招标公告</h1>
  342. <h2>一、项目概况</h2>
  343. <p>本项目预算金额为5000万元,招标编号为XX-ZB-2024-001。</p>
  344. <h2>二、投标人资格要求</h2>
  345. <p>1. 具有独立承担民事责任的能力</p>
  346. <p>2. 具有相关资质证书</p>
  347. <h2>三、评标方法</h2>
  348. <p>采用综合评分法,技术分占比60%,商务分占比40%。</p>
  349. <h2>四、投标文件递交</h2>
  350. <p>截止时间:2024年12月31日</p>
  351. </body>
  352. </html>
  353. """
  354. rag = BM25HTMLTreeRAG()
  355. rag.build_index(sample_html)
  356. queries = ["预算金额", "资质要求", "评标方法"]
  357. for query in queries:
  358. print("\n查询: {}".format(query))
  359. print("-" * 60)
  360. results = rag.query(query, k=3)
  361. for i, (doc, score) in enumerate(results, 1):
  362. print(" [{}] 分数: {:.4f}".format(i, score))
  363. print(" 路径: {}".format(doc.metadata.get("path", "")))
  364. print(" 内容: {}...".format(doc.page_content[:100]))