predictor.py 90 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664
  1. '''
  2. Created on 2018年12月26日
  3. @author: User
  4. '''
  5. import os
  6. import sys
  7. sys.path.append(os.path.abspath("../.."))
  8. # from keras.engine import topology
  9. # from keras import models
  10. # from keras import layers
  11. # from keras_contrib.layers.crf import CRF
  12. # from keras.preprocessing.sequence import pad_sequences
  13. # from keras import optimizers,losses,metrics
  14. from BiddingKG.dl.common.Utils import *
  15. from BiddingKG.dl.interface.modelFactory import *
  16. import tensorflow as tf
  17. from tensorflow.python.framework import graph_util
  18. from BiddingKG.dl.product.data_util import decode, process_data, result_to_json
  19. from BiddingKG.dl.interface.Entitys import Entity
  20. from threading import RLock
  21. dict_predictor = {"codeName":{"predictor":None,"Lock":RLock()},
  22. "prem":{"predictor":None,"Lock":RLock()},
  23. "epc":{"predictor":None,"Lock":RLock()},
  24. "roleRule":{"predictor":None,"Lock":RLock()},
  25. "form":{"predictor":None,"Lock":RLock()}}
  26. def getPredictor(_type):
  27. if _type in dict_predictor:
  28. with dict_predictor[_type]["Lock"]:
  29. if dict_predictor[_type]["predictor"] is None:
  30. if _type=="codeName":
  31. dict_predictor[_type]["predictor"] = CodeNamePredict()
  32. if _type=="prem":
  33. dict_predictor[_type]["predictor"] = PREMPredict()
  34. if _type=="epc":
  35. dict_predictor[_type]["predictor"] = EPCPredict()
  36. if _type=="roleRule":
  37. dict_predictor[_type]["predictor"] = RoleRulePredictor()
  38. if _type=="form":
  39. dict_predictor[_type]["predictor"] = FormPredictor()
  40. return dict_predictor[_type]["predictor"]
  41. raise NameError("no this type of predictor")
  42. #编号名称模型
  43. class CodeNamePredict():
  44. def __init__(self,EMBED_DIM=None,BiRNN_UNITS=None,lazyLoad=getLazyLoad()):
  45. self.model = None
  46. self.MAX_LEN = None
  47. self.model_code = None
  48. if EMBED_DIM is None:
  49. self.EMBED_DIM = 60
  50. else:
  51. self.EMBED_DIM = EMBED_DIM
  52. if BiRNN_UNITS is None:
  53. self.BiRNN_UNITS = 200
  54. else:
  55. self.BiRNN_UNITS = BiRNN_UNITS
  56. self.filepath = os.path.dirname(__file__)+"/../projectCode/models/model_project_"+str(self.EMBED_DIM)+"_"+str(self.BiRNN_UNITS)+".hdf5"
  57. #self.filepath = "../projectCode/models/model_project_60_200_200ep017-loss6.456-val_loss7.852-val_acc0.969.hdf5"
  58. self.filepath_code = os.path.dirname(__file__)+"/../projectCode/models/model_code.hdf5"
  59. vocabpath = os.path.dirname(__file__)+"/codename_vocab.pk"
  60. classlabelspath = os.path.dirname(__file__)+"/codename_classlabels.pk"
  61. self.vocab = load(vocabpath)
  62. self.class_labels = load(classlabelspath)
  63. #生成提取编号和名称的正则
  64. id_PC_B = self.class_labels.index("PC_B")
  65. id_PC_M = self.class_labels.index("PC_M")
  66. id_PC_E = self.class_labels.index("PC_E")
  67. id_PN_B = self.class_labels.index("PN_B")
  68. id_PN_M = self.class_labels.index("PN_M")
  69. id_PN_E = self.class_labels.index("PN_E")
  70. self.PC_pattern = re.compile(str(id_PC_B)+str(id_PC_M)+"*"+str(id_PC_E))
  71. self.PN_pattern = re.compile(str(id_PN_B)+str(id_PN_M)+"*"+str(id_PN_E))
  72. print("pc",self.PC_pattern)
  73. print("pn",self.PN_pattern)
  74. self.word2index = dict((w,i) for i,w in enumerate(np.array(self.vocab)))
  75. self.inputs = None
  76. self.outputs = None
  77. self.sess_codename = tf.Session(graph=tf.Graph())
  78. self.sess_codesplit = tf.Session(graph=tf.Graph())
  79. self.inputs_code = None
  80. self.outputs_code = None
  81. if not lazyLoad:
  82. self.getModel()
  83. self.getModel_code()
  84. def getModel(self):
  85. '''
  86. @summary: 取得编号和名称模型
  87. '''
  88. if self.inputs is None:
  89. log("get model of codename")
  90. with self.sess_codename.as_default():
  91. with self.sess_codename.graph.as_default():
  92. meta_graph_def = tf.saved_model.loader.load(self.sess_codename, ["serve"], export_dir=os.path.dirname(__file__)+"/codename_savedmodel_tf")
  93. signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
  94. signature_def = meta_graph_def.signature_def
  95. self.inputs = self.sess_codename.graph.get_tensor_by_name(signature_def[signature_key].inputs["inputs"].name)
  96. self.inputs_length = self.sess_codename.graph.get_tensor_by_name(signature_def[signature_key].inputs["inputs_length"].name)
  97. self.keepprob = self.sess_codename.graph.get_tensor_by_name(signature_def[signature_key].inputs["keepprob"].name)
  98. self.logits = self.sess_codename.graph.get_tensor_by_name(signature_def[signature_key].outputs["logits"].name)
  99. self.trans = self.sess_codename.graph.get_tensor_by_name(signature_def[signature_key].outputs["trans"].name)
  100. return self.inputs,self.inputs_length,self.keepprob,self.logits,self.trans
  101. else:
  102. return self.inputs,self.inputs_length,self.keepprob,self.logits,self.trans
  103. '''
  104. if self.model is None:
  105. self.model = self.getBiLSTMCRFModel(self.MAX_LEN, self.vocab, self.EMBED_DIM, self.BiRNN_UNITS, self.class_labels,weights=None)
  106. self.model.load_weights(self.filepath)
  107. return self.model
  108. '''
  109. def getModel_code(self):
  110. if self.inputs_code is None:
  111. log("get model of code")
  112. with self.sess_codesplit.as_default():
  113. with self.sess_codesplit.graph.as_default():
  114. meta_graph_def = tf.saved_model.loader.load(self.sess_codesplit, ["serve"], export_dir=os.path.dirname(__file__)+"/codesplit_savedmodel")
  115. signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
  116. signature_def = meta_graph_def.signature_def
  117. self.inputs_code = []
  118. self.inputs_code.append(self.sess_codesplit.graph.get_tensor_by_name(signature_def[signature_key].inputs["input0"].name))
  119. self.inputs_code.append(self.sess_codesplit.graph.get_tensor_by_name(signature_def[signature_key].inputs["input1"].name))
  120. self.inputs_code.append(self.sess_codesplit.graph.get_tensor_by_name(signature_def[signature_key].inputs["input2"].name))
  121. self.outputs_code = self.sess_codesplit.graph.get_tensor_by_name(signature_def[signature_key].outputs["outputs"].name)
  122. self.sess_codesplit.graph.finalize()
  123. return self.inputs_code,self.outputs_code
  124. else:
  125. return self.inputs_code,self.outputs_code
  126. '''
  127. if self.model_code is None:
  128. log("get model of model_code")
  129. with self.sess_codesplit.as_default():
  130. with self.sess_codesplit.graph.as_default():
  131. self.model_code = models.load_model(self.filepath_code, custom_objects={'precision':precision,'recall':recall,'f1_score':f1_score})
  132. return self.model_code
  133. '''
  134. def getBiLSTMCRFModel(self,MAX_LEN,vocab,EMBED_DIM,BiRNN_UNITS,chunk_tags,weights):
  135. '''
  136. model = models.Sequential()
  137. model.add(layers.Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding
  138. model.add(layers.Bidirectional(layers.LSTM(BiRNN_UNITS // 2, return_sequences=True)))
  139. crf = CRF(len(chunk_tags), sparse_target=True)
  140. model.add(crf)
  141. model.summary()
  142. model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
  143. return model
  144. '''
  145. input = layers.Input(shape=(None,))
  146. if weights is not None:
  147. embedding = layers.embeddings.Embedding(len(vocab),EMBED_DIM,mask_zero=True,weights=[weights],trainable=True)(input)
  148. else:
  149. embedding = layers.embeddings.Embedding(len(vocab),EMBED_DIM,mask_zero=True)(input)
  150. bilstm = layers.Bidirectional(layers.LSTM(BiRNN_UNITS//2,return_sequences=True))(embedding)
  151. bilstm_dense = layers.TimeDistributed(layers.Dense(len(chunk_tags)))(bilstm)
  152. crf = CRF(len(chunk_tags),sparse_target=True)
  153. crf_out = crf(bilstm_dense)
  154. model = models.Model(input=[input],output = [crf_out])
  155. model.summary()
  156. model.compile(optimizer = 'adam', loss = crf.loss_function, metrics = [crf.accuracy])
  157. return model
  158. #根据规则补全编号或名称两边的符号
  159. def fitDataByRule(self,data):
  160. symbol_dict = {"(":")",
  161. "(":")",
  162. "[":"]",
  163. "【":"】",
  164. ")":"(",
  165. ")":"(",
  166. "]":"[",
  167. "】":"【"}
  168. leftSymbol_pattern = re.compile("[\((\[【]")
  169. rightSymbol_pattern = re.compile("[\))\]】]")
  170. leftfinds = re.findall(leftSymbol_pattern,data)
  171. rightfinds = re.findall(rightSymbol_pattern,data)
  172. result = data
  173. if len(leftfinds)+len(rightfinds)==0:
  174. return data
  175. elif len(leftfinds)==len(rightfinds):
  176. return data
  177. elif abs(len(leftfinds)-len(rightfinds))==1:
  178. if len(leftfinds)>len(rightfinds):
  179. if symbol_dict.get(data[0]) is not None:
  180. result = data[1:]
  181. else:
  182. #print(symbol_dict.get(leftfinds[0]))
  183. result = data+symbol_dict.get(leftfinds[0])
  184. else:
  185. if symbol_dict.get(data[-1]) is not None:
  186. result = data[:-1]
  187. else:
  188. result = symbol_dict.get(rightfinds[0])+data
  189. return result
  190. def decode(self,logits, trans, sequence_lengths, tag_num):
  191. viterbi_sequences = []
  192. for logit, length in zip(logits, sequence_lengths):
  193. score = logit[:length]
  194. viterbi_seq, viterbi_score = viterbi_decode(score, trans)
  195. viterbi_sequences.append(viterbi_seq)
  196. return viterbi_sequences
  197. def predict(self,list_sentences,list_entitys=None,MAX_AREA = 5000):
  198. #@summary: 获取每篇文章的code和name
  199. pattern_score = re.compile("工程|服务|采购|施工|项目|系统|招标|中标|公告|学校|[大中小]学校?|医院|公司|分公司|研究院|政府采购中心|学院|中心校?|办公室|政府|财[政务]局|办事处|委员会|[部总支]队|警卫局|幼儿园|党委|党校|银行|分行|解放军|发电厂|供电局|管理所|供电公司|卷烟厂|机务段|研究[院所]|油厂|调查局|调查中心|出版社|电视台|监狱|水厂|服务站|信用合作联社|信用社|交易所|交易中心|交易中心党校|科学院|测绘所|运输厅|管理处|局|中心|机关|部门?|处|科|厂|集团|图书馆|馆|所|厅|楼|区|酒店|场|基地|矿|餐厅|酒店")
  200. result = []
  201. index_unk = self.word2index.get("<unk>")
  202. # index_pad = self.word2index.get("<pad>")
  203. if list_entitys is None:
  204. list_entitys = [[] for _ in range(len(list_sentences))]
  205. for list_sentence,list_entity in zip(list_sentences,list_entitys):
  206. if len(list_sentence)==0:
  207. result.append([{"code":[],"name":""}])
  208. continue
  209. doc_id = list_sentence[0].doc_id
  210. # sentences = []
  211. # for sentence in list_sentence:
  212. # if len(sentence.sentence_text)>MAX_AREA:
  213. # for _sentence_comma in re.split("[;;,\n]",sentence):
  214. # _comma_index = 0
  215. # while(_comma_index<len(_sentence_comma)):
  216. # sentences.append(_sentence_comma[_comma_index:_comma_index+MAX_AREA])
  217. # _comma_index += MAX_AREA
  218. # else:
  219. # sentences.append(sentence+"。")
  220. list_sentence.sort(key=lambda x:len(x.sentence_text),reverse=True)
  221. _begin_index = 0
  222. item = {"code":[],"name":""}
  223. code_set = set()
  224. dict_name_freq_score = dict()
  225. while(True):
  226. MAX_LEN = len(list_sentence[_begin_index].sentence_text)
  227. if MAX_LEN>MAX_AREA:
  228. MAX_LEN = MAX_AREA
  229. _LEN = MAX_AREA//MAX_LEN
  230. #预测
  231. x = [[self.word2index.get(word,index_unk)for word in sentence.sentence_text[:MAX_AREA]]for sentence in list_sentence[_begin_index:_begin_index+_LEN]]
  232. # x = [[getIndexOfWord(word) for word in sentence.sentence_text[:MAX_AREA]]for sentence in list_sentence[_begin_index:_begin_index+_LEN]]
  233. x_len = [len(_x) if len(_x) < MAX_LEN else MAX_LEN for _x in x]
  234. x = pad_sequences(x,maxlen=MAX_LEN,padding="post",truncating="post")
  235. if USE_PAI_EAS:
  236. request = tf_predict_pb2.PredictRequest()
  237. request.inputs["inputs"].dtype = tf_predict_pb2.DT_INT32
  238. request.inputs["inputs"].array_shape.dim.extend(np.shape(x))
  239. request.inputs["inputs"].int_val.extend(np.array(x,dtype=np.int32).reshape(-1))
  240. request_data = request.SerializeToString()
  241. list_outputs = ["outputs"]
  242. _result = vpc_requests(codename_url, codename_authorization, request_data, list_outputs)
  243. if _result is not None:
  244. predict_y = _result["outputs"]
  245. else:
  246. with self.sess_codename.as_default():
  247. t_input,t_output = self.getModel()
  248. predict_y = self.sess_codename.run(t_output,feed_dict={t_input:x})
  249. else:
  250. with self.sess_codename.as_default():
  251. t_input,t_input_length,t_keepprob,t_logits,t_trans = self.getModel()
  252. _logits,_trans = self.sess_codename.run([t_logits,t_trans],feed_dict={t_input:x,
  253. t_input_length:x_len,
  254. t_keepprob:1.0})
  255. predict_y = self.decode(_logits,_trans,x_len,7)
  256. # print('==========',_logits)
  257. '''
  258. for item11 in np.argmax(predict_y,-1):
  259. print(item11)
  260. print(predict_y)
  261. '''
  262. # print(predict_y)
  263. for sentence,predict in zip(list_sentence[_begin_index:_begin_index+_LEN],np.array(predict_y)):
  264. pad_sentence = sentence.sentence_text[:MAX_LEN]
  265. join_predict = "".join([str(s) for s in predict])
  266. # print(pad_sentence)
  267. # print(join_predict)
  268. code_x = []
  269. code_text = []
  270. temp_entitys = []
  271. for iter in re.finditer(self.PC_pattern,join_predict):
  272. get_len = 40
  273. if iter.span()[0]<get_len:
  274. begin = 0
  275. else:
  276. begin = iter.span()[0]-get_len
  277. end = iter.span()[1]+get_len
  278. code_x.append(embedding_word([pad_sentence[begin:iter.span()[0]],pad_sentence[iter.span()[0]:iter.span()[1]],pad_sentence[iter.span()[1]:end]],shape=(3,get_len,60)))
  279. code_text.append(pad_sentence[iter.span()[0]:iter.span()[1]])
  280. _entity = Entity(doc_id=sentence.doc_id,entity_id="%s_%s_%s_%s"%(sentence.doc_id,sentence.sentence_index,iter.span()[0],iter.span()[1]),entity_text=pad_sentence[iter.span()[0]:iter.span()[1]],entity_type="code",sentence_index=sentence.sentence_index,begin_index=0,end_index=0,wordOffset_begin=iter.span()[0],wordOffset_end=iter.span()[1])
  281. temp_entitys.append(_entity)
  282. #print("code",code_text)
  283. if len(code_x)>0:
  284. code_x = np.transpose(np.array(code_x,dtype=np.float32),(1,0,2,3))
  285. if USE_PAI_EAS:
  286. request = tf_predict_pb2.PredictRequest()
  287. request.inputs["input0"].dtype = tf_predict_pb2.DT_FLOAT
  288. request.inputs["input0"].array_shape.dim.extend(np.shape(code_x[0]))
  289. request.inputs["input0"].float_val.extend(np.array(code_x[0],dtype=np.float64).reshape(-1))
  290. request.inputs["input1"].dtype = tf_predict_pb2.DT_FLOAT
  291. request.inputs["input1"].array_shape.dim.extend(np.shape(code_x[1]))
  292. request.inputs["input1"].float_val.extend(np.array(code_x[1],dtype=np.float64).reshape(-1))
  293. request.inputs["input2"].dtype = tf_predict_pb2.DT_FLOAT
  294. request.inputs["input2"].array_shape.dim.extend(np.shape(code_x[2]))
  295. request.inputs["input2"].float_val.extend(np.array(code_x[2],dtype=np.float64).reshape(-1))
  296. request_data = request.SerializeToString()
  297. list_outputs = ["outputs"]
  298. _result = vpc_requests(codeclasses_url, codeclasses_authorization, request_data, list_outputs)
  299. if _result is not None:
  300. predict_code = _result["outputs"]
  301. else:
  302. with self.sess_codesplit.as_default():
  303. with self.sess_codesplit.graph.as_default():
  304. predict_code = self.getModel_code().predict([code_x[0],code_x[1],code_x[2]])
  305. else:
  306. with self.sess_codesplit.as_default():
  307. with self.sess_codesplit.graph.as_default():
  308. inputs_code,outputs_code = self.getModel_code()
  309. predict_code = limitRun(self.sess_codesplit,[outputs_code],feed_dict={inputs_code[0]:code_x[0],inputs_code[1]:code_x[1],inputs_code[2]:code_x[2]},MAX_BATCH=2)[0]
  310. #predict_code = self.sess_codesplit.run(outputs_code,feed_dict={inputs_code[0]:code_x[0],inputs_code[1]:code_x[1],inputs_code[2]:code_x[2]})
  311. #predict_code = self.getModel_code().predict([code_x[0],code_x[1],code_x[2]])
  312. for h in range(len(predict_code)):
  313. if predict_code[h][0]>0.5:
  314. the_code = self.fitDataByRule(code_text[h])
  315. #add code to entitys
  316. list_entity.append(temp_entitys[h])
  317. if the_code not in code_set:
  318. code_set.add(the_code)
  319. item['code'] = list(code_set)
  320. for iter in re.finditer(self.PN_pattern,join_predict):
  321. _name = self.fitDataByRule(pad_sentence[iter.span()[0]:iter.span()[1]])
  322. #add name to entitys
  323. _entity = Entity(doc_id=sentence.doc_id,entity_id="%s_%s_%s_%s"%(sentence.doc_id,sentence.sentence_index,iter.span()[0],iter.span()[1]),entity_text=_name,entity_type="name",sentence_index=sentence.sentence_index,begin_index=0,end_index=0,wordOffset_begin=iter.span()[0],wordOffset_end=iter.span()[1])
  324. list_entity.append(_entity)
  325. w = 1 if re.search('(项目|工程|招标|合同|标项|标的|计划|询价|询价单|询价通知书|申购)(名称|标题|主题)[::\s]', pad_sentence[iter.span()[0]-10:iter.span()[0]])!=None else 0.5
  326. if _name not in dict_name_freq_score:
  327. # dict_name_freq_score[_name] = [1,len(re.findall(pattern_score,_name))+len(_name)*0.1]
  328. dict_name_freq_score[_name] = [1, (len(re.findall(pattern_score, _name)) + len(_name) * 0.05)*w]
  329. else:
  330. dict_name_freq_score[_name][0] += 1
  331. '''
  332. for iter in re.finditer(self.PN_pattern,join_predict):
  333. print("name-",self.fitDataByRule(pad_sentence[iter.span()[0]:iter.span()[1]]))
  334. if item[1]['name']=="":
  335. for iter in re.finditer(self.PN_pattern,join_predict):
  336. #item[1]['name']=item[1]['name']+";"+self.fitDataByRule(pad_sentence[iter.span()[0]:iter.span()[1]])
  337. item[1]['name']=self.fitDataByRule(pad_sentence[iter.span()[0]:iter.span()[1]])
  338. break
  339. '''
  340. if _begin_index+_LEN>=len(list_sentence):
  341. break
  342. _begin_index += _LEN
  343. list_name_freq_score = []
  344. # 2020/11/23 大网站规则调整
  345. if len(dict_name_freq_score) == 0:
  346. name_re1 = '(项目|工程|招标|合同|标项|标的|计划|询价|询价单|询价通知书|申购)(名称|标题|主题)[::\s]+([^,。:;]{2,60})[,。]'
  347. for sentence in list_sentence:
  348. # pad_sentence = sentence.sentence_text
  349. othername = re.search(name_re1, sentence.sentence_text)
  350. if othername != None:
  351. project_name = othername.group(3)
  352. beg = find_index([project_name], sentence.sentence_text)[0]
  353. end = beg + len(project_name)
  354. _name = self.fitDataByRule(sentence.sentence_text[beg:end])
  355. # add name to entitys
  356. _entity = Entity(doc_id=sentence.doc_id, entity_id="%s_%s_%s_%s" % (
  357. sentence.doc_id, sentence.sentence_index, beg, end), entity_text=_name,
  358. entity_type="name", sentence_index=sentence.sentence_index, begin_index=0,
  359. end_index=0, wordOffset_begin=beg, wordOffset_end=end)
  360. list_entity.append(_entity)
  361. w = 1
  362. if _name not in dict_name_freq_score:
  363. # dict_name_freq_score[_name] = [1,len(re.findall(pattern_score,_name))+len(_name)*0.1]
  364. dict_name_freq_score[_name] = [1, (len(re.findall(pattern_score, _name)) + len(_name) * 0.05) * w]
  365. else:
  366. dict_name_freq_score[_name][0] += 1
  367. # othername = re.search(name_re1, sentence.sentence_text)
  368. # if othername != None:
  369. # _name = othername.group(3)
  370. # if _name not in dict_name_freq_score:
  371. # dict_name_freq_score[_name] = [1, len(re.findall(pattern_score, _name)) + len(_name) * 0.1]
  372. # else:
  373. # dict_name_freq_score[_name][0] += 1
  374. for _name in dict_name_freq_score.keys():
  375. list_name_freq_score.append([_name,dict_name_freq_score[_name]])
  376. # print(list_name_freq_score)
  377. if len(list_name_freq_score)>0:
  378. list_name_freq_score.sort(key=lambda x:x[1][0]*x[1][1],reverse=True)
  379. item['name'] = list_name_freq_score[0][0]
  380. # if list_name_freq_score[0][1][0]>1:
  381. # item[1]['name'] = list_name_freq_score[0][0]
  382. # else:
  383. # list_name_freq_score.sort(key=lambda x:x[1][1],reverse=True)
  384. # item[1]["name"] = list_name_freq_score[0][0]
  385. #下面代码加上去用正则添加某些识别不到的项目编号
  386. if item['code'] == []:
  387. for sentence in list_sentence:
  388. # othercode = re.search('(采购计划编号|询价编号)[\))]?[::]?([\[\]a-zA-Z0-9\-]{5,30})', sentence.sentence_text)
  389. # if othercode != None:
  390. # item[1]['code'].append(othercode.group(2))
  391. # 2020/11/23 大网站规则调整
  392. othercode = re.search('(项目|采购|招标|品目|询价|竞价|询价单|磋商|订单|账单|交易|文件|计划|场次|标的|标段|标包|分包|标段\(包\)|招标文件|合同|通知书|公告)(单号|编号|标号|编码|代码|备案号|号)[::\s]+([^,。;:、]{8,30}[a-zA-Z0-9\号])[\),。]', sentence.sentence_text)
  393. if othercode != None:
  394. item['code'].append(othercode.group(3))
  395. result.append(item)
  396. list_sentence.sort(key=lambda x: x.sentence_index,reverse=False)
  397. return result
  398. '''
  399. #当数据量过大时会报错
  400. def predict(self,articles,MAX_LEN = None):
  401. sentences = []
  402. for article in articles:
  403. for sentence in article.content.split("。"):
  404. sentences.append([sentence,article.id])
  405. if MAX_LEN is None:
  406. sent_len = [len(sentence[0]) for sentence in sentences]
  407. MAX_LEN = max(sent_len)
  408. #print(MAX_LEN)
  409. #若为空,则直接返回空
  410. result = []
  411. if MAX_LEN==0:
  412. for article in articles:
  413. result.append([article.id,{"code":[],"name":""}])
  414. return result
  415. index_unk = self.word2index.get("<unk>")
  416. index_pad = self.word2index.get("<pad>")
  417. x = [[self.word2index.get(word,index_unk)for word in sentence[0]]for sentence in sentences]
  418. x = pad_sequences(x,maxlen=MAX_LEN,padding="post",truncating="post")
  419. predict_y = self.getModel().predict(x)
  420. last_doc_id = ""
  421. item = []
  422. for sentence,predict in zip(sentences,np.argmax(predict_y,-1)):
  423. pad_sentence = sentence[0][:MAX_LEN]
  424. doc_id = sentence[1]
  425. join_predict = "".join([str(s) for s in predict])
  426. if doc_id!=last_doc_id:
  427. if last_doc_id!="":
  428. result.append(item)
  429. item = [doc_id,{"code":[],"name":""}]
  430. code_set = set()
  431. code_x = []
  432. code_text = []
  433. for iter in re.finditer(self.PC_pattern,join_predict):
  434. get_len = 40
  435. if iter.span()[0]<get_len:
  436. begin = 0
  437. else:
  438. begin = iter.span()[0]-get_len
  439. end = iter.span()[1]+get_len
  440. code_x.append(embedding_word([pad_sentence[begin:iter.span()[0]],pad_sentence[iter.span()[0]:iter.span()[1]],pad_sentence[iter.span()[1]:end]],shape=(3,get_len,60)))
  441. code_text.append(pad_sentence[iter.span()[0]:iter.span()[1]])
  442. if len(code_x)>0:
  443. code_x = np.transpose(np.array(code_x),(1,0,2,3))
  444. predict_code = self.getModel_code().predict([code_x[0],code_x[1],code_x[2]])
  445. for h in range(len(predict_code)):
  446. if predict_code[h][0]>0.5:
  447. the_code = self.fitDataByRule(code_text[h])
  448. if the_code not in code_set:
  449. code_set.add(the_code)
  450. item[1]['code'] = list(code_set)
  451. if item[1]['name']=="":
  452. for iter in re.finditer(self.PN_pattern,join_predict):
  453. #item[1]['name']=item[1]['name']+";"+self.fitDataByRule(pad_sentence[iter.span()[0]:iter.span()[1]])
  454. item[1]['name']=self.fitDataByRule(pad_sentence[iter.span()[0]:iter.span()[1]])
  455. break
  456. last_doc_id = doc_id
  457. result.append(item)
  458. return result
  459. '''
  460. #角色金额模型
  461. class PREMPredict():
  462. def __init__(self):
  463. #self.model_role_file = os.path.abspath("../role/models/model_role.model.hdf5")
  464. self.model_role_file = os.path.dirname(__file__)+"/../role/log/new_biLSTM-ep012-loss0.028-val_loss0.040-f10.954.h5"
  465. self.model_role = Model_role_classify_word()
  466. self.model_money = Model_money_classify()
  467. return
  468. def search_role_data(self,list_sentences,list_entitys):
  469. '''
  470. @summary:根据句子list和实体list查询角色模型的输入数据
  471. @param:
  472. list_sentences:文章的sentences
  473. list_entitys:文章的entitys
  474. @return:角色模型的输入数据
  475. '''
  476. data_x = []
  477. points_entitys = []
  478. for list_entity,list_sentence in zip(list_entitys,list_sentences):
  479. p_entitys = 0
  480. p_sentences = 0
  481. while(p_entitys<len(list_entity)):
  482. entity = list_entity[p_entitys]
  483. if entity.entity_type in ['org','company']:
  484. while(p_sentences<len(list_sentence)):
  485. sentence = list_sentence[p_sentences]
  486. if entity.doc_id==sentence.doc_id and entity.sentence_index==sentence.sentence_index:
  487. #item_x = embedding(spanWindow(tokens=sentence.tokens,begin_index=entity.begin_index,end_index=entity.end_index,size=settings.MODEL_ROLE_INPUT_SHAPE[1]),shape=settings.MODEL_ROLE_INPUT_SHAPE)
  488. item_x = self.model_role.encode(tokens=sentence.tokens,begin_index=entity.begin_index,end_index=entity.end_index,entity_text=entity.entity_text)
  489. data_x.append(item_x)
  490. points_entitys.append(entity)
  491. break
  492. p_sentences += 1
  493. p_entitys += 1
  494. if len(points_entitys)==0:
  495. return None
  496. return [data_x,points_entitys]
  497. def search_money_data(self,list_sentences,list_entitys):
  498. '''
  499. @summary:根据句子list和实体list查询金额模型的输入数据
  500. @param:
  501. list_sentences:文章的sentences
  502. list_entitys:文章的entitys
  503. @return:金额模型的输入数据
  504. '''
  505. data_x = []
  506. points_entitys = []
  507. for list_entity,list_sentence in zip(list_entitys,list_sentences):
  508. p_entitys = 0
  509. while(p_entitys<len(list_entity)):
  510. entity = list_entity[p_entitys]
  511. if entity.entity_type=="money":
  512. p_sentences = 0
  513. while(p_sentences<len(list_sentence)):
  514. sentence = list_sentence[p_sentences]
  515. if entity.doc_id==sentence.doc_id and entity.sentence_index==sentence.sentence_index:
  516. #item_x = embedding(spanWindow(tokens=sentence.tokens,begin_index=entity.begin_index,end_index=entity.end_index,size=settings.MODEL_MONEY_INPUT_SHAPE[1]),shape=settings.MODEL_MONEY_INPUT_SHAPE)
  517. #item_x = embedding_word(spanWindow(tokens=sentence.tokens, begin_index=entity.begin_index, end_index=entity.end_index, size=10, center_include=True, word_flag=True),shape=settings.MODEL_MONEY_INPUT_SHAPE)
  518. item_x = self.model_money.encode(tokens=sentence.tokens,begin_index=entity.begin_index,end_index=entity.end_index)
  519. data_x.append(item_x)
  520. points_entitys.append(entity)
  521. break
  522. p_sentences += 1
  523. p_entitys += 1
  524. if len(points_entitys)==0:
  525. return None
  526. return [data_x,points_entitys]
  527. def predict_role(self,list_sentences, list_entitys):
  528. datas = self.search_role_data(list_sentences, list_entitys)
  529. if datas is None:
  530. return
  531. points_entitys = datas[1]
  532. if USE_PAI_EAS:
  533. _data = datas[0]
  534. _data = np.transpose(np.array(_data),(1,0,2))
  535. request = tf_predict_pb2.PredictRequest()
  536. request.inputs["input0"].dtype = tf_predict_pb2.DT_FLOAT
  537. request.inputs["input0"].array_shape.dim.extend(np.shape(_data[0]))
  538. request.inputs["input0"].float_val.extend(np.array(_data[0],dtype=np.float64).reshape(-1))
  539. request.inputs["input1"].dtype = tf_predict_pb2.DT_FLOAT
  540. request.inputs["input1"].array_shape.dim.extend(np.shape(_data[1]))
  541. request.inputs["input1"].float_val.extend(np.array(_data[1],dtype=np.float64).reshape(-1))
  542. request.inputs["input2"].dtype = tf_predict_pb2.DT_FLOAT
  543. request.inputs["input2"].array_shape.dim.extend(np.shape(_data[2]))
  544. request.inputs["input2"].float_val.extend(np.array(_data[2],dtype=np.float64).reshape(-1))
  545. request_data = request.SerializeToString()
  546. list_outputs = ["outputs"]
  547. _result = vpc_requests(role_url, role_authorization, request_data, list_outputs)
  548. if _result is not None:
  549. predict_y = _result["outputs"]
  550. else:
  551. predict_y = self.model_role.predict(datas[0])
  552. else:
  553. predict_y = self.model_role.predict(np.array(datas[0],dtype=np.float64))
  554. for i in range(len(predict_y)):
  555. entity = points_entitys[i]
  556. label = np.argmax(predict_y[i])
  557. values = []
  558. for item in predict_y[i]:
  559. values.append(item)
  560. entity.set_Role(label,values)
  561. def predict_money(self,list_sentences,list_entitys):
  562. datas = self.search_money_data(list_sentences, list_entitys)
  563. if datas is None:
  564. return
  565. points_entitys = datas[1]
  566. _data = datas[0]
  567. if USE_PAI_EAS:
  568. _data = np.transpose(np.array(_data),(1,0,2,3))
  569. request = tf_predict_pb2.PredictRequest()
  570. request.inputs["input0"].dtype = tf_predict_pb2.DT_FLOAT
  571. request.inputs["input0"].array_shape.dim.extend(np.shape(_data[0]))
  572. request.inputs["input0"].float_val.extend(np.array(_data[0],dtype=np.float64).reshape(-1))
  573. request.inputs["input1"].dtype = tf_predict_pb2.DT_FLOAT
  574. request.inputs["input1"].array_shape.dim.extend(np.shape(_data[1]))
  575. request.inputs["input1"].float_val.extend(np.array(_data[1],dtype=np.float64).reshape(-1))
  576. request.inputs["input2"].dtype = tf_predict_pb2.DT_FLOAT
  577. request.inputs["input2"].array_shape.dim.extend(np.shape(_data[2]))
  578. request.inputs["input2"].float_val.extend(np.array(_data[2],dtype=np.float64).reshape(-1))
  579. request_data = request.SerializeToString()
  580. list_outputs = ["outputs"]
  581. _result = vpc_requests(money_url, money_authorization, request_data, list_outputs)
  582. if _result is not None:
  583. predict_y = _result["outputs"]
  584. else:
  585. predict_y = self.model_money.predict(_data)
  586. else:
  587. predict_y = self.model_money.predict(_data)
  588. for i in range(len(predict_y)):
  589. entity = points_entitys[i]
  590. label = np.argmax(predict_y[i])
  591. values = []
  592. for item in predict_y[i]:
  593. values.append(item)
  594. entity.set_Money(label,values)
  595. def predict(self,list_sentences,list_entitys):
  596. self.predict_role(list_sentences,list_entitys)
  597. self.predict_money(list_sentences,list_entitys)
  598. #联系人模型
  599. class EPCPredict():
  600. def __init__(self):
  601. self.model_person = Model_person_classify()
  602. def search_person_data(self,list_sentences,list_entitys):
  603. '''
  604. @summary:根据句子list和实体list查询联系人模型的输入数据
  605. @param:
  606. list_sentences:文章的sentences
  607. list_entitys:文章的entitys
  608. @return:联系人模型的输入数据
  609. '''
  610. def phoneFromList(phones):
  611. for phone in phones:
  612. if len(phone)==11:
  613. return re.sub('电话[:|:]|联系方式[:|:]','',phone)
  614. return phones[0]
  615. data_x = []
  616. dianhua = []
  617. points_entitys = []
  618. for list_entity,list_sentence in zip(list_entitys,list_sentences):
  619. p_entitys = 0
  620. p_sentences = 0
  621. key_word = re.compile('电话[:|:]\d{7,12}|联系方式[:|:]\d{7,12}')
  622. # phone = re.compile('1[3|4|5|7|8][0-9][-—-]?\d{4}[-—-]?\d{4}|\d{3,4}[-—-]\d{7,8}/\d{3,8}|\d{3,4}[-—-]\d{7,8}转\d{1,4}|\d{3,4}[-—-]\d{7,8}|[\(|\(]0\d{2,3}[\)|\)]-?\d{7,8}-?\d{,4}') # 联系电话
  623. # 2020/11/25 增加发现的号码段
  624. phone = re.compile('1[3|4|5|6|7|8|9][0-9][-—-]?\d{4}[-—-]?\d{4}|\d{3,4}[-—-]\d{7,8}/\d{3,8}|\d{3,4}[-—-]\d{7,8}转\d{1,4}|\d{3,4}[-—-]\d{7,8}|[\(|\(]0\d{2,3}[\)|\)]-?\d{7,8}-?\d{,4}') # 联系电话
  625. dict_index_sentence = {}
  626. for _sentence in list_sentence:
  627. dict_index_sentence[_sentence.sentence_index] = _sentence
  628. dict_context_itemx = {}
  629. while(p_entitys<len(list_entity)):
  630. entity = list_entity[p_entitys]
  631. if entity.entity_type=="person":
  632. sentence = dict_index_sentence[entity.sentence_index]
  633. #item_x = embedding(spanWindow(tokens=sentence.tokens,begin_index=entity.begin_index,end_index=entity.end_index,size=settings.MODEL_PERSON_INPUT_SHAPE[1]),shape=settings.MODEL_PERSON_INPUT_SHAPE)
  634. s = spanWindow(tokens=sentence.tokens,begin_index=entity.begin_index,end_index=entity.end_index,size=20)
  635. _key = "".join(["".join(x) for x in s])
  636. if _key in dict_context_itemx:
  637. item_x = dict_context_itemx[_key][0]
  638. _dianhua = dict_context_itemx[_key][1]
  639. else:
  640. item_x = self.model_person.encode(tokens=sentence.tokens,begin_index=entity.begin_index,end_index=entity.end_index)
  641. s1 = ''.join(s[1])
  642. #s1 = re.sub(',)', '-', s1)
  643. s1 = re.sub('\s','',s1)
  644. have_key = re.findall(key_word, s1)
  645. have_phone = re.findall(phone, s1)
  646. s0 = ''.join(s[0])
  647. #s0 = re.sub(',)', '-', s0)
  648. s0 = re.sub('\s','',s0)
  649. have_key2 = re.findall(key_word, s0)
  650. have_phone2 = re.findall(phone, s0)
  651. s3 = ''.join(s[1])
  652. #s0 = re.sub(',)', '-', s0)
  653. s3 = re.sub(',|,|\s','',s3)
  654. have_key3 = re.findall(key_word, s3)
  655. have_phone3 = re.findall(phone, s3)
  656. s4 = ''.join(s[0])
  657. #s0 = re.sub(',)', '-', s0)
  658. s4 = re.sub(',|,|\s','',s0)
  659. have_key4 = re.findall(key_word, s4)
  660. have_phone4 = re.findall(phone, s4)
  661. _dianhua = ""
  662. if have_phone:
  663. _dianhua = phoneFromList(have_phone)
  664. elif have_key:
  665. _dianhua = phoneFromList(have_key)
  666. elif have_phone2:
  667. _dianhua = phoneFromList(have_phone2)
  668. elif have_key2:
  669. _dianhua =phoneFromList(have_key2)
  670. elif have_phone3:
  671. _dianhua = phoneFromList(have_phone3)
  672. elif have_key3:
  673. _dianhua = phoneFromList(have_key3)
  674. elif have_phone4:
  675. _dianhua = phoneFromList(have_phone4)
  676. elif have_key4:
  677. _dianhua = phoneFromList(have_key4)
  678. else:
  679. _dianhua = ""
  680. dict_context_itemx[_key] = [item_x,_dianhua]
  681. data_x.append(item_x)
  682. points_entitys.append(entity)
  683. dianhua.append(_dianhua)
  684. p_entitys += 1
  685. if len(points_entitys)==0:
  686. return None
  687. return [data_x,points_entitys,dianhua]
  688. def predict_person(self,list_sentences, list_entitys):
  689. datas = self.search_person_data(list_sentences, list_entitys)
  690. if datas is None:
  691. return
  692. points_entitys = datas[1]
  693. phone = datas[2]
  694. if USE_PAI_EAS:
  695. _data = datas[0]
  696. _data = np.transpose(np.array(_data),(1,0,2,3))
  697. request = tf_predict_pb2.PredictRequest()
  698. request.inputs["input0"].dtype = tf_predict_pb2.DT_FLOAT
  699. request.inputs["input0"].array_shape.dim.extend(np.shape(_data[0]))
  700. request.inputs["input0"].float_val.extend(np.array(_data[0],dtype=np.float64).reshape(-1))
  701. request.inputs["input1"].dtype = tf_predict_pb2.DT_FLOAT
  702. request.inputs["input1"].array_shape.dim.extend(np.shape(_data[1]))
  703. request.inputs["input1"].float_val.extend(np.array(_data[1],dtype=np.float64).reshape(-1))
  704. request_data = request.SerializeToString()
  705. list_outputs = ["outputs"]
  706. _result = vpc_requests(person_url, person_authorization, request_data, list_outputs)
  707. if _result is not None:
  708. predict_y = _result["outputs"]
  709. else:
  710. predict_y = self.model_person.predict(datas[0])
  711. else:
  712. predict_y = self.model_person.predict(datas[0])
  713. assert len(predict_y)==len(points_entitys)==len(phone)
  714. for i in range(len(predict_y)):
  715. entity = points_entitys[i]
  716. label = np.argmax(predict_y[i])
  717. values = []
  718. for item in predict_y[i]:
  719. values.append(item)
  720. phone_number = phone[i]
  721. entity.set_Person(label,values,phone_number)
  722. def predict(self,list_sentences,list_entitys):
  723. self.predict_person(list_sentences,list_entitys)
  724. #表格预测
  725. class FormPredictor():
  726. def __init__(self,lazyLoad=getLazyLoad()):
  727. self.model_file_line = os.path.dirname(__file__)+"/../form/model/model_form.model_line.hdf5"
  728. self.model_file_item = os.path.dirname(__file__)+"/../form/model/model_form.model_item.hdf5"
  729. self.model_form_item = Model_form_item()
  730. self.model_form_context = Model_form_context()
  731. self.model_dict = {"line":[None,self.model_file_line]}
  732. def getModel(self,type):
  733. if type=="item":
  734. return self.model_form_item
  735. elif type=="context":
  736. return self.model_form_context
  737. else:
  738. return self.getModel(type)
  739. def encode(self,data,**kwargs):
  740. return encodeInput([data], word_len=50, word_flag=True,userFool=False)[0]
  741. return encodeInput_form(data)
  742. def predict(self,form_datas,type):
  743. if type=="item":
  744. return self.model_form_item.predict(form_datas)
  745. elif type=="context":
  746. return self.model_form_context.predict(form_datas)
  747. else:
  748. return self.getModel(type).predict(form_datas)
  749. #角色规则
  750. #依据正则给所有无角色的实体赋予角色,给予等于阈值的最低概率
  751. class RoleRulePredictor():
  752. def __init__(self):
  753. self.pattern_tenderee_left = "(?P<tenderee_left>((采购|招标|项目|竞价|议价|需求|最终|建设|转让|招租|甲|议标|合同主体)(?:人|公司|单位|组织|用户|业主|方|部门)|文章来源|业主名称|需方)(是|为|信息|:|:|\s*$))"
  754. self.pattern_tenderee_center = "(?P<tenderee_center>(受.{,20}委托))"
  755. self.pattern_tenderee_right = "(?P<tenderee_right>(\((以下简称)?[\"”]?(招标|采购)(人|单位|机构)\)?)|(^[^.。,,::](采购|竞价|招标|施工|监理|中标|物资)(公告|公示|项目|结果|招标))|的.*正在进行询比价)"
  756. self.pattern_agency_left = "(?P<agency_left>(代理(?:人|机构|公司|单位|组织)|专业采购机构|集中采购机构|招标机构)(是|为|:|:|[,,]?\s*$)|(受.{,20}委托))"
  757. self.pattern_agency_right = "(?P<agency_right>(\((以下简称)?[\"”]?(代理)(人|单位|机构)\))|受.*委托)"
  758. # 2020//11/24 大网站规则 中标关键词添加 选定单位|指定的中介服务机构
  759. self.pattern_winTenderer_left = "(?P<winTenderer_left>((中标|中选|中价|乙|成交|承做|施工|供货|承包|竞得|受让)(候选)?(人|单位|机构|供应商|方|公司|厂商|商)|(供应商|供货商|服务商|选定单位|指定的中介服务机构)).{,4}[::是为].{,2}|(第[一1](名|((中标|中选|中价|成交)?(候选)?(人|单位|机构|供应商))))(是|为|:|:|\s*$)|((评审结果|名次|排名)[::]第?[一1]名?)|(单一来源(采购)?方式向.?$)|((中标|成交)(结果|信息))|(单一来源采购(供应商|供货商|服务商))|((分包|标包).*供应商|供应商名称|服务机构|供方[::]))"
  760. self.pattern_winTenderer_center = "(?P<winTenderer_center>第[一1].{,20}[是为]((中标|中选|中价|成交|施工)(人|单位|机构|供应商|公司)|供应商).{,4}[::是为])"
  761. self.pattern_winTenderer_right = "(?P<winTenderer_right>[是为\(]((采购(供应商|供货商|服务商)|(第[一1]|预)?(拟?(中标|中选|中价|成交)(候选)?(人|单位|机构|供应商|公司|厂商)))))"
  762. self.pattern_winTenderer_whole = "(?P<winTenderer_whole>贵公司.*以.*中标|最终由.*竞买成功|经.*[以由].*中标|成交供应商,成交供应商名称:|谈判结果:由.{5,20}供货)" # 2020//11/24 大网站规则 中标关键词添加 谈判结果:由.{5,20}供货
  763. self.pattern_winTenderer_location = "(中标|中选|中价|乙|成交|承做|施工|供货|承包|竞得|受让)(候选)?(人|单位|机构|供应商|方|公司|厂商|商)|(供应商|供货商|服务商).{,4}[::]?$|(第[一1](名|((中标|中选|中价|成交)?(候选)?(人|单位|机构|供应商))))(是|为|:|:|\s*$)|((评审结果|名次|排名)[::]第?[一1]名?)|(单一来源(采购)?方式向.?$)"
  764. self.pattern_secondTenderer_left = "(?P<secondTenderer_left>((第[二2](名|((中标|中选|中价|成交)(候选)?(人|单位|机构|供应商|公司))))(是|为|:|:|\s*$))|((评审结果|名次|排名)[::]第?[二2]名?))"
  765. self.pattern_secondTenderer_right = "(?P<secondTenderer_right>[是为\(]第[二2](名|(中标|中选|中价|成交)(候选)?(人|单位|机构|供应商|公司)))"
  766. self.pattern_thirdTenderer_left = "(?P<thirdTenderer_left>(第[三3](名|((中标|中选|中价|成交)(候选)?(人|单位|机构|供应商|公司))))|((评审结果|名次|排名)[::]第?[三3]名?))"
  767. self.pattern_thirdTenderer_right = "(?P<thirdTenderer_right>[是为\(]第[三3](名|(中标|中选|中价|成交)(候选)?(人|单位|机构|供应商|公司)))"
  768. self.dict_list_pattern = {"0":[["L",self.pattern_tenderee_left],
  769. ["C",self.pattern_tenderee_center],
  770. ["R",self.pattern_tenderee_right]],
  771. "1":[["L",self.pattern_agency_left],
  772. ["R",self.pattern_agency_right]],
  773. "2":[["L",self.pattern_winTenderer_left],
  774. ["C",self.pattern_winTenderer_center],
  775. ["R",self.pattern_winTenderer_right],
  776. ["W",self.pattern_winTenderer_whole]],
  777. "3":[["L",self.pattern_secondTenderer_left],
  778. ["R",self.pattern_secondTenderer_right]],
  779. "4":[["L",self.pattern_thirdTenderer_left],
  780. ["R",self.pattern_thirdTenderer_right]]}
  781. list_pattern = []
  782. for _k,_v in self.dict_list_pattern.items():
  783. for _d,_p in _v:
  784. list_pattern.append(_p)
  785. self.pattern_whole = "|".join(list_pattern)
  786. self.SET_NOT_TENDERER = set(["人民政府","人民法院","中华人民共和国","人民检察院","评标委员会","中国政府","中国海关"])
  787. self.pattern_money_tenderee = re.compile("投标最高限价|采购计划金额|项目预算|招标金额|采购金额|项目金额|建安费用|采购(单位|人)委托价|限价|拦标价|预算金额")
  788. self.pattern_money_tenderer = re.compile("((合同|成交|中标|应付款|交易|投标)[)\)]?(总?金额|结果|[单报]?价))|总价|标的基本情况")
  789. self.pattern_money_tenderer_whole = re.compile("(以金额.*中标)|中标供应商.*单价|以.*元中标")
  790. self.pattern_money_other = re.compile("代理费|服务费")
  791. self.pattern_pack = "(([^承](包|标[段号的包]|分?包|包组)编?号?|项目)[::]?[\((]?[0-9A-Za-z一二三四五六七八九十]{1,4})[^至]?|(第?[0-9A-Za-z一二三四五六七八九十]{1,4}(包号|标[段号的包]|分?包))|[0-9]个(包|标[段号的包]|分?包|包组)"
  792. def _check_input(self,text, ignore=False):
  793. if not text:
  794. return []
  795. if not isinstance(text, list):
  796. text = [text]
  797. null_index = [i for i, t in enumerate(text) if not t]
  798. if null_index and not ignore:
  799. raise Exception("null text in input ")
  800. return text
  801. def predict(self,list_articles,list_sentences,list_entitys,list_codenames,on_value = 0.5):
  802. for article,list_entity,list_sentence,list_codename in zip(list_articles,list_entitys,list_sentences,list_codenames):
  803. list_name = list_codename["name"]
  804. list_name = self._check_input(list_name)+[article.title]
  805. for p_entity in list_entity:
  806. if p_entity.entity_type in ["org","company"]:
  807. #将上下文包含标题的实体概率置为0.6,因为标题中的实体不一定是招标人
  808. if str(p_entity.label)=="0":
  809. find_flag = False
  810. for _sentence in list_sentence:
  811. if _sentence.sentence_index==p_entity.sentence_index:
  812. _span = spanWindow(tokens=_sentence.tokens,begin_index=p_entity.begin_index,end_index=p_entity.end_index,size=20,center_include=True,word_flag=True,text=p_entity.entity_text)
  813. for _name in list_name:
  814. if _name!="" and str(_span[1]+_span[2][:len(str(_name))]).find(_name)>=0:
  815. find_flag = True
  816. if p_entity.values[0]>on_value:
  817. p_entity.values[0] = 0.6+(p_entity.values[0]-0.6)/10
  818. if find_flag:
  819. continue
  820. #只解析角色为无的或者概率低于阈值的
  821. if p_entity.label is None:
  822. continue
  823. role_prob = float(p_entity.values[int(p_entity.label)])
  824. if role_prob<on_value or str(p_entity.label)=="5":
  825. #将标题中的实体置为招标人
  826. _list_name = self._check_input(list_name,ignore=True)
  827. find_flag = False
  828. for _name in _list_name:
  829. if str(_name).find(p_entity.entity_text)>=0:
  830. find_flag = True
  831. _label = 0
  832. p_entity.label = _label
  833. p_entity.values[int(_label)] = on_value
  834. break
  835. #若是实体在标题中,默认为招标人,不进行以下的规则匹配
  836. if find_flag:
  837. continue
  838. for s_index in range(len(list_sentence)):
  839. if p_entity.doc_id==list_sentence[s_index].doc_id and p_entity.sentence_index==list_sentence[s_index].sentence_index:
  840. tokens = list_sentence[s_index].tokens
  841. begin_index = p_entity.begin_index
  842. end_index = p_entity.end_index
  843. size = 15
  844. spans = spanWindow(tokens, begin_index, end_index, size, center_include=True, word_flag=True, use_text=False)
  845. #距离
  846. list_distance = [100,100,100,100,100]
  847. _flag = False
  848. #使用正则+距离解决冲突
  849. list_spans = [spans[0][-30:],spans[1],spans[2]]
  850. for _i_span in range(len(list_spans)):
  851. # print(list_spans[_i_span],p_entity.entity_text)
  852. for _iter in re.finditer(self.pattern_whole,list_spans[_i_span]):
  853. for _group,_v_group in _iter.groupdict().items():
  854. if _v_group is not None and _v_group!="":
  855. # print(_group,_v_group)
  856. _role = _group.split("_")[0]
  857. _direct = _group.split("_")[1]
  858. _label = {"tenderee":0,"agency":1,"winTenderer":2,"secondTenderer":3,"thirdTenderer":4}.get(_role)
  859. if _i_span==0 and _direct=="left":
  860. _flag = True
  861. _distance = abs((len(list_spans[_i_span])-_iter.span()[1]))
  862. list_distance[int(_label)] = min(_distance,list_distance[int(_label)])
  863. if _i_span==1 and _direct=="center":
  864. _flag = True
  865. _distance = abs((len(list_spans[_i_span])-_iter.span()[1]))
  866. list_distance[int(_label)] = min(_distance,list_distance[int(_label)])
  867. if _i_span==2 and _direct=="right":
  868. _flag = True
  869. _distance = _iter.span()[0]
  870. list_distance[int(_label)] = min(_distance,list_distance[int(_label)])
  871. # for _key in self.dict_list_pattern.keys():
  872. #
  873. # for pattern in self.dict_list_pattern[_key]:
  874. # if pattern[0]=="L":
  875. # for _iter in re.finditer(pattern[1], spans[0][-30:]):
  876. # _flag = True
  877. # if len(spans[0])-_iter.span()[1]<list_distance[int(_key)]:
  878. # list_distance[int(_key)] = len(spans[0])-_iter.span()[1]-(_iter.span()[1]-_iter.span()[0])
  879. #
  880. # if pattern[0]=="C":
  881. # if re.search(pattern[1],spans[0]) is None and re.search(pattern[1],spans[2]) is None and re.search(pattern[1],spans[0]+spans[1]+spans[2]) is not None:
  882. # _flag = True
  883. # list_distance[int(_key)] = 0
  884. #
  885. # if pattern[0]=="R":
  886. # for _iter in re.finditer(pattern[1], spans[2][:30]):
  887. # _flag = True
  888. # if _iter.span()[0]<list_distance[int(_key)]:
  889. # list_distance[int(_key)] = _iter.span()[0]
  890. # if pattern[0]=="W":
  891. # spans = spanWindow(tokens, begin_index, end_index, size=20, center_include=True, word_flag=True, use_text=False)
  892. # for _iter in re.finditer(pattern[1], "".join(spans)):
  893. # _flag = True
  894. # if _iter.span()[0]<list_distance[int(_key)]:
  895. # list_distance[int(_key)] = _iter.span()[0]
  896. # print("==",list_distance)
  897. #得到结果
  898. _label = np.argmin(list_distance)
  899. if _flag:
  900. # if _label==2 and min(list_distance[3:])<100:
  901. # _label += np.argmin(list_distance[3:])+1
  902. if _label in [2,3,4]:
  903. if p_entity.entity_type in ["company","org"]:
  904. p_entity.label = _label
  905. p_entity.values[int(_label)] = on_value+p_entity.values[int(_label)]/10
  906. else:
  907. p_entity.label = _label
  908. p_entity.values[int(_label)] = on_value+p_entity.values[int(_label)]/10
  909. # if p_entity.entity_type=="location":
  910. # for _sentence in list_sentence:
  911. # if _sentence.sentence_index==p_entity.sentence_index:
  912. # _span = spanWindow(tokens=_sentence.tokens,begin_index=p_entity.begin_index,end_index=p_entity.end_index,size=5,center_include=True,word_flag=True,text=p_entity.entity_text)
  913. # if re.search(self.pattern_winTenderer_location,_span[0][-10:]) is not None and re.search("地址|地点",_span[0]) is None:
  914. # p_entity.entity_type="company"
  915. # _label = "2"
  916. # p_entity.label = _label
  917. # p_entity.values = [0]*6
  918. # p_entity.values[int(_label)] = on_value
  919. #确定性强的特殊修改
  920. for s_index in range(len(list_sentence)):
  921. if p_entity.doc_id==list_sentence[s_index].doc_id and p_entity.sentence_index==list_sentence[s_index].sentence_index:
  922. tokens = list_sentence[s_index].tokens
  923. begin_index = p_entity.begin_index
  924. end_index = p_entity.end_index
  925. size = 15
  926. spans = spanWindow(tokens, begin_index, end_index, size, center_include=True, word_flag=True, use_text=False)
  927. #距离
  928. list_distance = [100,100,100,100,100]
  929. _flag = False
  930. for _key in self.dict_list_pattern.keys():
  931. for pattern in self.dict_list_pattern[_key]:
  932. if pattern[0]=="W":
  933. spans = spanWindow(tokens, begin_index, end_index, size=30, center_include=True, word_flag=True, use_text=False)
  934. for _iter in re.finditer(pattern[1], spans[0][-10:]+spans[1]+spans[2]):
  935. _flag = True
  936. if _iter.span()[0]<list_distance[int(_key)]:
  937. list_distance[int(_key)] = _iter.span()[0]
  938. #得到结果
  939. _label = np.argmin(list_distance)
  940. if _flag:
  941. if _label==2 and min(list_distance[3:])<100:
  942. _label += np.argmin(list_distance[3:])+1
  943. if _label in [2,3,4]:
  944. if p_entity.entity_type in ["company","org"]:
  945. p_entity.label = _label
  946. p_entity.values[int(_label)] = on_value+p_entity.values[int(_label)]/10
  947. else:
  948. p_entity.label = _label
  949. p_entity.values[int(_label)] = on_value+p_entity.values[int(_label)]/10
  950. if p_entity.entity_type in ["money"]:
  951. if str(p_entity.label)=="2":
  952. for _sentence in list_sentence:
  953. if _sentence.sentence_index==p_entity.sentence_index:
  954. _span = spanWindow(tokens=_sentence.tokens,begin_index=p_entity.begin_index,end_index=p_entity.end_index,size=20,center_include=True,word_flag=True,text=p_entity.entity_text)
  955. if re.search(self.pattern_money_tenderee,_span[0]) is not None and re.search(self.pattern_money_other,_span[0]) is None:
  956. p_entity.values[0] = 0.8+p_entity.values[0]/10
  957. p_entity.label = 0
  958. if re.search(self.pattern_money_tenderer,_span[0]) is not None:
  959. if re.search(self.pattern_money_other,_span[0]) is not None:
  960. if re.search(self.pattern_money_tenderer,_span[0]).span()[1]>re.search(self.pattern_money_other,_span[0]).span()[1]:
  961. p_entity.values[1] = 0.8+p_entity.values[1]/10
  962. p_entity.label = 1
  963. else:
  964. p_entity.values[1] = 0.8+p_entity.values[1]/10
  965. p_entity.label = 1
  966. if re.search(self.pattern_money_tenderer_whole,"".join(_span)) is not None and re.search(self.pattern_money_other,_span[0]) is None:
  967. p_entity.values[1] = 0.8+p_entity.values[1]/10
  968. p_entity.label = 1
  969. #增加招标金额扩展,招标金额+连续的未识别金额,并且都可以匹配到标段信息,则将为识别的金额设置为招标金额
  970. list_p = []
  971. state = 0
  972. for p_entity in list_entity:
  973. for _sentence in list_sentence:
  974. if _sentence.sentence_index==p_entity.sentence_index:
  975. _span = spanWindow(tokens=_sentence.tokens,begin_index=p_entity.begin_index,end_index=p_entity.end_index,size=20,center_include=True,word_flag=True,text=p_entity.entity_text)
  976. if state==2:
  977. for _p in list_p[1:]:
  978. _p.values[0] = 0.8+_p.values[0]/10
  979. _p.label = 0
  980. state = 0
  981. list_p = []
  982. if state==0:
  983. if p_entity.entity_type in ["money"]:
  984. if str(p_entity.label)=="0" and re.search(self.pattern_pack,_span[0]+"-"+_span[2]) is not None:
  985. state = 1
  986. list_p.append(p_entity)
  987. elif state==1:
  988. if p_entity.entity_type in ["money"]:
  989. if str(p_entity.label) in ["0","2"] and re.search(self.pattern_pack,_span[0]+"-"+_span[2]) is not None and re.search(self.pattern_money_other,_span[0]+"-"+_span[2]) is None and p_entity.sentence_index==list_p[0].sentence_index:
  990. list_p.append(p_entity)
  991. else:
  992. state = 2
  993. if len(list_p)>1:
  994. for _p in list_p[1:]:
  995. #print("==",_p.entity_text,_p.sentence_index,_p.label)
  996. _p.values[0] = 0.8+_p.values[0]/10
  997. _p.label = 0
  998. state = 0
  999. list_p = []
  1000. for p_entity in list_entity:
  1001. #将属于集合中的不可能是中标人的标签置为无
  1002. if p_entity.entity_text in self.SET_NOT_TENDERER:
  1003. p_entity.label=5
  1004. # 时间类别
  1005. class TimePredictor():
  1006. def __init__(self):
  1007. self.sess = tf.Session(graph=tf.Graph())
  1008. self.inputs_code = None
  1009. self.outputs_code = None
  1010. self.input_shape = (2,10,128)
  1011. self.load_model()
  1012. def load_model(self):
  1013. model_path = os.path.dirname(__file__)+'/timesplit_model'
  1014. if self.inputs_code is None:
  1015. log("get model of time")
  1016. with self.sess.as_default():
  1017. with self.sess.graph.as_default():
  1018. meta_graph_def = tf.saved_model.loader.load(self.sess, tags=["serve"], export_dir=model_path)
  1019. signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
  1020. signature_def = meta_graph_def.signature_def
  1021. self.inputs_code = []
  1022. self.inputs_code.append(
  1023. self.sess.graph.get_tensor_by_name(signature_def[signature_key].inputs["input0"].name))
  1024. self.inputs_code.append(
  1025. self.sess.graph.get_tensor_by_name(signature_def[signature_key].inputs["input1"].name))
  1026. self.outputs_code = self.sess.graph.get_tensor_by_name(signature_def[signature_key].outputs["outputs"].name)
  1027. return self.inputs_code, self.outputs_code
  1028. else:
  1029. return self.inputs_code, self.outputs_code
  1030. def search_time_data(self,list_sentences,list_entitys):
  1031. data_x = []
  1032. points_entitys = []
  1033. for list_sentence, list_entity in zip(list_sentences, list_entitys):
  1034. p_entitys = 0
  1035. p_sentences = 0
  1036. while(p_entitys<len(list_entity)):
  1037. entity = list_entity[p_entitys]
  1038. if entity.entity_type in ['time']:
  1039. while(p_sentences<len(list_sentence)):
  1040. sentence = list_sentence[p_sentences]
  1041. if entity.doc_id == sentence.doc_id and entity.sentence_index == sentence.sentence_index:
  1042. # left = sentence.sentence_text[max(0,entity.wordOffset_begin-self.input_shape[1]):entity.wordOffset_begin]
  1043. # right = sentence.sentence_text[entity.wordOffset_end:entity.wordOffset_end+self.input_shape[1]]
  1044. s = spanWindow(tokens=sentence.tokens,begin_index=entity.begin_index,end_index=entity.end_index,size=self.input_shape[1])
  1045. left = s[0]
  1046. right = s[1]
  1047. context = [left, right]
  1048. x = embedding(context, shape=self.input_shape)
  1049. data_x.append(x)
  1050. points_entitys.append(entity)
  1051. break
  1052. p_sentences += 1
  1053. p_entitys += 1
  1054. if len(points_entitys)==0:
  1055. return None
  1056. data_x = np.transpose(np.array(data_x), (1, 0, 2, 3))
  1057. return [data_x, points_entitys]
  1058. def predict(self, list_sentences,list_entitys):
  1059. datas = self.search_time_data(list_sentences, list_entitys)
  1060. if datas is None:
  1061. return
  1062. points_entitys = datas[1]
  1063. with self.sess.as_default():
  1064. predict_y = limitRun(self.sess,[self.outputs_code], feed_dict={self.inputs_code[0]:datas[0][0]
  1065. ,self.inputs_code[1]:datas[0][1]})[0]
  1066. for i in range(len(predict_y)):
  1067. entity = points_entitys[i]
  1068. label = np.argmax(predict_y[i])
  1069. values = []
  1070. for item in predict_y[i]:
  1071. values.append(item)
  1072. entity.set_Role(label, values)
  1073. # 产品字段提取
  1074. class ProductPredictor():
  1075. def __init__(self):
  1076. self.sess = tf.Session(graph=tf.Graph())
  1077. self.load_model()
  1078. def load_model(self):
  1079. model_path = os.path.dirname(__file__)+'/product_savedmodel/product.pb'
  1080. with self.sess.as_default():
  1081. with self.sess.graph.as_default():
  1082. output_graph_def = tf.GraphDef()
  1083. with open(model_path, 'rb') as f:
  1084. output_graph_def.ParseFromString(f.read())
  1085. tf.import_graph_def(output_graph_def, name='')
  1086. self.sess.run(tf.global_variables_initializer())
  1087. self.char_input = self.sess.graph.get_tensor_by_name('CharInputs:0')
  1088. self.length = self.sess.graph.get_tensor_by_name("Sum:0")
  1089. self.dropout = self.sess.graph.get_tensor_by_name("Dropout:0")
  1090. self.logit = self.sess.graph.get_tensor_by_name("logits/Reshape:0")
  1091. self.tran = self.sess.graph.get_tensor_by_name("crf_loss/transitions:0")
  1092. def predict(self, list_sentences,list_entitys=None, MAX_AREA=5000):
  1093. '''
  1094. 预测实体代码,每个句子最多取MAX_AREA个字,超过截断
  1095. :param list_sentences: 多篇公告句子列表,[[一篇公告句子列表],[公告句子列表]]
  1096. :param list_entitys: 多篇公告实体列表
  1097. :param MAX_AREA: 每个句子最多截取多少字
  1098. :return: 把预测出来的实体放进实体类
  1099. '''
  1100. with self.sess.as_default() as sess:
  1101. with self.sess.graph.as_default():
  1102. result = []
  1103. if list_entitys is None:
  1104. list_entitys = [[] for _ in range(len(list_sentences))]
  1105. for list_sentence, list_entity in zip(list_sentences,list_entitys):
  1106. if len(list_sentence)==0:
  1107. result.append({"product":[]})
  1108. continue
  1109. list_sentence.sort(key=lambda x:len(x.sentence_text), reverse=True)
  1110. _begin_index = 0
  1111. item = {"product":[]}
  1112. temp_list = []
  1113. while True:
  1114. MAX_LEN = len(list_sentence[_begin_index].sentence_text)
  1115. if MAX_LEN > MAX_AREA:
  1116. MAX_LEN = MAX_AREA
  1117. _LEN = MAX_AREA//MAX_LEN
  1118. chars = process_data([sentence.sentence_text[:MAX_LEN] for sentence in list_sentence[_begin_index:_begin_index+_LEN]])
  1119. lengths, scores, tran_ = sess.run([self.length, self.logit, self.tran],
  1120. feed_dict={
  1121. self.char_input: np.asarray(chars),
  1122. self.dropout: 1.0
  1123. })
  1124. batch_paths = decode(scores, lengths, tran_)
  1125. for sentence, path, length in zip(list_sentence[_begin_index:_begin_index+_LEN],batch_paths, lengths):
  1126. tags = ''.join([str(it) for it in path[:length]])
  1127. for it in re.finditer("12*3", tags):
  1128. start = it.start()
  1129. end = it.end()
  1130. _entity = Entity(doc_id=sentence.doc_id, entity_id="%s_%s_%s_%s" % (
  1131. sentence.doc_id, sentence.sentence_index, start, end),
  1132. entity_text=sentence.sentence_text[start:end],
  1133. entity_type="product", sentence_index=sentence.sentence_index,
  1134. begin_index=0, end_index=0, wordOffset_begin=start,
  1135. wordOffset_end=end)
  1136. list_entity.append(_entity)
  1137. temp_list.append(sentence.sentence_text[start:end])
  1138. item["product"] = list(set(temp_list))
  1139. result.append(item)
  1140. if _begin_index+_LEN >= len(list_sentence):
  1141. break
  1142. _begin_index += _LEN
  1143. return result
  1144. def getSavedModel():
  1145. #predictor = FormPredictor()
  1146. graph = tf.Graph()
  1147. with graph.as_default():
  1148. model = tf.keras.models.load_model("../form/model/model_form.model_item.hdf5",custom_objects={"precision":precision,"recall":recall,"f1_score":f1_score})
  1149. #print(tf.graph_util.remove_training_nodes(model))
  1150. tf.saved_model.simple_save(
  1151. tf.keras.backend.get_session(),
  1152. "./h5_savedmodel/",
  1153. inputs={"image": model.input},
  1154. outputs={"scores": model.output}
  1155. )
  1156. def getBiLSTMCRFModel(MAX_LEN,vocab,EMBED_DIM,BiRNN_UNITS,chunk_tags,weights):
  1157. '''
  1158. model = models.Sequential()
  1159. model.add(layers.Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding
  1160. model.add(layers.Bidirectional(layers.LSTM(BiRNN_UNITS // 2, return_sequences=True)))
  1161. crf = CRF(len(chunk_tags), sparse_target=True)
  1162. model.add(crf)
  1163. model.summary()
  1164. model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
  1165. return model
  1166. '''
  1167. input = layers.Input(shape=(None,),dtype="int32")
  1168. if weights is not None:
  1169. embedding = layers.embeddings.Embedding(len(vocab),EMBED_DIM,mask_zero=True,weights=[weights],trainable=True)(input)
  1170. else:
  1171. embedding = layers.embeddings.Embedding(len(vocab),EMBED_DIM,mask_zero=True)(input)
  1172. bilstm = layers.Bidirectional(layers.LSTM(BiRNN_UNITS//2,return_sequences=True))(embedding)
  1173. bilstm_dense = layers.TimeDistributed(layers.Dense(len(chunk_tags)))(bilstm)
  1174. crf = CRF(len(chunk_tags),sparse_target=True)
  1175. crf_out = crf(bilstm_dense)
  1176. model = models.Model(input=[input],output = [crf_out])
  1177. model.summary()
  1178. model.compile(optimizer = 'adam', loss = crf.loss_function, metrics = [crf.accuracy])
  1179. return model
  1180. from tensorflow.contrib.crf import crf_log_likelihood
  1181. from tensorflow.contrib.layers.python.layers import initializers
  1182. def BiLSTM_CRF_tfmodel(sess,embedding_weights):
  1183. '''
  1184. :param embedding_weights: 预训练的字向量矩阵
  1185. '''
  1186. BiRNN_Unit = 100
  1187. chunk_tags = {
  1188. 'O': 0,
  1189. 'PN_B': 1,
  1190. 'PN_M': 2,
  1191. 'PN_E': 3,
  1192. 'PC_B': 4,
  1193. 'PC_M': 5,
  1194. 'PC_E': 6,
  1195. }
  1196. def embedding_layer(input,keepprob):
  1197. # 加载预训练的字向量矩阵
  1198. embedding = tf.get_variable(name="embedding",initializer=np.array(embedding_weights, dtype=np.float32),dtype=tf.float32)
  1199. embedding = tf.nn.embedding_lookup(params=embedding,ids=input)
  1200. embedding_drop = tf.nn.dropout(embedding,keepprob)
  1201. return embedding_drop
  1202. def BiLSTM_Layer(input,length):
  1203. with tf.variable_scope("BiLSTM"):
  1204. forward_cell = tf.contrib.rnn.BasicLSTMCell(BiRNN_Unit,state_is_tuple=True)
  1205. backward_cell = tf.contrib.rnn.BasicLSTMCell(BiRNN_Unit,state_is_tuple=True)
  1206. output, _ = tf.nn.bidirectional_dynamic_rnn(forward_cell,backward_cell,input,dtype=tf.float32,sequence_length=length)
  1207. output = tf.concat(output,2)
  1208. return output
  1209. def CRF_layer(input,num_tags,BiRNN_Unit,time_step,keepprob):
  1210. with tf.variable_scope("CRF"):
  1211. with tf.variable_scope("hidden"):
  1212. w_hidden = tf.get_variable(name='w_hidden',shape=(BiRNN_Unit*2,BiRNN_Unit),dtype=tf.float32,
  1213. initializer=initializers.xavier_initializer(),regularizer=tf.contrib.layers.l2_regularizer(0.001))
  1214. b_hidden = tf.get_variable(name='b_hidden',shape=(BiRNN_Unit),dtype=tf.float32,initializer=tf.zeros_initializer())
  1215. # print(input)
  1216. input_reshape = tf.reshape(input,shape=(-1,BiRNN_Unit*2))
  1217. hidden = tf.tanh(tf.nn.xw_plus_b(input_reshape,w_hidden,b_hidden))
  1218. hidden = tf.nn.dropout(hidden,keepprob)
  1219. with tf.variable_scope("output"):
  1220. w_output = tf.get_variable(name='w_output',shape=(BiRNN_Unit,num_tags),dtype=tf.float32,initializer=initializers.xavier_initializer(),regularizer=tf.contrib.layers.l2_regularizer(0.001))
  1221. b_output = tf.get_variable(name='b_output',shape=(num_tags),dtype=tf.float32,initializer=tf.zeros_initializer())
  1222. pred = tf.nn.xw_plus_b(hidden,w_output,b_output)
  1223. logits_ = tf.reshape(pred,shape=(-1,time_step,num_tags),name='logits')
  1224. return logits_
  1225. def layer_loss(input,true_target,num_tags,length):
  1226. with tf.variable_scope("crf_loss"):
  1227. trans = tf.get_variable(name='transitons',shape=(num_tags,num_tags),dtype=tf.float32,initializer=initializers.xavier_initializer())
  1228. log_likelihood,trans = crf_log_likelihood(inputs=input,tag_indices=true_target,transition_params=trans,sequence_lengths=length)
  1229. return tf.reduce_mean(-log_likelihood),trans
  1230. with sess.graph.as_default():
  1231. char_input = tf.placeholder(name='char_input',shape=(None,None),dtype=tf.int32)
  1232. target = tf.placeholder(name='target',shape=(None,None),dtype=tf.int32)
  1233. length = tf.placeholder(name='length',shape=(None,),dtype=tf.int32)
  1234. keepprob = tf.placeholder(name='keepprob',dtype=tf.float32)
  1235. _embedding = embedding_layer(char_input,keepprob)
  1236. _shape = tf.shape(char_input)
  1237. batch_size = _shape[0]
  1238. step_size = _shape[-1]
  1239. bilstm = BiLSTM_Layer(_embedding,length)
  1240. _logits = CRF_layer(bilstm,num_tags=len(chunk_tags),BiRNN_Unit=BiRNN_Unit,time_step=step_size,keepprob=keepprob)
  1241. crf_loss,trans = layer_loss(_logits,true_target=target,num_tags=len(chunk_tags),length=length)
  1242. global_step = tf.Variable(0,trainable=False)
  1243. with tf.variable_scope("optimizer"):
  1244. opt = tf.train.AdamOptimizer(0.002)
  1245. grads_vars = opt.compute_gradients(crf_loss)
  1246. capped_grads_vars = [[tf.clip_by_value(g,-5,5),v] for g,v in grads_vars]
  1247. train_op = opt.apply_gradients(capped_grads_vars,global_step)
  1248. return char_input,_logits,target,keepprob,length,crf_loss,trans,train_op
  1249. import h5py
  1250. def h5_to_graph(sess,graph,h5file):
  1251. f = h5py.File(h5file,'r') #打开h5文件
  1252. def getValue(v):
  1253. _value = f["model_weights"]
  1254. list_names = str(v.name).split("/")
  1255. for _index in range(len(list_names)):
  1256. print(v.name)
  1257. if _index==1:
  1258. _value = _value[list_names[0]]
  1259. _value = _value[list_names[_index]]
  1260. return _value.value
  1261. def _load_attributes_from_hdf5_group(group, name):
  1262. """Loads attributes of the specified name from the HDF5 group.
  1263. This method deals with an inherent problem
  1264. of HDF5 file which is not able to store
  1265. data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
  1266. # Arguments
  1267. group: A pointer to a HDF5 group.
  1268. name: A name of the attributes to load.
  1269. # Returns
  1270. data: Attributes data.
  1271. """
  1272. if name in group.attrs:
  1273. data = [n.decode('utf8') for n in group.attrs[name]]
  1274. else:
  1275. data = []
  1276. chunk_id = 0
  1277. while ('%s%d' % (name, chunk_id)) in group.attrs:
  1278. data.extend([n.decode('utf8')
  1279. for n in group.attrs['%s%d' % (name, chunk_id)]])
  1280. chunk_id += 1
  1281. return data
  1282. def readGroup(gr,parent_name,data):
  1283. for subkey in gr:
  1284. print(subkey)
  1285. if parent_name!=subkey:
  1286. if parent_name=="":
  1287. _name = subkey
  1288. else:
  1289. _name = parent_name+"/"+subkey
  1290. else:
  1291. _name = parent_name
  1292. if str(type(gr[subkey]))=="<class 'h5py._hl.group.Group'>":
  1293. readGroup(gr[subkey],_name,data)
  1294. else:
  1295. data.append([_name,gr[subkey].value])
  1296. print(_name,gr[subkey].shape)
  1297. layer_names = _load_attributes_from_hdf5_group(f["model_weights"], 'layer_names')
  1298. list_name_value = []
  1299. readGroup(f["model_weights"], "", list_name_value)
  1300. '''
  1301. for k, name in enumerate(layer_names):
  1302. g = f["model_weights"][name]
  1303. weight_names = _load_attributes_from_hdf5_group(g, 'weight_names')
  1304. #weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
  1305. for weight_name in weight_names:
  1306. list_name_value.append([weight_name,np.asarray(g[weight_name])])
  1307. '''
  1308. for name_value in list_name_value:
  1309. name = name_value[0]
  1310. '''
  1311. if re.search("dense",name) is not None:
  1312. name = name[:7]+"_1"+name[7:]
  1313. '''
  1314. value = name_value[1]
  1315. print(name,graph.get_tensor_by_name(name),np.shape(value))
  1316. sess.run(tf.assign(graph.get_tensor_by_name(name),value))
  1317. def initialize_uninitialized(sess):
  1318. global_vars = tf.global_variables()
  1319. is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
  1320. not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
  1321. adam_vars = []
  1322. for _vars in not_initialized_vars:
  1323. if re.search("Adam",_vars.name) is not None:
  1324. adam_vars.append(_vars)
  1325. print([str(i.name) for i in adam_vars]) # only for testing
  1326. if len(adam_vars):
  1327. sess.run(tf.variables_initializer(adam_vars))
  1328. def save_codename_model():
  1329. # filepath = "../projectCode/models/model_project_"+str(60)+"_"+str(200)+".hdf5"
  1330. filepath = "../projectCode/models_tf/59-L0.471516189943-F0.8802154826344823-P0.8789179683459191-R0.8815168335321886/model.ckpt"
  1331. vocabpath = "../projectCode/models/vocab.pk"
  1332. classlabelspath = "../projectCode/models/classlabels.pk"
  1333. # vocab = load(vocabpath)
  1334. # class_labels = load(classlabelspath)
  1335. w2v_matrix = load('codename_w2v_matrix.pk')
  1336. graph = tf.get_default_graph()
  1337. with graph.as_default() as g:
  1338. ''''''
  1339. # model = getBiLSTMCRFModel(None, vocab, 60, 200, class_labels,weights=None)
  1340. #model = models.load_model(filepath,custom_objects={'precision':precision,'recall':recall,'f1_score':f1_score,"CRF":CRF,"loss":CRF.loss_function})
  1341. sess = tf.Session(graph=g)
  1342. # sess = tf.keras.backend.get_session()
  1343. char_input, logits, target, keepprob, length, crf_loss, trans, train_op = BiLSTM_CRF_tfmodel(sess, w2v_matrix)
  1344. #with sess.as_default():
  1345. sess.run(tf.global_variables_initializer())
  1346. # print(sess.run("time_distributed_1/kernel:0"))
  1347. # model.load_weights(filepath)
  1348. saver = tf.train.Saver()
  1349. saver.restore(sess, filepath)
  1350. # print("logits",sess.run(logits))
  1351. # print("#",sess.run("time_distributed_1/kernel:0"))
  1352. # x = load("codename_x.pk")
  1353. #y = model.predict(x)
  1354. # y = sess.run(model.output,feed_dict={model.input:x})
  1355. # for item in np.argmax(y,-1):
  1356. # print(item)
  1357. tf.saved_model.simple_save(
  1358. sess,
  1359. "./codename_savedmodel_tf/",
  1360. inputs={"inputs": char_input,
  1361. "inputs_length":length,
  1362. 'keepprob':keepprob},
  1363. outputs={"logits": logits,
  1364. "trans":trans}
  1365. )
  1366. def save_role_model():
  1367. '''
  1368. @summary: 保存model为savedModel,部署到PAI平台上调用
  1369. '''
  1370. model_role = PREMPredict().model_role
  1371. with model_role.graph.as_default():
  1372. model = model_role.getModel()
  1373. sess = tf.Session(graph=model_role.graph)
  1374. print(type(model.input))
  1375. sess.run(tf.global_variables_initializer())
  1376. h5_to_graph(sess, model_role.graph, model_role.model_role_file)
  1377. model = model_role.getModel()
  1378. tf.saved_model.simple_save(sess,
  1379. "./role_savedmodel/",
  1380. inputs={"input0":model.input[0],
  1381. "input1":model.input[1],
  1382. "input2":model.input[2]},
  1383. outputs={"outputs":model.output}
  1384. )
  1385. def save_money_model():
  1386. model_money = PREMPredict().model_money
  1387. with model_money.graph.as_default():
  1388. model = model_money.getModel()
  1389. sess = tf.Session(graph=model_money.graph)
  1390. model.summary()
  1391. sess.run(tf.global_variables_initializer())
  1392. h5_to_graph(sess, model_money.graph, model_money.model_money_file)
  1393. tf.saved_model.simple_save(sess,
  1394. "./money_savedmodel/",
  1395. inputs = {"input0":model.input[0],
  1396. "input1":model.input[1],
  1397. "input2":model.input[2]},
  1398. outputs = {"outputs":model.output}
  1399. )
  1400. def save_person_model():
  1401. model_person = EPCPredict().model_person
  1402. with model_person.graph.as_default():
  1403. x = load("person_x.pk")
  1404. _data = np.transpose(np.array(x),(1,0,2,3))
  1405. model = model_person.getModel()
  1406. sess = tf.Session(graph=model_person.graph)
  1407. with sess.as_default():
  1408. sess.run(tf.global_variables_initializer())
  1409. model_person.load_weights()
  1410. #h5_to_graph(sess, model_person.graph, model_person.model_person_file)
  1411. predict_y = sess.run(model.output,feed_dict={model.input[0]:_data[0],model.input[1]:_data[1]})
  1412. #predict_y = model.predict([_data[0],_data[1]])
  1413. print(np.argmax(predict_y,-1))
  1414. tf.saved_model.simple_save(sess,
  1415. "./person_savedmodel/",
  1416. inputs={"input0":model.input[0],
  1417. "input1":model.input[1]},
  1418. outputs = {"outputs":model.output})
  1419. def save_form_model():
  1420. model_form = FormPredictor()
  1421. with model_form.graph.as_default():
  1422. model = model_form.getModel("item")
  1423. sess = tf.Session(graph=model_form.graph)
  1424. sess.run(tf.global_variables_initializer())
  1425. h5_to_graph(sess, model_form.graph, model_form.model_file_item)
  1426. tf.saved_model.simple_save(sess,
  1427. "./form_savedmodel/",
  1428. inputs={"inputs":model.input},
  1429. outputs = {"outputs":model.output})
  1430. def save_codesplit_model():
  1431. filepath_code = "../projectCode/models/model_code.hdf5"
  1432. graph = tf.Graph()
  1433. with graph.as_default():
  1434. model_code = models.load_model(filepath_code, custom_objects={'precision':precision,'recall':recall,'f1_score':f1_score})
  1435. sess = tf.Session()
  1436. sess.run(tf.global_variables_initializer())
  1437. h5_to_graph(sess, graph, filepath_code)
  1438. tf.saved_model.simple_save(sess,
  1439. "./codesplit_savedmodel/",
  1440. inputs={"input0":model_code.input[0],
  1441. "input1":model_code.input[1],
  1442. "input2":model_code.input[2]},
  1443. outputs={"outputs":model_code.output})
  1444. def save_timesplit_model():
  1445. filepath = '../time/model_label_time_classify.model.hdf5'
  1446. with tf.Graph().as_default() as graph:
  1447. time_model = models.load_model(filepath, custom_objects={'precision': precision, 'recall': recall, 'f1_score': f1_score})
  1448. with tf.Session() as sess:
  1449. sess.run(tf.global_variables_initializer())
  1450. h5_to_graph(sess, graph, filepath)
  1451. tf.saved_model.simple_save(sess,
  1452. "./timesplit_model/",
  1453. inputs={"input0":time_model.input[0],
  1454. "input1":time_model.input[1]},
  1455. outputs={"outputs":time_model.output})
  1456. if __name__=="__main__":
  1457. #save_role_model()
  1458. # save_codename_model()
  1459. #save_money_model()
  1460. #save_person_model()
  1461. #save_form_model()
  1462. #save_codesplit_model()
  1463. # save_timesplit_model()
  1464. '''
  1465. with tf.Session(graph=tf.Graph()) as sess:
  1466. from tensorflow.python.saved_model import tag_constants
  1467. meta_graph_def = tf.saved_model.loader.load(sess, [tag_constants.SERVING], "./person_savedModel")
  1468. graph = tf.get_default_graph()
  1469. signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
  1470. signature = meta_graph_def.signature_def
  1471. input0 = sess.graph.get_tensor_by_name(signature[signature_key].inputs["input0"].name)
  1472. input1 = sess.graph.get_tensor_by_name(signature[signature_key].inputs["input1"].name)
  1473. outputs = sess.graph.get_tensor_by_name(signature[signature_key].outputs["outputs"].name)
  1474. x = load("person_x.pk")
  1475. _data = np.transpose(x,[1,0,2,3])
  1476. y = sess.run(outputs,feed_dict={input0:_data[0],input1:_data[1]})
  1477. print(np.argmax(y,-1))
  1478. '''