bidi_rag.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. # -*- coding: utf-8 -*-
  2. """BidiRag - Unified interface for RAG-based document retrieval."""
  3. import os
  4. from typing import List, Dict, Optional, Union
  5. from loguru import logger
  6. from .config import (
  7. EMBEDDING_MODEL_NAME,
  8. CHUNK_SIZE,
  9. CHUNK_OVERLAP,
  10. VECTOR_STORE_TYPE,
  11. RETRIEVAL_TOP_K,
  12. )
  13. from .document_processor import DocumentProcessor, Document
  14. from .embedding_models import SentenceTransformerEmbedding
  15. from .vector_stores import get_vector_store
  16. from .rag_methods import (
  17. NaiveRAG, RerankRAG, ParentDocumentRAG, LLMFilterRAG,
  18. QueryRoutingRAG, MetadataFilterRAG, AdaptiveRAG, HybridSearchRAG,
  19. MultiQueryRAG, HyDERAG, StepBackRAG, ContextualCompressionRAG,
  20. SelfRAG, CorrectiveRAG, FLARERAG, RAPTORRAG, EnsembleRAG,
  21. BidFieldExtractionRAG, TableAwareRAG, GraphRAG,
  22. BM25RAG, TFIDFRAG, KeywordRAG, BM25HTMLTreeRAG,
  23. )
  24. # Available RAG methods mapping
  25. RAG_METHODS = {
  26. 'naive': NaiveRAG,
  27. 'rerank': RerankRAG,
  28. 'parent_document': ParentDocumentRAG,
  29. 'llm_filter': LLMFilterRAG,
  30. 'query_routing': QueryRoutingRAG,
  31. 'metadata_filter': MetadataFilterRAG,
  32. 'adaptive': AdaptiveRAG,
  33. 'hybrid_search': HybridSearchRAG,
  34. 'multi_query': MultiQueryRAG,
  35. 'hyde': HyDERAG,
  36. 'step_back': StepBackRAG,
  37. 'contextual_compression': ContextualCompressionRAG,
  38. 'self_rag': SelfRAG,
  39. 'corrective': CorrectiveRAG,
  40. 'flare': FLARERAG,
  41. 'raptor': RAPTORRAG,
  42. 'ensemble': EnsembleRAG,
  43. 'bid_field_extraction': BidFieldExtractionRAG,
  44. 'table_aware': TableAwareRAG,
  45. 'graph': GraphRAG,
  46. 'bm25': BM25RAG,
  47. 'tfidf': TFIDFRAG,
  48. 'keyword': KeywordRAG,
  49. 'bm25_html_tree': BM25HTMLTreeRAG,
  50. }
  51. class BidiRag:
  52. """
  53. BidiRag - Unified interface for RAG-based document retrieval and extraction.
  54. This class provides a high-level API to:
  55. 1. Load and process HTML documents
  56. 2. Build index using embedding models
  57. 3. Retrieve relevant text using various RAG methods
  58. 4. Support keyword-based search
  59. Usage:
  60. >>> rag = BidiRag(rag_method='bm25_html_tree')
  61. >>> rag.add_documents(["path/to/document1.html", "path/to/document2.html"])
  62. >>> results = rag.retrieve("招标人和中标人", top_k=5)
  63. >>> for doc in results:
  64. ... print(doc.page_content)
  65. """
  66. def __init__(
  67. self,
  68. rag_method: str = 'bm25_html_tree',
  69. chunk_size: int = None,
  70. chunk_overlap: int = None,
  71. vector_store_type: str = None,
  72. embedding_model_name: str = None,
  73. llm_client=None,
  74. llm_model: str = "gpt-4o",
  75. index_name: str = "default",
  76. **kwargs
  77. ):
  78. """
  79. Initialize BidiRag.
  80. Args:
  81. rag_method: RAG method to use (e.g., 'bm25_html_tree', 'naive', 'hybrid_search')
  82. chunk_size: Document chunk size (default: from config)
  83. chunk_overlap: Chunk overlap size (default: from config)
  84. vector_store_type: Vector store type 'faiss' or 'chroma' (default: from config)
  85. embedding_model_name: Embedding model name (default: from config)
  86. llm_client: LLM client instance (optional, for methods requiring LLM)
  87. llm_model: LLM model name
  88. index_name: Index name for storage
  89. **kwargs: Additional arguments for specific RAG methods
  90. """
  91. # Configuration
  92. self.rag_method_name = rag_method
  93. self.chunk_size = chunk_size or CHUNK_SIZE
  94. self.chunk_overlap = chunk_overlap or CHUNK_OVERLAP
  95. self.vector_store_type = vector_store_type or VECTOR_STORE_TYPE
  96. self.embedding_model_name = embedding_model_name or EMBEDDING_MODEL_NAME
  97. self.llm_client = llm_client
  98. self.llm_model = llm_model
  99. self.index_name = index_name
  100. # Validate RAG method
  101. if rag_method not in RAG_METHODS:
  102. available = ', '.join(RAG_METHODS.keys())
  103. raise ValueError(
  104. f"Unknown RAG method '{rag_method}'. "
  105. f"Available methods: {available}"
  106. )
  107. # Methods that don't require embedding model
  108. self.embedding_free_methods = {'bm25', 'tfidf', 'keyword', 'bm25_html_tree'}
  109. # Methods that don't require vector store (they have their own indexing)
  110. self.vector_store_free_methods = {'bm25', 'tfidf', 'keyword', 'bm25_html_tree'}
  111. # Initialize components
  112. self.document_processor = DocumentProcessor(
  113. chunk_size=self.chunk_size,
  114. chunk_overlap=self.chunk_overlap
  115. )
  116. # Initialize embedding model (only if needed)
  117. if rag_method not in self.embedding_free_methods:
  118. logger.info(f"Loading embedding model: {self.embedding_model_name}")
  119. self.embedding_model = SentenceTransformerEmbedding(model_name=self.embedding_model_name)
  120. else:
  121. logger.info(f"Method {rag_method} doesn't require embedding model")
  122. self.embedding_model = None
  123. # Initialize vector store (only if needed)
  124. if rag_method not in self.vector_store_free_methods:
  125. logger.info(f"Initializing vector store: {self.vector_store_type}")
  126. self.vector_store = get_vector_store(
  127. store_type=self.vector_store_type,
  128. embedding_model=self.embedding_model
  129. )
  130. else:
  131. logger.info(f"Method {rag_method} doesn't require vector store")
  132. self.vector_store = None
  133. # Initialize RAG method
  134. self._init_rag_method(**kwargs)
  135. # Document storage
  136. self.documents = []
  137. self.indexed = False
  138. logger.info(f"BidiRag initialized with method: {rag_method}")
  139. def _init_rag_method(self, **kwargs):
  140. """Initialize the specified RAG method."""
  141. rag_class = RAG_METHODS[self.rag_method_name]
  142. # Check if this is a special method that doesn't follow BaseRAG interface
  143. special_methods = {'bm25_html_tree'}
  144. if self.rag_method_name in special_methods:
  145. # BM25HTMLTreeRAG has its own interface
  146. try:
  147. self.rag_method = rag_class()
  148. self.is_special_method = True
  149. logger.info(f"Special RAG method {self.rag_method_name} initialized")
  150. except Exception as e:
  151. logger.error(f"Failed to initialize special RAG method {self.rag_method_name}: {e}")
  152. raise
  153. else:
  154. # Common parameters for BaseRAG methods
  155. common_params = {
  156. 'embedding_model': self.embedding_model,
  157. 'vector_store': self.vector_store,
  158. }
  159. # Add LLM parameters if provided
  160. if self.llm_client is not None:
  161. common_params['llm_client'] = self.llm_client
  162. common_params['llm_model'] = self.llm_model
  163. # Merge with additional kwargs
  164. common_params.update(kwargs)
  165. try:
  166. self.rag_method = rag_class(**common_params)
  167. self.is_special_method = False
  168. logger.info(f"RAG method {self.rag_method_name} initialized successfully")
  169. except Exception as e:
  170. logger.error(f"Failed to initialize RAG method {self.rag_method_name}: {e}")
  171. raise
  172. def add_documents(
  173. self,
  174. sources: Union[str, List[str]],
  175. rebuild_index: bool = True
  176. ) -> int:
  177. """
  178. Add documents from file paths or directories.
  179. Args:
  180. sources: Single file path, directory path, or list of paths
  181. rebuild_index: Whether to rebuild the index after adding documents
  182. Returns:
  183. Number of documents added
  184. """
  185. if isinstance(sources, str):
  186. sources = [sources]
  187. new_docs = []
  188. for source in sources:
  189. if not os.path.exists(source):
  190. logger.warning(f"Source not found: {source}")
  191. continue
  192. logger.info(f"Processing: {source}")
  193. docs = self.document_processor.process(source)
  194. new_docs.extend(docs)
  195. if not new_docs:
  196. logger.warning("No documents were loaded")
  197. return 0
  198. # Store documents
  199. self.documents.extend(new_docs)
  200. logger.info(f"Added {len(new_docs)} documents (total: {len(self.documents)})")
  201. # Build index
  202. if rebuild_index:
  203. self.build_index()
  204. return len(new_docs)
  205. def add_texts(
  206. self,
  207. texts: List[str],
  208. metadata: List[Dict] = None,
  209. rebuild_index: bool = True
  210. ) -> int:
  211. """
  212. Add documents from text strings.
  213. Args:
  214. texts: List of text content
  215. metadata: List of metadata dictionaries (optional)
  216. rebuild_index: Whether to rebuild the index
  217. Returns:
  218. Number of documents added
  219. """
  220. if metadata is None:
  221. metadata = [{} for _ in texts]
  222. if len(texts) != len(metadata):
  223. raise ValueError("texts and metadata must have the same length")
  224. new_docs = [
  225. Document(page_content=text, metadata=meta)
  226. for text, meta in zip(texts, metadata)
  227. ]
  228. self.documents.extend(new_docs)
  229. logger.info(f"Added {len(new_docs)} text documents")
  230. if rebuild_index:
  231. self.build_index()
  232. return len(new_docs)
  233. def build_index(self):
  234. """Build or rebuild the document index."""
  235. if not self.documents:
  236. logger.warning("No documents to index")
  237. return
  238. logger.info(f"Building index with {len(self.documents)} documents...")
  239. if self.is_special_method:
  240. # For BM25HTMLTreeRAG, we need to combine all HTML content
  241. html_content = "\n".join([doc.page_content for doc in self.documents])
  242. self.rag_method.build_index(html_content)
  243. elif self.vector_store is not None:
  244. # Clear existing index
  245. self.vector_store.clear()
  246. # Index documents using the RAG method
  247. self.rag_method.index_documents(self.documents)
  248. else:
  249. # For methods like BM25, TFIDF that have their own indexing
  250. # Call their index_documents method directly
  251. if hasattr(self.rag_method, 'index_documents'):
  252. self.rag_method.index_documents(self.documents)
  253. else:
  254. logger.info(f"Method {self.rag_method_name} handles indexing internally")
  255. self.indexed = True
  256. logger.info("Index built successfully")
  257. def retrieve(
  258. self,
  259. query: str,
  260. top_k: int = None,
  261. keywords: List[str] = None,
  262. return_scores: bool = True
  263. ) -> List:
  264. """
  265. Retrieve relevant documents/fragments for a query.
  266. Args:
  267. query: Search query (can be natural language or keywords)
  268. top_k: Number of results to return (default: from config)
  269. keywords: Additional keywords to filter results (optional)
  270. return_scores: Whether to return documents with scores
  271. Returns:
  272. List of retrieved documents/fragments (with scores if return_scores=True)
  273. """
  274. if not self.indexed:
  275. logger.warning("Index not built. Call add_documents() first.")
  276. return []
  277. top_k = top_k or RETRIEVAL_TOP_K
  278. # Build query with keywords if provided
  279. if keywords:
  280. enhanced_query = f"{query} {' '.join(keywords)}"
  281. logger.info(f"Enhanced query with keywords: {enhanced_query}")
  282. else:
  283. enhanced_query = query
  284. # Retrieve using the RAG method
  285. logger.info(f"Retrieving with query: {query}")
  286. if self.is_special_method:
  287. # BM25HTMLTreeRAG returns tree node fragments directly
  288. # This returns (Document, score) tuples with tree node content
  289. results = self.rag_method.query(enhanced_query, k=top_k)
  290. logger.info(f"Retrieved {len(results)} tree node fragments")
  291. else:
  292. # BaseRAG methods use retrieve() method
  293. results = self.rag_method.retrieve(enhanced_query, k=top_k)
  294. # Filter by keywords if provided
  295. # Only apply strict filtering if we have more results than needed
  296. if keywords and len(results) > top_k:
  297. results = self._filter_by_keywords(results, keywords)
  298. logger.info(f"Filtered to {len(results)} results matching keywords")
  299. # If no results from retrieval but keywords provided, try search_keywords
  300. if not results and keywords:
  301. logger.info("No results from retrieval, trying search_keywords")
  302. results = self.search_keywords(keywords, top_k)
  303. # Add dummy scores for consistency
  304. results = [(doc, 1.0) for doc in results]
  305. return results
  306. def _filter_by_keywords(
  307. self,
  308. results: List,
  309. keywords: List[str]
  310. ) -> List:
  311. """Filter retrieval results by keywords."""
  312. filtered = []
  313. for item in results:
  314. # Handle both (doc, score) tuples and plain documents
  315. if isinstance(item, tuple):
  316. doc, score = item
  317. else:
  318. doc = item
  319. score = None
  320. # Check if any keyword is in the document content
  321. content_lower = doc.page_content.lower()
  322. if any(kw.lower() in content_lower for kw in keywords):
  323. if score is not None:
  324. filtered.append((doc, score))
  325. else:
  326. filtered.append(doc)
  327. return filtered
  328. def query(
  329. self,
  330. query: str,
  331. top_k: int = None,
  332. keywords: List[str] = None
  333. ):
  334. """
  335. Full RAG pipeline: retrieve + generate answer.
  336. Args:
  337. query: Search query
  338. top_k: Number of documents to retrieve
  339. keywords: Additional keywords for filtering
  340. Returns:
  341. RAGResult object with answer and retrieved documents
  342. """
  343. if not self.indexed:
  344. logger.warning("Index not built. Call add_documents() first.")
  345. return None
  346. # Build query with keywords
  347. if keywords:
  348. enhanced_query = f"{query} {' '.join(keywords)}"
  349. else:
  350. enhanced_query = query
  351. logger.info(f"Querying: {query}")
  352. if self.is_special_method:
  353. # BM25HTMLTreeRAG doesn't have generate() method, only retrieve
  354. result = self.rag_method.query(enhanced_query, k=top_k or RETRIEVAL_TOP_K)
  355. # Wrap in a simple result object
  356. from .rag_methods.base import RAGResult
  357. rag_result = RAGResult(
  358. answer="",
  359. retrieved_docs=result,
  360. metadata={"method": self.rag_method_name}
  361. )
  362. else:
  363. result = self.rag_method.query(enhanced_query, k=top_k or RETRIEVAL_TOP_K)
  364. rag_result = result
  365. # Filter retrieved docs by keywords if provided
  366. if keywords and rag_result.retrieved_docs:
  367. filtered_docs = self._filter_by_keywords(rag_result.retrieved_docs, keywords)
  368. rag_result.retrieved_docs = filtered_docs
  369. return rag_result
  370. def search_keywords(
  371. self,
  372. keywords: List[str],
  373. top_k: int = None
  374. ) -> List[Document]:
  375. """
  376. Search documents by exact keywords.
  377. Args:
  378. keywords: List of keywords to search for
  379. top_k: Maximum number of results
  380. Returns:
  381. List of documents containing the keywords
  382. """
  383. if not self.documents:
  384. logger.warning("No documents loaded")
  385. return []
  386. top_k = top_k or len(self.documents)
  387. results = []
  388. for doc in self.documents:
  389. content_lower = doc.page_content.lower()
  390. # Check if all keywords are present
  391. if all(kw.lower() in content_lower for kw in keywords):
  392. results.append(doc)
  393. if len(results) >= top_k:
  394. break
  395. logger.info(f"Found {len(results)} documents matching keywords: {keywords}")
  396. return results
  397. def get_document_count(self) -> int:
  398. """Get the number of loaded documents."""
  399. return len(self.documents)
  400. def clear(self):
  401. """Clear all documents and index."""
  402. self.documents = []
  403. if self.vector_store is not None:
  404. self.vector_store.clear()
  405. self.indexed = False
  406. logger.info("Cleared all documents and index")
  407. def list_available_methods(self) -> List[str]:
  408. """List all available RAG methods."""
  409. return list(RAG_METHODS.keys())
  410. def get_method_info(self) -> Dict:
  411. """Get information about the current RAG method."""
  412. return {
  413. 'method_name': self.rag_method_name,
  414. 'method_class': self.rag_method.__class__.__name__,
  415. 'embedding_model': self.embedding_model_name,
  416. 'vector_store': self.vector_store_type,
  417. 'document_count': len(self.documents),
  418. 'indexed': self.indexed,
  419. }