Explorar o código

产品模型合并产品名称及失败原因;预处理标点问题修复; 角色正则优化

lishimin %!s(int64=3) %!d(string=hai) anos
pai
achega
0414d16fb5

+ 2 - 1
BiddingKG/dl/interface/Preprocessing.py

@@ -1091,7 +1091,8 @@ def segment(soup,final=True):
 
 
     #替换"""为"“",否则导入deepdive出错
-    text = text.replace('"',"“").replace("\r","").replace("\n",",")
+    # text = text.replace('"',"“").replace("\r","").replace("\n",",")
+    text = text.replace('"',"“").replace("\r","").replace("\n","")  #2022/1/4修复 非分段\n 替换为逗号造成 公司拆分 span \n南航\n上海\n分公司
     # print('==1',text)
     # text = re.sub("\s{4,}",",",text)
     # 解决公告中的" "空格替换问题

+ 4 - 2
BiddingKG/dl/interface/extract.py

@@ -68,7 +68,9 @@ def predict(doc_id,text,title="",page_time="",**kwargs):
     cost_time["prem"] = round(time.time()-start_time,2)
 
     start_time = time.time() # 产品名称及废标原因提取
-    predictor.getPredictor("product").predict(list_sentences,list_entitys)
+    fail = list_channel_dic[0]['docchannel'] == "废标公告"
+    fail_reason = predictor.getPredictor("product").predict(list_sentences,list_entitys,list_articles, fail) #只返回失败原因,产品已加入到Entity类
+    # predictor.getPredictor("product").predict(list_sentences, list_entitys)
     log("get product done of doc_id%s"%(doc_id))
     cost_time["product"] = round(time.time()-start_time,2)
 
@@ -138,7 +140,7 @@ def predict(doc_id,text,title="",page_time="",**kwargs):
 
     # data_res = Preprocessing.union_result(Preprocessing.union_result(codeName, prem),list_punish_dic)[0]
     # data_res = Preprocessing.union_result(Preprocessing.union_result(Preprocessing.union_result(codeName, prem),list_punish_dic), list_channel_dic)[0]
-    data_res = dict(codeName[0], **prem[0], **list_channel_dic[0], **product_attrs[0], **product_attrs[1], **payment_way_dic)
+    data_res = dict(codeName[0], **prem[0], **list_channel_dic[0], **product_attrs[0], **product_attrs[1], **payment_way_dic, **fail_reason)
     data_res["cost_time"] = cost_time
     data_res["success"] = True
 

+ 92 - 22
BiddingKG/dl/interface/predictor.py

@@ -110,8 +110,8 @@ class CodeNamePredict():
         id_PN_E = self.class_labels.index("PN_E")
         self.PC_pattern = re.compile(str(id_PC_B)+str(id_PC_M)+"*"+str(id_PC_E))
         self.PN_pattern = re.compile(str(id_PN_B)+str(id_PN_M)+"*"+str(id_PN_E))
-        print("pc",self.PC_pattern)
-        print("pn",self.PN_pattern)
+        # print("pc",self.PC_pattern)
+        # print("pn",self.PN_pattern)
         self.word2index = dict((w,i) for i,w in enumerate(np.array(self.vocab)))
         
         self.inputs = None
@@ -1097,10 +1097,10 @@ class RoleRulePredictor():
         # self.pattern_tenderee_left = "(?P<tenderee_left>((遴选|采购|招标|项目|竞价|议价|需求|最终|建设|业主|转让|招租|甲|议标|合同主体|比选|委托|询价)(?:人|公司|单位|组织|用户|业主|方|部门)|文章来源|需方)(名称)?(是|为|信息|:|:|\s*)$)"
         self.pattern_tenderee_left = "(?P<tenderee_left>((遴选|采购|招标|项目|竞价|议价|需求|最终|建设|业主|转让|招租|甲|议标|合同主体|比选|委托|询价|评选|挂牌|出租|出让|谈判|邀标|邀请|洽谈|约谈|买受|选取|抽取|抽选|出售|标卖|比价)(人|公司|单位|组织|用户|业主|主体|方|部门)|文章来源|委托机构|产权所有人|需方|买方|业主|权属人|甲方当事人)[))]?(名称|信息)?([((](全称|盖章)[))])?(是|为|:|:|,|\s*)+$)"
         self.pattern_tenderee_center = "(?P<tenderee_center>(受.{,20}委托))"
-        self.pattern_tenderee_right = "(?P<tenderee_right>^([((](以下简称)?[,\"“]*(招标|采购)(人|单位|机构)\)?))|^委托"  #|(^[^.。,,::](采购|竞价|招标|施工|监理|中标|物资)(公告|公示|项目|结果|招标))|的.*正在进行询比价)
+        self.pattern_tenderee_right = "(?P<tenderee_right>^([((](以下简称)?[,\"“]*(招标|采购)(人|单位|机构)[,\"”]*[))])|^委托|^拟对|^现就|^现委托)"  #|(^[^.。,,::](采购|竞价|招标|施工|监理|中标|物资)(公告|公示|项目|结果|招标))|的.*正在进行询比价)
 
         self.pattern_agency_left = "(?P<agency_left>(代理(?:人|机构|公司|单位|组织)|专业采购机构|集中采购机构|集采机构|[招议))]+标机构)(.{,4}名,?称|全称|是|为|:|:|[,,]?\s*)$|(受.{,20}委托))"
-        self.pattern_agency_right = "(?P<agency_right>^([((](以下简称)?[,\"“]*(代理)(人|单位|机构)\))|受.{,15}委托)|^受托"
+        self.pattern_agency_right = "(?P<agency_right>^([((](以下简称)?[,\"“]*(代理)(人|单位|机构)[,\"”]*[))])|受.{,15}委托|^受托)"
         # 2020//11/24 大网站规则 中标关键词添加 选定单位|指定的中介服务机构
         self.pattern_winTenderer_left = "(?P<winTenderer_left>((中标|中选|中价|乙|成交|承做|施工|供货|承包|竞得|受让)(候选)?(人|单位|机构|各?供应商|方|公司|厂商|商)[::是为]+$|(选定单位|指定的中介服务机构))[::是为,]+$|(第[一1](名|((中标|中选|中价|成交)?(候选)?(人|单位|机构|供应商))))[::是为]+$|((评审结果|名次|排名)[::]第?[一1]名?)$|单一来源(采购)?方式向$|((中标|成交)(结果|信息))(是|为|:|:)$|(单一来源采购(供应商|供货商|服务商))$|[^候选]((分包|标包){,5}供应商|供货商|服务商|供应商名称|服务机构|供方)[::]$)"
         # self.pattern_winTenderer_center = "(?P<winTenderer_center>第[一1].{,20}[是为]((中标|中选|中价|成交|施工)(人|单位|机构|供应商|公司)|供应商)[::是为])"
@@ -1156,10 +1156,14 @@ class RoleRulePredictor():
         return text
         
     def predict(self,list_articles,list_sentences,list_entitys,list_codenames,on_value = 0.5):
-        
 
         for article,list_entity,list_sentence,list_codename in zip(list_articles,list_entitys,list_sentences,list_codenames):
-            list_name = list_codename["name"]
+            list_sentence.sort(key=lambda x: x.sentence_index)  # 2022/1/5 按句子顺序排序
+            # list_name = list_codename["name"]
+            list_name = []  # 20212/1/5  改为实体列表内所有项目名称
+            for entity in list_entity:
+                if entity.entity_type == 'name':
+                    list_name.append(entity.entity_text)
             list_name = self._check_input(list_name)+[article.title]
             for p_entity in list_entity:
 
@@ -1188,13 +1192,30 @@ class RoleRulePredictor():
                         #将标题中的实体置为招标人
                         _list_name = self._check_input(list_name,ignore=True)
                         find_flag = False
-                        for _name in _list_name:
-                            if str(_name).find(p_entity.entity_text)>=0:
-                                find_flag = True
-                                _label = 0
-                                p_entity.label = _label
-                                p_entity.values[int(_label)] = on_value
+                        for _name in _list_name: #2022/1/5修正只要项目名称出现过的角色,所有位置都标注为招标人
+                            if str(_name).find(p_entity.entity_text) >= 0 and p_entity.sentence_index<4:
+                                for _sentence in list_sentence:
+                                    if _sentence.sentence_index == p_entity.sentence_index:
+                                        _span = spanWindow(tokens=_sentence.tokens, begin_index=p_entity.begin_index,
+                                                           end_index=p_entity.end_index, size=20, center_include=True,
+                                                           word_flag=True, text=p_entity.entity_text)
+                                        if str(_span[1] + _span[2][:len(str(_name))]).find(
+                                            _name) >= 0:
+                                            find_flag = True
+                                            _label = 0
+                                            p_entity.label = _label
+                                            p_entity.values[int(_label)] = on_value
+                                            break
+                                    if p_entity.sentence_index>=4:
+                                        break
+                            if find_flag:
                                 break
+                            # if str(_name).find(p_entity.entity_text)>=0:
+                            #     find_flag = True
+                            #     _label = 0
+                            #     p_entity.label = _label
+                            #     p_entity.values[int(_label)] = on_value
+                            #     break
                         #若是实体在标题中,默认为招标人,不进行以下的规则匹配
                         if find_flag:
                             continue
@@ -1236,6 +1257,7 @@ class RoleRulePredictor():
                                                         _distance = _iter.span()[0]
                                                         list_distance[int(_label)] = min(_distance,list_distance[int(_label)])
 
+
                                 # print(list_distance)
 
                                 # for _key in self.dict_list_pattern.keys():
@@ -1386,10 +1408,11 @@ class RoleRulePredictor():
 '''正则补充最后一句实体日期格式为招标或代理 2021/12/30'''
 class RoleRuleFinalAdd():
     def predict(self, list_articles, list_entitys):
-        text_end = list_articles[0].content[-30:]
-        sear_ent = re.search('[,。]([\u4e00-\u9fa5()()]{5,20}),\s*.{2,4}年.{1,2}月.{1,2}日', text_end)
+        text_end = list_articles[0].content[-40:]
+        # sear_ent = re.search('[,。]([\u4e00-\u9fa5()()]{5,20}),?\s*[.]{2,4}年.{1,2}月.{1,2}日', text_end)
+        sear_ent = re.search('[,。]([\u4e00-\u9fa5()()]{5,20}(,?[\u4e00-\u9fa5]{,6}(分公司|部))?),?\s*[0-9零一二三四五六七八九十]{2,4}年.{1,2}月.{1,2}日', text_end)
         if sear_ent:
-            ent_re = sear_ent.group(1)
+            ent_re = sear_ent.group(1).replace(',', '')
             tenderee_notfound = True
             agency_notfound = True
             ents = []
@@ -1407,7 +1430,7 @@ class RoleRuleFinalAdd():
                     n += 1
                     if n > 3:
                         break
-                    if ents[i].entity_text == ent_re:
+                    if ents[i].entity_text == ent_re or (ents[i].entity_text in ent_re and len(ents[i].entity_text)/len(ent_re)>0.6):
                         ents[i].label = 1
                         ents[i].values[1] = 0.5
                         break
@@ -1418,7 +1441,7 @@ class RoleRuleFinalAdd():
                     n += 1
                     if n > 3:
                         break
-                    if ents[i].entity_text == ent_re:
+                    if ents[i].entity_text == ent_re or (ents[i].entity_text in ent_re and len(ents[i].entity_text)/len(ent_re)>0.6):
                         ents[i].label = 0
                         ents[i].values[0] = 0.5
                         break
@@ -1532,11 +1555,15 @@ class TimePredictor():
 # 产品字段提取
 class ProductPredictor():
     def __init__(self):
+        vocabpath = os.path.dirname(__file__) + "/codename_vocab.pk"
+        self.vocab = load(vocabpath)
+        self.word2index = dict((w, i) for i, w in enumerate(np.array(self.vocab)))
         self.sess = tf.Session(graph=tf.Graph())
         self.load_model()
 
     def load_model(self):
-        model_path = os.path.dirname(__file__)+'/product_savedmodel/product.pb'
+        # model_path = os.path.dirname(__file__)+'/product_savedmodel/product.pb'
+        model_path = os.path.dirname(__file__)+'/product_savedmodel/productAndfailreason.pb'
         with self.sess.as_default():
             with self.sess.graph.as_default():
                 output_graph_def = tf.GraphDef()
@@ -1550,7 +1577,21 @@ class ProductPredictor():
                     self.logit = self.sess.graph.get_tensor_by_name("logits/Reshape:0")
                     self.tran = self.sess.graph.get_tensor_by_name("crf_loss/transitions:0")
 
-    def predict(self, list_sentences,list_entitys=None, MAX_AREA=5000):
+    def decode(self,logits, lengths, matrix):
+        paths = []
+        small = -1000.0
+        # start = np.asarray([[small] * 4 + [0]])
+        start = np.asarray([[small]*7+[0]])
+        for score, length in zip(logits, lengths):
+            score = score[:length]
+            pad = small * np.ones([length, 1])
+            logits = np.concatenate([score, pad], axis=1)
+            logits = np.concatenate([start, logits], axis=0)
+            path, _ = viterbi_decode(logits, matrix)
+            paths.append(path[1:])
+        return paths
+
+    def predict(self, list_sentences,list_entitys=None,list_articles=[], fail=False, MAX_AREA=5000):
         '''
         预测实体代码,每个句子最多取MAX_AREA个字,超过截断
         :param list_sentences: 多篇公告句子列表,[[一篇公告句子列表],[公告句子列表]]
@@ -1561,6 +1602,32 @@ class ProductPredictor():
         with self.sess.as_default() as sess:
             with self.sess.graph.as_default():
                 result = []
+                if fail and list_articles!=[]:
+                    text_list = [list_articles[0].content[:MAX_AREA]]
+                    chars = [[self.word2index.get(it, self.word2index.get('<unk>')) for it in text] for text in text_list]
+                    lengths, scores, tran_ = sess.run([self.length, self.logit, self.tran],
+                                                      feed_dict={
+                                                          self.char_input: np.asarray(chars),
+                                                          self.dropout: 1.0
+                                                      })
+                    batch_paths = self.decode(scores, lengths, tran_)
+                    for text, path, length in zip(text_list, batch_paths, lengths):
+                        tags = ''.join([str(it) for it in path[:length]])
+                        for it in re.finditer("45*6", tags):
+                            start = it.start()
+                            end = it.end()
+                            result.append(text[start:end].replace('?', '').strip())
+                    reasons = []
+                    for it in result:
+                        if "(√)" in it or "(√)" in it:
+                            reasons = [it]
+                            break
+                        if reasons != [] and (it not in reasons[-1] and it not in reasons):
+                            reasons.append(it)
+                        elif reasons == []:
+                            reasons.append(it)
+                    return {'fail_reason':';'.join(reasons)}
+
                 if list_entitys is None:
                     list_entitys = [[] for _ in range(len(list_sentences))]
                 for list_sentence, list_entity in zip(list_sentences,list_entitys):
@@ -1576,13 +1643,15 @@ class ProductPredictor():
                         if MAX_LEN > MAX_AREA:
                             MAX_LEN = MAX_AREA
                         _LEN = MAX_AREA//MAX_LEN
-                        chars = process_data([sentence.sentence_text[:MAX_LEN] for sentence in list_sentence[_begin_index:_begin_index+_LEN]])
+                        chars = [sentence.sentence_text[:MAX_LEN] for sentence in list_sentence[_begin_index:_begin_index+_LEN]]
+                        chars = [[self.word2index.get(it, self.word2index.get('<unk>')) for it in l] for l in chars]
+                        chars = pad_sequences(chars, maxlen=MAX_LEN, padding="post", truncating="post")
                         lengths, scores, tran_ = sess.run([self.length, self.logit, self.tran],
                                                           feed_dict={
                                                                     self.char_input: np.asarray(chars),
                                                                     self.dropout: 1.0
                                                                     })
-                        batch_paths = decode(scores, lengths, tran_)
+                        batch_paths = self.decode(scores, lengths, tran_)
                         for sentence, path, length in zip(list_sentence[_begin_index:_begin_index+_LEN],batch_paths, lengths):
                             tags = ''.join([str(it) for it in path[:length]])
                             for it in re.finditer("12*3", tags):
@@ -1603,7 +1672,8 @@ class ProductPredictor():
                         _begin_index += _LEN
                     item["product"] = list(set(temp_list))
                     result.append(item) # 修正bug
-                return result
+                return {'fail_reason': ""}
+
 
 # 产品数量单价品牌规格提取 #2021/11/10 添加表格中的项目、需求、预算、时间要素提取
 class ProductAttributesPredictor():

BIN=BIN
BiddingKG/dl/interface/product_savedmodel/productAndfailreason.pb


+ 344 - 0
BiddingKG/dl/product/data_tfrecord.py

@@ -0,0 +1,344 @@
+#!/usr/bin/python3
+# -*- coding: utf-8 -*-
+# @Author  : bidikeji
+# @Time    : 2021/12/13 18:28
+import tensorflow as tf
+import pandas as pd
+import random
+import json
+import os
+import re
+import collections
+from BiddingKG.dl.product.data_util import word2id, max_id
+max_len = 500
+
+def create_int_feature(values):
+    feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) #list(values)
+    return feature
+
+class TrainingInstance(object):
+    def __init__(self, word_list, tag_list):
+        self.word_list = word_list
+        self.tag_list = tag_list
+
+def fix_label_ner_句号分开(sentence, product_list, reasons_list):
+    tag_list = ['S'] * len(sentence)
+    word_list = list(sentence)
+    for product in product_list:
+        b = sentence.find(product)
+        while b != -1:
+            e = b + len(product)
+            if tag_list[b] == 'S' and tag_list[e - 1] == 'S':
+                tag_list[b] = 'B-pro'
+                tag_list[e - 1] = 'E-pro'
+                for i in range(b + 1, e - 1):
+                    tag_list[i] = 'I-pro'
+            b = sentence.find(product, e)
+    for reason in reasons_list:
+        if '。' in reason:
+            for reason in reason.split('。'):
+                # print('分句reason: ', reason)
+                b = sentence.find(reason)
+                while b != -1:
+                    e = b + len(reason)
+                    if tag_list[b] == 'S' and tag_list[e - 1] == 'S':
+                        tag_list[b] = 'B-rea'
+                        tag_list[e - 1] = 'E-rea'
+                        for i in range(b + 1, e - 1):
+                            tag_list[i] = 'I-rea'
+                    b = sentence.find(reason, e)
+        else:
+            b = sentence.find(reason)
+            while b != -1:
+                e = b + len(reason)
+                if tag_list[b] == 'S' and tag_list[e - 1] == 'S':
+                    tag_list[b] = 'B-rea'
+                    tag_list[e - 1] = 'E-rea'
+                    for i in range(b + 1, e - 1):
+                        tag_list[i] = 'I-rea'
+                b = sentence.find(reason, e)
+    return tag_list, word_list
+
+
+def create_instances_from_document_句号分开(docid, document_text, product_list, reasons_list):
+    for it in ['一','二','三','四','五','六','七','八','九','十','十一','十二','十三','十四','十五']:
+        document_text = document_text.replace(',%s、'%it, '。%s、'%it)
+
+    if docid in ['docid']:
+        pass
+    product_list = sorted(product_list, key=lambda x:len(x), reverse=True)
+    reasons_list = sorted(reasons_list, key=lambda x:len(x), reverse=True)
+    kw_re = re.search('(流标|废标|终止|中止|失败|异常)的?原因', document_text)
+    if reasons_list == [] and kw_re:
+        kw = kw_re.group(0)
+        idx = document_text.find(kw)
+        if idx!=-1:
+            document_text = document_text[:idx]
+    # instances = []
+    pos = []
+    neg = []
+    for sentence in document_text.split('。'):
+        if len(sentence)<2:
+            # print("句子长度小于5")
+            # print(sentence)
+            continue
+        if len(sentence)>=2*max_len: # 超过1000字的取前1000字
+            sentences = sentence.split(',')
+            i = 0
+            sentence = ""
+            while i < len(sentences):
+                if len(sentences[i]) > max_len:
+                    sentence = sentences[i][:max_len]
+                    print('len(sentences[i]) > max_len:', len(sentence))
+                    tag_list, word_list = fix_label_ner(sentence, product_list, reasons_list)
+                    instance = TrainingInstance(word_list, tag_list)
+                    if 'B-pro' in tag_list or 'B-rea' in tag_list:
+                        pos.append(instance)
+                    else:
+                        neg.append(instance)
+                    sentence = ""
+                elif len(sentence)<max_len*0.8 :
+                    sentence += sentences[i]+','
+                else:
+                    sentence =sentence[:max_len]
+                    print('else: ', len(sentence))
+                    tag_list, word_list = fix_label_ner(sentence, product_list, reasons_list)
+                    instance = TrainingInstance(word_list, tag_list)
+                    if 'B-pro' in tag_list or 'B-rea' in tag_list:
+                        pos.append(instance)
+                    else:
+                        neg.append(instance)
+                    sentence = sentences[i]+','
+                i += 1
+            if len(sentence)>=10:
+                sentence = sentence[:max_len]
+                tag_list, word_list = fix_label_ner(sentence, product_list, reasons_list)
+                instance = TrainingInstance(word_list, tag_list)
+                if 'B-pro' in tag_list or 'B-rea' in tag_list:
+                    pos.append(instance)
+                else:
+                    neg.append(instance)
+                print('len(sentence)>=10: ',len(sentence))
+        else:
+            sentence = sentence[:max_len]
+            tag_list, word_list = fix_label_ner(sentence, product_list, reasons_list)
+            instance = TrainingInstance(word_list, tag_list)
+            if 'B-pro' in tag_list or 'B-rea' in tag_list:
+                pos.append(instance)
+            else:
+                neg.append(instance)
+    random.shuffle(neg)
+    neg = neg[:min(5, 10*len(pos))]
+    instances = pos+neg
+    random.shuffle(instances)
+    return instances
+
+def fix_label_ner(sentence, product_list, reasons_list):
+    tag_list = ['S'] * len(sentence)
+    word_list = list(sentence)
+    for product in product_list:
+        b = sentence.find(product)
+        while b != -1:
+            e = b + len(product)
+            if tag_list[b] == 'S' and tag_list[e - 1] == 'S':
+                tag_list[b] = 'B-pro'
+                tag_list[e - 1] = 'E-pro'
+                for i in range(b + 1, e - 1):
+                    tag_list[i] = 'I-pro'
+            b = sentence.find(product, e)
+    for reason in reasons_list:
+        b = sentence.find(reason)
+        while b != -1:
+            e = b + len(reason)
+            if tag_list[b] == 'S' and tag_list[e - 1] == 'S':
+                tag_list[b] = 'B-rea'
+                tag_list[e - 1] = 'E-rea'
+                for i in range(b + 1, e - 1):
+                    tag_list[i] = 'I-rea'
+            b = sentence.find(reason, e)
+    return tag_list, word_list
+
+def create_instances_from_document(docid, document_text, product_list, reasons_list):
+    product_list = sorted(product_list, key=lambda x:len(x), reverse=True)
+    reasons_list = sorted(reasons_list, key=lambda x:len(x), reverse=True)
+    kw_re = re.search('(流标|废标|终止|中止|失败|异常)的?原因', document_text)
+    if reasons_list == [] and kw_re:
+        kw = kw_re.group(0)
+        idx = document_text.find(kw)
+        if idx!=-1:
+            document_text = document_text[:idx]
+    # instances = []
+    pos = []
+    neg = []
+    if len(document_text)<= max_len:
+        document_text = document_text[:max_len]
+        tag_list, word_list = fix_label_ner(document_text, product_list, reasons_list)
+        if len(reasons_list)>0 and 'B-rea' not in tag_list:
+            print("少于%d字的文章废标原因标注未找到:%s"%(max_len, docid))
+        instance = TrainingInstance(word_list, tag_list)
+        if 'B-pro' in tag_list or 'E-rea' in tag_list:
+            pos.append(instance)
+        else:
+            neg.append(instance)
+    elif len(reasons_list)>0:
+        b = document_text.find(reasons_list[0])
+        if b != -1:
+            document_text = document_text[max(0, b-8):][:max_len]
+        else:
+            document_text = document_text[:max_len]
+            print("多于%d字的文章废标原因标注未找到:%s," % (max_len, docid))
+        tag_list, word_list = fix_label_ner(document_text, product_list, reasons_list)
+        if 'E-rea' not in tag_list:
+            print("文章废标原因标注未找到:%s, 开始位置:%d"%(docid, b))
+        instance = TrainingInstance(word_list, tag_list)
+        if 'B-pro' in tag_list or 'B-rea' in tag_list:
+            pos.append(instance)
+        else:
+            neg.append(instance)
+    else:
+        for it in ['一', '二', '三', '四', '五', '六', '七', '八', '九', '十', '十一', '十二', '十三', '十四', '十五']:
+            document_text = document_text.replace(',%s、' % it, '。%s、' % it)
+        for sentence in document_text.split('。'):
+            if len(sentence)<2:
+                # print("句子长度小于5")
+                # print(sentence)
+                continue
+            if len(sentence)>=2*max_len: # 超过1000字的取前1000字
+                sentences = sentence.split(',')
+                i = 0
+                sentence = ""
+                while i < len(sentences):
+                    if len(sentences[i]) > max_len:
+                        sentence = sentences[i][:max_len]
+                        # print('len(sentences[i]) > max_len:', len(sentence))
+                        tag_list, word_list = fix_label_ner(sentence, product_list, reasons_list)
+                        instance = TrainingInstance(word_list, tag_list)
+                        if 'B-pro' in tag_list or 'B-rea' in tag_list:
+                            pos.append(instance)
+                        else:
+                            neg.append(instance)
+                        sentence = ""
+                    elif len(sentence)<max_len*0.8 :
+                        sentence += sentences[i]+','
+                    else:
+                        sentence =sentence[:max_len]
+                        # print('else: ', len(sentence))
+                        tag_list, word_list = fix_label_ner(sentence, product_list, reasons_list)
+                        instance = TrainingInstance(word_list, tag_list)
+                        if 'B-pro' in tag_list or 'B-rea' in tag_list:
+                            pos.append(instance)
+                        else:
+                            neg.append(instance)
+                        sentence = sentences[i]+','
+                    i += 1
+                if len(sentence)>=10:
+                    sentence = sentence[:max_len]
+                    tag_list, word_list = fix_label_ner(sentence, product_list, reasons_list)
+                    instance = TrainingInstance(word_list, tag_list)
+                    if 'B-pro' in tag_list or 'B-rea' in tag_list:
+                        pos.append(instance)
+                    else:
+                        neg.append(instance)
+                    # print('len(sentence)>=10: ',len(sentence))
+            else:
+                sentence = sentence[:max_len]
+                tag_list, word_list = fix_label_ner(sentence, product_list, reasons_list)
+                instance = TrainingInstance(word_list, tag_list)
+                if 'B-pro' in tag_list or 'B-rea' in tag_list:
+                    pos.append(instance)
+                else:
+                    neg.append(instance)
+    random.shuffle(neg)
+    neg = neg[:min(5, 10*len(pos))]
+    instances = pos+neg
+    random.shuffle(instances)
+    return instances
+
+def create_training_instances(df):
+    instances = []
+    # df = pd.read_excel(xlsx)
+    df.fillna('', inplace=True)
+    for i in df.index:
+        try:
+            docid = df.loc[i, 'docid']
+            document_text = df.loc[i, 'text']
+            product_list = json.loads(df.loc[i, 'lbset'])
+            reasons_list = json.loads(df.loc[i, 'reasons_list'])
+            # if reasons_list == []:
+            #     continue
+            instances.extend(
+                create_instances_from_document(
+                    docid, document_text, product_list, reasons_list
+                ))
+        except Exception as e:
+            print('json出错',i,  df.loc[i, 'lbset'], type(df.loc[i, 'lbset']), e)
+    return instances
+
+def write_instance_to_example_files(instances, word2index, tag2index, output_dir):
+    # writers = []
+    # instances = sorted(instances, key=lambda x: len(x.word_list))
+    i = 0
+    # for max_len in [200, 500, 1000]:
+    writer = tf.python_io.TFRecordWriter(output_dir + '/maxlen_%s_addunk_product_reason.tfrecode'%max_len)
+    # print('排序前:', [len(x.word_list) for x in instances[:5]])
+    # instances.sort(key=lambda x:len(x.word_list), reverse=True)
+    # print('排序后:', [len(x.word_list) for x in instances[:5]])
+    while i < len(instances):
+        instance = instances[i]
+        if len(instance.word_list)>max_len:
+            writer.close()
+            break
+        i += 1
+        # word_ids = [word2index.get(word, max_id) for word in instance.word_list]
+        word_ids = [word2index.get(word, word2index.get('<unk>')) for word in instance.word_list]
+        tag_ids = [tag2index.get(tag, 0) for tag in instance.tag_list]
+        while len(word_ids)<max_len:
+            word_ids.append(0)
+            tag_ids.append(0)
+        features = collections.OrderedDict()
+        features["word_ids"] = create_int_feature(word_ids)
+        features["tag_ids"] = create_int_feature(tag_ids)
+        features['text_len'] = create_int_feature([len(instance.word_list)])
+        tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+        writer.write(tf_example.SerializeToString())
+    writer.close()
+
+if __name__ == "__main__":
+    df = pd.read_excel(os.path.dirname(__file__) + '/data/所有产品标注数据筛选20211125_ProductAndReason.xlsx')
+    df['pos'] = df.apply(lambda x:1 if re.search('(流标|废标|终止|中止|失败|异常)(公告|公示)', x['text']) and x['reasons_list']=='[]' else 0, axis=1)
+    df = df[df.loc[:, 'pos']==0]  # 过滤掉未标注废标原因文章
+    df.reset_index(drop=True, inplace=True)
+    print('总文章数:',len(df))
+    df.fillna('', inplace=True)
+    print('读取完毕')
+    df['lbs'] = df['lbset'].apply(lambda x: json.loads(x))
+    lbset = [it for l in df['lbs'] for it in l]
+    c = collections.Counter(lbset)
+    m = c.most_common()
+    m3 = [it[0] for it in m if it[1] > 2]
+    df['pos'] = df['lbs'].apply(lambda x: 1 if len(set(m3) & set(x)) >= 1 else 0)
+    df_dev = df[df.loc[:, 'pos'] == 1].sample(frac=0.1, random_state=8)
+    print('len_df_dev:', len(df_dev))
+    df_reason = df[df.loc[:, 'reasons_list'] != '[]'].sample(frac=0.1, random_state=8)
+    print('len(df_reason)', len(df_reason))
+    df_dev.append(df_reason)
+    df_dev.drop_duplicates(subset=['docid'], inplace=True)
+    print('len_df_dev:', len(df_dev))
+    df_train = df[~df.index.isin(df_dev.index)]
+    print(len(df), len(df_dev), len(df_train))
+    df_train = df_train.sample(frac=1)
+    df_dev = df_dev.sample(frac=1)
+
+    # file = 'data/traindata.xlsx'
+    instances = create_training_instances(df_train)
+    # word2index = {'<unk>':0, '我':1, '们':2, '中':3, '国':4, '人':5}
+    tag2index = {'S':0,'B-pro':1, 'I-pro':2, 'E-pro':3, 'B-rea':4, 'I-rea':5, 'E-rea':6}
+    output_dir = 'data/train_data'
+    write_instance_to_example_files(instances, word2id, tag2index, output_dir)
+
+    instances = create_training_instances(df_dev)
+    # word2index = {'<unk>':0, '我':1, '们':2, '中':3, '国':4, '人':5}
+    tag2index = {'S':0,'B-pro':1, 'I-pro':2, 'E-pro':3, 'B-rea':4, 'I-rea':5, 'E-rea':6}
+    output_dir = 'data/test_data'
+    write_instance_to_example_files(instances, word2id, tag2index, output_dir)
+    print('全部写入成功!')

+ 25 - 11
BiddingKG/dl/product/data_util.py

@@ -3,22 +3,30 @@
 # @Author  : bidikeji
 # @Time    : 2021/1/13 0013 14:19
 import re
+import os
 import math
 import json
 import random
 import numpy as np
 import pandas as pd
-from BiddingKG.dl.common.Utils import getVocabAndMatrix,getModel_word,viterbi_decode
+from BiddingKG.dl.common.Utils import getVocabAndMatrix,getModel_word,viterbi_decode, load
 
 tag2index = {'S':0,'B-pro':1, 'I-pro':2, 'E-pro':3, 'B-rea':4, 'I-rea':5, 'E-rea':6}
 id_to_tag = {v:k for k,v in tag2index.items()}
 # id_to_tag = {0:'O',1:'B',2:'I',3:'E'}
-word_model = getModel_word()
 
+word_model = getModel_word()
 vocab, matrix = getVocabAndMatrix(word_model, Embedding_size=60)
 word2id = {k: v for v, k in enumerate(vocab)}
 max_id = len(vocab)
 
+# path1 = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))+"/interface/codename_vocab.pk"
+# path2 = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))+"/interface/codename_w2v_matrix.pk"
+# vocab = load(path1)
+# matrix = load(path2)
+# max_id = len(vocab)
+# word2id = {k: v for v, k in enumerate(vocab)}
+
 # vocab = ["<pad>"] + word_model.index2word+ ["<unk>"]
 # matrix = np.zeros((len(vocab), 60))
 # for i in range(1, len(vocab)-1):
@@ -51,7 +59,8 @@ def df2data(df):
         tag_sentence = []
         for i in range(len(text)):
             text_sentence.append(text[i])
-            ids_sentence.append(word2id.get(text[i], max_id))
+            # ids_sentence.append(word2id.get(text[i], max_id))
+            ids_sentence.append(word2id.get(text[i], word2id.get('<unk>')))
             tag_sentence.append(tags[i])
             if text[i] in ['。','!']:
                 if text_sentence:
@@ -104,7 +113,8 @@ def get_feature(text, lbs):
         if len(sentence) > 1000:
             sentence = sentence[:1000]
         tags = [0] * len(sentence)
-        ids = [word2id.get(word, max_id) for word in sentence]
+        # ids = [word2id.get(word, max_id) for word in sentence]
+        ids = [word2id.get(word, word2id.get('<unk>')) for word in sentence]
         for lb in lbs:
             kw_indexs = find_kw_from_text(lb, sentence)
             for indexs in kw_indexs:
@@ -161,7 +171,8 @@ def get_label_data():
         tag_sentence = []
         for i in range(len(text)):
             text_sentence.append(text[i])
-            ids_sentence.append(word2id.get(text[i], max_id))
+            # ids_sentence.append(word2id.get(text[i], max_id))
+            ids_sentence.append(word2id.get(text[i], word2id.get('<unk>')))
             tag_sentence.append(tags[i])
             writer.write("%s\t%s\n"%(text[i],tags[i]))
             if text[i] in ['。','?','!',';']:
@@ -188,7 +199,8 @@ def get_label_data():
 
 def input_from_line(line):
     string = list(line)
-    ids = [word2id.get(k, max_id) for k in string]
+    # ids = [word2id.get(k, max_id) for k in string]
+    ids = [word2id.get(k, word2id.get('<unk>')) for k in string]
     tags = []
     return [[string], [ids], [tags]]
 
@@ -200,6 +212,7 @@ def process_data(sentences):
     '''
     maxLen = max([len(sentence) for sentence in sentences])
     tags = [[word2id.get(k, max_id) for k in sentence] for sentence in sentences]
+    # tags = [[word2id.get(k, word2id.get('<unk>')) for k in sentence] for sentence in sentences]
     pad_tags = [tag[:maxLen]+[0]*(maxLen-len(tag)) for tag in tags]
     return pad_tags
 
@@ -213,6 +226,7 @@ def decode(logits, lengths, matrix):
     paths = []
     small = -1000.0
     start = np.asarray([[small]*4+[0]])
+    # start = np.asarray([[small]*7+[0]])
     for score, length in zip(logits, lengths):
         score = score[:length]
         pad = small * np.ones([length, 1])
@@ -226,14 +240,14 @@ def result_to_json(line, tags):
     result = []
     ner = []
     tags = ''.join([str(it) for it in tags])
-    # for it in re.finditer("12*3", tags):
-    #     start = it.start()
-    #     end = it.end()
-    #     ner.append([line[start:end], (start, end)])
-    for it in re.finditer("45*6", tags):
+    for it in re.finditer("12*3", tags):
         start = it.start()
         end = it.end()
         ner.append([line[start:end], (start, end)])
+    # for it in re.finditer("45*6", tags):
+    #     start = it.start()
+    #     end = it.end()
+    #     ner.append([line[start:end], (start, end)])
     result.append([line, ner])
     # print(tags)
     return result

+ 73 - 13
BiddingKG/dl/product/main.py

@@ -93,7 +93,7 @@ def evaluate_line():
         sess.run(tf.global_variables_initializer())
         # model.saver.restore(sess, 'model/ner1215.ckpt')
         # model.saver.restore(sess, 'model/ner_f10.7039_loss1.2353.ckpt')
-        model.saver.restore(sess, 'model/ner_epoch10_f10.6875_loss1.5230.ckpt')
+        model.saver.restore(sess, 'model/ner_epoch5_f10.6855_loss1.3800.ckpt')
         while True:
             line = input("请输入测试句子:")
             result = model.evaluate_line(sess, line)
@@ -107,8 +107,61 @@ def evaluate_line():
         #         line = input("请输入测试句子:")
         #         result = model.evaluate_line(sess, line)
         #         print(result)
+
+def save_model_pb():
+    from tensorflow.python.framework import graph_util
+
+    model_folder = r"D:\Bidi\BIDI_ML_INFO_EXTRACTION\BiddingKG\dl\product\model"
+    output_graph = r"D:\Bidi\BIDI_ML_INFO_EXTRACTION\BiddingKG\dl\product\model\product.pb"
+    #
+    # 把cpkt转为pb
+
+    input_checkpoint = "model/ner_epoch5_f10.6855_loss1.3800.ckpt"
+    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
+    graph = tf.get_default_graph()  # 获得默认的图
+    input_graph_def = graph.as_graph_def()  # 返回一个序列号
+    with tf.Session() as sess:
+        saver.restore(sess, input_checkpoint)  # 恢复图并获得数据
+        output_graph_def = graph_util.convert_variables_to_constants(
+            sess=sess,
+            input_graph_def=input_graph_def,
+            output_node_names=["CharInputs", "Sum", "Dropout", "logits/Reshape", "crf_loss/transitions"]
+        )
+        with tf.gfile.GFile("model/productAndfailreason.pb", "wb") as f:
+            f.write(output_graph_def.SerializeToString())
+        print("%d ops in the final graph" % len(output_graph_def.node))
+    #
+    # graph = tf.get_default_graph()  # 获得默认的图
+    # input_graph_def = graph.as_graph_def()  # 返回一个序列号
+    # with tf.Session() as sess:
+    #     model = Product_Model()
+    #     sess.run(tf.global_variables_initializer())
+    #     model.saver.restore(sess, 'model/ner_epoch5_f10.6855_loss1.3800.ckpt') # 恢复图并获得数据
+    #     output_graph_def = graph_util.convert_variables_to_constants(
+    #         sess=sess,
+    #         input_graph_def=input_graph_def,
+    #         output_node_names=["CharInputs", "Sum", "Dropout", "logits/Reshape", "crf_loss/transitions"]
+    #     )
+    #     with tf.gfile.GFile("model/productAndfailreason.pb", "wb") as f:
+    #         f.write(output_graph_def.SerializeToString())
+    # with tf.Session() as sess:
+    #     model = Product_Model()
+    #     sess.run(tf.global_variables_initializer())
+    #     model.saver.restore(sess, 'model/ner_epoch5_f10.6855_loss1.3800.ckpt')
+    #     tf.saved_model.simple_save(sess, 'model/productAndfailreason',
+    #                                inputs={
+    #                                    "CharInputs":model.char_inputs,
+    #                                    "Dropout":model.dropout,
+    #                                },
+    #                                outputs={
+    #                                    "Sum:": model.lengths,
+    #                                    "logits/Reshape": model.logits,
+    #                                    "crf_loss/transitions": model.trans
+    #                                })
+        print('保存pb文件')
 def predict():
-    pb_path = "model/product.pb"
+    # pb_path = "model/product.pb"
+    pb_path = "model/productAndfailreason.pb"
     with tf.Graph().as_default():
         output_graph_def = tf.GraphDef()
         with open(pb_path, 'rb') as f:
@@ -146,23 +199,27 @@ def predict_df():
         # model.saver.restore(sess, 'model/ner1201_find_lb.ckpt')  # f1:0.6972, precision:0.7403, recall:0.6588, evl_loss:1.2983  model saved, val_loss is: 1.32706
         # model.saver.restore(sess, 'model/ner1208_find_lb.ckpt')  # f1:0.7038, precision:0.7634, recall:0.6528, evl_loss:1.3046 model saved, val_loss is: 1.29316
         # model.saver.restore(sess, 'model/ner_f10.7039_loss1.2353.ckpt')  # f1:0.70 ner1215
-        model.saver.restore(sess, 'model/ner_epoch4_f10.6952_loss1.2512.ckpt')  # f1:0.70 ner1215
+        model.saver.restore(sess, 'model/ner_epoch5_f10.6855_loss1.3800.ckpt')  # f1:0.70 ner1215
 
         print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs)
         # df = pd.read_excel('../test/data/贵州数据新字段提取信息_predict.xlsx')
         # df = pd.read_excel('../test/data/所有产品标注数据_补充筛选废标原因数据.xlsx')
         # df = pd.read_excel('../test/data/所有产品标注数据筛选_废标_predict.xlsx')
-        df = pd.read_excel('../test/data/所有产品标注数据筛选20211125_ProductAndReason.xlsx')
+        # df = pd.read_excel('data/所有产品标注数据筛选20211125_ProductAndReason.xlsx') #../test/
+        df = pd.read_excel('data/产品数据自己人标注的原始数据_pred.xlsx') #../test/
+        df.fillna('', inplace=True)
         # df = pd.read_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
         df.reset_index(drop=True, inplace=True)
         rs = []
         for i in df.index:
             line = df.loc[i, 'text']
-            pos = df.loc[i, 'pos']
-            reason = df.loc[i, 'reasons_label']
-            if pos==0 or reason!='[]':
-                rs.append('')
-                continue
+
+            # pos = df.loc[i, 'feibiao']
+            # reason = df.loc[i, 'reasons_label']
+            # if pos==0 and reason=='[]':
+            #     rs.append('')
+            #     continue
+
             # if i > 200:
             #     rs.append('')
             #     continue
@@ -171,17 +228,20 @@ def predict_df():
             print(result[0][1])
             rs.append(json.dumps(result[0][1], ensure_ascii=False))
         # df['pred_new1202'] = pd.Series(rs)
-        df['reson_model'] = pd.Series(rs)
+        # df['reson_model2'] = pd.Series(rs)
+        df['product_pred'] = pd.Series(rs)
         # df.to_excel('../test/data/贵州数据新字段提取信息_predict.xlsx')
         # df.to_excel('../test/data/所有产品标注数据_补充筛选废标原因数据_predict.xlsx')
         # df.to_excel('../test/data/所有产品标注数据筛选_废标_predict.xlsx')
-        df.to_excel('../test/data/所有产品标注数据筛选20211125_ProductAndReason.xlsx')
+        # df.to_excel('../test/data/所有产品标注数据筛选20211125_ProductAndReason.xlsx')
+        df.to_excel('data/产品数据自己人标注的原始数据_pred.xlsx')
         # df.to_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
 
 if __name__ == "__main__":
     # train()
-    evaluate_line()
-    # predict()
+    # evaluate_line()
+    # save_model_pb()
+    predict()
     # predict_df()
     # import json
     # df = pd.read_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')

+ 7 - 5
BiddingKG/dl/product/product_model.py

@@ -21,12 +21,14 @@ class Product_Model(object):
         self.clip = 5.0
         self.dropout_rate = 0.5
         # vocab, matrix = getVocabAndMatrix(word_model, Embedding_size=60)
-        self.matrix = matrix
-        # self.word2id = {k:v for v,k in enumerate(self.vocab)}
-        self.num_chars = len(vocab)+1
-        self.emb_matrix = np.random.random((self.num_chars, self.char_dim))
-        self.emb_matrix[:self.num_chars-1:,:] = self.matrix
 
+        # self.matrix = matrix
+        # # self.word2id = {k:v for v,k in enumerate(self.vocab)}
+        # self.num_chars = len(vocab)+1
+        # self.emb_matrix = np.random.random((self.num_chars, self.char_dim))
+        # self.emb_matrix[:self.num_chars-1:,:] = self.matrix
+
+        self.emb_matrix = matrix
 
         self.globel_step = tf.Variable(0, trainable=False)
         self.best_dev_f1 = tf.Variable(0.0, trainable=False)

+ 169 - 0
BiddingKG/dl/product/train.py

@@ -0,0 +1,169 @@
+#!/usr/bin/python3
+# -*- coding: utf-8 -*-
+# @Author  : bidikeji
+# @Time    : 2021/12/14 11:08
+
+#!/usr/bin/python3
+# -*- coding: utf-8 -*-
+# @Author  : bidikeji
+# @Time    : 2021/1/13 0013 10:12
+import os
+import re
+print('准备导入tf')
+import tensorflow as tf
+print('准备导入np')
+import numpy as np
+print('准备导入Product_model')
+from BiddingKG.dl.product.product_model import Product_Model
+print('准备导入max_len')
+from BiddingKG.dl.product.data_tfrecord import max_len
+# from BiddingKG.dl.common.Utils import viterbi_decode
+print('准备设置CUDA环境')
+os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
+# max_len = 500
+
+def read_tfRecord(sess, file_tfRecord):
+    queue = tf.train.string_input_producer([file_tfRecord])
+    reader = tf.TFRecordReader()
+    filename_, serialized_example = reader.read(queue)
+    features = tf.parse_single_example(
+        serialized_example,
+        features = {
+            "word_ids": tf.FixedLenFeature([max_len], tf.int64),
+            "text_len":tf.FixedLenFeature([1], tf.int64),
+            "tag_ids": tf.FixedLenFeature([max_len], tf.int64)
+        }
+    )
+    word_ids = tf.cast(features['word_ids'], tf.int64)
+    tag_ids = tf.cast(features['tag_ids'], tf.int64)
+    text_len = tf.cast(features['text_len'], tf.int64)
+    return text_len, word_ids , tag_ids
+
+def get_batch_record(sess,filename, batch_size):
+    text_len, word_ids, tag_ids = read_tfRecord(sess, filename)
+    text_len, word_ids, tag_ids = tf.train.shuffle_batch([text_len, word_ids , tag_ids],
+                                                     batch_size=batch_size,
+                                                     capacity=200+batch_size*3,
+                                                     min_after_dequeue=1,
+                                                     num_threads=5)
+    text_len = tf.squeeze(text_len, squeeze_dims=1)
+    return text_len, word_ids , tag_ids
+
+def total_sample(file_name):
+    '''
+    返回样本总数
+    :param file_name: tfrecord 文件名
+    :return:
+    '''
+    sample_num = 0
+    for record in tf.python_io.tf_record_iterator(file_name):
+        sample_num += 1
+    return sample_num
+
+if __name__ == "__main__":
+    print('进入main ')
+    filename = os.path.dirname(__file__)+'/data/train_data/maxlen_500_addunk_product_reason.tfrecode'
+    filename_dev = os.path.dirname(__file__)+'/data/test_data/maxlen_500_addunk_product_reason.tfrecode'
+    assert os.path.exists(filename)
+    assert os.path.exists(filename_dev)
+    print('确保文件存在')
+    batch_size = 100
+    # id_to_tag = {0: 'O', 1: 'B', 2: 'I', 3: 'E'}
+    tag2index = {'S': 0, 'B-pro': 1, 'I-pro': 2, 'E-pro': 3, 'B-rea': 4, 'I-rea': 5, 'E-rea': 6}
+    id_to_tag = {v:k for k,v in tag2index.items()}
+
+    def get_ner(id_tag):
+        ner = set()
+        for it in re.finditer('12*3', id_tag):
+            ner.add((it.start(), it.end(), 'pro'))
+        for it in re.finditer('45*6', id_tag):
+            ner.add((it.start(), it.end(), 'rea'))
+        return ner
+
+    print('准备构建模型')
+    model = Product_Model()
+    # saver = tf.train.Saver()
+    print('模型构建完毕')
+    with tf.Session() as sess:
+        init_op = tf.global_variables_initializer()
+        sess.run(init_op)
+        print('参数初始化')
+        text_len, word_ids, tag_ids = get_batch_record(sess, filename, batch_size=batch_size)
+        print('get_batch_record')
+        text_len_dev, word_ids_dev, tag_ids_dev = get_batch_record(sess, filename_dev, batch_size=batch_size)
+        print('get_batch_record_dev')
+        coord = tf.train.Coordinator()
+        threads = tf.train.start_queue_runners(coord=coord)
+        print('total_sample(filename)', total_sample(filename))
+
+        total_num = total_sample(filename)
+        batch_num = total_num//batch_size
+        batch_num_dev = total_sample(filename_dev)//batch_size
+        num = 0
+        l = []
+        max_f1 = 0
+
+        # model.saver.restore(sess, os.path.join(os.path.dirname(__file__)+'/model','ner_epoch10_f10.6875_loss1.5230.ckpt'))
+        # print('模型加载成功')
+
+        for epoch in range(50):
+            for batch in range(batch_num):
+                text_len_, word_ids_, tag_ids_ = sess.run([text_len, word_ids, tag_ids])
+                # print(text_len_.shape, word_ids_.shape, tag_ids_.shape)
+                loss_, train_op_, globel_step_ = sess.run([model.loss, model.train_op, model.globel_step],
+                                                                 feed_dict={model.char_inputs:word_ids_,
+                                                        # model.lengths:text_len_,
+                                                        model.dropout:0.5,
+                                                        model.targets:tag_ids_})
+
+
+                if batch % 100==0:
+                    print('loss_:', loss_, '\tglobel_step_:',globel_step_)
+                num += text_len_.shape[0]
+            print('训练总数:%d, 样本总数:%d'%(num, total_num))
+
+            results = []
+            trans = model.trans.eval()
+            Precision = []
+            Recall = []
+            F1 = []
+            loss = []
+            pred_num = 0
+            gold_num = 0
+            equal_num = 0
+            for batch in range(batch_num_dev):
+                text_len_dev_, word_ids_dev_, tag_ids_dev_ = sess.run([text_len_dev, word_ids_dev, tag_ids_dev])
+                lengths, scores, batch_loss = sess.run([model.lengths, model.logits, model.loss], feed_dict={model.char_inputs:word_ids_dev_,
+                                                        # model.lengths:text_len_dev_,
+                                                        model.dropout:0.5,
+                                                        model.targets:tag_ids_dev_})
+                loss.append(batch_loss)
+                batch_paths = model.decode(scores, lengths, trans)
+                for i in range(len(tag_ids_dev_)):
+                    result = []
+                    gold = [str(x) for x in tag_ids_dev_[i][:lengths[i]]]
+                    pred = [str(x) for x in batch_paths[i][:lengths[i]]]
+                    gold_ner = get_ner("".join(gold))
+                    pred_ner = get_ner("".join(pred))
+                    # print('标签实体:',gold_ner)
+                    # print('预测实体:',pred_ner)
+                    pred_num += len(pred_ner)
+                    gold_num += len(gold_ner)
+                    equal_num += len(gold_ner & pred_ner)
+            precision = equal_num / (pred_num + 1e-10)
+            recall = equal_num / (gold_num + 1e-10)
+            f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
+            val_loss = np.mean(loss)
+            print('epoch: %d, f1:%.4f, acc:%.4f, recall:%.4f, val_loss:%.4f'%(epoch, f1, precision, recall, val_loss))
+            if f1>max_f1:
+                max_f1 = f1
+                model.saver.save(sess, os.path.join(os.path.dirname(__file__)+'/model', "ner_epoch%d_f1%.4f_loss%.4f.ckpt"%(epoch,f1, val_loss)))
+                print('save model, max_f1:%.4f' %f1)
+
+        coord.request_stop()
+        coord.join(threads)
+
+
+
+
+