# -*- coding: utf-8 -*- """Query Routing RAG - classify query and boost retrieval by category keywords.""" from .base import BaseRAG from loguru import logger class QueryRoutingRAG(BaseRAG): def __init__(self, **kwargs): super().__init__(**kwargs) def _classify_query(self, query): categories = [ "budget", "deadline", "qualification", "evaluation", "payment", "warranty", "delivery", "contact", "scope" ] categories_str = ", ".join(categories) prompt = ( "将以下问题分类到以下类别之一:{}\n\n" "问题:{}\n\n" "类别:".format(categories_str, query) ) return self._call_llm(prompt).strip().lower() def retrieve(self, query, k=10): query_embedding = self.embedding_model.embed_query(query) semantic_results = self.vector_store.similarity_search(query_embedding, k) category = self._classify_query(query) logger.info("QueryRouting: category={}".format(category)) category_boost = {} for doc, score in semantic_results: content_lower = doc.page_content.lower() category_keywords = { "budget": ["预算", "金额", "价格", "费用", "报价"], "deadline": ["截止", "时间", "日期", "开标"], "qualification": ["资格", "要求", "证书", "业绩"], "evaluation": ["评标", "评价", "分数", "方法"], "payment": ["付款", "结算", "进度", "保证金"], "warranty": ["质保", "维修", "售后", "服务"], "delivery": ["交货", "工期", "交付", "地点"], "contact": ["联系人", "电话", "邮箱"], "scope": ["范围", "内容", "清单", "设备"], } keywords = category_keywords.get(category, []) boost = sum(1 for kw in keywords if kw in content_lower) * 0.1 category_boost[self._dedup_key(doc)] = boost enhanced_results = [] for doc, score in semantic_results: boost = category_boost.get(self._dedup_key(doc), 0) enhanced_results.append((doc, score + boost)) enhanced_results.sort(key=lambda x: x[1], reverse=True) return self._deduplicate_results(enhanced_results, k) def generate(self, query, context): prompt = ( "根据以下经过查询路由的参考文档,回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) return self._call_llm(prompt)