|
@@ -250,7 +250,6 @@ class CodeNamePredict():
|
|
|
|
|
|
def predict(self,list_sentences,list_entitys=None,MAX_AREA = 5000):
|
|
|
#@summary: 获取每篇文章的code和name
|
|
|
-
|
|
|
pattern_score = re.compile("工程|服务|采购|施工|项目|系统|招标|中标|公告|学校|[大中小]学校?|医院|公司|分公司|研究院|政府采购中心|学院|中心校?|办公室|政府|财[政务]局|办事处|委员会|[部总支]队|警卫局|幼儿园|党委|党校|银行|分行|解放军|发电厂|供电局|管理所|供电公司|卷烟厂|机务段|研究[院所]|油厂|调查局|调查中心|出版社|电视台|监狱|水厂|服务站|信用合作联社|信用社|交易所|交易中心|交易中心党校|科学院|测绘所|运输厅|管理处|局|中心|机关|部门?|处|科|厂|集团|图书馆|馆|所|厅|楼|区|酒店|场|基地|矿|餐厅|酒店")
|
|
|
|
|
|
result = []
|
|
@@ -291,20 +290,11 @@ class CodeNamePredict():
|
|
|
x_len = [len(_x) if len(_x) < MAX_LEN else MAX_LEN for _x in x]
|
|
|
x = pad_sequences(x,maxlen=MAX_LEN,padding="post",truncating="post")
|
|
|
|
|
|
- if USE_PAI_EAS:
|
|
|
- request = tf_predict_pb2.PredictRequest()
|
|
|
- request.inputs["inputs"].dtype = tf_predict_pb2.DT_INT32
|
|
|
- request.inputs["inputs"].array_shape.dim.extend(np.shape(x))
|
|
|
- request.inputs["inputs"].int_val.extend(np.array(x,dtype=np.int32).reshape(-1))
|
|
|
- request_data = request.SerializeToString()
|
|
|
- list_outputs = ["outputs"]
|
|
|
- _result = vpc_requests(codename_url, codename_authorization, request_data, list_outputs)
|
|
|
- if _result is not None:
|
|
|
- predict_y = _result["outputs"]
|
|
|
- else:
|
|
|
- with self.sess_codename.as_default():
|
|
|
- t_input,t_output = self.getModel()
|
|
|
- predict_y = self.sess_codename.run(t_output,feed_dict={t_input:x})
|
|
|
+ if USE_API:
|
|
|
+ requests_result = requests.post(API_URL + "/predict_codeName", json={"inouts": x.tolist(), "inouts_len": x_len},verify=True)
|
|
|
+ predict_y = json.loads(requests_result.text)['result']
|
|
|
+ # print("cost_time:", json.loads(requests_result.text)['cost_time'])
|
|
|
+ # print(MAX_LEN,_LEN,_begin_index)
|
|
|
else:
|
|
|
with self.sess_codename.as_default():
|
|
|
t_input,t_input_length,t_keepprob,t_logits,t_trans = self.getModel()
|
|
@@ -1816,12 +1806,18 @@ class ProductPredictor():
|
|
|
if fail and list_articles!=[]:
|
|
|
text_list = [list_articles[0].content[:MAX_AREA]]
|
|
|
chars = [[self.word2index.get(it, self.word2index.get('<unk>')) for it in text] for text in text_list]
|
|
|
- lengths, scores, tran_ = sess.run([self.length, self.logit, self.tran],
|
|
|
- feed_dict={
|
|
|
- self.char_input: np.asarray(chars),
|
|
|
- self.dropout: 1.0
|
|
|
- })
|
|
|
- batch_paths = self.decode(scores, lengths, tran_)
|
|
|
+ if USE_API:
|
|
|
+ requests_result = requests.post(API_URL + "/predict_product",
|
|
|
+ json={"inputs": chars}, verify=True)
|
|
|
+ batch_paths = json.loads(requests_result.text)['result']
|
|
|
+ lengths = json.loads(requests_result.text)['lengths']
|
|
|
+ else:
|
|
|
+ lengths, scores, tran_ = sess.run([self.length, self.logit, self.tran],
|
|
|
+ feed_dict={
|
|
|
+ self.char_input: np.asarray(chars),
|
|
|
+ self.dropout: 1.0
|
|
|
+ })
|
|
|
+ batch_paths = self.decode(scores, lengths, tran_)
|
|
|
for text, path, length in zip(text_list, batch_paths, lengths):
|
|
|
tags = ''.join([str(it) for it in path[:length]])
|
|
|
for it in re.finditer("12*3", tags):
|
|
@@ -1867,12 +1863,18 @@ class ProductPredictor():
|
|
|
chars = [sentence.sentence_text[:MAX_LEN] for sentence in list_sentence[_begin_index:_begin_index+_LEN]]
|
|
|
chars = [[self.word2index.get(it, self.word2index.get('<unk>')) for it in l] for l in chars]
|
|
|
chars = pad_sequences(chars, maxlen=MAX_LEN, padding="post", truncating="post")
|
|
|
- lengths, scores, tran_ = sess.run([self.length, self.logit, self.tran],
|
|
|
- feed_dict={
|
|
|
- self.char_input: np.asarray(chars),
|
|
|
- self.dropout: 1.0
|
|
|
- })
|
|
|
- batch_paths = self.decode(scores, lengths, tran_)
|
|
|
+ if USE_API:
|
|
|
+ requests_result = requests.post(API_URL + "/predict_product",
|
|
|
+ json={"inputs": chars.tolist()}, verify=True)
|
|
|
+ batch_paths = json.loads(requests_result.text)['result']
|
|
|
+ lengths = json.loads(requests_result.text)['lengths']
|
|
|
+ else:
|
|
|
+ lengths, scores, tran_ = sess.run([self.length, self.logit, self.tran],
|
|
|
+ feed_dict={
|
|
|
+ self.char_input: np.asarray(chars),
|
|
|
+ self.dropout: 1.0
|
|
|
+ })
|
|
|
+ batch_paths = self.decode(scores, lengths, tran_)
|
|
|
for sentence, path, length in zip(list_sentence[_begin_index:_begin_index+_LEN],batch_paths, lengths):
|
|
|
tags = ''.join([str(it) for it in path[:length]])
|
|
|
for it in re.finditer("12*3", tags):
|
|
@@ -2067,7 +2069,7 @@ class ProductAttributesPredictor():
|
|
|
order_end = "%s-%s-%s" % (y, m, num)
|
|
|
return order_begin, order_end
|
|
|
|
|
|
- t1 = re.search('^(\d{4})(年|/|.|-)(\d{1,2})月?$', text)
|
|
|
+ t1 = re.search('^(\d{4})(年|/|\.|-)(\d{1,2})月?$', text)
|
|
|
if t1:
|
|
|
year = t1.group(1)
|
|
|
month = t1.group(3)
|
|
@@ -2079,7 +2081,7 @@ class ProductAttributesPredictor():
|
|
|
order_begin = "%s-%s-01" % (year, month)
|
|
|
order_end = "%s-%s-%s" % (year, month, num)
|
|
|
return order_begin, order_end
|
|
|
- t2 = re.search('^(\d{4})(年|/|.|-)(\d{1,2})(月|/|.|-)(\d{1,2})日?$', text)
|
|
|
+ t2 = re.search('^(\d{4})(年|/|\.|-)(\d{1,2})(月|/|\.|-)(\d{1,2})日?$', text)
|
|
|
if t2:
|
|
|
y = t2.group(1)
|
|
|
m = t2.group(3)
|
|
@@ -2088,8 +2090,31 @@ class ProductAttributesPredictor():
|
|
|
d = '0'+d if len(d)<2 else d
|
|
|
order_begin = order_end = "%s-%s-%s"%(y,m,d)
|
|
|
return order_begin, order_end
|
|
|
- all_match = re.finditer('^(?P<y1>\d{4})(年|/|.)(?P<m1>\d{1,2})(?:(月|/|.)(?:(?P<d1>\d{1,2})日)?)?'
|
|
|
- '(到|至|-)(?:(?P<y2>\d{4})(年|/|.))?(?P<m2>\d{1,2})(?:(月|/|.)'
|
|
|
+ # 时间样式:"202105"
|
|
|
+ t3 = re.search("^(20\d{2})(\d{1,2})$",text)
|
|
|
+ if t3:
|
|
|
+ year = t3.group(1)
|
|
|
+ month = t3.group(2)
|
|
|
+ if int(month)>0 and int(month)<=12:
|
|
|
+ num = self.get_monthlen(year, month)
|
|
|
+ if len(month) < 2:
|
|
|
+ month = '0' + month
|
|
|
+ if len(num) < 2:
|
|
|
+ num = '0' + num
|
|
|
+ order_begin = "%s-%s-01" % (year, month)
|
|
|
+ order_end = "%s-%s-%s" % (year, month, num)
|
|
|
+ return order_begin, order_end
|
|
|
+ # 时间样式:"20210510"
|
|
|
+ t4 = re.search("^(20\d{2})(\d{2})(\d{2})$", text)
|
|
|
+ if t4:
|
|
|
+ year = t4.group(1)
|
|
|
+ month = t4.group(2)
|
|
|
+ day = t4.group(3)
|
|
|
+ if int(month) > 0 and int(month) <= 12 and int(day)>0 and int(day)<=31:
|
|
|
+ order_begin = order_end = "%s-%s-%s"%(year,month,day)
|
|
|
+ return order_begin, order_end
|
|
|
+ all_match = re.finditer('^(?P<y1>\d{4})(年|/|\.)(?P<m1>\d{1,2})(?:(月|/|\.)(?:(?P<d1>\d{1,2})日)?)?'
|
|
|
+ '(到|至|-)(?:(?P<y2>\d{4})(年|/|\.))?(?P<m2>\d{1,2})(?:(月|/|\.)'
|
|
|
'(?:(?P<d2>\d{1,2})日)?)?$', text)
|
|
|
y1 = m1 = d1 = y2 = m2 = d2 = ""
|
|
|
found_math = False
|
|
@@ -3373,3 +3398,31 @@ if __name__=="__main__":
|
|
|
# y = sess.run(outputs,feed_dict={input0:_data[0],input1:_data[1]})
|
|
|
# print(np.argmax(y,-1))
|
|
|
'''
|
|
|
+
|
|
|
+ MAX_LEN = 1000
|
|
|
+ vocabpath = os.path.dirname(__file__) + "/codename_vocab.pk"
|
|
|
+ vocab = load(vocabpath)
|
|
|
+ word2index = dict((w, i) for i, w in enumerate(np.array(vocab)))
|
|
|
+ index_unk = word2index.get("<unk>")
|
|
|
+ sentence = "招标人:广州市重点公共建设项目管理中心,联系人:李工,联系方式:020-22905689,招标代理:广东重工建设监理有限公司," \
|
|
|
+ "代理联系人:薛家伟,代理联系方式:13535014481,招标监督机构:广州市重点公共建设项目管理中心,监督电话:020-22905690," \
|
|
|
+ "备注:以上为招标公告简要描述,招标公告详细信息请查看“招标公告”附件,"
|
|
|
+ sentence = sentence*5
|
|
|
+ list_sentence = [sentence]*200
|
|
|
+ # print(list_sentence)
|
|
|
+ x = [[word2index.get(word, index_unk) for word in sentence] for sentence in
|
|
|
+ list_sentence]
|
|
|
+ x_len = [len(_x) if len(_x) < MAX_LEN else MAX_LEN for _x in x]
|
|
|
+ # print(x_len)
|
|
|
+ x = pad_sequences(x, maxlen=MAX_LEN, padding="post", truncating="post")
|
|
|
+
|
|
|
+ requests_result = requests.post(API_URL + "/predict_codeName", json={"inouts": x.tolist(), "inouts_len": x_len},
|
|
|
+ verify=True)
|
|
|
+ # predict_y = json.loads(requests_result.text)['result']
|
|
|
+ print("cost_time:", json.loads(requests_result.text)['cost_time'])
|
|
|
+ print(MAX_LEN, len(sentence), len(list_sentence))
|
|
|
+ requests_result = requests.post(API_URL + "/predict_codeName", json={"inouts": x.tolist(), "inouts_len": x_len},
|
|
|
+ verify=True)
|
|
|
+ # predict_y = json.loads(requests_result.text)['result']
|
|
|
+ print("cost_time:", json.loads(requests_result.text)['cost_time'])
|
|
|
+ print(MAX_LEN, len(sentence), len(list_sentence))
|