Просмотр исходного кода

时间分类模型编码方式更新优化

admin 4 лет назад
Родитель
Сommit
7dcc9fa751

BIN
BiddingKG/dl/interface/timesplit_model/saved_model.pb


BIN
BiddingKG/dl/interface/timesplit_model/variables/variables.data-00000-of-00001


BIN
BiddingKG/dl/interface/timesplit_model/variables/variables.index


BIN
BiddingKG/dl/time/model_label_time_classify.model.hdf5


+ 45 - 16
BiddingKG/dl/time/train_2.py

@@ -21,23 +21,18 @@ def getModel():
     '''
     L_input = layers.Input(shape=input_shape[1:], dtype='float32')
     R_input = layers.Input(shape=input_shape[1:], dtype='float32')
-    L_lstm = layers.Bidirectional(layers.LSTM(32,return_sequences=True,dropout=0.1))(L_input)
+    L_lstm = layers.Bidirectional(layers.LSTM(40,return_sequences=True,dropout=0.1))(L_input)
     # L_lstm = layers.LSTM(32,return_sequences=True,dropout=0.2)(L_input)
     avg_l = layers.GlobalAveragePooling1D()(L_lstm)
-    R_lstm = layers.Bidirectional(layers.LSTM(32,return_sequences=True,dropout=0.1))(R_input)
+    R_lstm = layers.Bidirectional(layers.LSTM(40,return_sequences=True,dropout=0.1))(R_input)
     # R_lstm = layers.LSTM(32, return_sequences=True, dropout=0.2)(R_input)
     avg_r = layers.GlobalAveragePooling1D()(R_lstm)
     concat = layers.merge([avg_l, avg_r], mode='concat')
-    # concat = layers.merge([L_lstm, R_lstm], mode='concat')
     # lstm = layers.LSTM(24,return_sequences=False,dropout=0.2)(concat)
     output = layers.Dense(output_shape[0],activation="softmax")(concat)
 
-    # L_lstm = layers.LSTM(32,return_sequences=True,dropout=0.2)(L_input)
-    # avg = layers.GlobalAveragePooling1D()(L_GRU)
-    # output = layers.Dense(output_shape[0],activation="softmax")(avg)
-
     model = models.Model(inputs=[L_input,R_input], outputs=output)
-    # model = models.Model(inputs=L_input, outputs=output)
+
     learn_rate = 0.0005
     model.compile(optimizer=optimizers.Adam(lr=learn_rate),
                   loss=losses.binary_crossentropy,
@@ -73,8 +68,8 @@ def getModel_center():
 
 
 def training():
-    data_load = pd.read_csv("C:\\Users\\admin\\Desktop\\newdata_30.csv", index_col=0)
-    test_data = data_load.sample(frac=0.25, random_state=7)
+    data_load = pd.read_csv("C:\\Users\\admin\\Desktop\\newdata_30_prc.csv", index_col=0)
+    test_data = data_load.sample(frac=0.2, random_state=7)
     train_data = data_load.drop(test_data.index, axis=0)
     train_data =train_data.reset_index(drop=True)
 
@@ -83,8 +78,12 @@ def training():
     for left, right, label in zip(train_data['context_left'], train_data['context_right'], train_data['re_label']):
         y = np.zeros(output_shape)
         y[label] = 1
-        left = ''.join(str(left))
-        right = ''.join(str(right))
+        left = str(left)
+        right = str(right)
+        if left=='nan': left = ''
+        if right=='nan': right = ''
+        left = list(left)
+        right = list(right)
         context = [left, right]
         x = embedding_word(context, shape=input_shape)
         train_x.append(x)
@@ -95,8 +94,12 @@ def training():
     for left, right, label in zip(test_data['context_left'], test_data['context_right'], test_data['re_label']):
         y = np.zeros(output_shape)
         y[label] = 1
-        left = ''.join(str(left))
-        right = ''.join(str(right))
+        left = str(left)
+        right = str(right)
+        if left == 'nan': left = ''
+        if right == 'nan': right = ''
+        left = list(left)
+        right = list(right)
         context = [left, right]
         x = embedding_word(context, shape=input_shape)
         test_x.append(x)
@@ -107,7 +110,7 @@ def training():
     train_x, test_x = (np.transpose(train_x, (1, 0, 2, 3)), np.transpose(test_x, (1, 0, 2, 3)))
 
     model = getModel()
-    epochs = 100
+    epochs = 150
     batch_size = 256
     checkpoint = ModelCheckpoint("model_label_time_classify.model.hdf5", monitor="val_loss", verbose=1,
                                  save_best_only=True, mode='min')
@@ -123,7 +126,7 @@ def training():
         callbacks=[checkpoint],
         class_weight='auto'
     )
-    plot_loss(history=history)
+    # plot_loss(history=history)
     load_model = models.load_model("model_label_time_classify.model.hdf5",
                                    custom_objects={'precision': precision, 'recall': recall, 'f1_score': f1_score})
     y_pre = load_model.predict([test_x[0], test_x[1]])
@@ -246,6 +249,30 @@ def predict_center():
     # print(error_data.info())
     error_data.to_csv("C:\\Users\\admin\\Desktop\\test\\error_center.csv")
 
+def data_process():
+    data_load = pd.read_csv("C:\\Users\\admin\\Desktop\\newdata_30.csv", index_col=0)
+    re_left = re.compile("。[^。]*?$")
+    re_right = re.compile("^[^。]*?。")
+    left_list = []
+    right_list = []
+    for left, right in zip(data_load['context_left'], data_load['context_right']):
+        left = str(left)
+        right = str(right)
+        if right=='nan':
+            right = ''
+            # print(1)
+        if re.search("。",left):
+            left = re_left.search(left)
+            left = left.group()[1:]
+        if re.search("。",right):
+            right = re_right.search(right)
+            right = right.group()
+        left_list.append(left)
+        right_list.append(right)
+    data_load['context_left'] = left_list
+    data_load['context_right'] = right_list
+    data_load.to_csv("C:\\Users\\admin\\Desktop\\newdata_30_prc.csv")
+
 def plot_loss(history):
     plt.plot(history.history['loss'])
     plt.plot(history.history['val_loss'])
@@ -259,6 +286,7 @@ if __name__ == '__main__':
     # getModel()
     # getModel_center()
     # training()
+    # data_process()
     # training_center()
     # predict()
     # predict_center()
@@ -275,4 +303,5 @@ if __name__ == '__main__':
     pre_y = model1.predict([test_x[0],test_x[1]])
     rs = [np.argmax(item) for item in pre_y]
     print(pre_y, rs)
+
     pass