channel_bert.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. # coding: UTF-8
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import time
  6. import re
  7. import os
  8. import transformers
  9. from transformers import ElectraTokenizer
  10. import numpy as np
  11. # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  12. device = torch.device("cpu") # 线上用CPU
  13. class PositionalEncoding(nn.Module):
  14. def __init__(self,dim_hid):
  15. super(PositionalEncoding,self).__init__()
  16. base_array = np.array([np.power(10000,2*(hid_j//2)/dim_hid) for hid_j in range(dim_hid)])
  17. self.base_tensor = torch.from_numpy(base_array).to(torch.float32).to(device) #[1,D]
  18. def forward(self,x):
  19. # x(B,N,d)
  20. B,N,d = x.shape
  21. pos = torch.arange(N).unsqueeze(-1).to(torch.float32).to(device) #[N,1]
  22. pos = pos/self.base_tensor
  23. pos = pos.unsqueeze(0)
  24. pos[:,:,0::2] = torch.sin(pos[:,:,0::2])
  25. pos[:,:,1::2] = torch.cos(pos[:,:,1::2])
  26. return x+pos
  27. class ScaledDotProductAttention(nn.Module):
  28. ''' Scaled Dot-Product Attention '''
  29. def __init__(self, temperature, attn_dropout=0.1):
  30. super().__init__()
  31. self.temperature = temperature
  32. self.dropout = nn.Dropout(attn_dropout)
  33. def forward(self, q, k, v, mask=None):
  34. # print(q.shape,k.shape)
  35. attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
  36. if mask is not None:
  37. attn = attn.masked_fill(mask == 0, -1e9)
  38. # t1 = time.time()
  39. attn = self.dropout(torch.softmax(attn, dim=-1))
  40. # print('cost',time.time()-t1) # 主要时间花费
  41. output = torch.matmul(attn, v)
  42. return output, attn
  43. class MultiHeadAttention(nn.Module):
  44. ''' Multi-Head Attention module '''
  45. def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
  46. super().__init__()
  47. self.n_head = n_head
  48. self.d_k = d_k
  49. self.d_v = d_v
  50. self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
  51. self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
  52. self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
  53. self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
  54. self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
  55. self.dropout = nn.Dropout(dropout)
  56. self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
  57. self.rotaryEmbedding = RotaryEmbedding(d_k)
  58. def forward(self, q, k, v, mask=None):
  59. d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
  60. sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
  61. residual = q
  62. # Pass through the pre-attention projection: b x lq x (n*dv)
  63. # Separate different heads: b x lq x n x dv
  64. q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
  65. k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
  66. v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
  67. # Transpose for attention dot product: b x n x lq x dv
  68. q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
  69. # RoPE embed
  70. target_tensor = torch.zeros((q.size(0), q.size(2)))
  71. position_ids = torch.arange(q.size(2), dtype=torch.long).unsqueeze(0).expand_as(target_tensor)
  72. _cos, _sin = self.rotaryEmbedding(q, position_ids)
  73. q, k = apply_rotary_pos_emb(q, k, _cos, _sin)
  74. if mask is not None:
  75. # mask = mask.unsqueeze(1) # For head axis broadcasting.
  76. mask = mask.unsqueeze(1).unsqueeze(2) # For head axis broadcasting.
  77. q, attn = self.attention(q, k, v, mask=mask)
  78. #q (sz_b,n_head,N=len_q,d_k)
  79. #k (sz_b,n_head,N=len_k,d_k)
  80. #v (sz_b,n_head,N=len_v,d_v)
  81. # Transpose to move the head dimension back: b x lq x n x dv
  82. # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
  83. q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
  84. #q (sz_b,len_q,n_head,N * d_k)
  85. q = self.dropout(self.fc(q))
  86. q += residual
  87. q = self.layer_norm(q)
  88. return q, attn
  89. class PositionwiseFeedForward(nn.Module):
  90. ''' A two-feed-forward-layer module '''
  91. def __init__(self, d_in, d_hid, dropout=0.1):
  92. super().__init__()
  93. self.w_1 = nn.Linear(d_in, d_hid) # position-wise
  94. self.w_2 = nn.Linear(d_hid, d_in) # position-wise
  95. self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
  96. self.dropout = nn.Dropout(dropout)
  97. def forward(self, x):
  98. residual = x
  99. x = self.w_2(torch.relu(self.w_1(x)))
  100. x = self.dropout(x)
  101. x += residual
  102. x = self.layer_norm(x)
  103. return x
  104. class RotaryEmbedding(nn.Module):
  105. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
  106. super().__init__()
  107. self.dim = dim # it is set to the head_dim
  108. self.max_position_embeddings = max_position_embeddings
  109. self.base = base
  110. # Calculate the theta according to the formula theta_i = base^(2i/dim) where i = 0, 1, 2, ..., dim // 2
  111. inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
  112. self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
  113. @torch.no_grad()
  114. def forward(self, x, position_ids, seq_len=None):
  115. # x: [bs, num_attention_heads, seq_len, head_size]
  116. self.inv_freq = self.inv_freq.to(device)
  117. position_ids = position_ids.to(device)
  118. # Copy the inv_freq tensor for batch in the sequence
  119. # inv_freq_expanded: [Batch_Size, Head_Dim // 2, 1]
  120. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  121. # position_ids_expanded: [Batch_Size, 1, Seq_Len]
  122. position_ids_expanded = position_ids[:, None, :].float()
  123. # Multiply each theta by the position (which is the argument of the sin and cos functions)
  124. # freqs: [Batch_Size, Head_Dim // 2, 1] @ [Batch_Size, 1, Seq_Len] --> [Batch_Size, Seq_Len, Head_Dim // 2]
  125. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  126. # emb: [Batch_Size, Seq_Len, Head_Dim]
  127. emb = torch.cat((freqs, freqs), dim=-1)
  128. # cos, sin: [Batch_Size, Seq_Len, Head_Dim]
  129. cos = emb.cos()
  130. sin = emb.sin()
  131. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  132. def rotate_half(x):
  133. # Build the [-x2, x1, -x4, x3, ...] tensor for the sin part of the positional encoding.
  134. x1 = x[..., : x.shape[-1] // 2] # Takes the first half of the last dimension
  135. x2 = x[..., x.shape[-1] // 2 :] # Takes the second half of the last dimension
  136. return torch.cat((-x2, x1), dim=-1)
  137. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  138. cos = cos.unsqueeze(unsqueeze_dim) # Add the head dimension
  139. sin = sin.unsqueeze(unsqueeze_dim) # Add the head dimension
  140. # Apply the formula (34) of the Rotary Positional Encoding paper.
  141. q_embed = (q * cos) + (rotate_half(q) * sin)
  142. k_embed = (k * cos) + (rotate_half(k) * sin)
  143. return q_embed, k_embed
  144. class EncoderLayer(nn.Module):
  145. ''' Compose with two layers '''
  146. def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
  147. super(EncoderLayer, self).__init__()
  148. self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
  149. self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
  150. def forward(self, enc_input, slf_attn_mask=None):
  151. enc_output, enc_slf_attn = self.slf_attn(
  152. enc_input, enc_input, enc_input, mask=slf_attn_mask)
  153. enc_output = self.pos_ffn(enc_output)
  154. return enc_output, enc_slf_attn
  155. class Encoder(nn.Module):
  156. ''' A encoder model with self attention mechanism. '''
  157. def __init__(
  158. self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
  159. d_model, d_inner, pad_idx, dropout=0.1, n_position=200, scale_emb=False,embedding=None):
  160. super().__init__()
  161. if embedding is None:
  162. self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
  163. else:
  164. self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec,padding_idx=pad_idx,
  165. # _weight=torch.from_numpy(embedding ,))
  166. # _weight=torch.tensor(embedding ,dtype=torch.float64).to(device))
  167. _weight=torch.tensor(embedding))
  168. self.position_enc = PositionalEncoding(d_word_vec)
  169. self.dropout = nn.Dropout(p=dropout)
  170. self.layer_stack = nn.ModuleList([
  171. EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
  172. for _ in range(n_layers)])
  173. self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
  174. self.scale_emb = scale_emb
  175. self.d_model = d_model
  176. def forward(self, src_seq, src_mask, return_attns=False):
  177. enc_slf_attn_list = []
  178. # -- Forward
  179. enc_output = self.src_word_emb(src_seq)
  180. if self.scale_emb:
  181. enc_output *= self.d_model ** 0.5
  182. # enc_output = self.dropout(self.position_enc(enc_output))
  183. enc_output = self.dropout(enc_output)
  184. enc_output = self.layer_norm(enc_output)
  185. for enc_layer in self.layer_stack:
  186. enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
  187. enc_slf_attn_list += [enc_slf_attn] if return_attns else []
  188. if return_attns:
  189. return enc_output, enc_slf_attn_list
  190. return enc_output
  191. class bidiBert(nn.Module):
  192. def __init__(self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
  193. d_model, d_inner, pad_idx,n_class,embedding = None):
  194. super(bidiBert, self).__init__()
  195. self.encoder = Encoder(n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
  196. d_model, d_inner, pad_idx,embedding = embedding)
  197. # self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx,
  198. # # _weight=torch.from_numpy(embedding ,))
  199. # # _weight=torch.tensor(embedding ,dtype=torch.float64).to(device))
  200. # _weight=torch.tensor(embedding))
  201. # self.encoder = nn.LSTM(128, 256, 2,
  202. # bidirectional=True, batch_first=True, dropout=0.3)
  203. self.dropout = nn.Dropout(p=0.1)
  204. self.pooler = nn.Linear(d_inner,d_inner)
  205. self.liner = nn.Linear(d_inner,n_class)
  206. # self.liner = nn.Linear(256*4,n_class)
  207. self.avg_pool1 = nn.AdaptiveAvgPool1d(1) # Max Pooling: nn.AdaptiveMaxPool1d(1)
  208. self.avg_pool2 = nn.AdaptiveAvgPool1d(1)
  209. def forward(self, input_title,input_doctext):
  210. # out = self.encoder(inputs, attention_mask)
  211. # input_title = self.src_word_emb(input_title)
  212. # input_title,_ = self.encoder(input_title)
  213. # input_title = self.encoder(src_seq=input_title[0], src_mask=input_title[1])
  214. # input_title = input_title[:, 1, :]
  215. # input_title = torch.mean(input_title,dim=-2)
  216. # input_title = self.avg_pool1(input_title.transpose(1, 2)).squeeze(-1)
  217. # input_title = torch.tanh(self.pooler(input_title[:,0]))
  218. # input_doctext = self.src_word_emb(input_doctext)
  219. # input_doctext,_ = self.encoder(input_doctext)
  220. input_doctext = self.encoder(src_seq=input_doctext[0], src_mask=input_doctext[1])
  221. # input_doctext = input_doctext[:, 1, :]
  222. # input_doctext = torch.mean(input_doctext,dim=-2)
  223. # input_doctext = self.avg_pool2(input_doctext.transpose(1, 2)).squeeze(-1)
  224. input_doctext = self.pooler(input_doctext[:,0])
  225. input_doctext = self.dropout(input_doctext)
  226. input_doctext = torch.tanh(input_doctext)
  227. # print('size:',input_title.size(),input_doctext.size())
  228. # out = torch.cat((input_title, input_doctext), dim=-1)
  229. out = input_doctext
  230. # bs, n, m = out.size()
  231. # out = out.view(bs,n*m)
  232. out = self.liner(out)
  233. out = F.softmax(out, dim=-1)
  234. return out
  235. phone = re.compile('1[3-9][0-9][-—-―]?\d{4}[-—-―]?\d{4}|'
  236. '\+86.?1[3-9]\d{9}|'
  237. # '0[^0]\d{1,2}[-—-―][1-9]\d{6,7}/[1-9]\d{6,10}|'
  238. '0[1-9]\d{1,2}[-—-―][2-9]\d{6}\d?[-—-―]\d{1,4}|'
  239. '0[1-9]\d{1,2}[-—-―]{0,2}[2-9]\d{6}\d?(?=1[3-9]\d{9})|'
  240. '0[1-9]\d{1,2}[-—-―]{0,2}[2-9]\d{6}\d?(?=0[1-9]\d{1,2}[-—-―]?[2-9]\d{6}\d?)|'
  241. '0[1-9]\d{1,2}[-—-―]{0,2}[2-9]\d{6}\d?(?=[2-9]\d{6,7})|'
  242. '0[1-9]\d{1,2}[-—-―]{0,2}[2-9]\d{6}\d?|'
  243. '[\(|\(]0[1-9]\d{1,2}[\)|\)]-?[2-9]\d{6}\d?-?\d{,4}|'
  244. '400\d{7}转\d{1,4}|'
  245. '[2-9]\d{6,7}')
  246. def text_process(text):
  247. text = text.strip()
  248. text = re.sub(r'[\000-\010]|[\013-\014]|[\016-\037]',"",text) # 非法字符
  249. text = re.sub("extractJson:|fullTextSeg:","",text)
  250. # text = re.sub("[??]{1,}", "", text)
  251. text = re.sub("[??]{2,}", "", text)
  252. text = re.sub(r'(http[s]?://|www\.)(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', "", text) # 网站
  253. text = re.sub(r'[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+', "", text)# 邮箱
  254. text = re.sub('[0-9a-zA-Z@#*%=?&~()_|<>/.(){}【】{}\[\]\-]{6,}', "", text) # 编号
  255. text = re.sub(phone,"",text) # 号码
  256. text = re.sub(r'\b\d[-.\s]?\d{3}[-.\s]?\d{4}\b', "", text) # 座机
  257. text = re.sub('0[1-9]\d{1,2}[-—-―]{0,2}[2-9]\d{6}\d?', "", text) # 座机
  258. text = re.sub('1(3[0-9]|4[01456879]|5[0-35-9]|6[2567]|7[0-8]|8[0-9]|9[0-35-9])\d{8}', "", text)# 手机号
  259. text = re.sub('&?nbsp;?|&?ensp;?|&?emsp;?', "", text)
  260. text = re.sub('\\\\n|\\\\r|\\\\t', "", text)
  261. # text = re.sub("\s+", "", text)
  262. text = re.sub("\s+", " ", text)
  263. # 优化部分未识别表达
  264. text = re.sub("中止", "终止", text)
  265. text = re.sub("遴选", "招标", text)
  266. return text
  267. label2class_dict = {
  268. 0: 51, 1:52 , 2:101,
  269. 3:102, 4:103, 5:105,
  270. 6:114, 7:118, 8:119,
  271. 9:120, 10:121, 11:122
  272. }
  273. def channel_predict(title,text):
  274. if globals().get("channel_pytorch_model") is None or globals().get("channel_tokenizer") is None:
  275. # config
  276. config = {
  277. # 'n_src_vocab': len(vocab),
  278. 'd_word_vec': 128,
  279. 'n_layers': 3,
  280. 'n_head': 3,
  281. 'd_k': 128,
  282. 'd_v': 128,
  283. 'd_model': 128,
  284. 'd_inner': 128,
  285. 'pad_idx': 0
  286. }
  287. # n_src_vocab = config['n_src_vocab']
  288. d_word_vec = config['d_word_vec']
  289. n_layers = config['n_layers']
  290. n_head = config['n_head']
  291. d_k = config['d_k']
  292. d_v = config['d_v']
  293. d_model = config['d_model']
  294. d_inner = config['d_inner']
  295. pad_idx = config['pad_idx']
  296. n_class = 12
  297. # tokenizer
  298. base_model_name = os.path.abspath(os.path.dirname(__file__)) + "/pytorch_model/tokenizer"
  299. tokenizer = ElectraTokenizer.from_pretrained(base_model_name)
  300. n_src_vocab = len(tokenizer.get_vocab())
  301. # 实例化模型
  302. model_path = os.path.abspath(os.path.dirname(__file__)) + '/pytorch_model/channel.pth'
  303. model = bidiBert(n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
  304. d_model, d_inner, pad_idx, n_class, embedding=None)
  305. model.to(device)
  306. model_state = torch.load(model_path, map_location=device)
  307. model_state_dict = model.state_dict()
  308. pretrained_state_dict = model_state
  309. # missing_keys = set(model_state_dict.keys()) - set(pretrained_state_dict.keys())
  310. unexpected_keys = set(pretrained_state_dict.keys()) - set(model_state_dict.keys())
  311. add_kv = []
  312. for k, v in model_state.items():
  313. if k in unexpected_keys:
  314. # model_state[k.replace("module.","")] = v
  315. add_kv.append([k.replace("module.", ""), v])
  316. for i in add_kv:
  317. model_state[i[0]] = i[1]
  318. for k in list(unexpected_keys):
  319. del model_state[k]
  320. model.load_state_dict(model_state)
  321. # 将模型设置为评估模式
  322. model.eval()
  323. globals()["channel_pytorch_model"] = model
  324. globals()["channel_tokenizer"] = tokenizer
  325. else:
  326. model = globals().get("channel_pytorch_model")
  327. tokenizer = globals().get("channel_tokenizer")
  328. # process text
  329. if title in text:
  330. text = text.replace(title, '', 1)
  331. if "##attachment##" in text:
  332. main_text,attachment_text = text.split("##attachment##",maxsplit=1)
  333. # print('main_text',main_text)
  334. if len(main_text)>=500: # 正文有足够的内容时不需要使用附件预测
  335. text = main_text
  336. text = re.sub("##attachment##。?","",text)
  337. text = text_process(text)
  338. if len(text)<=100:
  339. # 正文内容过短时,不预测
  340. return
  341. elif len(text)<=150:
  342. # 正文内容过短时,重复正文
  343. text = text * 2
  344. text = text[:2000]
  345. title = text_process(title)
  346. title = title[:100]
  347. text = "公告标题:" + title + "。" + "公告内容:" + text
  348. text = text[:2000]
  349. # print('predict text:',text)
  350. # to torch data
  351. text = [text]
  352. text_max_len = 2000
  353. text = [tokenizer.encode_plus(
  354. _t,
  355. add_special_tokens=True, # 添加特殊标记,如[CLS]和[SEP]
  356. max_length=text_max_len, # 设置最大长度
  357. padding='max_length', # 填充到最大长度
  358. truncation=True, # 截断超过最大长度的文本
  359. return_attention_mask=True, # 返回attention_mask
  360. return_tensors='pt' # 返回PyTorch张量
  361. ) for _t in text]
  362. text = [torch.LongTensor(np.array([_t['input_ids'].numpy()[0] for _t in text])).to(device),
  363. torch.LongTensor(np.array([_t['attention_mask'].numpy()[0] for _t in text])).to(device)]
  364. # predict
  365. with torch.no_grad():
  366. outputs = model(None, text)
  367. predic = torch.max(outputs.data, 1)[1].cpu().numpy()
  368. pred_prob = torch.max(outputs.data, 1)[0].cpu().numpy()
  369. # print('pred_prob',pred_prob)
  370. if pred_prob>0.5:
  371. pred_label = predic[0]
  372. pred_class = label2class_dict[pred_label]
  373. else:
  374. return
  375. # print('check rule before',pred_class)
  376. # check rule
  377. if pred_class==101 and re.search("((资格|资质)(审查|预审|后审|审核)|资审)结果(公告|公示)?|(资质|资格)(预审|后审)公示|资审及业绩公示",title): # 纠正部分‘资审结果’模型错误识别为中标
  378. pred_class = 105
  379. elif pred_class==122 and re.search("验收服务",title):
  380. pred_class = None
  381. # elif pred_class==118 and re.search("重新招标",title): #重新招标类公告,因之前公告的废标原因而错识别为废标公告
  382. # pred_class = 52
  383. return pred_class
  384. class_dict = {51: '公告变更',
  385. 52: '招标公告',
  386. 101: '中标信息',
  387. 102: '招标预告',
  388. 103: '招标答疑',
  389. 104: '招标文件',
  390. 105: '资审结果',
  391. 106: '法律法规',
  392. 107: '新闻资讯',
  393. 108: '拟建项目',
  394. 109: '展会推广',
  395. 110: '企业名录',
  396. 111: '企业资质',
  397. 112: '全国工程',
  398. 113: '业主采购',
  399. 114: '采购意向',
  400. 115: '拍卖出让',
  401. 116: '土地矿产',
  402. 117: '产权交易',
  403. 118: '废标公告',
  404. 119: '候选人公示',
  405. 120: '合同公告',
  406. 121: '开标记录',
  407. 122: '验收合同'
  408. }
  409. tenderee_type = ['公告变更','招标公告','招标预告','招标答疑','资审结果','采购意向']
  410. win_type = ['中标信息','废标公告','候选人公示','合同公告','开标记录','验收合同']
  411. def merge_channel(list_articles,channel_dic,original_docchannel):
  412. def merge_rule(title,text,docchannel,pred_channel,channel_dic,original_docchannel):
  413. front_text_len = len(text)//3 if len(text)>300 else 100
  414. front_text = text[:front_text_len]
  415. pred_channel = class_dict[pred_channel]
  416. if pred_channel == docchannel:
  417. channel_dic['docchannel']['use_original_docchannel'] = 0
  418. else:
  419. if pred_channel in ['采购意向','招标预告'] and docchannel in ['采购意向','招标预告']:
  420. merge_res = '采购意向' if re.search("意向|意愿",title) or re.search("意向|意愿",front_text) else "招标预告"
  421. channel_dic['docchannel']['docchannel'] = merge_res
  422. channel_dic['docchannel']['use_original_docchannel'] = 0
  423. elif pred_channel in ['公告变更','招标答疑'] and docchannel in ['公告变更','招标答疑']:
  424. channel_dic['docchannel']['docchannel'] = docchannel
  425. channel_dic['docchannel']['use_original_docchannel'] = 0
  426. elif pred_channel=='公告变更' and docchannel in ['中标信息','废标公告','候选人公示','合同公告']: #中标类的变更还是中标类公告
  427. channel_dic['docchannel']['docchannel'] = docchannel
  428. channel_dic['docchannel']['use_original_docchannel'] = 0
  429. elif docchannel=='公告变更' and pred_channel in ['中标信息','废标公告','候选人公示','合同公告']:
  430. channel_dic['docchannel']['docchannel'] = pred_channel
  431. channel_dic['docchannel']['use_original_docchannel'] = 0
  432. else:
  433. original_type = class_dict.get(original_docchannel, '原始类别')
  434. if pred_channel in tenderee_type and docchannel in tenderee_type and original_type not in tenderee_type:
  435. # pred_channel和docchannel都是同一(招标/中标)类型时,original_docchannel不一致时不使用原网类型
  436. channel_dic['docchannel']['use_original_docchannel'] = 0
  437. elif pred_channel in win_type and docchannel in win_type and original_type not in win_type:
  438. # pred_channel和docchannel都是同一(招标/中标)类型时,original_docchannel不一致时不使用原网类型
  439. channel_dic['docchannel']['use_original_docchannel'] = 0
  440. else:
  441. channel_dic = {'docchannel': {'doctype': '采招数据',
  442. 'docchannel': original_type,
  443. 'life_docchannel': original_type}}
  444. channel_dic['docchannel']['use_original_docchannel'] = 1
  445. return channel_dic
  446. article = list_articles[0]
  447. title = article.title
  448. text = article.content
  449. doctype = channel_dic['docchannel']['doctype']
  450. docchannel = channel_dic['docchannel']['docchannel']
  451. # print('doctype',doctype,'docchannel',docchannel,'original_docchannel',original_docchannel)
  452. compare_type = ['公告变更','招标公告','中标信息','招标预告','招标答疑','资审结果','采购意向','废标公告','候选人公示',
  453. '合同公告','开标记录','验收合同']
  454. # 仅比较部分数据
  455. if doctype=='采招数据' and docchannel in compare_type:
  456. if not re.search("单一来源",title) and not re.search("单一来源",text[:100]):
  457. pred = channel_predict(title, text)
  458. # print('pred_res', pred)
  459. if pred is not None and original_docchannel: # 无original_docchannel时不进行对比校正
  460. channel_dic = merge_rule(title,text,docchannel,pred,channel_dic,original_docchannel)
  461. elif doctype=='采招数据' and docchannel=="":
  462. pred = channel_predict(title, text)
  463. if pred is not None:
  464. pred = class_dict[pred]
  465. channel_dic['docchannel']['docchannel'] = pred
  466. channel_dic['docchannel']['use_original_docchannel'] = 0
  467. # '招标预告'类 规则纠正
  468. if channel_dic['docchannel']['doctype']=='采招数据' and channel_dic['docchannel']['docchannel']=="招标公告":
  469. if "##attachment##" in text:
  470. main_text, attachment_text = text.split("##attachment##", maxsplit=1)
  471. else:
  472. main_text = text
  473. main_text = text_process(main_text)
  474. if re.search("采购实施月份|采购月份|预计(招标|采购|发标|发包)(时间|月份)|招标公告预计发布时间",main_text[:len(main_text)//2]):
  475. front_text_len = len(main_text) // 3 if len(main_text) > 300 else 100
  476. front_text = main_text[:front_text_len]
  477. if re.search("意向|意愿",title) or re.search("意向|意愿",front_text):
  478. channel_dic['docchannel']['docchannel'] = "采购意向"
  479. else:
  480. channel_dic['docchannel']['docchannel'] = "招标预告"
  481. channel_dic['docchannel']['use_original_docchannel'] = 0
  482. return channel_dic
  483. if __name__ == '__main__':
  484. title = '关于【2024年四好农村路大中村药红路、空坦路延伸段设计服务】无效项目的公示'
  485. text = '''关于【2024年四好农村路大中村药红路、空坦路延伸段设计服务】无效项目的公示 点击查看招标公告 关于【2024年四好农村路大中村药红路、空坦路延伸段设计服务】无效项目的公示 项目名称 2024年四好农村路大中村药红路、空坦路延伸段设计服务, 采购人 重庆市巴南区人民政府莲花街道办事处, 选取方式 直接选取, 是否重新发布招标公告 是 ,无效类型 项目取消, 无效原因 资质设置错误,附件已盖章上传 ,无效时间 2024-10-21 ,公示附件 大中村设计变更.jpg'''
  486. pred_class = channel_predict(title,text)
  487. print(pred_class)
  488. # pred_class2 = channel_predict(title,text)
  489. # print(pred_class2)
  490. pass