table_aware_rag.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # -*- coding: utf-8 -*-
  2. """Table-Aware RAG - specialized retrieval for tabular data in bidding documents."""
  3. import re
  4. from .base import BaseRAG
  5. from loguru import logger
  6. class TableAwareRAG(BaseRAG):
  7. def __init__(self, **kwargs):
  8. super().__init__(**kwargs)
  9. self.table_pattern = re.compile(r'\|.*\|.*\|')
  10. def _detect_tables(self, text):
  11. lines = text.split('\n')
  12. tables = []
  13. current_table = []
  14. for line in lines:
  15. if self.table_pattern.search(line):
  16. current_table.append(line)
  17. else:
  18. if current_table:
  19. tables.append('\n'.join(current_table))
  20. current_table = []
  21. if current_table:
  22. tables.append('\n'.join(current_table))
  23. return tables
  24. def index_documents(self, documents):
  25. table_docs = []
  26. text_docs = []
  27. for doc in documents:
  28. tables = self._detect_tables(doc.page_content)
  29. if tables:
  30. for i, table in enumerate(tables):
  31. table_doc = type(doc)(
  32. page_content=table,
  33. metadata=dict(doc.metadata, is_table=True, table_index=i)
  34. )
  35. table_docs.append(table_doc)
  36. table_text = '\n'.join(tables)
  37. text_only = doc.page_content.replace(table_text, '')
  38. if text_only.strip():
  39. text_doc = type(doc)(
  40. page_content=text_only,
  41. metadata=dict(doc.metadata, is_table=False)
  42. )
  43. text_docs.append(text_doc)
  44. else:
  45. text_docs.append(doc)
  46. all_docs = text_docs + table_docs
  47. texts = [doc.page_content for doc in all_docs]
  48. embeddings = self.embedding_model.embed_documents(texts)
  49. self.vector_store.add_documents(all_docs, embeddings)
  50. logger.info("TableAwareRAG indexed {} text docs and {} tables".format(len(text_docs), len(table_docs)))
  51. def retrieve(self, query, k=10):
  52. query_embedding = self.embedding_model.embed_query(query)
  53. results = self.vector_store.similarity_search(query_embedding, k * 2)
  54. table_results = []
  55. text_results = []
  56. for doc, score in results:
  57. if doc.metadata.get('is_table'):
  58. table_results.append((doc, score * 1.1))
  59. else:
  60. text_results.append((doc, score))
  61. combined = table_results + text_results
  62. combined.sort(key=lambda x: x[1], reverse=True)
  63. return self._deduplicate_results(combined, k)
  64. def generate(self, query, context):
  65. prompt = (
  66. "根据以下参考文档(包含表格数据),回答问题。\n\n"
  67. "参考文档:\n{}\n\n"
  68. "问题:{}\n\n"
  69. "请详细回答,如果涉及表格数据,请准确提取相关数值。".format(context, query)
  70. )
  71. return self._call_llm(prompt)