# -*- coding: utf-8 -*- """BM25 HTML Tree RAG - BM25 retrieval on HTML hierarchical tree structure.""" import importlib.util import os import numpy as np from bs4 import BeautifulSoup from loguru import logger from ..document_processor import Document from .bm25_backend import get_bm25_okapi from .dedup import deduplicate_ranked_results from .tokenization import bm25_tokenize class BM25HTMLTreeRAG: """ BM25-based retrieval on HTML hierarchical tree. This class: 1. Parses HTML into a hierarchical tree using ParseDocument 2. Extracts all text nodes from the tree 3. Builds BM25 index on tree node texts 4. Retrieves relevant subtrees based on query """ def __init__(self): self.tree = [] self.all_nodes = [] self.all_texts = [] self.bm25 = None self.html_content = "" def _tokenize(self, text): """Tokenize mixed Chinese/English text for BM25.""" return bm25_tokenize(text) def _get_node_depth(self, node, visited=None): """Calculate the depth of a node in the tree.""" if visited is None: visited = set() depth = 0 current = node while current is not None: current_id = id(current) if current_id in visited: break visited.add(current_id) parent = current.get("parent_title") if parent is None: break current = parent depth += 1 return depth def _extract_node_text(self, node): """Extract clean text from a tree node.""" text = node.get("text", "") if not text: return "" # If it's HTML content, extract text if text.startswith("<") and text.endswith(">"): try: soup = BeautifulSoup(text, "lxml") text = soup.get_text(strip=True) except: pass return text.strip() def _get_node_full_text(self, node): """Get full text including children for a node.""" texts = [self._extract_node_text(node)] childs = node.get("child_title", []) for child in childs: child_text = self._get_node_full_text(child) if child_text: texts.append(child_text) return " ".join(texts) def _collect_all_nodes(self, tree_nodes, visited=None): """Recursively collect all nodes from the tree.""" if visited is None: visited = set() nodes = [] for node in tree_nodes: node_id = id(node) if node_id in visited: continue visited.add(node_id) nodes.append(node) childs = node.get("child_title", []) if childs: nodes.extend(self._collect_all_nodes(childs, visited)) return nodes def build_index(self, html_content, auto_merge_table=True): """ Build BM25 index from HTML content. Args: html_content: HTML string to parse auto_merge_table: Whether to auto-merge tables """ self.html_content = html_content # Import the local parser/htmlparser.py explicitly. The package name # collides with Python's historical stdlib parser module on older # interpreters, so a normal ``from parser...`` import is unreliable. parser_path = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "parser", "htmlparser.py", ) spec = importlib.util.spec_from_file_location("bidirag_htmlparser", parser_path) htmlparser = importlib.util.module_from_spec(spec) spec.loader.exec_module(htmlparser) ParseDocument = htmlparser.ParseDocument # Parse HTML into tree structure pd = ParseDocument(html_content, auto_merge_table=auto_merge_table) self.tree = pd.tree # Collect all nodes self.all_nodes = self._collect_all_nodes(self.tree) # Extract texts for BM25 self.all_texts = [] for node in self.all_nodes: text = self._get_node_full_text(node) self.all_texts.append(self._tokenize(text) if text else []) # Build BM25 index BM25Okapi = get_bm25_okapi() self.bm25 = BM25Okapi(self.all_texts) if self.all_nodes else None logger.info("BM25HTMLTreeRAG: indexed {} nodes from HTML tree".format(len(self.all_nodes))) def retrieve_subtrees(self, query, k=5, min_score=0.0): """ Retrieve relevant subtrees using BM25. Args: query: Search query k: Number of top results to return min_score: Minimum score threshold Returns: List of (node, score, subtree_text) tuples """ if self.bm25 is None or k <= 0: return [] query_tokens = self._tokenize(query) if not query_tokens: return [] scores = self.bm25.get_scores(query_tokens) # Build id-to-index mapping for efficient lookup id_to_index = {} for j, node in enumerate(self.all_nodes): id_to_index[id(node)] = j # Compute enhanced scores for each node with multi-level boosting enhanced_scores = [] max_score = float(np.max(scores)) if len(scores) > 0 and np.max(scores) > 0 else 1.0 for i, node in enumerate(self.all_nodes): base_score = scores[i] normalized_score = base_score / max_score # Normalize to [0, 1] # Optimization 1: Boost from children (max child score) childs = node.get("child_title", []) child_boost = 0.0 child_match_count = 0 # Count how many children have matches for child in childs: child_id = id(child) if child_id in id_to_index: child_idx = id_to_index[child_id] child_norm_score = scores[child_idx] / max_score if child_norm_score > 0.1: # Count significant matches child_match_count += 1 child_boost = max(child_boost, child_norm_score * 0.4) # Optimization 2: Boost from parent (if parent matches, boost children) parent = node.get("parent_title") parent_boost = 0.0 if parent is not None: parent_id = id(parent) if parent_id in id_to_index: parent_idx = id_to_index[parent_id] parent_norm_score = scores[parent_idx] / max_score if parent_norm_score > 0.3: # Only boost if parent has significant match parent_boost = parent_norm_score * 0.3 # Optimization 3: Title boost (headers are more important) has_match = normalized_score > 0 or child_boost > 0 title_boost = 0.0 if has_match: # Check if this is a title/header node if node.get("sentence_title") is not None: title_boost = 0.15 # Also check node type node_type = node.get("type", "").lower() if node_type in ["h1", "h2", "h3", "h4", "h5", "h6", "title", "header"]: title_boost = max(title_boost, 0.2) # Optimization 4: Depth penalty (prefer mid-level nodes over very deep or root) depth = self._get_node_depth(node) depth_factor = 1.0 if depth == 0: # Root level depth_factor = 0.7 elif depth > 5: # Too deep depth_factor = 0.8 elif 1 <= depth <= 3: # Sweet spot depth_factor = 1.1 # Optimization 5: Child diversity bonus - reward nodes with multiple matching children child_diversity_bonus = 0.0 if child_match_count >= 2: child_diversity_bonus = 0.05 * min(child_match_count, 4) # Cap at 4 children # Optimization 6: Structural coherence bonus # If this node is a header/title and has matching content children, boost it structural_bonus = 0.0 if node.get("sentence_title") is not None and child_match_count > 0: # Header with matching content is very valuable structural_bonus = 0.08 * min(child_match_count, 3) # Combined score with better weights combined_before_depth = ( 0.45 * normalized_score + # Own content (reduced to give more weight to structure) 0.20 * child_boost + # Children influence 0.15 * parent_boost + # Parent context 0.10 * title_boost + # Title importance 0.05 * child_diversity_bonus + # Child diversity 0.05 * structural_bonus # NEW: Structural coherence ) enhanced_score = combined_before_depth * depth_factor enhanced_scores.append((node, enhanced_score, base_score)) # Sort by enhanced score enhanced_scores.sort(key=lambda x: x[1], reverse=True) # Filter by min_score and take top k with smart deduplication results = [] section_counts = {} # Track how many results per section seen_paths = set() # Track exact paths to avoid duplicates for node, enhanced_score, base_score in enhanced_scores: if len(results) >= k: break if enhanced_score <= 0: break if enhanced_score < min_score: continue # Get path for smarter deduplication path = self.get_node_path(node) # Avoid exact duplicate paths if path in seen_paths: continue seen_paths.add(path) # Extract the main section (first level) section_parts = path.split(" > ") main_section = section_parts[0] if section_parts else path # Dynamic max results per section based on node characteristics # If this is a high-scoring parent node with multiple matching children, allow more results childs = node.get("child_title", []) matching_children = sum(1 for c in childs if id(c) in id_to_index and scores[id_to_index[id(c)]] / max_score > 0.1) # Check if this is a header node with strong structural bonus is_structural_header = ( node.get("sentence_title") is not None and matching_children >= 2 and enhanced_score > 0.3 ) # Allow up to 3-4 results for rich sections, otherwise 2 if is_structural_header: max_per_section = 4 # High-value structural nodes get more slots elif matching_children >= 2: max_per_section = 3 # Rich sub-structure else: max_per_section = 2 # Standard diversity # Count how many results we already have from this section section_count = section_counts.get(main_section, 0) if section_count >= max_per_section: continue # Get subtree text subtree_text = self._get_subtree_text(node) results.append((node, enhanced_score, subtree_text)) # Update section count section_counts[main_section] = section_count + 1 return results def _get_subtree_text(self, node, indent=0): """Get formatted text of a subtree with hierarchy.""" lines = [] node_text = self._extract_node_text(node) if node_text: # Add title info if available title_info = "" if node.get("sentence_title") is not None: title_text = node.get("sentence_title_text", "") if title_text: title_info = " [标题: {}]".format(title_text) prefix = " " * indent lines.append("{}{}{}".format(prefix, node_text, title_info)) # Recursively get children childs = node.get("child_title", []) for child in childs: child_text = self._get_subtree_text(child, indent + 1) if child_text: lines.append(child_text) return "\n".join(lines) def _is_relevant_to_query(self, node, query_tokens, id_to_index, scores, max_score, threshold=0.1): """ Check if a node or its subtree is relevant to the query. This includes checking the node itself and all its children. Also considers parent context for better relevance judgment. """ # Check the node itself node_id = id(node) if node_id in id_to_index: node_score = scores[id_to_index[node_id]] / max_score if node_score > threshold: return True # Check all children recursively childs = node.get("child_title", []) for child in childs: if self._is_relevant_to_query(child, query_tokens, id_to_index, scores, max_score, threshold): return True # Check parent context - if parent has strong match, child might be relevant too parent = node.get("parent_title") if parent is not None: parent_id = id(parent) if parent_id in id_to_index: parent_score = scores[id_to_index[parent_id]] / max_score if parent_score > 0.3: # Parent has significant match return True return False def get_node_path(self, node): """Get the path from root to this node.""" path = [] current = node while current is not None: text = self._extract_node_text(current) if text: path.insert(0, text[:50]) current = current.get("parent_title") return " > ".join(path) def query(self, query, k=5): """ Full query pipeline: retrieve subtrees and format results. Args: query: Search query k: Number of results Returns: List of result dicts with node, score, path, and content """ results = self.retrieve_subtrees(query, k) formatted_results = [] for node, score, subtree_text in results: path = self.get_node_path(node) doc = Document( page_content=subtree_text, metadata={ "node_type": node.get("type", "unknown"), "path": path, "score": score, "title": node.get("sentence_title_text", ""), } ) formatted_results.append((doc, score)) return deduplicate_ranked_results(formatted_results, k) # Example usage if __name__ == "__main__": sample_html = """
本项目预算金额为5000万元,招标编号为XX-ZB-2024-001。
1. 具有独立承担民事责任的能力
2. 具有相关资质证书
采用综合评分法,技术分占比60%,商务分占比40%。
截止时间:2024年12月31日
""" rag = BM25HTMLTreeRAG() rag.build_index(sample_html) queries = ["预算金额", "资质要求", "评标方法"] for query in queries: print("\n查询: {}".format(query)) print("-" * 60) results = rag.query(query, k=3) for i, (doc, score) in enumerate(results, 1): print(" [{}] 分数: {:.4f}".format(i, score)) print(" 路径: {}".format(doc.metadata.get("path", ""))) print(" 内容: {}...".format(doc.page_content[:100]))