| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- # -*- coding: utf-8 -*-
- """Query Routing RAG - classify query and boost retrieval by category keywords."""
- from .base import BaseRAG
- from loguru import logger
- class QueryRoutingRAG(BaseRAG):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- def _classify_query(self, query):
- categories = [
- "budget", "deadline", "qualification", "evaluation",
- "payment", "warranty", "delivery", "contact", "scope"
- ]
- categories_str = ", ".join(categories)
- prompt = (
- "将以下问题分类到以下类别之一:{}\n\n"
- "问题:{}\n\n"
- "类别:".format(categories_str, query)
- )
- return self._call_llm(prompt).strip().lower()
- def retrieve(self, query, k=10):
- query_embedding = self.embedding_model.embed_query(query)
- semantic_results = self.vector_store.similarity_search(query_embedding, k)
-
- category = self._classify_query(query)
- logger.info("QueryRouting: category={}".format(category))
-
- category_boost = {}
- for doc, score in semantic_results:
- content_lower = doc.page_content.lower()
- category_keywords = {
- "budget": ["预算", "金额", "价格", "费用", "报价"],
- "deadline": ["截止", "时间", "日期", "开标"],
- "qualification": ["资格", "要求", "证书", "业绩"],
- "evaluation": ["评标", "评价", "分数", "方法"],
- "payment": ["付款", "结算", "进度", "保证金"],
- "warranty": ["质保", "维修", "售后", "服务"],
- "delivery": ["交货", "工期", "交付", "地点"],
- "contact": ["联系人", "电话", "邮箱"],
- "scope": ["范围", "内容", "清单", "设备"],
- }
- keywords = category_keywords.get(category, [])
- boost = sum(1 for kw in keywords if kw in content_lower) * 0.1
- category_boost[self._dedup_key(doc)] = boost
-
- enhanced_results = []
- for doc, score in semantic_results:
- boost = category_boost.get(self._dedup_key(doc), 0)
- enhanced_results.append((doc, score + boost))
-
- enhanced_results.sort(key=lambda x: x[1], reverse=True)
- return self._deduplicate_results(enhanced_results, k)
- def generate(self, query, context):
- prompt = (
- "根据以下经过查询路由的参考文档,回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
- return self._call_llm(prompt)
|