Bladeren bron

时间分类新增类别

znj 3 jaren geleden
bovenliggende
commit
f529113ce6

+ 1 - 1
BiddingKG/dl/interface/Preprocessing.py

@@ -1084,7 +1084,7 @@ def segment(soup,final=True):
     # 感叹号替换为中文句号
     text = re.sub("(?<=[\u4e00-\u9fa5])[!!]|[!!](?=[\u4e00-\u9fa5])","。",text)
     #替换"?"为 " " ,update:2021/7/20
-    text = re.sub("?"," ",text)
+    text = re.sub("?{1,}"," ",text)
 
 
     #替换"""为"“",否则导入deepdive出错

+ 199 - 30
BiddingKG/dl/interface/getAttributes.py

@@ -1129,10 +1129,10 @@ def findAttributeAfterEntity(PackDict,roleSet,PackageList,PackageSet,list_entity
                 if not re.search("电,?话", phone_left):
                     last_phone_mask = False
                     continue
-            if re.search("注册[证号]|帐,?号|编,?[号码]|报,?价|证,?号|价,?格|[\(\(]万?元[\)\)]|[a-zA-Z]+\d*$", phone_left):
+            if re.search("注册[证号]|帐,?号|编,?[号码]|报,?价|标,?价|证,?号|价,?格|[\(\(]万?元[\)\)]|[a-zA-Z]+\d*$", phone_left):
                 last_phone_mask = False
                 continue
-            if re.search("^\d{0,4}[.,]\d{2,}|^[0-9a-zA-Z\.]*@|^\d*[a-zA-Z]+", phone_right):
+            if re.search("^\d{0,4}[.,]\d{2,}|^[0-9a-zA-Z\.]*@|^\d*[a-zA-Z]+|元", phone_right):
                 last_phone_mask = False
                 continue
             # if:上一个phone实体不符合条件
@@ -2195,22 +2195,191 @@ def turnBidWay(bidway):
     else:
         return "其他"
 
+my_time_format_pattern = re.compile("((?P<year>\d{4}|\d{2})\s*[-\/年\.]\s*(?P<month>\d{1,2})\s*[-\/月\.]\s*(?P<day>\d{1,2}))")
+import time
+def my_timeFormat(_time):
+    current_year = time.strftime("%Y",time.localtime())
+    all_match = re.finditer(my_time_format_pattern,_time)
+    time_list = []
+    for _match in all_match:
+        if len(_match.group())>0:
+            legal = True
+            year = ""
+            month = ""
+            day = ""
+            for k,v in _match.groupdict().items():
+                if k=="year":
+                    year = v
+                if k=="month":
+                    month = v
+                if k=="day":
+                    day = v
+            if year!="":
+                if len(year)==2:
+                    year = "20"+year
+                if int(year)>int(current_year):
+                    legal = False
+            else:
+                legal = False
+            if month!="":
+                if int(month)>12:
+                    legal = False
+            else:
+                legal = False
+            if day!="":
+                if int(day)>31:
+                    legal = False
+            else:
+                legal = False
+            if legal:
+                # return "%s-%s-%s"%(year,month.rjust(2,"0"),day.rjust(2,"0"))
+                time_list.append("%s-%s-%s"%(year,month.rjust(2,"0"),day.rjust(2,"0")))
+    return time_list
+
+def getTimeAttributes(list_entity,list_sentence):
+    # list_entity = [i for i in list_entity if i.entity_type=='time']
+    list_sentence = sorted(list_sentence,key=lambda x:x.sentence_index)
+    dict_time = {
+        "time_release": [],
+        "time_bidopen": [],
+        "time_bidclose": [],
+        'time_bidstart': [],  # 12 投标(开始)时间、响应文件接收(开始)时间
+
+        'time_publicityStart': [],  # 4 公示开始时间(公示时间、公示期)
+        'time_publicityEnd': [],  # 5 公示截止时间
+        'time_getFileStart': [],  # 6 文件获取开始时间(文件获取时间)
+        'time_getFileEnd': [],  # 7 文件获取截止时间
+        'time_registrationStart': [],  # 8 报名开始时间(报名时间)
+        'time_registrationEnd': [],  # 9 报名截止时间
+        'time_earnestMoneyStart': [], #10 保证金递交开始时间(保证金递交时间)
+        'time_earnestMoneyEnd': [] , # 11 保证金递交截止时间
+        'time_commencement':[] , #13 开工日期
+        'time_completion': []  # 14 竣工日期
+    }
+    for entity in list_entity:
+        if entity.label!=0:
+            entity_text = entity.entity_text
+            extract_time = my_timeFormat(entity_text)
+            if extract_time:
+                sentence_text = list_sentence[entity.sentence_index].sentence_text
+                entity_left = sentence_text[max(0,entity.wordOffset_begin-2):entity.wordOffset_begin]
+                entity_right = sentence_text[entity.wordOffset_end:entity.wordOffset_end+3]
+                label_prob = entity.values[entity.label]
+                if entity.label==1 and label_prob>0.5:
+                    dict_time['time_release'].append((extract_time[0],label_prob))
+                elif entity.label==2 and label_prob>0.5:
+                    dict_time['time_bidopen'].append((extract_time[0],label_prob))
+                elif entity.label==3 and label_prob>0.5:
+                    dict_time['time_bidclose'].append((extract_time[0],label_prob))
+                elif entity.label==12 and label_prob>0.5:
+                    if len(extract_time)==1:
+                        if re.search("前|止|截止",entity_right) or re.search("至|止",entity_left) or re.search("前",entity_text[-2:]):
+                            dict_time['time_bidclose'].append((extract_time[0], label_prob))
+                        else:
+                            dict_time['time_bidstart'].append((extract_time[0], label_prob))
+                    else:
+                        dict_time['time_bidstart'].append((extract_time[0],label_prob))
+                        dict_time['time_bidclose'].append((extract_time[1],label_prob))
+                elif entity.label==4 and label_prob>0.5:
+                    if len(extract_time)==1:
+                        if re.search("前|止|截止",entity_right) or re.search("至|止",entity_left) or re.search("前",entity_text[-2:]):
+                            dict_time['time_publicityEnd'].append((extract_time[0], label_prob))
+                        else:
+                            dict_time['time_publicityStart'].append((extract_time[0], label_prob))
+                    else:
+                        dict_time['time_publicityStart'].append((extract_time[0],label_prob))
+                        dict_time['time_publicityEnd'].append((extract_time[1],label_prob))
+                elif entity.label==5 and label_prob>0.5:
+                    if len(extract_time)==1:
+                        dict_time['time_publicityEnd'].append((extract_time[0], label_prob))
+                    else:
+                        dict_time['time_publicityStart'].append((extract_time[0],label_prob))
+                        dict_time['time_publicityEnd'].append((extract_time[1],label_prob))
+                elif entity.label==6 and label_prob>0.5:
+                    if len(extract_time)==1:
+                        if re.search("前|止|截止",entity_right) or re.search("至|止",entity_left) or re.search("前",entity_text[-2:]):
+                            dict_time['time_getFileEnd'].append((extract_time[0], label_prob))
+                        else:
+                            dict_time['time_getFileStart'].append((extract_time[0], label_prob))
+                    else:
+                        dict_time['time_getFileStart'].append((extract_time[0],label_prob))
+                        dict_time['time_getFileEnd'].append((extract_time[1],label_prob))
+                elif entity.label==7 and label_prob>0.5:
+                    if len(extract_time)==1:
+                        dict_time['time_getFileEnd'].append((extract_time[0], label_prob))
+                    else:
+                        dict_time['time_getFileStart'].append((extract_time[0],label_prob))
+                        dict_time['time_getFileEnd'].append((extract_time[1],label_prob))
+                elif entity.label==8 and label_prob>0.5:
+                    if len(extract_time)==1:
+                        if re.search("前|止|截止",entity_right) or re.search("至|止",entity_left) or re.search("前",entity_text[-2:]):
+                            dict_time['time_registrationEnd'].append((extract_time[0], label_prob))
+                        else:
+                            dict_time['time_registrationStart'].append((extract_time[0], label_prob))
+                    else:
+                        dict_time['time_registrationStart'].append((extract_time[0],label_prob))
+                        dict_time['time_registrationEnd'].append((extract_time[1],label_prob))
+                elif entity.label==9 and label_prob>0.5:
+                    if len(extract_time)==1:
+                        dict_time['time_registrationEnd'].append((extract_time[0], label_prob))
+                    else:
+                        dict_time['time_registrationStart'].append((extract_time[0],label_prob))
+                        dict_time['time_registrationEnd'].append((extract_time[1],label_prob))
+                elif entity.label==10 and label_prob>0.5:
+                    if len(extract_time)==1:
+                        if re.search("前|止|截止",entity_right) or re.search("至|止",entity_left) or re.search("前",entity_text[-2:]):
+                            dict_time['time_earnestMoneyEnd'].append((extract_time[0], label_prob))
+                        else:
+                            dict_time['time_earnestMoneyStart'].append((extract_time[0], label_prob))
+                    else:
+                        dict_time['time_earnestMoneyStart'].append((extract_time[0],label_prob))
+                        dict_time['time_earnestMoneyEnd'].append((extract_time[1],label_prob))
+                elif entity.label==11 and label_prob>0.5:
+                    if len(extract_time)==1:
+                        dict_time['time_earnestMoneyEnd'].append((extract_time[0], label_prob))
+                    else:
+                        dict_time['time_earnestMoneyStart'].append((extract_time[0],label_prob))
+                        dict_time['time_earnestMoneyEnd'].append((extract_time[1],label_prob))
+                elif entity.label==13 and label_prob>0.5:
+                    if len(extract_time)==1:
+                        if re.search("前|止|截止",entity_right) or re.search("至|止",entity_left) or re.search("前",entity_text[-2:]):
+                            dict_time['time_completion'].append((extract_time[0], label_prob))
+                        else:
+                            dict_time['time_commencement'].append((extract_time[0], label_prob))
+                    else:
+                        dict_time['time_commencement'].append((extract_time[0],label_prob))
+                        dict_time['time_completion'].append((extract_time[1],label_prob))
+                elif entity.label==14 and label_prob>0.5:
+                    if len(extract_time)==1:
+                        dict_time['time_completion'].append((extract_time[0], label_prob))
+                    else:
+                        dict_time['time_commencement'].append((extract_time[0],label_prob))
+                        dict_time['time_completion'].append((extract_time[1],label_prob))
+
+
+    result_dict = dict((key,"") for key in dict_time.keys())
+    for time_type,value in dict_time.items():
+        list_time = dict_time[time_type]
+        if list_time:
+            list_time.sort(key=lambda x:x[1],reverse=True)
+            result_dict[time_type] = list_time[0][0]
+    return result_dict
+
 def getOtherAttributes(list_entity):
     dict_other = {"moneysource":"",
                   "person_review":[],
-                  "time_release":"",
-                  "time_bidopen":"",
-                  "time_bidclose":"",
+                  # "time_release":"",
+                  # "time_bidopen":"",
+                  # "time_bidclose":"",
                   "serviceTime":"",
                   "product":[],
                   "total_tendereeMoney":0,
-                  "total_tendereeMoneyUnit":''
-                   }
-    dict_time = {
-        "time_release": [],
-        "time_bidopen": [],
-        "time_bidclose": []
-    }
+                  "total_tendereeMoneyUnit":''}
+    # dict_time = {
+    #     "time_release": [],
+    #     "time_bidopen": [],
+    #     "time_bidclose": []
+    # }
     for entity in list_entity:
         if entity.entity_type == 'bidway':
             dict_other["bidway"] = turnBidWay(entity.entity_text)
@@ -2218,18 +2387,18 @@ def getOtherAttributes(list_entity):
             dict_other["moneysource"] = entity.entity_text
         elif entity.entity_type=='serviceTime':
             dict_other["serviceTime"] = entity.entity_text
-        elif entity.entity_type == 'time' and entity.label==1:
-            if entity.values[entity.label]>0.6:
-                dict_time['time_release'].append((timeFormat(entity.entity_text),entity.values[entity.label]))
-            # dict_other["time_release"] = timeFormat(entity.entity_text)
-        elif entity.entity_type == 'time' and entity.label==2:
-            if entity.values[entity.label]>0.6:
-                dict_time['time_bidopen'].append((timeFormat(entity.entity_text),entity.values[entity.label]))
-            # dict_other["time_bidopen"] = timeFormat(entity.entity_text)
-        elif entity.entity_type == 'time' and entity.label == 3:
-            if entity.values[entity.label]>0.6:
-                dict_time['time_bidclose'].append((timeFormat(entity.entity_text),entity.values[entity.label]))
-            # dict_other["time_bidclose"] = timeFormat(entity.entity_text)
+        # elif entity.entity_type == 'time' and entity.label==1:
+        #     if entity.values[entity.label]>0.6:
+        #         dict_time['time_release'].append((timeFormat(entity.entity_text),entity.values[entity.label]))
+        #     # dict_other["time_release"] = timeFormat(entity.entity_text)
+        # elif entity.entity_type == 'time' and entity.label==2:
+        #     if entity.values[entity.label]>0.6:
+        #         dict_time['time_bidopen'].append((timeFormat(entity.entity_text),entity.values[entity.label]))
+        #     # dict_other["time_bidopen"] = timeFormat(entity.entity_text)
+        # elif entity.entity_type == 'time' and entity.label == 3:
+        #     if entity.values[entity.label]>0.6:
+        #         dict_time['time_bidclose'].append((timeFormat(entity.entity_text),entity.values[entity.label]))
+        #     # dict_other["time_bidclose"] = timeFormat(entity.entity_text)
         elif entity.entity_type=="person" and entity.label ==4:
             dict_other["person_review"].append(entity.entity_text)
         elif entity.entity_type=='product':
@@ -2238,11 +2407,11 @@ def getOtherAttributes(list_entity):
                 dict_other["total_tendereeMoney"] = float(entity.entity_text)
                 dict_other["total_tendereeMoneyUnit"] = entity.money_unit
     # 时间类别
-    for time_type,value in dict_time.items():
-        list_time = dict_time[time_type]
-        if list_time:
-            list_time.sort(key=lambda x:x[1],reverse=True)
-            dict_other[time_type] = list_time[0][0]
+    # for time_type,value in dict_time.items():
+    #     list_time = dict_time[time_type]
+    #     if list_time:
+    #         list_time.sort(key=lambda x:x[1],reverse=True)
+    #         dict_other[time_type] = list_time[0][0]
     dict_other["product"] = list(set(dict_other["product"]))
     return dict_other
 
@@ -2259,7 +2428,7 @@ def getPREMs(list_sentences,list_entitys,list_articles):
     result = []
     for list_sentence,list_entity,list_article in zip(list_sentences,list_entitys,list_articles):
         RoleList = getPackageRoleMoney(list_sentence,list_entity)
-        result.append(dict({"prem":RoleList,"docid":list_article.id},**getOtherAttributes(list_entity),
+        result.append(dict({"prem":RoleList,"docid":list_article.doc_id},**getOtherAttributes(list_entity),**getTimeAttributes(list_entity,list_sentence),
                            **{"fingerprint":list_article.fingerprint,"match_enterprise":list_article.match_enterprise,
                               "match_enterprise_type":list_article.match_enterprise_type,"process_time":getCurrent_date(),
                               "attachmentTypes":list_article.attachmentTypes, "bidway": list_article.bidway}))

BIN
BiddingKG/dl/interface/timesplit_model/saved_model.pb


BIN
BiddingKG/dl/interface/timesplit_model/variables/variables.data-00000-of-00001


BIN
BiddingKG/dl/interface/timesplit_model/variables/variables.index


+ 4 - 2
BiddingKG/dl/test/测试整个要素提取流程.py

@@ -147,6 +147,8 @@ def predict(doc_id,text):
                     # print(_sentence.sentence_text,_sentence.tokens[entity.begin_index:entity.end_index+1])
                     pass
                 # print(entity.pointer_pack)
+            # elif entity.entity_type =='serviceTime':
+            #     print(entity.entity_text)
             #     if entity.pointer_pack:
             #         print('pointer_pack_name:',entity.pointer_pack.entity_text)
             # elif entity.entity_type in ['package']:
@@ -439,8 +441,8 @@ if __name__=="__main__":
     a = time.time()
     print("start")
     # print(predict("12",content))
-    # result = predict("12",text)
-    result = predict("12",content)
+    result = predict("12",text)
+    # result = predict("12",content)
     # print(json.loads(result))
     #test("12",text)
     print("takes",time.time()-a)

BIN
BiddingKG/dl/time/model_time_classify.weights


+ 367 - 16
BiddingKG/dl/time/train_2.py

@@ -13,10 +13,32 @@ from BiddingKG.dl.common.models import *
 from sklearn.metrics import classification_report
 from sklearn.utils import shuffle,class_weight
 import matplotlib.pyplot as plt
+import random
 
 input_shape = (2,30,60)
 input_shape2 = (2,40,128)
-output_shape = [4]
+# output_shape = [4]
+
+time_label_dict = {
+             'time': 0,
+            'time_release': 1, #发布时间
+            'time_bidopen': 2, #开标时间
+            'time_bidclose': 3, #截标时间
+            'time_bidstart': 12, #投标(开始)时间、响应文件接收(开始)时间
+
+            'time_publicityStart': 4, #公示开始时间(公示时间、公示期)
+            'time_publicityEnd': 5, #公示截止时间
+            'time_getFileStart': 6, #文件获取开始时间(文件获取时间)
+            'time_getFileEnd': 7, #文件获取截止时间
+            'time_registrationStart': 8, #报名开始时间(报名时间)
+            'time_registrationEnd': 9, #报名截止时间
+            'time_earnestMoneyStart': 10, #保证金递交开始时间(保证金递交时间)
+            'time_earnestMoneyEnd': 11, #保证金递交截止时间
+            'time_commencement': 13, #开工日期
+            'time_completion': 14 #竣工日期
+        }
+output_shape = [len(time_label_dict)]
+
 
 def get_data():
     data_load = pd.read_csv("newdata_30_prc.csv", index_col=0)
@@ -91,16 +113,23 @@ def getModel2():
     R_input = layers.Input(shape=input_shape2[1:], dtype='float32')
     R_mask = Lambda(lambda x: K.cast(K.not_equal(K.sum(x,axis=-1,keepdims=True), 0), 'float32'))(R_input)
 
-    L_input_drop = Dropout(0.2)(L_input)
-    R_input_drop = Dropout(0.2)(R_input)
+    L_input_drop = Dropout(0.3)(L_input)
+    R_input_drop = Dropout(0.3)(R_input)
     # L_lstm = layers.Bidirectional(layers.GRU(40,return_sequences=True,dropout=0.1))(L_input)
     L_lstm = OurBidirectional(GRU(64, return_sequences=True))([L_input_drop,L_mask])
     L_att = Attention02()(L_lstm,mask=K.squeeze(L_mask,axis=-1))
     # R_lstm = layers.Bidirectional(layers.GRU(40,return_sequences=True,dropout=0.1))(R_input)
     R_lstm = OurBidirectional(GRU(64, return_sequences=True))([R_input_drop,R_mask])
     R_att = Attention02()(R_lstm,mask=K.squeeze(R_mask,axis=-1))
+    L_R = layers.merge([L_lstm, R_lstm],concat_axis=1, mode='concat')
+    L_R_mask = layers.merge([L_mask, R_mask],concat_axis=1, mode='concat')
+    L_R_att = Attention02()(L_R,mask=K.squeeze(L_R_mask,axis=-1))
+
+    L_att = layers.add([L_att,L_R_att])
+    R_att = layers.add([R_att,L_R_att])
     concat = layers.merge([L_att, R_att], mode='concat')
-    concat = Dropout(0.3)(concat)
+
+    concat = Dropout(0.2)(concat)
     output = layers.Dense(output_shape[0],activation="softmax")(concat)
 
     model = models.Model(inputs=[L_input,R_input], outputs=output)
@@ -111,6 +140,36 @@ def getModel2():
                   metrics=[precision,recall,f1_score])
     model.summary()
     return model
+# def getModel2():
+#     '''
+#     @summary: 时间分类模型
+#     '''
+#     L_input = layers.Input(shape=input_shape2[1:], dtype='float32')
+#     L_mask = Lambda(lambda x: K.cast(K.not_equal(K.sum(x,axis=-1,keepdims=True), 0), 'float32'))(L_input)
+#     R_input = layers.Input(shape=input_shape2[1:], dtype='float32')
+#     R_mask = Lambda(lambda x: K.cast(K.not_equal(K.sum(x,axis=-1,keepdims=True), 0), 'float32'))(R_input)
+#
+#     L_input_drop = Dropout(0.3)(L_input)
+#     R_input_drop = Dropout(0.3)(R_input)
+#     # L_lstm = layers.Bidirectional(layers.GRU(40,return_sequences=True,dropout=0.1))(L_input)
+#     L_lstm = OurBidirectional(GRU(64, return_sequences=True))([L_input_drop,L_mask])
+#     L_att = Attention02()(L_lstm,mask=K.squeeze(L_mask,axis=-1))
+#     # R_lstm = layers.Bidirectional(layers.GRU(40,return_sequences=True,dropout=0.1))(R_input)
+#     R_lstm = OurBidirectional(GRU(64, return_sequences=True))([R_input_drop,R_mask])
+#     R_att = Attention02()(R_lstm,mask=K.squeeze(R_mask,axis=-1))
+#     concat = layers.merge([L_att, R_att], mode='concat')
+#
+#     concat = Dropout(0.2)(concat)
+#     output = layers.Dense(output_shape[0],activation="softmax")(concat)
+#
+#     model = models.Model(inputs=[L_input,R_input], outputs=output)
+#
+#     learn_rate = 0.00005
+#     model.compile(optimizer=optimizers.Adam(lr=learn_rate),
+#                   loss=losses.binary_crossentropy,
+#                   metrics=[precision,recall,f1_score])
+#     model.summary()
+#     return model
 
 def getModel3():
     '''
@@ -121,8 +180,8 @@ def getModel3():
     R_input = layers.Input(shape=input_shape2[1:], dtype='float32')
     R_mask = Lambda(lambda x: K.cast(K.not_equal(K.sum(x,axis=-1,keepdims=True), 0), 'float32'))(R_input)
 
-    L_input_drop = Dropout(0.2)(L_input)
-    R_input_drop = Dropout(0.2)(R_input)
+    L_input_drop = Dropout(0.3)(L_input)
+    R_input_drop = Dropout(0.3)(R_input)
     # L_lstm = layers.Bidirectional(layers.GRU(40,return_sequences=True,dropout=0.1))(L_input)
     L_lstm = OurBidirectional(GRU(64, return_sequences=True))([L_input_drop,L_mask])
     # L_att = Attention02()(L_lstm,mask=K.squeeze(L_mask,axis=-1))
@@ -133,7 +192,7 @@ def getModel3():
     att = Attention02()(concat,mask=K.squeeze(concat_mask,axis=-1))
     # R_att = Attention02()(R_lstm,mask=K.squeeze(R_mask,axis=-1))
     # concat = layers.merge([L_att, R_att], mode='concat')
-    att = Dropout(0.3)(att)
+    att = Dropout(0.2)(att)
     output = layers.Dense(output_shape[0],activation="softmax")(att)
 
     model = models.Model(inputs=[L_input,R_input], outputs=output)
@@ -145,6 +204,72 @@ def getModel3():
     model.summary()
     return model
 
+class Attention(Layer):
+    """多头注意力机制
+    """
+    def __init__(self, nb_head, size_per_head, **kwargs):
+        self.nb_head = nb_head
+        self.size_per_head = size_per_head
+        self.out_dim = nb_head * size_per_head
+        super(Attention, self).__init__(**kwargs)
+    def build(self, input_shape):
+        super(Attention, self).build(input_shape)
+        q_in_dim = input_shape[0][-1]
+        k_in_dim = input_shape[1][-1]
+        v_in_dim = input_shape[2][-1]
+        self.q_kernel = self.add_weight(name='q_kernel',
+                                        shape=(q_in_dim, self.out_dim),
+                                        initializer='glorot_normal')
+        self.k_kernel = self.add_weight(name='k_kernel',
+                                        shape=(k_in_dim, self.out_dim),
+                                        initializer='glorot_normal')
+        self.v_kernel = self.add_weight(name='w_kernel',
+                                        shape=(v_in_dim, self.out_dim),
+                                        initializer='glorot_normal')
+    def mask(self, x, mask, mode='mul'):
+        if mask is None:
+            return x
+        else:
+            for _ in range(K.ndim(x) - K.ndim(mask)):
+                mask = K.expand_dims(mask, K.ndim(mask))
+            if mode == 'mul':
+                return x * mask
+            else:
+                return x - (1 - mask) * 1e10
+    def call(self, inputs):
+        q, k, v = inputs[:3]
+        v_mask, q_mask = None, None
+        if len(inputs) > 3:
+            v_mask = inputs[3]
+            if len(inputs) > 4:
+                q_mask = inputs[4]
+        # 线性变换
+        qw = K.dot(q, self.q_kernel)
+        kw = K.dot(k, self.k_kernel)
+        vw = K.dot(v, self.v_kernel)
+        # 形状变换
+        qw = K.reshape(qw, (-1, K.shape(qw)[1], self.nb_head, self.size_per_head))
+        kw = K.reshape(kw, (-1, K.shape(kw)[1], self.nb_head, self.size_per_head))
+        vw = K.reshape(vw, (-1, K.shape(vw)[1], self.nb_head, self.size_per_head))
+        # 维度置换
+        qw = K.permute_dimensions(qw, (0, 2, 1, 3))
+        kw = K.permute_dimensions(kw, (0, 2, 1, 3))
+        vw = K.permute_dimensions(vw, (0, 2, 1, 3))
+        # Attention
+        a = K.batch_dot(qw, kw, [3, 3]) / self.size_per_head**0.5
+        a = K.permute_dimensions(a, (0, 3, 2, 1))
+        a = self.mask(a, v_mask, 'add')
+        a = K.permute_dimensions(a, (0, 3, 2, 1))
+        a = K.softmax(a)
+        # 完成输出
+        o = K.batch_dot(a, vw, [3, 2])
+        o = K.permute_dimensions(o, (0, 2, 1, 3))
+        o = K.reshape(o, (-1, K.shape(o)[1], self.out_dim))
+        o = self.mask(o, q_mask, 'mul')
+        return o
+    def compute_output_shape(self, input_shape):
+        return (input_shape[0][0], input_shape[0][1], self.out_dim)
+
 class Attention02(Layer):
     def __init__(self, **kwargs):
         self.init = initializers.get('normal')
@@ -530,11 +655,216 @@ def train3():
     # # y_pre2 = load_model.predict(train_x[0])
     # res2 = classification_report(np.argmax(train_y, axis=1), np.argmax(y_pre2, axis=1))
     # print(res2)
+
+def train4():
+    # data_load = pd.read_excel("tokens_tolabel_data1.xlsx", index_col=0)
+    data_load = pd.read_excel("tokens_tolabel_data1_res13New.xlsx", index_col=0)
+    # data_load = pd.concat([data_load[data_load['re_label']==0],data_load])
+    # data_load = data_load[data_load['pre_label_prob']>0.97]
+    # data_load = data_load[data_load['is_same']==1]
+    data_zero = pd.read_excel("time_entity5.xlsx")
+    data_zero = data_zero[(data_zero['viewed']==1)|(data_zero['is_same']==2)]
+    # data_old = pd.read_excel("tokens_data_02.xlsx")
+    data_old = pd.read_excel("tokens_data_02_res7New.xlsx")
+    data_delay1 = pd.read_excel("delayTime_entity1.xlsx")
+    data_delay1 = data_delay1[data_delay1['label']!=0]
+    data_delay2 = pd.read_excel("delayTime_entity2.xlsx")
+
+    # data_zero = pd.concat([data_zero,data_zero])
+    # data_zero = pd.concat([data_zero[(data_zero['label']!=0)|(data_zero['is_same']==2)],data_zero.sample(n=3000)])
+    # data_zero = data_zero.sample(n=80000)
+    print("输入shape:",input_shape2)
+    data_x = []
+    data_y = []
+    import random
+    for left, right, label,_label in zip(data_load['context_left'], data_load['context_right'], data_load['re_label'], data_load['label']):
+        # if label==_label:
+
+        y = np.zeros(output_shape)
+        y[label] = 1
+        left = eval(left)
+        left = left[-40:]
+        right = eval(right)
+        right = right[:40]
+        context = [left, right]
+        # x = embedding(context, shape=input_shape2)
+        data_x.append(context)
+        data_y.append(y)
+    # data_load2 = data_load[data_load['re_label']==0]
+    # for left, right, label,_label in zip(data_load2['context_left'], data_load2['context_right'], data_load2['re_label'], data_load2['label']):
+    #         if label==_label:
+    #             y = np.zeros(output_shape)
+    #             y[label] = 1
+    #             left = eval(left)
+    #             left = left[-40:]
+    #             if len(left)>30:
+    #                 left = left[2:]
+    #             elif len(left)>15:
+    #                 left = left[1:]
+    #             right = eval(right)
+    #             right = right[:40]
+    #             if len(right)>15:
+    #                 right = right[:-1]
+    #             context = [left, right]
+    #             # x = embedding(context, shape=input_shape2)
+    #             data_x.append(context)
+    #             data_y.append(y)
+
+    for left, right, label in zip(data_zero['context_left'], data_zero['context_right'], data_zero['re_label']):
+
+        y = np.zeros(output_shape)
+        y[label] = 1
+        left = eval(left)
+        left = left[-40:]
+        right = eval(right)
+        right = right[:40]
+        context = [left, right]
+        # x = embedding(context, shape=input_shape2)
+        data_x.append(context)
+        data_y.append(y)
+
+    for left, right, label in zip(data_delay1['context_left'], data_delay1['context_right'], data_delay1['label']):
+            y = np.zeros(output_shape)
+            y[label] = 1
+            left = eval(left)
+            left = left[-40:]
+            right = eval(right)
+            right = right[:40]
+            context = [left, right]
+            # x = embedding(context, shape=input_shape2)
+            data_x.append(context)
+            data_y.append(y)
+    for left, right, label in zip(data_delay2['context_left'], data_delay2['context_right'], data_delay2['re_label']):
+                y = np.zeros(output_shape)
+                y[label] = 1
+                left = eval(left)
+                left = left[-40:]
+                right = eval(right)
+                right = right[:40]
+                context = [left, right]
+                # x = embedding(context, shape=input_shape2)
+                data_x.append(context)
+                data_y.append(y)
+
+    # for left, right, label in zip(data_zero['context_left'], data_zero['context_right'], data_zero['label']):
+    #         y = np.zeros(output_shape)
+    #         y[label] = 1
+    #         left = eval(left)
+    #         left = left[-40:]
+    #         if len(left) > 30:
+    #             left = left[2:]
+    #         elif len(left) > 15:
+    #             left = left[1:]
+    #         right = eval(right)
+    #         right = right[:40]
+    #         if len(right) > 15:
+    #             right = right[:-1]
+    #         context = [left, right]
+    #         # x = embedding(context, shape=input_shape2)
+    #         data_x.append(context)
+    #         data_y.append(y)
+
+    # for left, right, label in zip(data_old['context_left'], data_old['context_right'], data_old['label']):
+    #         y = np.zeros(output_shape)
+    #         y[label] = 1
+    #         left = eval(left)
+    #         left = left[-40:]
+    #         right = eval(right)
+    #         right = right[:40]
+    #         context = [left, right]
+    #         # x = embedding(context, shape=input_shape2)
+    #         data_x.append(context)
+    #         data_y.append(y)
+    for left, right, label,pre_label,is_same in zip(data_old['context_left'], data_old['context_right'], data_old['label'],
+                                                    data_old['pre_label'],data_old['is_same']):
+        if label==0:
+            if is_same==1:
+                pass
+            else:
+                if pre_label>3:
+                    label = pre_label
+                else:
+                    continue
+        y = np.zeros(output_shape)
+        y[label] = 1
+        left = eval(left)
+        left = left[-40:]
+        right = eval(right)
+        right = right[:40]
+        context = [left, right]
+        # x = embedding(context, shape=input_shape2)
+        data_x.append(context)
+        data_y.append(y)
+
+    _data = [d for d in zip(data_x,data_y)]
+    random.shuffle(_data)
+    data_x = [i[0] for i in _data]
+    data_y = [i[1] for i in _data]
+    test_len = int(len(data_x) * 0.11)
+    test_x = data_x[:test_len]
+    test_y = data_y[:test_len]
+    print("测试数据量:", len(test_x))
+    train_x = data_x[test_len:]
+    train_y = data_y[test_len:]
+
+    # for left, right, label,pre_label,is_same in zip(data_old['context_left'], data_old['context_right'], data_old['label'],
+    #                                                 data_old['pre_label'],data_old['is_same']):
+    #     # if label==0:
+    #     #     if random.random()>0.25:
+    #     #         continue
+    #     if label==0:
+    #         if is_same==1:
+    #             pass
+    #         else:
+    #             if pre_label>3:
+    #                 label = pre_label
+    #             else:
+    #                 continue
+    #     y = np.zeros(output_shape)
+    #     y[label] = 1
+    #     left = eval(left)
+    #     left = left[-40:]
+    #     right = eval(right)
+    #     right = right[:40]
+    #     context = [left, right]
+    #     # x = embedding(context, shape=input_shape2)
+    #     train_x.append(context)
+    #     train_y.append(y)
+    print("训练数据量:", len(train_x))
+
+    # train_y, test_y = np.array(train_y), np.array(test_y)
+    # train_x = np.array(train_x)
+    # test_x = np.array(test_x)
+    # test_x = np.transpose(test_x, (1, 0, 2, 3))
+    # train_x, test_x = (np.transpose(train_x, (1, 0, 2, 3)), np.transpose(test_x, (1, 0, 2, 3)))
+    training_generator = DataGenerator(train_x, train_y,is_train=True)
+    # training_generator = DataGenerator(data_x, data_y)
+    validation_generator = DataGenerator(test_x, test_y,is_train=False,shuffle=False)
+
+    # model = getModel3()
+    model = getModel2()
+    epochs = 100
+    # batch_size = 256
+    checkpoint = ModelCheckpoint("model_time_classify.weights",save_weights_only=True, monitor="val_loss", verbose=1,
+                                 save_best_only=True, mode='min')
+    # checkpoint = ModelCheckpoint("model_time_classify2.weights",save_weights_only=True, monitor="loss", verbose=1,
+    #                                  save_best_only=True, mode='min')
+
+    history = model.fit_generator(
+        generator=training_generator,
+        validation_data=validation_generator,
+        use_multiprocessing=True, workers=2,
+        epochs=epochs,
+        shuffle=True,
+        callbacks=[checkpoint],
+        class_weight='auto'
+    )
+
 from keras.utils import Sequence,to_categorical
 class DataGenerator(Sequence):
     'Generates data for Keras'
-    def __init__(self, texts, labels, batch_size=256,
-                 n_classes=4, shuffle=True):
+    def __init__(self, texts, labels, is_train=True,batch_size=256,
+                 n_classes=len(time_label_dict), shuffle=True):
         'Initialization'
         # self.dim = dim
         self.batch_size = batch_size
@@ -542,6 +872,7 @@ class DataGenerator(Sequence):
         self.texts = texts
         self.n_classes = n_classes
         self.shuffle = shuffle
+        self.is_train = is_train
         self.on_epoch_end()
 
     def __len__(self):
@@ -583,8 +914,22 @@ class DataGenerator(Sequence):
         # Generate data
         for i, context in enumerate(list_texts):
             # Store sample
-            # tokens = preprocess2(text)
-            # tokens = tokens[:maxlen]
+            if self.is_train:
+                left = context[0]
+                if len(left) > 30:
+                    if random.random() > 0.5:
+                        left = left[2:]
+                elif len(left) > 15:
+                    if random.random() > 0.5:
+                        left = left[1:]
+                right = context[1]
+                if len(right) > 30:
+                    if random.random() > 0.5:
+                        right = right[:-2]
+                elif len(right) > 15:
+                    if random.random() > 0.5:
+                        right = right[:-1]
+                context = [left, right]
             words_matrix = embedding_mywords(context, shape=input_shape2)
             # Store class
             # y[i] = _label[i]
@@ -647,7 +992,11 @@ def predict3():
     new_data.to_excel("new_tokens_data1_res.xlsx")
 
 def predict4():
-    data = pd.read_csv("tokens_tolabel_data1_res11.csv", chunksize=3000)
+    data = pd.read_csv("tokens_data_02_res6New.csv", chunksize=3000)
+    # data = pd.read_excel("C:\\Users\\Administrator\\Desktop\\time_entity4.xlsx")
+    # data.to_csv("C:\\Users\\Administrator\\Desktop\\time_entity4.csv")
+    # data = pd.read_csv("C:\\Users\\Administrator\\Desktop\\time_entity4.csv", chunksize=3000)
+
     model1 = getModel2()
     model1.load_weights("model_time_classify.weights")
     new_data = pd.DataFrame()
@@ -671,14 +1020,15 @@ def predict4():
         pre_y = model1.predict([test_x[0], test_x[1]])
         _data['pre_label'] = [np.argmax(item) for item in pre_y]
         _data['pre_label_prob'] = [max(item) for item in pre_y]
-        _data['is_same'] = [1 if int(_label)==_pre else 0 for _label,_pre in zip(_data['re_label'],_data['pre_label'])]
+        _data['is_same'] = [1 if int(_label)==_pre else 0 for _label,_pre in zip(_data['label'],_data['pre_label'])]
         # _data['is_same'] = [1 if int(_re)==int(_pre) and int(_re)==int(_label) else 0 for _label,_re,_pre in zip(_data['label'],_data['re_label'],_data['pre_label'])]
         # data['label'] = label
         new_data = pd.concat([new_data, _data])
         idx += 3000
         print(idx)
-    # data.to_csv("new_tokens_data1.csv")
-    new_data.to_excel("tokens_tolabel_data1_res12.xlsx")
+    # new_data.to_csv("tokens_data_02_res7New.csv")
+    new_data.to_excel("tokens_data_02_res7New.xlsx")
+    # new_data.to_excel("C:\\Users\\Administrator\\Desktop\\tokens_data_02_res7New.xlsx")
 
 
 def predict():
@@ -863,7 +1213,7 @@ def save_model():
             test_model = getModel2()
             test_model.load_weights("model_time_classify.weights")
             tf.saved_model.simple_save(sess,
-                                       "models/timesplit_model/",
+                                       "models/timesplit_model2/",
                                        inputs={"input0": test_model.input[0],
                                                "input1":test_model.input[1]
                                                },
@@ -879,6 +1229,7 @@ if __name__ == '__main__':
     # training()
     # train2()
     # train3()
+    # train4()
     # data_process()
     # data_process2()
     # data_process3()