| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322 |
- # -*- coding: utf-8 -*-
- """Bid-field extraction RAG tuned for tender/bidding HTML text recall."""
- import json
- import re
- import time
- from collections import OrderedDict
- import numpy as np
- from loguru import logger
- from .base import BaseRAG, RAGResult
- from .bm25_backend import get_bm25_okapi
- from .dedup import content_dedup_key
- from .tokenization import bm25_tokenize
- FIELD_SPECS = OrderedDict([
- ("project_name", {
- "label": "项目名称",
- "aliases": ["项目名称", "采购项目名称", "招标项目名称", "Project Name"],
- }),
- ("project_code", {
- "label": "项目编号",
- "aliases": ["项目编号", "采购编号", "招标编号", "Project Code", "Tender No"],
- }),
- ("budget_amount", {
- "label": "预算金额",
- "aliases": ["预算金额", "采购预算", "项目预算", "最高限价", "Budget Amount", "Project Budget"],
- }),
- ("currency", {
- "label": "币种",
- "aliases": ["币种", "货币", "人民币", "Currency", "RMB"],
- }),
- ("bid_deadline", {
- "label": "投标截止时间",
- "aliases": ["投标截止时间", "递交截止时间", "提交投标文件截止时间", "Bid Submission Deadline", "Bid Deadline"],
- }),
- ("bid_opening_time", {
- "label": "开标时间",
- "aliases": ["开标时间", "开启时间", "Bid Opening Time", "Bid Opening"],
- }),
- ("bid_location", {
- "label": "投标地点",
- "aliases": ["投标地点", "递交地点", "开标地点", "Bid Location", "Venue"],
- }),
- ("purchaser_name", {
- "label": "采购人名称",
- "aliases": ["采购人", "招标人", "采购单位", "Purchaser", "Tenderer"],
- }),
- ("purchaser_contact", {
- "label": "采购人联系人",
- "aliases": ["采购人联系人", "联系人", "Contact Person", "Purchaser Contact"],
- }),
- ("purchaser_phone", {
- "label": "采购人电话",
- "aliases": ["采购人电话", "联系电话", "Contact Phone", "Purchaser Phone"],
- }),
- ("agency_name", {
- "label": "代理机构名称",
- "aliases": ["代理机构", "采购代理机构", "招标代理", "Agency Name", "Bidding Agency"],
- }),
- ("agency_contact", {
- "label": "代理机构联系人",
- "aliases": ["代理机构联系人", "Agency Contact"],
- }),
- ("agency_phone", {
- "label": "代理机构电话",
- "aliases": ["代理机构电话", "Agency Phone"],
- }),
- ("qualification_requirements", {
- "label": "资格要求",
- "aliases": ["资格要求", "资质要求", "投标人资格", "Qualification Requirements"],
- }),
- ("bid_bond_amount", {
- "label": "投标保证金",
- "aliases": ["投标保证金", "Bid Bond", "Bid Bond Amount"],
- }),
- ("performance_bond_amount", {
- "label": "履约保证金",
- "aliases": ["履约保证金", "Performance Bond", "Performance Bond Amount"],
- }),
- ("warranty_period", {
- "label": "质保期",
- "aliases": ["质保期", "保修期", "免费保修", "Warranty Period", "Warranty"],
- }),
- ("delivery_time", {
- "label": "交货时间",
- "aliases": ["交货时间", "交付时间", "工期", "建设周期", "Delivery Time", "Construction Period"],
- }),
- ("delivery_location", {
- "label": "交货地点",
- "aliases": ["交货地点", "交付地点", "安装地点", "Delivery Location", "Installation Location"],
- }),
- ("payment_terms", {
- "label": "付款方式",
- "aliases": ["付款方式", "支付方式", "Payment Terms", "Payment"],
- }),
- ("evaluation_method", {
- "label": "评标方法",
- "aliases": ["评标方法", "评审方法", "评分办法", "Evaluation Method"],
- }),
- ("scope_of_work", {
- "label": "工作范围",
- "aliases": ["工作范围", "采购内容", "招标范围", "建设内容", "Scope of Work", "Scope"],
- }),
- ])
- class BidFieldExtractionRAG(BaseRAG):
- """Retrieve field-specific context before asking the LLM to extract JSON."""
- def __init__(self, extraction_prompt_template=None, bm25_weight=0.45, vector_weight=0.55, **kwargs):
- super(BidFieldExtractionRAG, self).__init__(**kwargs)
- self.extraction_prompt_template = extraction_prompt_template or (
- "你是招投标领域的信息抽取助手。请只依据给定上下文抽取字段。\n\n"
- "目标字段:\n{fields}\n\n"
- "上下文:\n{context}\n\n"
- "要求:\n"
- "1. 仅返回一个 JSON 对象,不要输出解释。\n"
- "2. JSON key 必须使用目标字段英文名。\n"
- "3. 无法从上下文确认的字段返回 null。\n"
- "4. 金额、时间、联系人、电话等值要保留原文表述。"
- )
- self.target_fields = list(FIELD_SPECS.keys())
- self.field_specs = FIELD_SPECS
- self.bm25_weight = float(bm25_weight)
- self.vector_weight = float(vector_weight)
- self._documents = []
- self._bm25 = None
- self._bm25_corpus = []
- def index_documents(self, documents):
- self._documents = list(documents or [])
- self._build_bm25(self._documents)
- if self.embedding_model is not None and self.vector_store is not None:
- super(BidFieldExtractionRAG, self).index_documents(self._documents)
- def _build_bm25(self, documents):
- self._bm25_corpus = []
- for doc in documents:
- meta_text = " ".join(str(v) for v in (doc.metadata or {}).values())
- self._bm25_corpus.append(bm25_tokenize("{}\n{}".format(meta_text, doc.page_content)))
- BM25Okapi = get_bm25_okapi()
- self._bm25 = BM25Okapi(self._bm25_corpus) if self._bm25_corpus else None
- def retrieve(self, query, k=10):
- """Hybrid retrieval for a free-form query."""
- return self._hybrid_retrieve(query, k=k)
- def retrieve_for_fields(self, fields=None, k_per_field=3, max_docs=12):
- """Retrieve and deduplicate contexts for each requested field."""
- selected_fields = [f for f in (fields or self.target_fields) if f in self.field_specs]
- merged = OrderedDict()
- for field in selected_fields:
- field_query = self._field_query(field)
- for doc, score in self._hybrid_retrieve(field_query, k=k_per_field):
- key = self._dedup_key(doc)
- if key not in merged:
- metadata = dict(doc.metadata or {})
- metadata["matched_fields"] = [field]
- metadata["retrieval_query"] = field_query
- merged[key] = [doc, float(score), metadata]
- else:
- merged[key][1] = max(merged[key][1], float(score))
- merged[key][2]["matched_fields"].append(field)
- results = []
- for doc, score, metadata in merged.values():
- doc.metadata.update(metadata)
- doc.metadata["matched_fields"] = sorted(set(doc.metadata["matched_fields"]))
- results.append((doc, score))
- results.sort(key=lambda item: item[1], reverse=True)
- return self._deduplicate_results(results, max_docs)
- def _field_query(self, field):
- spec = self.field_specs[field]
- return "{} {} {}".format(field, spec["label"], " ".join(spec["aliases"]))
- def _hybrid_retrieve(self, query, k=10):
- k = max(0, int(k or 0))
- if k == 0:
- return []
- vector_scores = self._vector_scores(query, k=max(k * 4, k))
- bm25_scores = self._bm25_scores(query)
- merged = {}
- for doc, score in vector_scores:
- merged[self._dedup_key(doc)] = [doc, self._normalize_score(score) * self.vector_weight]
- for doc, score in bm25_scores:
- key = self._dedup_key(doc)
- if key not in merged:
- merged[key] = [doc, 0.0]
- merged[key][1] += self._normalize_score(score) * self.bm25_weight
- results = [(doc, score) for doc, score in merged.values() if score > 0]
- results.sort(key=lambda item: item[1], reverse=True)
- return self._deduplicate_results(results, k)
- def _vector_scores(self, query, k):
- if self.embedding_model is None or self.vector_store is None:
- return []
- try:
- query_embedding = self.embedding_model.embed_query(query)
- return self.vector_store.similarity_search(query_embedding, k)
- except Exception as exc:
- logger.warning("Vector retrieval failed, falling back to BM25: {}".format(exc))
- return []
- def _bm25_scores(self, query):
- if self._bm25 is None:
- store_docs = getattr(self.vector_store, "documents", None)
- if store_docs:
- self._documents = list(store_docs)
- self._build_bm25(self._documents)
- if self._bm25 is None:
- return []
- query_tokens = bm25_tokenize(query)
- if not query_tokens:
- return []
- scores = self._bm25.get_scores(query_tokens)
- max_score = float(np.max(scores)) if len(scores) else 0.0
- if max_score <= 0:
- return []
- scored = []
- for doc, score in zip(self._documents, scores):
- score = float(score)
- if score > 0:
- scored.append((doc, score / max_score))
- scored.sort(key=lambda item: item[1], reverse=True)
- return scored
- @staticmethod
- def _normalize_score(score):
- try:
- value = float(score)
- except (TypeError, ValueError):
- return 0.0
- if value < 0:
- return 0.0
- return min(value, 1.0)
- def generate(self, query, context):
- fields_str = "\n".join(
- "- {}: {} ({})".format(field, spec["label"], ", ".join(spec["aliases"][:4]))
- for field, spec in self.field_specs.items()
- if field in self.target_fields
- )
- prompt = self.extraction_prompt_template.format(
- fields=fields_str,
- context=context[:6000],
- )
- return self._call_llm(prompt)
- def extract_fields(self, query=None, k=5, fields=None):
- start_total = time.time()
- t0 = time.time()
- requested_fields = [f for f in (fields or self.target_fields) if f in self.field_specs]
- docs = self.retrieve_for_fields(requested_fields, k_per_field=max(1, k), max_docs=max(k, len(requested_fields)))
- if query:
- docs = self._merge_ranked(docs, self.retrieve(query, k=k))
- retrieval_time = time.time() - t0
- context = self._format_context(docs)
- t1 = time.time()
- raw_response = self.generate(query or "", context)
- generation_time = time.time() - t1
- fields_result = self._parse_extraction(raw_response)
- total_time = time.time() - start_total
- return RAGResult(
- answer=json.dumps(fields_result, indent=2, ensure_ascii=False),
- retrieved_docs=docs,
- latency_retrieval=retrieval_time,
- latency_generation=generation_time,
- latency_total=total_time,
- metadata={
- "method": self.name,
- "fields": fields_result,
- "requested_fields": requested_fields,
- "num_context_chars": len(context),
- },
- )
- @staticmethod
- def _merge_ranked(primary, secondary):
- merged = OrderedDict()
- for doc, score in list(primary or []) + list(secondary or []):
- key = content_dedup_key(doc)
- if key not in merged or float(score) > merged[key][1]:
- merged[key] = [doc, float(score)]
- results = [(doc, score) for doc, score in merged.values()]
- results.sort(key=lambda item: item[1], reverse=True)
- return results
- def _parse_extraction(self, raw_response):
- if not raw_response:
- return {field: None for field in self.target_fields}
- try:
- start = raw_response.find("{")
- end = raw_response.rfind("}") + 1
- if start >= 0 and end > start:
- json_str = raw_response[start:end]
- json_str = re.sub(r",\s*}", "}", json_str)
- parsed = json.loads(json_str)
- return {field: parsed.get(field) for field in self.target_fields}
- except (TypeError, ValueError, json.JSONDecodeError):
- logger.warning("Failed to parse JSON from LLM response")
- result = {field: None for field in self.target_fields}
- result["raw"] = raw_response
- return result
|