| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- # -*- coding: utf-8 -*-
- """Table-Aware RAG - specialized retrieval for tabular data in bidding documents."""
- import re
- from .base import BaseRAG
- from loguru import logger
- class TableAwareRAG(BaseRAG):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.table_pattern = re.compile(r'\|.*\|.*\|')
- def _detect_tables(self, text):
- lines = text.split('\n')
- tables = []
- current_table = []
-
- for line in lines:
- if self.table_pattern.search(line):
- current_table.append(line)
- else:
- if current_table:
- tables.append('\n'.join(current_table))
- current_table = []
-
- if current_table:
- tables.append('\n'.join(current_table))
-
- return tables
- def index_documents(self, documents):
- table_docs = []
- text_docs = []
-
- for doc in documents:
- tables = self._detect_tables(doc.page_content)
- if tables:
- for i, table in enumerate(tables):
- table_doc = type(doc)(
- page_content=table,
- metadata=dict(doc.metadata, is_table=True, table_index=i)
- )
- table_docs.append(table_doc)
-
- table_text = '\n'.join(tables)
- text_only = doc.page_content.replace(table_text, '')
- if text_only.strip():
- text_doc = type(doc)(
- page_content=text_only,
- metadata=dict(doc.metadata, is_table=False)
- )
- text_docs.append(text_doc)
- else:
- text_docs.append(doc)
-
- all_docs = text_docs + table_docs
- texts = [doc.page_content for doc in all_docs]
- embeddings = self.embedding_model.embed_documents(texts)
- self.vector_store.add_documents(all_docs, embeddings)
-
- logger.info("TableAwareRAG indexed {} text docs and {} tables".format(len(text_docs), len(table_docs)))
- def retrieve(self, query, k=10):
- query_embedding = self.embedding_model.embed_query(query)
- results = self.vector_store.similarity_search(query_embedding, k * 2)
-
- table_results = []
- text_results = []
-
- for doc, score in results:
- if doc.metadata.get('is_table'):
- table_results.append((doc, score * 1.1))
- else:
- text_results.append((doc, score))
-
- combined = table_results + text_results
- combined.sort(key=lambda x: x[1], reverse=True)
-
- return self._deduplicate_results(combined, k)
- def generate(self, query, context):
- prompt = (
- "根据以下参考文档(包含表格数据),回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答,如果涉及表格数据,请准确提取相关数值。".format(context, query)
- )
- return self._call_llm(prompt)
|