# -*- coding: utf-8 -*- """Table-Aware RAG - specialized retrieval for tabular data in bidding documents.""" import re from .base import BaseRAG from loguru import logger class TableAwareRAG(BaseRAG): def __init__(self, **kwargs): super().__init__(**kwargs) self.table_pattern = re.compile(r'\|.*\|.*\|') def _detect_tables(self, text): lines = text.split('\n') tables = [] current_table = [] for line in lines: if self.table_pattern.search(line): current_table.append(line) else: if current_table: tables.append('\n'.join(current_table)) current_table = [] if current_table: tables.append('\n'.join(current_table)) return tables def index_documents(self, documents): table_docs = [] text_docs = [] for doc in documents: tables = self._detect_tables(doc.page_content) if tables: for i, table in enumerate(tables): table_doc = type(doc)( page_content=table, metadata=dict(doc.metadata, is_table=True, table_index=i) ) table_docs.append(table_doc) table_text = '\n'.join(tables) text_only = doc.page_content.replace(table_text, '') if text_only.strip(): text_doc = type(doc)( page_content=text_only, metadata=dict(doc.metadata, is_table=False) ) text_docs.append(text_doc) else: text_docs.append(doc) all_docs = text_docs + table_docs texts = [doc.page_content for doc in all_docs] embeddings = self.embedding_model.embed_documents(texts) self.vector_store.add_documents(all_docs, embeddings) logger.info("TableAwareRAG indexed {} text docs and {} tables".format(len(text_docs), len(table_docs))) def retrieve(self, query, k=10): query_embedding = self.embedding_model.embed_query(query) results = self.vector_store.similarity_search(query_embedding, k * 2) table_results = [] text_results = [] for doc, score in results: if doc.metadata.get('is_table'): table_results.append((doc, score * 1.1)) else: text_results.append((doc, score)) combined = table_results + text_results combined.sort(key=lambda x: x[1], reverse=True) return self._deduplicate_results(combined, k) def generate(self, query, context): prompt = ( "根据以下参考文档(包含表格数据),回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答,如果涉及表格数据,请准确提取相关数值。".format(context, query) ) return self._call_llm(prompt)