| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448 |
- # -*- coding: utf-8 -*-
- """BM25 HTML Tree RAG - BM25 retrieval on HTML hierarchical tree structure."""
- import importlib.util
- import os
- import numpy as np
- from bs4 import BeautifulSoup
- from loguru import logger
- from ..document_processor import Document
- from .bm25_backend import get_bm25_okapi
- from .dedup import deduplicate_ranked_results
- from .tokenization import bm25_tokenize
- class BM25HTMLTreeRAG:
- """
- BM25-based retrieval on HTML hierarchical tree.
-
- This class:
- 1. Parses HTML into a hierarchical tree using ParseDocument
- 2. Extracts all text nodes from the tree
- 3. Builds BM25 index on tree node texts
- 4. Retrieves relevant subtrees based on query
- """
-
- def __init__(self):
- self.tree = []
- self.all_nodes = []
- self.all_texts = []
- self.bm25 = None
- self.html_content = ""
-
- def _tokenize(self, text):
- """Tokenize mixed Chinese/English text for BM25."""
- return bm25_tokenize(text)
-
- def _get_node_depth(self, node, visited=None):
- """Calculate the depth of a node in the tree."""
- if visited is None:
- visited = set()
-
- depth = 0
- current = node
- while current is not None:
- current_id = id(current)
- if current_id in visited:
- break
- visited.add(current_id)
-
- parent = current.get("parent_title")
- if parent is None:
- break
- current = parent
- depth += 1
-
- return depth
-
- def _extract_node_text(self, node):
- """Extract clean text from a tree node."""
- text = node.get("text", "")
- if not text:
- return ""
-
- # If it's HTML content, extract text
- if text.startswith("<") and text.endswith(">"):
- try:
- soup = BeautifulSoup(text, "lxml")
- text = soup.get_text(strip=True)
- except:
- pass
-
- return text.strip()
-
- def _get_node_full_text(self, node):
- """Get full text including children for a node."""
- texts = [self._extract_node_text(node)]
-
- childs = node.get("child_title", [])
- for child in childs:
- child_text = self._get_node_full_text(child)
- if child_text:
- texts.append(child_text)
-
- return " ".join(texts)
-
- def _collect_all_nodes(self, tree_nodes, visited=None):
- """Recursively collect all nodes from the tree."""
- if visited is None:
- visited = set()
-
- nodes = []
- for node in tree_nodes:
- node_id = id(node)
- if node_id in visited:
- continue
- visited.add(node_id)
-
- nodes.append(node)
- childs = node.get("child_title", [])
- if childs:
- nodes.extend(self._collect_all_nodes(childs, visited))
-
- return nodes
-
- def build_index(self, html_content, auto_merge_table=True):
- """
- Build BM25 index from HTML content.
-
- Args:
- html_content: HTML string to parse
- auto_merge_table: Whether to auto-merge tables
- """
- self.html_content = html_content
-
- # Import the local parser/htmlparser.py explicitly. The package name
- # collides with Python's historical stdlib parser module on older
- # interpreters, so a normal ``from parser...`` import is unreliable.
- parser_path = os.path.join(
- os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
- "parser",
- "htmlparser.py",
- )
- spec = importlib.util.spec_from_file_location("bidirag_htmlparser", parser_path)
- htmlparser = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(htmlparser)
- ParseDocument = htmlparser.ParseDocument
-
- # Parse HTML into tree structure
- pd = ParseDocument(html_content, auto_merge_table=auto_merge_table)
- self.tree = pd.tree
-
- # Collect all nodes
- self.all_nodes = self._collect_all_nodes(self.tree)
-
- # Extract texts for BM25
- self.all_texts = []
- for node in self.all_nodes:
- text = self._get_node_full_text(node)
- self.all_texts.append(self._tokenize(text) if text else [])
-
- # Build BM25 index
- BM25Okapi = get_bm25_okapi()
- self.bm25 = BM25Okapi(self.all_texts) if self.all_nodes else None
-
- logger.info("BM25HTMLTreeRAG: indexed {} nodes from HTML tree".format(len(self.all_nodes)))
-
- def retrieve_subtrees(self, query, k=5, min_score=0.0):
- """
- Retrieve relevant subtrees using BM25.
-
- Args:
- query: Search query
- k: Number of top results to return
- min_score: Minimum score threshold
-
- Returns:
- List of (node, score, subtree_text) tuples
- """
- if self.bm25 is None or k <= 0:
- return []
-
- query_tokens = self._tokenize(query)
- if not query_tokens:
- return []
- scores = self.bm25.get_scores(query_tokens)
-
- # Build id-to-index mapping for efficient lookup
- id_to_index = {}
- for j, node in enumerate(self.all_nodes):
- id_to_index[id(node)] = j
-
- # Compute enhanced scores for each node with multi-level boosting
- enhanced_scores = []
- max_score = float(np.max(scores)) if len(scores) > 0 and np.max(scores) > 0 else 1.0
-
- for i, node in enumerate(self.all_nodes):
- base_score = scores[i]
- normalized_score = base_score / max_score # Normalize to [0, 1]
-
- # Optimization 1: Boost from children (max child score)
- childs = node.get("child_title", [])
- child_boost = 0.0
- child_match_count = 0 # Count how many children have matches
- for child in childs:
- child_id = id(child)
- if child_id in id_to_index:
- child_idx = id_to_index[child_id]
- child_norm_score = scores[child_idx] / max_score
- if child_norm_score > 0.1: # Count significant matches
- child_match_count += 1
- child_boost = max(child_boost, child_norm_score * 0.4)
-
- # Optimization 2: Boost from parent (if parent matches, boost children)
- parent = node.get("parent_title")
- parent_boost = 0.0
- if parent is not None:
- parent_id = id(parent)
- if parent_id in id_to_index:
- parent_idx = id_to_index[parent_id]
- parent_norm_score = scores[parent_idx] / max_score
- if parent_norm_score > 0.3: # Only boost if parent has significant match
- parent_boost = parent_norm_score * 0.3
-
- # Optimization 3: Title boost (headers are more important)
- has_match = normalized_score > 0 or child_boost > 0
- title_boost = 0.0
- if has_match:
- # Check if this is a title/header node
- if node.get("sentence_title") is not None:
- title_boost = 0.15
- # Also check node type
- node_type = node.get("type", "").lower()
- if node_type in ["h1", "h2", "h3", "h4", "h5", "h6", "title", "header"]:
- title_boost = max(title_boost, 0.2)
-
- # Optimization 4: Depth penalty (prefer mid-level nodes over very deep or root)
- depth = self._get_node_depth(node)
- depth_factor = 1.0
- if depth == 0: # Root level
- depth_factor = 0.7
- elif depth > 5: # Too deep
- depth_factor = 0.8
- elif 1 <= depth <= 3: # Sweet spot
- depth_factor = 1.1
-
- # Optimization 5: Child diversity bonus - reward nodes with multiple matching children
- child_diversity_bonus = 0.0
- if child_match_count >= 2:
- child_diversity_bonus = 0.05 * min(child_match_count, 4) # Cap at 4 children
-
- # Optimization 6: Structural coherence bonus
- # If this node is a header/title and has matching content children, boost it
- structural_bonus = 0.0
- if node.get("sentence_title") is not None and child_match_count > 0:
- # Header with matching content is very valuable
- structural_bonus = 0.08 * min(child_match_count, 3)
-
- # Combined score with better weights
- combined_before_depth = (
- 0.45 * normalized_score + # Own content (reduced to give more weight to structure)
- 0.20 * child_boost + # Children influence
- 0.15 * parent_boost + # Parent context
- 0.10 * title_boost + # Title importance
- 0.05 * child_diversity_bonus + # Child diversity
- 0.05 * structural_bonus # NEW: Structural coherence
- )
-
- enhanced_score = combined_before_depth * depth_factor
-
- enhanced_scores.append((node, enhanced_score, base_score))
-
- # Sort by enhanced score
- enhanced_scores.sort(key=lambda x: x[1], reverse=True)
-
- # Filter by min_score and take top k with smart deduplication
- results = []
- section_counts = {} # Track how many results per section
- seen_paths = set() # Track exact paths to avoid duplicates
-
- for node, enhanced_score, base_score in enhanced_scores:
- if len(results) >= k:
- break
- if enhanced_score <= 0:
- break
- if enhanced_score < min_score:
- continue
-
- # Get path for smarter deduplication
- path = self.get_node_path(node)
-
- # Avoid exact duplicate paths
- if path in seen_paths:
- continue
- seen_paths.add(path)
-
- # Extract the main section (first level)
- section_parts = path.split(" > ")
- main_section = section_parts[0] if section_parts else path
-
- # Dynamic max results per section based on node characteristics
- # If this is a high-scoring parent node with multiple matching children, allow more results
- childs = node.get("child_title", [])
- matching_children = sum(1 for c in childs if id(c) in id_to_index and scores[id_to_index[id(c)]] / max_score > 0.1)
-
- # Check if this is a header node with strong structural bonus
- is_structural_header = (
- node.get("sentence_title") is not None and
- matching_children >= 2 and
- enhanced_score > 0.3
- )
-
- # Allow up to 3-4 results for rich sections, otherwise 2
- if is_structural_header:
- max_per_section = 4 # High-value structural nodes get more slots
- elif matching_children >= 2:
- max_per_section = 3 # Rich sub-structure
- else:
- max_per_section = 2 # Standard diversity
-
- # Count how many results we already have from this section
- section_count = section_counts.get(main_section, 0)
-
- if section_count >= max_per_section:
- continue
-
- # Get subtree text
- subtree_text = self._get_subtree_text(node)
- results.append((node, enhanced_score, subtree_text))
-
- # Update section count
- section_counts[main_section] = section_count + 1
-
- return results
-
- def _get_subtree_text(self, node, indent=0):
- """Get formatted text of a subtree with hierarchy."""
- lines = []
-
- node_text = self._extract_node_text(node)
- if node_text:
- # Add title info if available
- title_info = ""
- if node.get("sentence_title") is not None:
- title_text = node.get("sentence_title_text", "")
- if title_text:
- title_info = " [标题: {}]".format(title_text)
-
- prefix = " " * indent
- lines.append("{}{}{}".format(prefix, node_text, title_info))
-
- # Recursively get children
- childs = node.get("child_title", [])
- for child in childs:
- child_text = self._get_subtree_text(child, indent + 1)
- if child_text:
- lines.append(child_text)
-
- return "\n".join(lines)
-
- def _is_relevant_to_query(self, node, query_tokens, id_to_index, scores, max_score, threshold=0.1):
- """
- Check if a node or its subtree is relevant to the query.
- This includes checking the node itself and all its children.
- Also considers parent context for better relevance judgment.
- """
- # Check the node itself
- node_id = id(node)
- if node_id in id_to_index:
- node_score = scores[id_to_index[node_id]] / max_score
- if node_score > threshold:
- return True
-
- # Check all children recursively
- childs = node.get("child_title", [])
- for child in childs:
- if self._is_relevant_to_query(child, query_tokens, id_to_index, scores, max_score, threshold):
- return True
-
- # Check parent context - if parent has strong match, child might be relevant too
- parent = node.get("parent_title")
- if parent is not None:
- parent_id = id(parent)
- if parent_id in id_to_index:
- parent_score = scores[id_to_index[parent_id]] / max_score
- if parent_score > 0.3: # Parent has significant match
- return True
-
- return False
-
- def get_node_path(self, node):
- """Get the path from root to this node."""
- path = []
- current = node
-
- while current is not None:
- text = self._extract_node_text(current)
- if text:
- path.insert(0, text[:50])
- current = current.get("parent_title")
-
- return " > ".join(path)
-
- def query(self, query, k=5):
- """
- Full query pipeline: retrieve subtrees and format results.
-
- Args:
- query: Search query
- k: Number of results
-
- Returns:
- List of result dicts with node, score, path, and content
- """
- results = self.retrieve_subtrees(query, k)
-
- formatted_results = []
- for node, score, subtree_text in results:
- path = self.get_node_path(node)
-
- doc = Document(
- page_content=subtree_text,
- metadata={
- "node_type": node.get("type", "unknown"),
- "path": path,
- "score": score,
- "title": node.get("sentence_title_text", ""),
- }
- )
-
- formatted_results.append((doc, score))
-
- return deduplicate_ranked_results(formatted_results, k)
- # Example usage
- if __name__ == "__main__":
- sample_html = """
- <html>
- <body>
- <h1>招标公告</h1>
- <h2>一、项目概况</h2>
- <p>本项目预算金额为5000万元,招标编号为XX-ZB-2024-001。</p>
- <h2>二、投标人资格要求</h2>
- <p>1. 具有独立承担民事责任的能力</p>
- <p>2. 具有相关资质证书</p>
- <h2>三、评标方法</h2>
- <p>采用综合评分法,技术分占比60%,商务分占比40%。</p>
- <h2>四、投标文件递交</h2>
- <p>截止时间:2024年12月31日</p>
- </body>
- </html>
- """
-
- rag = BM25HTMLTreeRAG()
- rag.build_index(sample_html)
-
- queries = ["预算金额", "资质要求", "评标方法"]
-
- for query in queries:
- print("\n查询: {}".format(query))
- print("-" * 60)
- results = rag.query(query, k=3)
- for i, (doc, score) in enumerate(results, 1):
- print(" [{}] 分数: {:.4f}".format(i, score))
- print(" 路径: {}".format(doc.metadata.get("path", "")))
- print(" 内容: {}...".format(doc.page_content[:100]))
|