# -*- coding: utf-8 -*- """Graph RAG - knowledge graph enhanced retrieval for bidding domain.""" import re from .base import BaseRAG from loguru import logger class GraphRAG(BaseRAG): def __init__(self, **kwargs): super().__init__(**kwargs) self.graph = {} self._build_graph_patterns = [ (r'预算金额[::]?\s*([0-9.]+)\s*万', 'budget_amount'), (r'投标保证金[::]?\s*([0-9.]+)\s*万', 'bid_bond'), (r'质保期[::]?\s*([^,,\n。]+)', 'warranty'), (r'交货时间[::]?\s*([^,,\n。]+)', 'delivery_time'), (r'资质要求[::]?\s*([^,,\n。]+)', 'qualification'), (r'评标方法[::]?\s*([^,,\n。]+)', 'evaluation_method'), ] def _extract_entities(self, text): entities = {} for pattern, entity_type in self._build_graph_patterns: match = re.search(pattern, text) if match: entities[entity_type] = match.group(1).strip() return entities def index_documents(self, documents): self.graph = {} for doc in documents: entities = self._extract_entities(doc.page_content) if entities: self.graph[id(doc)] = { 'doc': doc, 'entities': entities } texts = [doc.page_content for doc in documents] embeddings = self.embedding_model.embed_documents(texts) self.vector_store.add_documents(documents, embeddings) logger.info("GraphRAG built graph with {} nodes".format(len(self.graph))) def _query_graph(self, query): matching_nodes = [] for node_id, node_data in self.graph.items(): for entity_type, entity_value in node_data['entities'].items(): if entity_value in query: matching_nodes.append((node_data['doc'], 1.0)) break return matching_nodes def retrieve(self, query, k=10): query_embedding = self.embedding_model.embed_query(query) semantic_results = self.vector_store.similarity_search(query_embedding, k) graph_results = self._query_graph(query) combined = {} for doc, score in semantic_results: combined[self._dedup_key(doc)] = (doc, score) for doc, score in graph_results: key = self._dedup_key(doc) if key in combined: combined[key] = (doc, combined[key][1] + 0.2) else: combined[key] = (doc, score) results = list(combined.values()) results.sort(key=lambda x: x[1], reverse=True) return self._deduplicate_results(results, k) def generate(self, query, context): prompt = ( "根据以下参考文档(知识图谱增强检索),回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) return self._call_llm(prompt)