# -*- coding: utf-8 -*- """Bid-field extraction RAG tuned for tender/bidding HTML text recall.""" import json import re import time from collections import OrderedDict import numpy as np from loguru import logger from .base import BaseRAG, RAGResult from .bm25_backend import get_bm25_okapi from .dedup import content_dedup_key from .tokenization import bm25_tokenize FIELD_SPECS = OrderedDict([ ("project_name", { "label": "项目名称", "aliases": ["项目名称", "采购项目名称", "招标项目名称", "Project Name"], }), ("project_code", { "label": "项目编号", "aliases": ["项目编号", "采购编号", "招标编号", "Project Code", "Tender No"], }), ("budget_amount", { "label": "预算金额", "aliases": ["预算金额", "采购预算", "项目预算", "最高限价", "Budget Amount", "Project Budget"], }), ("currency", { "label": "币种", "aliases": ["币种", "货币", "人民币", "Currency", "RMB"], }), ("bid_deadline", { "label": "投标截止时间", "aliases": ["投标截止时间", "递交截止时间", "提交投标文件截止时间", "Bid Submission Deadline", "Bid Deadline"], }), ("bid_opening_time", { "label": "开标时间", "aliases": ["开标时间", "开启时间", "Bid Opening Time", "Bid Opening"], }), ("bid_location", { "label": "投标地点", "aliases": ["投标地点", "递交地点", "开标地点", "Bid Location", "Venue"], }), ("purchaser_name", { "label": "采购人名称", "aliases": ["采购人", "招标人", "采购单位", "Purchaser", "Tenderer"], }), ("purchaser_contact", { "label": "采购人联系人", "aliases": ["采购人联系人", "联系人", "Contact Person", "Purchaser Contact"], }), ("purchaser_phone", { "label": "采购人电话", "aliases": ["采购人电话", "联系电话", "Contact Phone", "Purchaser Phone"], }), ("agency_name", { "label": "代理机构名称", "aliases": ["代理机构", "采购代理机构", "招标代理", "Agency Name", "Bidding Agency"], }), ("agency_contact", { "label": "代理机构联系人", "aliases": ["代理机构联系人", "Agency Contact"], }), ("agency_phone", { "label": "代理机构电话", "aliases": ["代理机构电话", "Agency Phone"], }), ("qualification_requirements", { "label": "资格要求", "aliases": ["资格要求", "资质要求", "投标人资格", "Qualification Requirements"], }), ("bid_bond_amount", { "label": "投标保证金", "aliases": ["投标保证金", "Bid Bond", "Bid Bond Amount"], }), ("performance_bond_amount", { "label": "履约保证金", "aliases": ["履约保证金", "Performance Bond", "Performance Bond Amount"], }), ("warranty_period", { "label": "质保期", "aliases": ["质保期", "保修期", "免费保修", "Warranty Period", "Warranty"], }), ("delivery_time", { "label": "交货时间", "aliases": ["交货时间", "交付时间", "工期", "建设周期", "Delivery Time", "Construction Period"], }), ("delivery_location", { "label": "交货地点", "aliases": ["交货地点", "交付地点", "安装地点", "Delivery Location", "Installation Location"], }), ("payment_terms", { "label": "付款方式", "aliases": ["付款方式", "支付方式", "Payment Terms", "Payment"], }), ("evaluation_method", { "label": "评标方法", "aliases": ["评标方法", "评审方法", "评分办法", "Evaluation Method"], }), ("scope_of_work", { "label": "工作范围", "aliases": ["工作范围", "采购内容", "招标范围", "建设内容", "Scope of Work", "Scope"], }), ]) class BidFieldExtractionRAG(BaseRAG): """Retrieve field-specific context before asking the LLM to extract JSON.""" def __init__(self, extraction_prompt_template=None, bm25_weight=0.45, vector_weight=0.55, **kwargs): super(BidFieldExtractionRAG, self).__init__(**kwargs) self.extraction_prompt_template = extraction_prompt_template or ( "你是招投标领域的信息抽取助手。请只依据给定上下文抽取字段。\n\n" "目标字段:\n{fields}\n\n" "上下文:\n{context}\n\n" "要求:\n" "1. 仅返回一个 JSON 对象,不要输出解释。\n" "2. JSON key 必须使用目标字段英文名。\n" "3. 无法从上下文确认的字段返回 null。\n" "4. 金额、时间、联系人、电话等值要保留原文表述。" ) self.target_fields = list(FIELD_SPECS.keys()) self.field_specs = FIELD_SPECS self.bm25_weight = float(bm25_weight) self.vector_weight = float(vector_weight) self._documents = [] self._bm25 = None self._bm25_corpus = [] def index_documents(self, documents): self._documents = list(documents or []) self._build_bm25(self._documents) if self.embedding_model is not None and self.vector_store is not None: super(BidFieldExtractionRAG, self).index_documents(self._documents) def _build_bm25(self, documents): self._bm25_corpus = [] for doc in documents: meta_text = " ".join(str(v) for v in (doc.metadata or {}).values()) self._bm25_corpus.append(bm25_tokenize("{}\n{}".format(meta_text, doc.page_content))) BM25Okapi = get_bm25_okapi() self._bm25 = BM25Okapi(self._bm25_corpus) if self._bm25_corpus else None def retrieve(self, query, k=10): """Hybrid retrieval for a free-form query.""" return self._hybrid_retrieve(query, k=k) def retrieve_for_fields(self, fields=None, k_per_field=3, max_docs=12): """Retrieve and deduplicate contexts for each requested field.""" selected_fields = [f for f in (fields or self.target_fields) if f in self.field_specs] merged = OrderedDict() for field in selected_fields: field_query = self._field_query(field) for doc, score in self._hybrid_retrieve(field_query, k=k_per_field): key = self._dedup_key(doc) if key not in merged: metadata = dict(doc.metadata or {}) metadata["matched_fields"] = [field] metadata["retrieval_query"] = field_query merged[key] = [doc, float(score), metadata] else: merged[key][1] = max(merged[key][1], float(score)) merged[key][2]["matched_fields"].append(field) results = [] for doc, score, metadata in merged.values(): doc.metadata.update(metadata) doc.metadata["matched_fields"] = sorted(set(doc.metadata["matched_fields"])) results.append((doc, score)) results.sort(key=lambda item: item[1], reverse=True) return self._deduplicate_results(results, max_docs) def _field_query(self, field): spec = self.field_specs[field] return "{} {} {}".format(field, spec["label"], " ".join(spec["aliases"])) def _hybrid_retrieve(self, query, k=10): k = max(0, int(k or 0)) if k == 0: return [] vector_scores = self._vector_scores(query, k=max(k * 4, k)) bm25_scores = self._bm25_scores(query) merged = {} for doc, score in vector_scores: merged[self._dedup_key(doc)] = [doc, self._normalize_score(score) * self.vector_weight] for doc, score in bm25_scores: key = self._dedup_key(doc) if key not in merged: merged[key] = [doc, 0.0] merged[key][1] += self._normalize_score(score) * self.bm25_weight results = [(doc, score) for doc, score in merged.values() if score > 0] results.sort(key=lambda item: item[1], reverse=True) return self._deduplicate_results(results, k) def _vector_scores(self, query, k): if self.embedding_model is None or self.vector_store is None: return [] try: query_embedding = self.embedding_model.embed_query(query) return self.vector_store.similarity_search(query_embedding, k) except Exception as exc: logger.warning("Vector retrieval failed, falling back to BM25: {}".format(exc)) return [] def _bm25_scores(self, query): if self._bm25 is None: store_docs = getattr(self.vector_store, "documents", None) if store_docs: self._documents = list(store_docs) self._build_bm25(self._documents) if self._bm25 is None: return [] query_tokens = bm25_tokenize(query) if not query_tokens: return [] scores = self._bm25.get_scores(query_tokens) max_score = float(np.max(scores)) if len(scores) else 0.0 if max_score <= 0: return [] scored = [] for doc, score in zip(self._documents, scores): score = float(score) if score > 0: scored.append((doc, score / max_score)) scored.sort(key=lambda item: item[1], reverse=True) return scored @staticmethod def _normalize_score(score): try: value = float(score) except (TypeError, ValueError): return 0.0 if value < 0: return 0.0 return min(value, 1.0) def generate(self, query, context): fields_str = "\n".join( "- {}: {} ({})".format(field, spec["label"], ", ".join(spec["aliases"][:4])) for field, spec in self.field_specs.items() if field in self.target_fields ) prompt = self.extraction_prompt_template.format( fields=fields_str, context=context[:6000], ) return self._call_llm(prompt) def extract_fields(self, query=None, k=5, fields=None): start_total = time.time() t0 = time.time() requested_fields = [f for f in (fields or self.target_fields) if f in self.field_specs] docs = self.retrieve_for_fields(requested_fields, k_per_field=max(1, k), max_docs=max(k, len(requested_fields))) if query: docs = self._merge_ranked(docs, self.retrieve(query, k=k)) retrieval_time = time.time() - t0 context = self._format_context(docs) t1 = time.time() raw_response = self.generate(query or "", context) generation_time = time.time() - t1 fields_result = self._parse_extraction(raw_response) total_time = time.time() - start_total return RAGResult( answer=json.dumps(fields_result, indent=2, ensure_ascii=False), retrieved_docs=docs, latency_retrieval=retrieval_time, latency_generation=generation_time, latency_total=total_time, metadata={ "method": self.name, "fields": fields_result, "requested_fields": requested_fields, "num_context_chars": len(context), }, ) @staticmethod def _merge_ranked(primary, secondary): merged = OrderedDict() for doc, score in list(primary or []) + list(secondary or []): key = content_dedup_key(doc) if key not in merged or float(score) > merged[key][1]: merged[key] = [doc, float(score)] results = [(doc, score) for doc, score in merged.values()] results.sort(key=lambda item: item[1], reverse=True) return results def _parse_extraction(self, raw_response): if not raw_response: return {field: None for field in self.target_fields} try: start = raw_response.find("{") end = raw_response.rfind("}") + 1 if start >= 0 and end > start: json_str = raw_response[start:end] json_str = re.sub(r",\s*}", "}", json_str) parsed = json.loads(json_str) return {field: parsed.get(field) for field in self.target_fields} except (TypeError, ValueError, json.JSONDecodeError): logger.warning("Failed to parse JSON from LLM response") result = {field: None for field in self.target_fields} result["raw"] = raw_response return result