# -*- coding: utf-8 -*- """BidiRag - Unified interface for RAG-based document retrieval.""" import os from typing import List, Dict, Optional, Union from loguru import logger from .config import ( EMBEDDING_MODEL_NAME, CHUNK_SIZE, CHUNK_OVERLAP, VECTOR_STORE_TYPE, RETRIEVAL_TOP_K, ) from .document_processor import DocumentProcessor, Document from .embedding_models import SentenceTransformerEmbedding from .vector_stores import get_vector_store from .rag_methods import ( NaiveRAG, RerankRAG, ParentDocumentRAG, LLMFilterRAG, QueryRoutingRAG, MetadataFilterRAG, AdaptiveRAG, HybridSearchRAG, MultiQueryRAG, HyDERAG, StepBackRAG, ContextualCompressionRAG, SelfRAG, CorrectiveRAG, FLARERAG, RAPTORRAG, EnsembleRAG, BidFieldExtractionRAG, TableAwareRAG, GraphRAG, BM25RAG, TFIDFRAG, KeywordRAG, BM25HTMLTreeRAG, ) # Available RAG methods mapping RAG_METHODS = { 'naive': NaiveRAG, 'rerank': RerankRAG, 'parent_document': ParentDocumentRAG, 'llm_filter': LLMFilterRAG, 'query_routing': QueryRoutingRAG, 'metadata_filter': MetadataFilterRAG, 'adaptive': AdaptiveRAG, 'hybrid_search': HybridSearchRAG, 'multi_query': MultiQueryRAG, 'hyde': HyDERAG, 'step_back': StepBackRAG, 'contextual_compression': ContextualCompressionRAG, 'self_rag': SelfRAG, 'corrective': CorrectiveRAG, 'flare': FLARERAG, 'raptor': RAPTORRAG, 'ensemble': EnsembleRAG, 'bid_field_extraction': BidFieldExtractionRAG, 'table_aware': TableAwareRAG, 'graph': GraphRAG, 'bm25': BM25RAG, 'tfidf': TFIDFRAG, 'keyword': KeywordRAG, 'bm25_html_tree': BM25HTMLTreeRAG, } class BidiRag: """ BidiRag - Unified interface for RAG-based document retrieval and extraction. This class provides a high-level API to: 1. Load and process HTML documents 2. Build index using embedding models 3. Retrieve relevant text using various RAG methods 4. Support keyword-based search Usage: >>> rag = BidiRag(rag_method='bm25_html_tree') >>> rag.add_documents(["path/to/document1.html", "path/to/document2.html"]) >>> results = rag.retrieve("招标人和中标人", top_k=5) >>> for doc in results: ... print(doc.page_content) """ def __init__( self, rag_method: str = 'bm25_html_tree', chunk_size: int = None, chunk_overlap: int = None, vector_store_type: str = None, embedding_model_name: str = None, llm_client=None, llm_model: str = "gpt-4o", index_name: str = "default", **kwargs ): """ Initialize BidiRag. Args: rag_method: RAG method to use (e.g., 'bm25_html_tree', 'naive', 'hybrid_search') chunk_size: Document chunk size (default: from config) chunk_overlap: Chunk overlap size (default: from config) vector_store_type: Vector store type 'faiss' or 'chroma' (default: from config) embedding_model_name: Embedding model name (default: from config) llm_client: LLM client instance (optional, for methods requiring LLM) llm_model: LLM model name index_name: Index name for storage **kwargs: Additional arguments for specific RAG methods """ # Configuration self.rag_method_name = rag_method self.chunk_size = chunk_size or CHUNK_SIZE self.chunk_overlap = chunk_overlap or CHUNK_OVERLAP self.vector_store_type = vector_store_type or VECTOR_STORE_TYPE self.embedding_model_name = embedding_model_name or EMBEDDING_MODEL_NAME self.llm_client = llm_client self.llm_model = llm_model self.index_name = index_name # Validate RAG method if rag_method not in RAG_METHODS: available = ', '.join(RAG_METHODS.keys()) raise ValueError( f"Unknown RAG method '{rag_method}'. " f"Available methods: {available}" ) # Methods that don't require embedding model self.embedding_free_methods = {'bm25', 'tfidf', 'keyword', 'bm25_html_tree'} # Methods that don't require vector store (they have their own indexing) self.vector_store_free_methods = {'bm25', 'tfidf', 'keyword', 'bm25_html_tree'} # Initialize components self.document_processor = DocumentProcessor( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap ) # Initialize embedding model (only if needed) if rag_method not in self.embedding_free_methods: logger.info(f"Loading embedding model: {self.embedding_model_name}") self.embedding_model = SentenceTransformerEmbedding(model_name=self.embedding_model_name) else: logger.info(f"Method {rag_method} doesn't require embedding model") self.embedding_model = None # Initialize vector store (only if needed) if rag_method not in self.vector_store_free_methods: logger.info(f"Initializing vector store: {self.vector_store_type}") self.vector_store = get_vector_store( store_type=self.vector_store_type, embedding_model=self.embedding_model ) else: logger.info(f"Method {rag_method} doesn't require vector store") self.vector_store = None # Initialize RAG method self._init_rag_method(**kwargs) # Document storage self.documents = [] self.indexed = False logger.info(f"BidiRag initialized with method: {rag_method}") def _init_rag_method(self, **kwargs): """Initialize the specified RAG method.""" rag_class = RAG_METHODS[self.rag_method_name] # Check if this is a special method that doesn't follow BaseRAG interface special_methods = {'bm25_html_tree'} if self.rag_method_name in special_methods: # BM25HTMLTreeRAG has its own interface try: self.rag_method = rag_class() self.is_special_method = True logger.info(f"Special RAG method {self.rag_method_name} initialized") except Exception as e: logger.error(f"Failed to initialize special RAG method {self.rag_method_name}: {e}") raise else: # Common parameters for BaseRAG methods common_params = { 'embedding_model': self.embedding_model, 'vector_store': self.vector_store, } # Add LLM parameters if provided if self.llm_client is not None: common_params['llm_client'] = self.llm_client common_params['llm_model'] = self.llm_model # Merge with additional kwargs common_params.update(kwargs) try: self.rag_method = rag_class(**common_params) self.is_special_method = False logger.info(f"RAG method {self.rag_method_name} initialized successfully") except Exception as e: logger.error(f"Failed to initialize RAG method {self.rag_method_name}: {e}") raise def add_documents( self, sources: Union[str, List[str]], rebuild_index: bool = True ) -> int: """ Add documents from file paths or directories. Args: sources: Single file path, directory path, or list of paths rebuild_index: Whether to rebuild the index after adding documents Returns: Number of documents added """ if isinstance(sources, str): sources = [sources] new_docs = [] for source in sources: if not os.path.exists(source): logger.warning(f"Source not found: {source}") continue logger.info(f"Processing: {source}") docs = self.document_processor.process(source) new_docs.extend(docs) if not new_docs: logger.warning("No documents were loaded") return 0 # Store documents self.documents.extend(new_docs) logger.info(f"Added {len(new_docs)} documents (total: {len(self.documents)})") # Build index if rebuild_index: self.build_index() return len(new_docs) def add_texts( self, texts: List[str], metadata: List[Dict] = None, rebuild_index: bool = True ) -> int: """ Add documents from text strings. Args: texts: List of text content metadata: List of metadata dictionaries (optional) rebuild_index: Whether to rebuild the index Returns: Number of documents added """ if metadata is None: metadata = [{} for _ in texts] if len(texts) != len(metadata): raise ValueError("texts and metadata must have the same length") new_docs = [ Document(page_content=text, metadata=meta) for text, meta in zip(texts, metadata) ] self.documents.extend(new_docs) logger.info(f"Added {len(new_docs)} text documents") if rebuild_index: self.build_index() return len(new_docs) def build_index(self): """Build or rebuild the document index.""" if not self.documents: logger.warning("No documents to index") return logger.info(f"Building index with {len(self.documents)} documents...") if self.is_special_method: # For BM25HTMLTreeRAG, we need to combine all HTML content html_content = "\n".join([doc.page_content for doc in self.documents]) self.rag_method.build_index(html_content) elif self.vector_store is not None: # Clear existing index self.vector_store.clear() # Index documents using the RAG method self.rag_method.index_documents(self.documents) else: # For methods like BM25, TFIDF that have their own indexing # Call their index_documents method directly if hasattr(self.rag_method, 'index_documents'): self.rag_method.index_documents(self.documents) else: logger.info(f"Method {self.rag_method_name} handles indexing internally") self.indexed = True logger.info("Index built successfully") def retrieve( self, query: str, top_k: int = None, keywords: List[str] = None, return_scores: bool = True ) -> List: """ Retrieve relevant documents/fragments for a query. Args: query: Search query (can be natural language or keywords) top_k: Number of results to return (default: from config) keywords: Additional keywords to filter results (optional) return_scores: Whether to return documents with scores Returns: List of retrieved documents/fragments (with scores if return_scores=True) """ if not self.indexed: logger.warning("Index not built. Call add_documents() first.") return [] top_k = top_k or RETRIEVAL_TOP_K # Build query with keywords if provided if keywords: enhanced_query = f"{query} {' '.join(keywords)}" logger.info(f"Enhanced query with keywords: {enhanced_query}") else: enhanced_query = query # Retrieve using the RAG method logger.info(f"Retrieving with query: {query}") if self.is_special_method: # BM25HTMLTreeRAG returns tree node fragments directly # This returns (Document, score) tuples with tree node content results = self.rag_method.query(enhanced_query, k=top_k) logger.info(f"Retrieved {len(results)} tree node fragments") else: # BaseRAG methods use retrieve() method results = self.rag_method.retrieve(enhanced_query, k=top_k) # Filter by keywords if provided # Only apply strict filtering if we have more results than needed if keywords and len(results) > top_k: results = self._filter_by_keywords(results, keywords) logger.info(f"Filtered to {len(results)} results matching keywords") # If no results from retrieval but keywords provided, try search_keywords if not results and keywords: logger.info("No results from retrieval, trying search_keywords") results = self.search_keywords(keywords, top_k) # Add dummy scores for consistency results = [(doc, 1.0) for doc in results] return results def _filter_by_keywords( self, results: List, keywords: List[str] ) -> List: """Filter retrieval results by keywords.""" filtered = [] for item in results: # Handle both (doc, score) tuples and plain documents if isinstance(item, tuple): doc, score = item else: doc = item score = None # Check if any keyword is in the document content content_lower = doc.page_content.lower() if any(kw.lower() in content_lower for kw in keywords): if score is not None: filtered.append((doc, score)) else: filtered.append(doc) return filtered def query( self, query: str, top_k: int = None, keywords: List[str] = None ): """ Full RAG pipeline: retrieve + generate answer. Args: query: Search query top_k: Number of documents to retrieve keywords: Additional keywords for filtering Returns: RAGResult object with answer and retrieved documents """ if not self.indexed: logger.warning("Index not built. Call add_documents() first.") return None # Build query with keywords if keywords: enhanced_query = f"{query} {' '.join(keywords)}" else: enhanced_query = query logger.info(f"Querying: {query}") if self.is_special_method: # BM25HTMLTreeRAG doesn't have generate() method, only retrieve result = self.rag_method.query(enhanced_query, k=top_k or RETRIEVAL_TOP_K) # Wrap in a simple result object from .rag_methods.base import RAGResult rag_result = RAGResult( answer="", retrieved_docs=result, metadata={"method": self.rag_method_name} ) else: result = self.rag_method.query(enhanced_query, k=top_k or RETRIEVAL_TOP_K) rag_result = result # Filter retrieved docs by keywords if provided if keywords and rag_result.retrieved_docs: filtered_docs = self._filter_by_keywords(rag_result.retrieved_docs, keywords) rag_result.retrieved_docs = filtered_docs return rag_result def search_keywords( self, keywords: List[str], top_k: int = None ) -> List[Document]: """ Search documents by exact keywords. Args: keywords: List of keywords to search for top_k: Maximum number of results Returns: List of documents containing the keywords """ if not self.documents: logger.warning("No documents loaded") return [] top_k = top_k or len(self.documents) results = [] for doc in self.documents: content_lower = doc.page_content.lower() # Check if all keywords are present if all(kw.lower() in content_lower for kw in keywords): results.append(doc) if len(results) >= top_k: break logger.info(f"Found {len(results)} documents matching keywords: {keywords}") return results def get_document_count(self) -> int: """Get the number of loaded documents.""" return len(self.documents) def clear(self): """Clear all documents and index.""" self.documents = [] if self.vector_store is not None: self.vector_store.clear() self.indexed = False logger.info("Cleared all documents and index") def list_available_methods(self) -> List[str]: """List all available RAG methods.""" return list(RAG_METHODS.keys()) def get_method_info(self) -> Dict: """Get information about the current RAG method.""" return { 'method_name': self.rag_method_name, 'method_class': self.rag_method.__class__.__name__, 'embedding_model': self.embedding_model_name, 'vector_store': self.vector_store_type, 'document_count': len(self.documents), 'indexed': self.indexed, }