| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498 |
- # -*- coding: utf-8 -*-
- """BidiRag - Unified interface for RAG-based document retrieval."""
- import os
- from typing import List, Dict, Optional, Union
- from loguru import logger
- from .config import (
- EMBEDDING_MODEL_NAME,
- CHUNK_SIZE,
- CHUNK_OVERLAP,
- VECTOR_STORE_TYPE,
- RETRIEVAL_TOP_K,
- )
- from .document_processor import DocumentProcessor, Document
- from .embedding_models import SentenceTransformerEmbedding
- from .vector_stores import get_vector_store
- from .rag_methods import (
- NaiveRAG, RerankRAG, ParentDocumentRAG, LLMFilterRAG,
- QueryRoutingRAG, MetadataFilterRAG, AdaptiveRAG, HybridSearchRAG,
- MultiQueryRAG, HyDERAG, StepBackRAG, ContextualCompressionRAG,
- SelfRAG, CorrectiveRAG, FLARERAG, RAPTORRAG, EnsembleRAG,
- BidFieldExtractionRAG, TableAwareRAG, GraphRAG,
- BM25RAG, TFIDFRAG, KeywordRAG, BM25HTMLTreeRAG,
- )
- # Available RAG methods mapping
- RAG_METHODS = {
- 'naive': NaiveRAG,
- 'rerank': RerankRAG,
- 'parent_document': ParentDocumentRAG,
- 'llm_filter': LLMFilterRAG,
- 'query_routing': QueryRoutingRAG,
- 'metadata_filter': MetadataFilterRAG,
- 'adaptive': AdaptiveRAG,
- 'hybrid_search': HybridSearchRAG,
- 'multi_query': MultiQueryRAG,
- 'hyde': HyDERAG,
- 'step_back': StepBackRAG,
- 'contextual_compression': ContextualCompressionRAG,
- 'self_rag': SelfRAG,
- 'corrective': CorrectiveRAG,
- 'flare': FLARERAG,
- 'raptor': RAPTORRAG,
- 'ensemble': EnsembleRAG,
- 'bid_field_extraction': BidFieldExtractionRAG,
- 'table_aware': TableAwareRAG,
- 'graph': GraphRAG,
- 'bm25': BM25RAG,
- 'tfidf': TFIDFRAG,
- 'keyword': KeywordRAG,
- 'bm25_html_tree': BM25HTMLTreeRAG,
- }
- class BidiRag:
- """
- BidiRag - Unified interface for RAG-based document retrieval and extraction.
-
- This class provides a high-level API to:
- 1. Load and process HTML documents
- 2. Build index using embedding models
- 3. Retrieve relevant text using various RAG methods
- 4. Support keyword-based search
-
- Usage:
- >>> rag = BidiRag(rag_method='bm25_html_tree')
- >>> rag.add_documents(["path/to/document1.html", "path/to/document2.html"])
- >>> results = rag.retrieve("招标人和中标人", top_k=5)
- >>> for doc in results:
- ... print(doc.page_content)
- """
-
- def __init__(
- self,
- rag_method: str = 'bm25_html_tree',
- chunk_size: int = None,
- chunk_overlap: int = None,
- vector_store_type: str = None,
- embedding_model_name: str = None,
- llm_client=None,
- llm_model: str = "gpt-4o",
- index_name: str = "default",
- **kwargs
- ):
- """
- Initialize BidiRag.
-
- Args:
- rag_method: RAG method to use (e.g., 'bm25_html_tree', 'naive', 'hybrid_search')
- chunk_size: Document chunk size (default: from config)
- chunk_overlap: Chunk overlap size (default: from config)
- vector_store_type: Vector store type 'faiss' or 'chroma' (default: from config)
- embedding_model_name: Embedding model name (default: from config)
- llm_client: LLM client instance (optional, for methods requiring LLM)
- llm_model: LLM model name
- index_name: Index name for storage
- **kwargs: Additional arguments for specific RAG methods
- """
- # Configuration
- self.rag_method_name = rag_method
- self.chunk_size = chunk_size or CHUNK_SIZE
- self.chunk_overlap = chunk_overlap or CHUNK_OVERLAP
- self.vector_store_type = vector_store_type or VECTOR_STORE_TYPE
- self.embedding_model_name = embedding_model_name or EMBEDDING_MODEL_NAME
- self.llm_client = llm_client
- self.llm_model = llm_model
- self.index_name = index_name
-
- # Validate RAG method
- if rag_method not in RAG_METHODS:
- available = ', '.join(RAG_METHODS.keys())
- raise ValueError(
- f"Unknown RAG method '{rag_method}'. "
- f"Available methods: {available}"
- )
-
- # Methods that don't require embedding model
- self.embedding_free_methods = {'bm25', 'tfidf', 'keyword', 'bm25_html_tree'}
- # Methods that don't require vector store (they have their own indexing)
- self.vector_store_free_methods = {'bm25', 'tfidf', 'keyword', 'bm25_html_tree'}
-
- # Initialize components
- self.document_processor = DocumentProcessor(
- chunk_size=self.chunk_size,
- chunk_overlap=self.chunk_overlap
- )
-
- # Initialize embedding model (only if needed)
- if rag_method not in self.embedding_free_methods:
- logger.info(f"Loading embedding model: {self.embedding_model_name}")
- self.embedding_model = SentenceTransformerEmbedding(model_name=self.embedding_model_name)
- else:
- logger.info(f"Method {rag_method} doesn't require embedding model")
- self.embedding_model = None
-
- # Initialize vector store (only if needed)
- if rag_method not in self.vector_store_free_methods:
- logger.info(f"Initializing vector store: {self.vector_store_type}")
- self.vector_store = get_vector_store(
- store_type=self.vector_store_type,
- embedding_model=self.embedding_model
- )
- else:
- logger.info(f"Method {rag_method} doesn't require vector store")
- self.vector_store = None
-
- # Initialize RAG method
- self._init_rag_method(**kwargs)
-
- # Document storage
- self.documents = []
- self.indexed = False
-
- logger.info(f"BidiRag initialized with method: {rag_method}")
-
- def _init_rag_method(self, **kwargs):
- """Initialize the specified RAG method."""
- rag_class = RAG_METHODS[self.rag_method_name]
-
- # Check if this is a special method that doesn't follow BaseRAG interface
- special_methods = {'bm25_html_tree'}
-
- if self.rag_method_name in special_methods:
- # BM25HTMLTreeRAG has its own interface
- try:
- self.rag_method = rag_class()
- self.is_special_method = True
- logger.info(f"Special RAG method {self.rag_method_name} initialized")
- except Exception as e:
- logger.error(f"Failed to initialize special RAG method {self.rag_method_name}: {e}")
- raise
- else:
- # Common parameters for BaseRAG methods
- common_params = {
- 'embedding_model': self.embedding_model,
- 'vector_store': self.vector_store,
- }
-
- # Add LLM parameters if provided
- if self.llm_client is not None:
- common_params['llm_client'] = self.llm_client
- common_params['llm_model'] = self.llm_model
-
- # Merge with additional kwargs
- common_params.update(kwargs)
-
- try:
- self.rag_method = rag_class(**common_params)
- self.is_special_method = False
- logger.info(f"RAG method {self.rag_method_name} initialized successfully")
- except Exception as e:
- logger.error(f"Failed to initialize RAG method {self.rag_method_name}: {e}")
- raise
-
- def add_documents(
- self,
- sources: Union[str, List[str]],
- rebuild_index: bool = True
- ) -> int:
- """
- Add documents from file paths or directories.
-
- Args:
- sources: Single file path, directory path, or list of paths
- rebuild_index: Whether to rebuild the index after adding documents
-
- Returns:
- Number of documents added
- """
- if isinstance(sources, str):
- sources = [sources]
-
- new_docs = []
- for source in sources:
- if not os.path.exists(source):
- logger.warning(f"Source not found: {source}")
- continue
-
- logger.info(f"Processing: {source}")
- docs = self.document_processor.process(source)
- new_docs.extend(docs)
-
- if not new_docs:
- logger.warning("No documents were loaded")
- return 0
-
- # Store documents
- self.documents.extend(new_docs)
- logger.info(f"Added {len(new_docs)} documents (total: {len(self.documents)})")
-
- # Build index
- if rebuild_index:
- self.build_index()
-
- return len(new_docs)
-
- def add_texts(
- self,
- texts: List[str],
- metadata: List[Dict] = None,
- rebuild_index: bool = True
- ) -> int:
- """
- Add documents from text strings.
-
- Args:
- texts: List of text content
- metadata: List of metadata dictionaries (optional)
- rebuild_index: Whether to rebuild the index
-
- Returns:
- Number of documents added
- """
- if metadata is None:
- metadata = [{} for _ in texts]
-
- if len(texts) != len(metadata):
- raise ValueError("texts and metadata must have the same length")
-
- new_docs = [
- Document(page_content=text, metadata=meta)
- for text, meta in zip(texts, metadata)
- ]
-
- self.documents.extend(new_docs)
- logger.info(f"Added {len(new_docs)} text documents")
-
- if rebuild_index:
- self.build_index()
-
- return len(new_docs)
-
- def build_index(self):
- """Build or rebuild the document index."""
- if not self.documents:
- logger.warning("No documents to index")
- return
-
- logger.info(f"Building index with {len(self.documents)} documents...")
-
- if self.is_special_method:
- # For BM25HTMLTreeRAG, we need to combine all HTML content
- html_content = "\n".join([doc.page_content for doc in self.documents])
- self.rag_method.build_index(html_content)
- elif self.vector_store is not None:
- # Clear existing index
- self.vector_store.clear()
-
- # Index documents using the RAG method
- self.rag_method.index_documents(self.documents)
- else:
- # For methods like BM25, TFIDF that have their own indexing
- # Call their index_documents method directly
- if hasattr(self.rag_method, 'index_documents'):
- self.rag_method.index_documents(self.documents)
- else:
- logger.info(f"Method {self.rag_method_name} handles indexing internally")
-
- self.indexed = True
- logger.info("Index built successfully")
-
- def retrieve(
- self,
- query: str,
- top_k: int = None,
- keywords: List[str] = None,
- return_scores: bool = True
- ) -> List:
- """
- Retrieve relevant documents/fragments for a query.
-
- Args:
- query: Search query (can be natural language or keywords)
- top_k: Number of results to return (default: from config)
- keywords: Additional keywords to filter results (optional)
- return_scores: Whether to return documents with scores
-
- Returns:
- List of retrieved documents/fragments (with scores if return_scores=True)
- """
- if not self.indexed:
- logger.warning("Index not built. Call add_documents() first.")
- return []
-
- top_k = top_k or RETRIEVAL_TOP_K
-
- # Build query with keywords if provided
- if keywords:
- enhanced_query = f"{query} {' '.join(keywords)}"
- logger.info(f"Enhanced query with keywords: {enhanced_query}")
- else:
- enhanced_query = query
-
- # Retrieve using the RAG method
- logger.info(f"Retrieving with query: {query}")
-
- if self.is_special_method:
- # BM25HTMLTreeRAG returns tree node fragments directly
- # This returns (Document, score) tuples with tree node content
- results = self.rag_method.query(enhanced_query, k=top_k)
- logger.info(f"Retrieved {len(results)} tree node fragments")
- else:
- # BaseRAG methods use retrieve() method
- results = self.rag_method.retrieve(enhanced_query, k=top_k)
-
- # Filter by keywords if provided
- # Only apply strict filtering if we have more results than needed
- if keywords and len(results) > top_k:
- results = self._filter_by_keywords(results, keywords)
- logger.info(f"Filtered to {len(results)} results matching keywords")
-
- # If no results from retrieval but keywords provided, try search_keywords
- if not results and keywords:
- logger.info("No results from retrieval, trying search_keywords")
- results = self.search_keywords(keywords, top_k)
- # Add dummy scores for consistency
- results = [(doc, 1.0) for doc in results]
-
- return results
-
- def _filter_by_keywords(
- self,
- results: List,
- keywords: List[str]
- ) -> List:
- """Filter retrieval results by keywords."""
- filtered = []
- for item in results:
- # Handle both (doc, score) tuples and plain documents
- if isinstance(item, tuple):
- doc, score = item
- else:
- doc = item
- score = None
-
- # Check if any keyword is in the document content
- content_lower = doc.page_content.lower()
- if any(kw.lower() in content_lower for kw in keywords):
- if score is not None:
- filtered.append((doc, score))
- else:
- filtered.append(doc)
-
- return filtered
-
- def query(
- self,
- query: str,
- top_k: int = None,
- keywords: List[str] = None
- ):
- """
- Full RAG pipeline: retrieve + generate answer.
-
- Args:
- query: Search query
- top_k: Number of documents to retrieve
- keywords: Additional keywords for filtering
-
- Returns:
- RAGResult object with answer and retrieved documents
- """
- if not self.indexed:
- logger.warning("Index not built. Call add_documents() first.")
- return None
-
- # Build query with keywords
- if keywords:
- enhanced_query = f"{query} {' '.join(keywords)}"
- else:
- enhanced_query = query
-
- logger.info(f"Querying: {query}")
-
- if self.is_special_method:
- # BM25HTMLTreeRAG doesn't have generate() method, only retrieve
- result = self.rag_method.query(enhanced_query, k=top_k or RETRIEVAL_TOP_K)
- # Wrap in a simple result object
- from .rag_methods.base import RAGResult
- rag_result = RAGResult(
- answer="",
- retrieved_docs=result,
- metadata={"method": self.rag_method_name}
- )
- else:
- result = self.rag_method.query(enhanced_query, k=top_k or RETRIEVAL_TOP_K)
- rag_result = result
-
- # Filter retrieved docs by keywords if provided
- if keywords and rag_result.retrieved_docs:
- filtered_docs = self._filter_by_keywords(rag_result.retrieved_docs, keywords)
- rag_result.retrieved_docs = filtered_docs
-
- return rag_result
-
- def search_keywords(
- self,
- keywords: List[str],
- top_k: int = None
- ) -> List[Document]:
- """
- Search documents by exact keywords.
-
- Args:
- keywords: List of keywords to search for
- top_k: Maximum number of results
-
- Returns:
- List of documents containing the keywords
- """
- if not self.documents:
- logger.warning("No documents loaded")
- return []
-
- top_k = top_k or len(self.documents)
- results = []
-
- for doc in self.documents:
- content_lower = doc.page_content.lower()
- # Check if all keywords are present
- if all(kw.lower() in content_lower for kw in keywords):
- results.append(doc)
-
- if len(results) >= top_k:
- break
-
- logger.info(f"Found {len(results)} documents matching keywords: {keywords}")
- return results
-
- def get_document_count(self) -> int:
- """Get the number of loaded documents."""
- return len(self.documents)
-
- def clear(self):
- """Clear all documents and index."""
- self.documents = []
- if self.vector_store is not None:
- self.vector_store.clear()
- self.indexed = False
- logger.info("Cleared all documents and index")
-
- def list_available_methods(self) -> List[str]:
- """List all available RAG methods."""
- return list(RAG_METHODS.keys())
-
- def get_method_info(self) -> Dict:
- """Get information about the current RAG method."""
- return {
- 'method_name': self.rag_method_name,
- 'method_class': self.rag_method.__class__.__name__,
- 'embedding_model': self.embedding_model_name,
- 'vector_store': self.vector_store_type,
- 'document_count': len(self.documents),
- 'indexed': self.indexed,
- }
|