graph_rag.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # -*- coding: utf-8 -*-
  2. """Graph RAG - knowledge graph enhanced retrieval for bidding domain."""
  3. import re
  4. from .base import BaseRAG
  5. from loguru import logger
  6. class GraphRAG(BaseRAG):
  7. def __init__(self, **kwargs):
  8. super().__init__(**kwargs)
  9. self.graph = {}
  10. self._build_graph_patterns = [
  11. (r'预算金额[::]?\s*([0-9.]+)\s*万', 'budget_amount'),
  12. (r'投标保证金[::]?\s*([0-9.]+)\s*万', 'bid_bond'),
  13. (r'质保期[::]?\s*([^,,\n。]+)', 'warranty'),
  14. (r'交货时间[::]?\s*([^,,\n。]+)', 'delivery_time'),
  15. (r'资质要求[::]?\s*([^,,\n。]+)', 'qualification'),
  16. (r'评标方法[::]?\s*([^,,\n。]+)', 'evaluation_method'),
  17. ]
  18. def _extract_entities(self, text):
  19. entities = {}
  20. for pattern, entity_type in self._build_graph_patterns:
  21. match = re.search(pattern, text)
  22. if match:
  23. entities[entity_type] = match.group(1).strip()
  24. return entities
  25. def index_documents(self, documents):
  26. self.graph = {}
  27. for doc in documents:
  28. entities = self._extract_entities(doc.page_content)
  29. if entities:
  30. self.graph[id(doc)] = {
  31. 'doc': doc,
  32. 'entities': entities
  33. }
  34. texts = [doc.page_content for doc in documents]
  35. embeddings = self.embedding_model.embed_documents(texts)
  36. self.vector_store.add_documents(documents, embeddings)
  37. logger.info("GraphRAG built graph with {} nodes".format(len(self.graph)))
  38. def _query_graph(self, query):
  39. matching_nodes = []
  40. for node_id, node_data in self.graph.items():
  41. for entity_type, entity_value in node_data['entities'].items():
  42. if entity_value in query:
  43. matching_nodes.append((node_data['doc'], 1.0))
  44. break
  45. return matching_nodes
  46. def retrieve(self, query, k=10):
  47. query_embedding = self.embedding_model.embed_query(query)
  48. semantic_results = self.vector_store.similarity_search(query_embedding, k)
  49. graph_results = self._query_graph(query)
  50. combined = {}
  51. for doc, score in semantic_results:
  52. combined[self._dedup_key(doc)] = (doc, score)
  53. for doc, score in graph_results:
  54. key = self._dedup_key(doc)
  55. if key in combined:
  56. combined[key] = (doc, combined[key][1] + 0.2)
  57. else:
  58. combined[key] = (doc, score)
  59. results = list(combined.values())
  60. results.sort(key=lambda x: x[1], reverse=True)
  61. return self._deduplicate_results(results, k)
  62. def generate(self, query, context):
  63. prompt = (
  64. "根据以下参考文档(知识图谱增强检索),回答问题。\n\n"
  65. "参考文档:\n{}\n\n"
  66. "问题:{}\n\n"
  67. "请详细回答。".format(context, query)
  68. )
  69. return self._call_llm(prompt)