bid_field_extraction_rag.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # -*- coding: utf-8 -*-
  2. """Bid-field extraction RAG tuned for tender/bidding HTML text recall."""
  3. import json
  4. import re
  5. import time
  6. from collections import OrderedDict
  7. import numpy as np
  8. from loguru import logger
  9. from .base import BaseRAG, RAGResult
  10. from .bm25_backend import get_bm25_okapi
  11. from .dedup import content_dedup_key
  12. from .tokenization import bm25_tokenize
  13. FIELD_SPECS = OrderedDict([
  14. ("project_name", {
  15. "label": "项目名称",
  16. "aliases": ["项目名称", "采购项目名称", "招标项目名称", "Project Name"],
  17. }),
  18. ("project_code", {
  19. "label": "项目编号",
  20. "aliases": ["项目编号", "采购编号", "招标编号", "Project Code", "Tender No"],
  21. }),
  22. ("budget_amount", {
  23. "label": "预算金额",
  24. "aliases": ["预算金额", "采购预算", "项目预算", "最高限价", "Budget Amount", "Project Budget"],
  25. }),
  26. ("currency", {
  27. "label": "币种",
  28. "aliases": ["币种", "货币", "人民币", "Currency", "RMB"],
  29. }),
  30. ("bid_deadline", {
  31. "label": "投标截止时间",
  32. "aliases": ["投标截止时间", "递交截止时间", "提交投标文件截止时间", "Bid Submission Deadline", "Bid Deadline"],
  33. }),
  34. ("bid_opening_time", {
  35. "label": "开标时间",
  36. "aliases": ["开标时间", "开启时间", "Bid Opening Time", "Bid Opening"],
  37. }),
  38. ("bid_location", {
  39. "label": "投标地点",
  40. "aliases": ["投标地点", "递交地点", "开标地点", "Bid Location", "Venue"],
  41. }),
  42. ("purchaser_name", {
  43. "label": "采购人名称",
  44. "aliases": ["采购人", "招标人", "采购单位", "Purchaser", "Tenderer"],
  45. }),
  46. ("purchaser_contact", {
  47. "label": "采购人联系人",
  48. "aliases": ["采购人联系人", "联系人", "Contact Person", "Purchaser Contact"],
  49. }),
  50. ("purchaser_phone", {
  51. "label": "采购人电话",
  52. "aliases": ["采购人电话", "联系电话", "Contact Phone", "Purchaser Phone"],
  53. }),
  54. ("agency_name", {
  55. "label": "代理机构名称",
  56. "aliases": ["代理机构", "采购代理机构", "招标代理", "Agency Name", "Bidding Agency"],
  57. }),
  58. ("agency_contact", {
  59. "label": "代理机构联系人",
  60. "aliases": ["代理机构联系人", "Agency Contact"],
  61. }),
  62. ("agency_phone", {
  63. "label": "代理机构电话",
  64. "aliases": ["代理机构电话", "Agency Phone"],
  65. }),
  66. ("qualification_requirements", {
  67. "label": "资格要求",
  68. "aliases": ["资格要求", "资质要求", "投标人资格", "Qualification Requirements"],
  69. }),
  70. ("bid_bond_amount", {
  71. "label": "投标保证金",
  72. "aliases": ["投标保证金", "Bid Bond", "Bid Bond Amount"],
  73. }),
  74. ("performance_bond_amount", {
  75. "label": "履约保证金",
  76. "aliases": ["履约保证金", "Performance Bond", "Performance Bond Amount"],
  77. }),
  78. ("warranty_period", {
  79. "label": "质保期",
  80. "aliases": ["质保期", "保修期", "免费保修", "Warranty Period", "Warranty"],
  81. }),
  82. ("delivery_time", {
  83. "label": "交货时间",
  84. "aliases": ["交货时间", "交付时间", "工期", "建设周期", "Delivery Time", "Construction Period"],
  85. }),
  86. ("delivery_location", {
  87. "label": "交货地点",
  88. "aliases": ["交货地点", "交付地点", "安装地点", "Delivery Location", "Installation Location"],
  89. }),
  90. ("payment_terms", {
  91. "label": "付款方式",
  92. "aliases": ["付款方式", "支付方式", "Payment Terms", "Payment"],
  93. }),
  94. ("evaluation_method", {
  95. "label": "评标方法",
  96. "aliases": ["评标方法", "评审方法", "评分办法", "Evaluation Method"],
  97. }),
  98. ("scope_of_work", {
  99. "label": "工作范围",
  100. "aliases": ["工作范围", "采购内容", "招标范围", "建设内容", "Scope of Work", "Scope"],
  101. }),
  102. ])
  103. class BidFieldExtractionRAG(BaseRAG):
  104. """Retrieve field-specific context before asking the LLM to extract JSON."""
  105. def __init__(self, extraction_prompt_template=None, bm25_weight=0.45, vector_weight=0.55, **kwargs):
  106. super(BidFieldExtractionRAG, self).__init__(**kwargs)
  107. self.extraction_prompt_template = extraction_prompt_template or (
  108. "你是招投标领域的信息抽取助手。请只依据给定上下文抽取字段。\n\n"
  109. "目标字段:\n{fields}\n\n"
  110. "上下文:\n{context}\n\n"
  111. "要求:\n"
  112. "1. 仅返回一个 JSON 对象,不要输出解释。\n"
  113. "2. JSON key 必须使用目标字段英文名。\n"
  114. "3. 无法从上下文确认的字段返回 null。\n"
  115. "4. 金额、时间、联系人、电话等值要保留原文表述。"
  116. )
  117. self.target_fields = list(FIELD_SPECS.keys())
  118. self.field_specs = FIELD_SPECS
  119. self.bm25_weight = float(bm25_weight)
  120. self.vector_weight = float(vector_weight)
  121. self._documents = []
  122. self._bm25 = None
  123. self._bm25_corpus = []
  124. def index_documents(self, documents):
  125. self._documents = list(documents or [])
  126. self._build_bm25(self._documents)
  127. if self.embedding_model is not None and self.vector_store is not None:
  128. super(BidFieldExtractionRAG, self).index_documents(self._documents)
  129. def _build_bm25(self, documents):
  130. self._bm25_corpus = []
  131. for doc in documents:
  132. meta_text = " ".join(str(v) for v in (doc.metadata or {}).values())
  133. self._bm25_corpus.append(bm25_tokenize("{}\n{}".format(meta_text, doc.page_content)))
  134. BM25Okapi = get_bm25_okapi()
  135. self._bm25 = BM25Okapi(self._bm25_corpus) if self._bm25_corpus else None
  136. def retrieve(self, query, k=10):
  137. """Hybrid retrieval for a free-form query."""
  138. return self._hybrid_retrieve(query, k=k)
  139. def retrieve_for_fields(self, fields=None, k_per_field=3, max_docs=12):
  140. """Retrieve and deduplicate contexts for each requested field."""
  141. selected_fields = [f for f in (fields or self.target_fields) if f in self.field_specs]
  142. merged = OrderedDict()
  143. for field in selected_fields:
  144. field_query = self._field_query(field)
  145. for doc, score in self._hybrid_retrieve(field_query, k=k_per_field):
  146. key = self._dedup_key(doc)
  147. if key not in merged:
  148. metadata = dict(doc.metadata or {})
  149. metadata["matched_fields"] = [field]
  150. metadata["retrieval_query"] = field_query
  151. merged[key] = [doc, float(score), metadata]
  152. else:
  153. merged[key][1] = max(merged[key][1], float(score))
  154. merged[key][2]["matched_fields"].append(field)
  155. results = []
  156. for doc, score, metadata in merged.values():
  157. doc.metadata.update(metadata)
  158. doc.metadata["matched_fields"] = sorted(set(doc.metadata["matched_fields"]))
  159. results.append((doc, score))
  160. results.sort(key=lambda item: item[1], reverse=True)
  161. return self._deduplicate_results(results, max_docs)
  162. def _field_query(self, field):
  163. spec = self.field_specs[field]
  164. return "{} {} {}".format(field, spec["label"], " ".join(spec["aliases"]))
  165. def _hybrid_retrieve(self, query, k=10):
  166. k = max(0, int(k or 0))
  167. if k == 0:
  168. return []
  169. vector_scores = self._vector_scores(query, k=max(k * 4, k))
  170. bm25_scores = self._bm25_scores(query)
  171. merged = {}
  172. for doc, score in vector_scores:
  173. merged[self._dedup_key(doc)] = [doc, self._normalize_score(score) * self.vector_weight]
  174. for doc, score in bm25_scores:
  175. key = self._dedup_key(doc)
  176. if key not in merged:
  177. merged[key] = [doc, 0.0]
  178. merged[key][1] += self._normalize_score(score) * self.bm25_weight
  179. results = [(doc, score) for doc, score in merged.values() if score > 0]
  180. results.sort(key=lambda item: item[1], reverse=True)
  181. return self._deduplicate_results(results, k)
  182. def _vector_scores(self, query, k):
  183. if self.embedding_model is None or self.vector_store is None:
  184. return []
  185. try:
  186. query_embedding = self.embedding_model.embed_query(query)
  187. return self.vector_store.similarity_search(query_embedding, k)
  188. except Exception as exc:
  189. logger.warning("Vector retrieval failed, falling back to BM25: {}".format(exc))
  190. return []
  191. def _bm25_scores(self, query):
  192. if self._bm25 is None:
  193. store_docs = getattr(self.vector_store, "documents", None)
  194. if store_docs:
  195. self._documents = list(store_docs)
  196. self._build_bm25(self._documents)
  197. if self._bm25 is None:
  198. return []
  199. query_tokens = bm25_tokenize(query)
  200. if not query_tokens:
  201. return []
  202. scores = self._bm25.get_scores(query_tokens)
  203. max_score = float(np.max(scores)) if len(scores) else 0.0
  204. if max_score <= 0:
  205. return []
  206. scored = []
  207. for doc, score in zip(self._documents, scores):
  208. score = float(score)
  209. if score > 0:
  210. scored.append((doc, score / max_score))
  211. scored.sort(key=lambda item: item[1], reverse=True)
  212. return scored
  213. @staticmethod
  214. def _normalize_score(score):
  215. try:
  216. value = float(score)
  217. except (TypeError, ValueError):
  218. return 0.0
  219. if value < 0:
  220. return 0.0
  221. return min(value, 1.0)
  222. def generate(self, query, context):
  223. fields_str = "\n".join(
  224. "- {}: {} ({})".format(field, spec["label"], ", ".join(spec["aliases"][:4]))
  225. for field, spec in self.field_specs.items()
  226. if field in self.target_fields
  227. )
  228. prompt = self.extraction_prompt_template.format(
  229. fields=fields_str,
  230. context=context[:6000],
  231. )
  232. return self._call_llm(prompt)
  233. def extract_fields(self, query=None, k=5, fields=None):
  234. start_total = time.time()
  235. t0 = time.time()
  236. requested_fields = [f for f in (fields or self.target_fields) if f in self.field_specs]
  237. docs = self.retrieve_for_fields(requested_fields, k_per_field=max(1, k), max_docs=max(k, len(requested_fields)))
  238. if query:
  239. docs = self._merge_ranked(docs, self.retrieve(query, k=k))
  240. retrieval_time = time.time() - t0
  241. context = self._format_context(docs)
  242. t1 = time.time()
  243. raw_response = self.generate(query or "", context)
  244. generation_time = time.time() - t1
  245. fields_result = self._parse_extraction(raw_response)
  246. total_time = time.time() - start_total
  247. return RAGResult(
  248. answer=json.dumps(fields_result, indent=2, ensure_ascii=False),
  249. retrieved_docs=docs,
  250. latency_retrieval=retrieval_time,
  251. latency_generation=generation_time,
  252. latency_total=total_time,
  253. metadata={
  254. "method": self.name,
  255. "fields": fields_result,
  256. "requested_fields": requested_fields,
  257. "num_context_chars": len(context),
  258. },
  259. )
  260. @staticmethod
  261. def _merge_ranked(primary, secondary):
  262. merged = OrderedDict()
  263. for doc, score in list(primary or []) + list(secondary or []):
  264. key = content_dedup_key(doc)
  265. if key not in merged or float(score) > merged[key][1]:
  266. merged[key] = [doc, float(score)]
  267. results = [(doc, score) for doc, score in merged.values()]
  268. results.sort(key=lambda item: item[1], reverse=True)
  269. return results
  270. def _parse_extraction(self, raw_response):
  271. if not raw_response:
  272. return {field: None for field in self.target_fields}
  273. try:
  274. start = raw_response.find("{")
  275. end = raw_response.rfind("}") + 1
  276. if start >= 0 and end > start:
  277. json_str = raw_response[start:end]
  278. json_str = re.sub(r",\s*}", "}", json_str)
  279. parsed = json.loads(json_str)
  280. return {field: parsed.get(field) for field in self.target_fields}
  281. except (TypeError, ValueError, json.JSONDecodeError):
  282. logger.warning("Failed to parse JSON from LLM response")
  283. result = {field: None for field in self.target_fields}
  284. result["raw"] = raw_response
  285. return result