| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- # -*- coding: utf-8 -*-
- """Graph RAG - knowledge graph enhanced retrieval for bidding domain."""
- import re
- from .base import BaseRAG
- from loguru import logger
- class GraphRAG(BaseRAG):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.graph = {}
- self._build_graph_patterns = [
- (r'预算金额[::]?\s*([0-9.]+)\s*万', 'budget_amount'),
- (r'投标保证金[::]?\s*([0-9.]+)\s*万', 'bid_bond'),
- (r'质保期[::]?\s*([^,,\n。]+)', 'warranty'),
- (r'交货时间[::]?\s*([^,,\n。]+)', 'delivery_time'),
- (r'资质要求[::]?\s*([^,,\n。]+)', 'qualification'),
- (r'评标方法[::]?\s*([^,,\n。]+)', 'evaluation_method'),
- ]
- def _extract_entities(self, text):
- entities = {}
- for pattern, entity_type in self._build_graph_patterns:
- match = re.search(pattern, text)
- if match:
- entities[entity_type] = match.group(1).strip()
- return entities
- def index_documents(self, documents):
- self.graph = {}
- for doc in documents:
- entities = self._extract_entities(doc.page_content)
- if entities:
- self.graph[id(doc)] = {
- 'doc': doc,
- 'entities': entities
- }
-
- texts = [doc.page_content for doc in documents]
- embeddings = self.embedding_model.embed_documents(texts)
- self.vector_store.add_documents(documents, embeddings)
-
- logger.info("GraphRAG built graph with {} nodes".format(len(self.graph)))
- def _query_graph(self, query):
- matching_nodes = []
- for node_id, node_data in self.graph.items():
- for entity_type, entity_value in node_data['entities'].items():
- if entity_value in query:
- matching_nodes.append((node_data['doc'], 1.0))
- break
- return matching_nodes
- def retrieve(self, query, k=10):
- query_embedding = self.embedding_model.embed_query(query)
- semantic_results = self.vector_store.similarity_search(query_embedding, k)
-
- graph_results = self._query_graph(query)
-
- combined = {}
- for doc, score in semantic_results:
- combined[self._dedup_key(doc)] = (doc, score)
-
- for doc, score in graph_results:
- key = self._dedup_key(doc)
- if key in combined:
- combined[key] = (doc, combined[key][1] + 0.2)
- else:
- combined[key] = (doc, score)
-
- results = list(combined.values())
- results.sort(key=lambda x: x[1], reverse=True)
-
- return self._deduplicate_results(results, k)
- def generate(self, query, context):
- prompt = (
- "根据以下参考文档(知识图谱增强检索),回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
- return self._call_llm(prompt)
|