query_routing_rag.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # -*- coding: utf-8 -*-
  2. """Query Routing RAG - classify query and boost retrieval by category keywords."""
  3. from .base import BaseRAG
  4. from loguru import logger
  5. class QueryRoutingRAG(BaseRAG):
  6. def __init__(self, **kwargs):
  7. super().__init__(**kwargs)
  8. def _classify_query(self, query):
  9. categories = [
  10. "budget", "deadline", "qualification", "evaluation",
  11. "payment", "warranty", "delivery", "contact", "scope"
  12. ]
  13. categories_str = ", ".join(categories)
  14. prompt = (
  15. "将以下问题分类到以下类别之一:{}\n\n"
  16. "问题:{}\n\n"
  17. "类别:".format(categories_str, query)
  18. )
  19. return self._call_llm(prompt).strip().lower()
  20. def retrieve(self, query, k=10):
  21. query_embedding = self.embedding_model.embed_query(query)
  22. semantic_results = self.vector_store.similarity_search(query_embedding, k)
  23. category = self._classify_query(query)
  24. logger.info("QueryRouting: category={}".format(category))
  25. category_boost = {}
  26. for doc, score in semantic_results:
  27. content_lower = doc.page_content.lower()
  28. category_keywords = {
  29. "budget": ["预算", "金额", "价格", "费用", "报价"],
  30. "deadline": ["截止", "时间", "日期", "开标"],
  31. "qualification": ["资格", "要求", "证书", "业绩"],
  32. "evaluation": ["评标", "评价", "分数", "方法"],
  33. "payment": ["付款", "结算", "进度", "保证金"],
  34. "warranty": ["质保", "维修", "售后", "服务"],
  35. "delivery": ["交货", "工期", "交付", "地点"],
  36. "contact": ["联系人", "电话", "邮箱"],
  37. "scope": ["范围", "内容", "清单", "设备"],
  38. }
  39. keywords = category_keywords.get(category, [])
  40. boost = sum(1 for kw in keywords if kw in content_lower) * 0.1
  41. category_boost[self._dedup_key(doc)] = boost
  42. enhanced_results = []
  43. for doc, score in semantic_results:
  44. boost = category_boost.get(self._dedup_key(doc), 0)
  45. enhanced_results.append((doc, score + boost))
  46. enhanced_results.sort(key=lambda x: x[1], reverse=True)
  47. return self._deduplicate_results(enhanced_results, k)
  48. def generate(self, query, context):
  49. prompt = (
  50. "根据以下经过查询路由的参考文档,回答问题。\n\n"
  51. "参考文档:\n{}\n\n"
  52. "问题:{}\n\n"
  53. "请详细回答。".format(context, query)
  54. )
  55. return self._call_llm(prompt)