Jelajahi Sumber

表头识别模型及代码优化

fangjiasheng 3 tahun lalu
induk
melakukan
c5be04ea2f

+ 12 - 12
BiddingKG/dl/entityLink/entityLink.py

@@ -48,18 +48,18 @@ def get_place_list():
     place_list.append('台湾')
     place_list.append('澳门')
     place_list.append('香港')
-    place_list.append('東莞')
-    place_list.append('廣州')
-    place_list.append('韩国')
-    place_list.append('德国')
-    place_list.append('英国')
-    place_list.append('日本')
-    place_list.append('意大利')
-    place_list.append('新加坡')
-    place_list.append('加拿大')
-    place_list.append('西班牙')
-    place_list.append('澳大利亚')
-    place_list.append('美国')
+    # place_list.append('東莞')
+    # place_list.append('廣州')
+    # place_list.append('韩国')
+    # place_list.append('德国')
+    # place_list.append('英国')
+    # place_list.append('日本')
+    # place_list.append('意大利')
+    # place_list.append('新加坡')
+    # place_list.append('加拿大')
+    # place_list.append('西班牙')
+    # place_list.append('澳大利亚')
+    # place_list.append('美国')
 
     place_list = list(set(place_list))
     return place_list

+ 74 - 11
BiddingKG/dl/table_head/check_user_label_accuracy.py

@@ -1,19 +1,24 @@
-import difflib
 import time
-
 from BiddingKG.dl.table_head.pre_process import postgresql_util
 from BiddingKG.dl.table_head.predict import predict
 
 
 def user_label_accuracy(update_user):
-    sql = """
-    select table_text, pre_label, post_label, id
-    from label_table_head_info 
-    where update_user='""" + update_user + "' "
+    if update_user != 'test':
+        sql = """
+        select table_text, pre_label, post_label, id
+        from label_table_head_info 
+        where update_user='""" + update_user + "' order by id desc limit 3000"
+    else:
+        sql = """
+        select table_text, pre_label, post_label, id
+        from label_table_head_info 
+        where update_user='""" + update_user + "' and status = 1 and update_time >= '2022-01-17'"
 
     result_list = postgresql_util(sql, limit=1000000)
     right_cnt = 0
     error_id_list = []
+    right_id_list = []
     i = 0
     start_time = time.time()
     for table in result_list:
@@ -48,6 +53,7 @@ def user_label_accuracy(update_user):
         if predict_label_list:
             if str(label_list) == str(predict_label_list):
                 right_cnt += 1
+                right_id_list.append(str(_id)+"\n")
             else:
                 # cnt = 0
                 # for j in range(len(label_list)):
@@ -60,13 +66,20 @@ def user_label_accuracy(update_user):
                 #         break
                 error_id_list.append(str(_id)+"\n")
 
-    print(update_user + " accuracy:", right_cnt / len(result_list), 'total:', len(result_list))
+    accuracy = right_cnt / len(result_list)
+    print(update_user + " accuracy:", accuracy, 'total:', len(result_list))
     print("error_id_list", len(error_id_list))
 
-    save_path = "check_user_result/"+update_user+".txt"
+    save_path = "check_user_result/"+update_user+"_error.txt"
     with open(save_path, 'w') as f:
         f.writelines(error_id_list)
 
+    save_path = "check_user_result/"+update_user+"_right.txt"
+    with open(save_path, 'w') as f:
+        f.writelines(right_id_list)
+
+    return accuracy
+
 
 def get_single_result(_id):
     sql = """
@@ -83,9 +96,59 @@ def get_single_result(_id):
     table_text = table_text.replace('\\', '/')
     table_text = eval(table_text)
     label_list = predict(table_text)
-    print(label_list)
+    for i in range(len(label_list)):
+        print(i+1, label_list[i])
 
 
 if __name__ == '__main__':
-    user_label_accuracy("test9")
-    # get_single_result(171520)
+    # users = ["test9", "test11", "test12", "test20", "test25", "test26", "test27"]
+    users = ["test20", "test27"]
+    # users = ['test']
+    acc_list = []
+    for user in users:
+        acc = user_label_accuracy(user)
+        acc_list.append([user, acc])
+
+    print(acc_list)
+
+    # get_single_result(161927)
+
+    # import pandas as pd
+    # df = pd.read_csv("C:\\Users\\Administrator\\Desktop\\4.csv")
+    # _dict = {
+    #     51: "公告变更",
+    #     52: "招标公告",
+    #     101: "中标信息",
+    #     102: "招标预告",
+    #     103: "招标答疑",
+    #     104: "招标文件",
+    #     105: "资审结果",
+    #     106: "法律法规",
+    #     107: "新闻资讯",
+    #     108: "拟建项目",
+    #     109: "展会推广",
+    #     110: "企业名录",
+    #     111: "企业资质",
+    #     112: "全国工程人员",
+    #     113: "业主采购",
+    #     114: "采购意向",
+    #     115: "拍卖出让",
+    #     116: "土地矿产",
+    #     117: "产权交易",
+    #     118: "废标公告",
+    #     119: "候选人公示",
+    #     120: "合同公告",
+    # }
+    # data_list = []
+    # for index, row in df.iterrows():
+    #     if index % 100000 == 0:
+    #         print("Loop", index)
+    #         print(_dict[int(row['docchannel'])])
+    #         print(df.iloc[index, 2])
+    #     data = row.tolist()
+    #     data[2] = _dict[int(data[2])]
+    #     data_list.append(data)
+    # df = pd.DataFrame(data_list)
+    # df.columns = ['docid', '项目名称', '信息类型', '发布时间', '地区', '业主',
+    #               '预算金额', '中标供应商', '成交金额', '代理机构']
+    # df.to_csv("C:\\Users\\Administrator\\Desktop\\4-1.csv", index=False)

TEMPAT SAMPAH
BiddingKG/dl/table_head/checkpoints/best.hdf5


+ 6 - 1
BiddingKG/dl/table_head/loss.py

@@ -2,7 +2,12 @@ import tensorflow as tf
 import keras as K
 
 
-def focal_loss(gamma=3., alpha=.25):
+def focal_loss(gamma=2., alpha=.5):
+    # 3-0.5 precision-low recall-high
+    # 3-0.37 precision-high recall-low
+    # 3-0.25 precision-high recall-low
+    # 2-0.5 precision-low recall-high
+    # 2-0.25 precision-high recall-low
     def f_loss(y_true, y_pred):
         pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
         pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))

+ 44 - 19
BiddingKG/dl/table_head/models/model.py

@@ -1,44 +1,69 @@
 import sys
 import os
+
+from keras.layers import Lambda
+
 sys.path.append(os.path.abspath("../.."))
 from keras import layers, models
-import tensorflow as tf
+import keras.backend as K
 from BiddingKG.dl.table_head.models.my_average_pooling import MyAveragePooling1D
 from BiddingKG.dl.table_head.models.self_attention import SeqSelfAttention
 
 
 def get_model(input_shape, output_shape):
-    # Input
+    # Input (batch, 10, 60)
     input_1 = layers.Input(shape=input_shape[1:], dtype="float32")
     input_2 = layers.Input(shape=input_shape[1:], dtype="float32")
     input_3 = layers.Input(shape=input_shape[1:], dtype="float32")
+    input_4 = layers.Input(shape=input_shape[1:], dtype="float32")
+    input_5 = layers.Input(shape=input_shape[1:], dtype="float32")
+    input_6 = layers.Input(shape=input_shape[1:], dtype="float32")
 
-    # Embedding
-    # embed_1 = layers.Embedding(input_dim=6624, output_dim=32,
-    #                            input_length=input_shape[1], mask_zero=True)(input_1)
-    # embed_2 = layers.Embedding(input_dim=6624, output_dim=32,
-    #                            input_length=input_shape[1], mask_zero=True)(input_2)
+    # ----------- Three box sequence -----------
+    # Concat (batch, 30, 60)
+    concat_1 = layers.concatenate([input_1, input_2, input_3], axis=-2, name='seq_concat')
+    concat_2 = layers.concatenate([input_4, input_5, input_6], axis=-2)
 
-    # Bi-LSTM
-    bi_lstm_1 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(input_1)
-    bi_lstm_2 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(input_2)
-    bi_lstm_3 = layers.Bidirectional(layers.LSTM(16, return_sequences=True))(input_3)
+    # Bi-LSTM (batch, 30, 128)
+    bi_lstm_1 = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(concat_1)
+    bi_lstm_2 = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(concat_2)
 
-    # Self-Attention
+    # Self-Attention (batch, 30, 128)
     self_attention_1 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_1)
     self_attention_2 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_2)
-    self_attention_3 = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm_3)
 
-    # Concat
-    concat = layers.concatenate([self_attention_1, self_attention_2, self_attention_3])
+    # Dense (batch, 30, 1)
+    dense_1 = layers.Dense(output_shape[0], activation="relu")(self_attention_1)
+    dense_2 = layers.Dense(output_shape[0], activation="relu")(self_attention_2)
+
+    # Squeeze (batch, 30)
+    squeeze_1 = Lambda(lambda x: K.squeeze(x, axis=-1))(dense_1)
+    squeeze_2 = Lambda(lambda x: K.squeeze(x, axis=-1))(dense_2)
+
+    # ----------- One box feature -----------
+    # Bi-LSTM (batch, 10, 128)
+    bi_lstm = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(input_2)
 
-    # Dense + Softmax
-    output = layers.Dense(output_shape[0], activation="sigmoid")(concat)
+    # Self-Attention (batch, 10, 128)
+    self_attention = SeqSelfAttention(attention_activation='sigmoid')(bi_lstm)
 
     # mask mean pooling
-    output = MyAveragePooling1D(axis=1, name='output')(output)
+    # pool_1 = MyAveragePooling1D(axis=-1)(self_attention_1)
+
+    # Dense (batch, 10, 1)
+    dense = layers.Dense(output_shape[0], activation="relu")(self_attention)
+
+    # Squeeze (batch, 10) - one box feature
+    squeeze = Lambda(lambda x: K.squeeze(x, axis=-1))(dense)
+
+    # ----------- Three box sequence & One box feature -----------
+    # Dense (batch, 1)
+    concat = layers.concatenate([squeeze, squeeze_1, squeeze_2])
+    output = layers.Dense(64, activation='relu')(concat)
+    output = layers.Dense(1, activation="sigmoid", name='output')(output)
 
-    model = models.Model(inputs=[input_1, input_2, input_3], outputs=output)
+    model = models.Model(inputs=[input_1, input_2, input_3, input_4, input_5, input_6],
+                         outputs=output)
 
     model.summary()
     return model

+ 3 - 4
BiddingKG/dl/table_head/post_process.py

@@ -3,10 +3,9 @@
 def table_post_process(table_text_list, predict_result, threshold=0.5):
     predict_result = predict_result.tolist()
     predict_list = []
-    for i in range(0, len(predict_result), 2):
-        row_predict = predict_result[i][1]
-        col_predict = predict_result[i+1][1]
-        if row_predict + col_predict >= 2 * threshold:
+    for i in range(0, len(predict_result)):
+        predict = predict_result[i][0]
+        if predict >= threshold:
             predict = "1"
         else:
             predict = "0"

+ 72 - 0
BiddingKG/dl/table_head/postgresql2csv.py

@@ -0,0 +1,72 @@
+import os
+import sys
+import time
+import pandas as pd
+import numpy as np
+import psycopg2
+sys.path.append(os.path.abspath("../../.."))
+from BiddingKG.dl.table_head.predict import predict
+
+
+def eval_text_list(table_text):
+    if table_text[0] == '"':
+        table_text = eval(table_text)
+    else:
+        table_text = table_text
+    table_text = table_text.replace('\\', '/')
+    table_text = eval(table_text)
+    return table_text
+
+
+def read_postgresql(txt_name, start_id, _time):
+    conn = psycopg2.connect(database="table_head_label", user="postgres",
+                            password="postgres", host="192.168.2.103", port="5432")
+
+    with open('check_user_result/' + txt_name, "r") as f:
+        id_list = f.readlines()
+    # with open('check_user_result/test27.txt', "r") as f:
+    #     id_list += f.readlines()
+
+    _list = []
+    for _id in id_list:
+        _id = _id[:-1]
+        sql = 'select * from label_table_head_info where id =' + _id
+        df = pd.read_sql(sql=sql, con=conn)
+        # df = df[0]
+        for index, row in df.iterrows():
+            _list.append([x for x in row])
+    cnt = 0
+    new_list = []
+    for line in _list:
+        try:
+            table_text = eval_text_list(line[2])
+        except:
+            print("无法识别")
+            continue
+        if cnt % 1000 == 0:
+            print("Loop", cnt)
+        line[0] = start_id + cnt
+        cnt += 1
+        line[6] = 'test'
+        line[9] = 1
+        if time == '':
+            line[7] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
+        else:
+            line[7] = _time
+        # 模型识别table_head
+        label_list = predict(table_text)
+        line[3] = str(label_list)
+        new_list.append(line)
+    df = pd.DataFrame(_list)
+    new_csv_path = "data_new.csv"
+
+    df.to_csv(new_csv_path, index=False)
+    conn.close()
+    return new_csv_path
+
+
+if __name__ == '__main__':
+    new_csv_path = read_postgresql('test20_error.txt', 203995, '2022-01-01 00:00:00')
+    # new_csv_path = read_postgresql('test20_right.txt', 203995, '')
+    # df = pd.read_csv('data_new.csv')
+    # print(df.iloc[:, 4])

+ 65 - 82
BiddingKG/dl/table_head/pre_process.py

@@ -1,3 +1,5 @@
+import random
+
 import psycopg2
 import numpy as np
 from BiddingKG.dl.common.Utils import embedding_word
@@ -41,17 +43,19 @@ def postgresql_util(sql, limit):
 
 
 def get_data_from_sql(dim=10):
+    # sql = """
+    # select table_text, pre_label, post_label, id
+    # from label_table_head_info
+    # where update_user <> 'test27' and update_user <> 'test20' and table_box_cnt >= 4 and table_box_cnt <= 200
+    # ;
+    # """
     sql = """
     select table_text, pre_label, post_label, id
     from label_table_head_info 
-    where update_user <> 'test27' and table_box_cnt >= 4 and table_box_cnt <= 200 
+    where status = 1 and update_time >= '2022-01-17'
     ;
     """
-    # sql = """
-    # select table_text, pre_label, post_label, id
-    # from label_table_head_info
-    # where id = 843
-    # """
+
     result_list = postgresql_util(sql, limit=1000000)
 
     all_data_list = []
@@ -92,50 +96,6 @@ def get_data_from_sql(dim=10):
         all_data_label_list += data_label_list
 
     print("len(all_data_list)", len(all_data_list))
-
-        #
-        # new_data_list = []
-        # for data in data_list:
-        #     # 中文字符映射为index
-        #     # data[0] = get_sentence_index_list(data[0])
-        #     # data[1] = get_sentence_index_list(data[1])
-        #     # 维度不够,填充掩码0
-        #     # if len(data[0]) < dim:
-        #     #     data[0] = data[0] + [0]*(dim-len(data[0]))
-        #     # elif len(data[0]) > dim:
-        #     #     data[0] = data[0][:dim]
-        #     # if len(data[1]) < dim:
-        #     #     data[1] = data[1] + [0]*(dim-len(data[1]))
-        #     # elif len(data[1]) > dim:
-        #     #     data[1] = data[1][:dim]
-        #
-        #     # 中文字符映射为Embedding
-        #     data = embedding_word(data, input_shape)
-        #     new_data_list.append(data)
-        #
-        # new_data_list = np.array(new_data_list)
-        # data_label_list = np.array(data_label_list)
-        # if np.array(new_data_list).shape[1:] == input_shape:
-        #     all_data_list.append(new_data_list)
-        #     all_data_label_list.append(data_label_list)
-
-    # # 防止concat太慢
-    # split_len = 1000
-    # _len = int(len(all_data_list) / split_len)
-    # all_data_list_1 = []
-    # all_data_list_2 = []
-    # for i in range(_len):
-    #     if i == _len - 1:
-    #         array1 = np.concatenate(all_data_list[i*split_len:])
-    #         array2 = np.concatenate(all_data_label_list[i*split_len:])
-    #     else:
-    #         array1 = np.concatenate(all_data_list[i*split_len:i*split_len+split_len])
-    #         array2 = np.concatenate(all_data_label_list[i*split_len:i*split_len+split_len])
-    #     all_data_list_1.append(array1)
-    #     all_data_list_2.append(array2)
-    # all_data_list = np.concatenate(all_data_list_1)
-    # all_data_label_list = np.concatenate(all_data_list_2)
-
     return all_data_list, all_data_label_list
 
 
@@ -191,49 +151,55 @@ def table_pre_process(text_list, label_list, _id, is_train=True):
                 if is_train:
                     last_col_label = row_label[j-1]
             else:
-                last_col = None
+                last_col = col
                 if is_train:
-                    last_col_label = 0
+                    last_col_label = col_label
 
             if j < len(row) - 1:
                 next_col = row[j+1]
                 if is_train:
                     next_col_label = row_label[j+1]
             else:
-                next_col = None
+                next_col = col
                 if is_train:
-                    next_col_label = 0
+                    next_col_label = col_label
 
             if last_row:
                 last_row_col = last_row[j]
                 if is_train:
                     last_row_col_label = last_row_label[j]
             else:
-                last_row_col = None
+                last_row_col = col
                 if is_train:
-                    last_row_col_label = 0
+                    last_row_col_label = col_label
 
             if next_row:
                 next_row_col = next_row[j]
                 if is_train:
                     next_row_col_label = next_row_label[j]
             else:
-                next_row_col = None
+                next_row_col = col
                 if is_train:
-                    next_row_col_label = 0
+                    next_row_col_label = col_label
+
+            # data_list.append([last_col, col, next_col])
+            # if is_train:
+            #     data_label_list.append([int(last_col_label), int(col_label),
+            #                             int(next_col_label)])
+            #
+            # data_list.append([last_row_col, col, next_row_col])
+            # if is_train:
+            #     data_label_list.append([int(last_row_col_label), int(col_label),
+            #                             int(next_row_col_label)])
 
-            # 三元组有一对不相等就作为数据
-            # if col != next_col or col != last_col:
-            data_list.append([last_col, col, next_col])
             if is_train:
-                data_label_list.append([int(last_col_label), int(col_label),
-                                        int(next_col_label)])
-
-            # if col != next_row_col or col != last_row_col:
-            data_list.append([last_row_col, col, next_row_col])
-            if is_train:
-                data_label_list.append([int(last_row_col_label), int(col_label),
-                                        int(next_row_col_label)])
+                dup_list = [str(x) for x in data_list]
+                data = [last_col, col, next_col, last_row_col, col, next_row_col]
+                if str(data) not in dup_list:
+                    data_list.append([last_col, col, next_col, last_row_col, col, next_row_col])
+                    data_label_list.append(int(col_label))
+            else:
+                data_list.append([last_col, col, next_col, last_row_col, col, next_row_col])
     if is_train:
         return data_list, data_label_list
     else:
@@ -249,8 +215,8 @@ def get_data_from_file(file_type):
         array2 = np.load(data_label_path)
         return array1, array2
     elif file_type == 'txt':
-        data_path = 'train_data/data.txt'
-        data_label_path = 'train_data/data_label.txt'
+        data_path = 'train_data/data3.txt'
+        data_label_path = 'train_data/data_label3.txt'
 
         with open(data_path, 'r') as f:
             data_list = f.readlines()
@@ -281,22 +247,27 @@ def processed_save_to_np():
 
 def processed_save_to_txt():
     list1, list2 = get_data_from_sql()
-    with open('train_data/data.txt', 'w') as f:
+
+    # 打乱
+    zip_list = list(zip(list1, list2))
+    random.shuffle(zip_list)
+    list1[:], list2[:] = zip(*zip_list)
+
+    with open('train_data/data3.txt', 'w') as f:
         for line in list1:
             f.write(str(line) + "\n")
-    with open('train_data/data_label.txt', 'w') as f:
+    with open('train_data/data_label3.txt', 'w') as f:
         for line in list2:
             f.write(str(line) + "\n")
 
 
 def data_balance():
-    array1, array2 = get_data_from_file()
-    data_list = array2.tolist()
-    all_cnt = len(data_list)
+    data_list, data_label_list = get_data_from_file('txt')
+    all_cnt = len(data_label_list)
     cnt_0 = 0
     cnt_1 = 0
-    for data in data_list:
-        if data[0] == 1 or data[1] == 1:
+    for data in data_label_list:
+        if eval(data[:-1])[1] == 1:
             cnt_1 += 1
         else:
             cnt_0 += 1
@@ -316,7 +287,7 @@ def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
     data_num = len(data_list)
 
     # 定义Embedding输出
-    output_shape = (3, 10, 60)
+    output_shape = (6, 10, 60)
 
     # batch循环取数据
     i = 0
@@ -343,8 +314,19 @@ def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
             Y = new_data_label_list
             # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding)
             X = np.transpose(X, (1, 0, 2, 3))
-            # print("input_1", X[0].shape, "input_2", X[1].shape, "input_3", X[2].shape, "Y", Y.shape)
-            yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}, {'output': Y}
+
+            if (X[0] == X[1]).all():
+                X[0] = np.zeros_like(X[1], dtype='float32')
+            if (X[2] == X[1]).all():
+                X[2] = np.zeros_like(X[1], dtype='float32')
+            if (X[3] == X[1]).all():
+                X[3] = np.zeros_like(X[1], dtype='float32')
+            if (X[5] == X[1]).all():
+                X[5] = np.zeros_like(X[1], dtype='float32')
+
+            yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2],
+                   'input_4': X[3], 'input_5': X[4], 'input_6': X[5]}, \
+                  {'output': Y}
 
     else:
         while True:
@@ -363,7 +345,8 @@ def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
             new_data_list = np.array(new_data_list)
             X = new_data_list
             X = np.transpose(X, (1, 0, 2, 3))
-            yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}
+            yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2],
+                   'input_4': X[3], 'input_5': X[4], 'input_6': X[5], }
 
 
 if __name__ == '__main__':

+ 77 - 11
BiddingKG/dl/table_head/predict.py

@@ -1,10 +1,15 @@
+import codecs
+import re
+
+from bs4 import BeautifulSoup
 from BiddingKG.dl.table_head.models.model import get_model
 from BiddingKG.dl.table_head.post_process import table_post_process
 from BiddingKG.dl.table_head.pre_process import my_data_loader, table_pre_process
+from BiddingKG.dl.interface.Preprocessing import tableToText, segment
 
 # init model
-input_shape = (3, 10, 60)
-output_shape = (3,)
+input_shape = (6, 10, 60)
+output_shape = (1,)
 model = get_model(input_shape, output_shape)
 
 # load weights
@@ -30,13 +35,74 @@ def predict(table_text_list):
     return table_label_list
 
 
+def predict_html():
+    def get_trs(tbody):
+        #获取所有的tr
+        trs = []
+        objs = tbody.find_all(recursive=False)
+        for obj in objs:
+            if obj.name=="tr":
+                trs.append(obj)
+            if obj.name=="tbody":
+                for tr in obj.find_all("tr",recursive=False):
+                    trs.append(tr)
+        return trs
+
+    def get_table(tbody):
+        trs = get_trs(tbody)
+        inner_table = []
+        for tr in trs:
+            tr_line = []
+            tds = tr.findChildren(['td', 'th'], recursive=False)
+            if len(tds) == 0:
+                tr_line.append(re.sub('\xa0', '', segment(tr,final=False))) # 2021/12/21 修复部分表格没有td 造成数据丢失
+            for td in tds:
+                tr_line.append(re.sub('\xa0', '', segment(td,final=False)))
+            inner_table.append(tr_line)
+        return inner_table
+
+    def fix_table(inner_table, fix_value=""):
+        maxWidth = 0
+        for item in inner_table:
+            if len(item)>maxWidth:
+                maxWidth = len(item)
+        for i in range(len(inner_table)):
+            if len(inner_table[i])<maxWidth:
+                for j in range(maxWidth-len(inner_table[i])):
+                    inner_table[i].append(fix_value)
+        return inner_table
+
+    text = codecs.open("C:\\Users\\\Administrator\\Desktop\\2.html","r",encoding="utf8").read()
+    content = str(BeautifulSoup(text).find("div",id="pcontent"))
+    soup = BeautifulSoup(content, 'lxml')
+
+    table_list = []
+    tbodies = soup.find_all('tbody')
+    for tbody_index in range(1,len(tbodies)+1):
+        tbody = tbodies[len(tbodies)-tbody_index]
+        table_list.append(tbody)
+
+    table_fix_list = []
+    for tbody in table_list:
+        inner_table = get_table(tbody)
+        inner_table = fix_table(inner_table)
+        table_fix_list.append(inner_table)
+
+    for table in table_fix_list:
+        print("="*30)
+        print(table)
+        print(predict(table))
+
+
 if __name__ == '__main__':
-    _str = "[['序号', '投标人名称', '价格得分', '技术得分', '商务得分', '综合得分', '排名'], " \
-           "['序号', '投标人名称', '比例(20%),', '比例(45%),', '比例(35%),', '100%', '排名'], " \
-           "['1', '广州中科雅图信息技术有限公司', '19.71', '11.50', '11.00', '42.21', '3'], " \
-           "['2', '核工业赣州工程勘察院', '19.64', '15.00', '11.00', '45.64', '2'], " \
-           "['3', '广东晟腾地信科技有限公司', '20.00', '16.17', '14.00', '50.17', '1']]"
-
-    data_list = eval(_str)
-    print("len(data_list)", len(data_list))
-    predict(data_list)
+    # _str = "[['序号', '投标人名称', '价格得分', '技术得分', '商务得分', '综合得分', '排名'], " \
+    #        "['序号', '投标人名称', '比例(20%),', '比例(45%),', '比例(35%),', '100%', '排名'], " \
+    #        "['1', '广州中科雅图信息技术有限公司', '19.71', '11.50', '11.00', '42.21', '3'], " \
+    #        "['2', '核工业赣州工程勘察院', '19.64', '15.00', '11.00', '45.64', '2'], " \
+    #        "['3', '广东晟腾地信科技有限公司', '20.00', '16.17', '14.00', '50.17', '1']]"
+    #
+    # data_list = eval(_str)
+    # print("len(data_list)", len(data_list))
+    # predict(data_list)
+
+    predict_html()

+ 19 - 8
BiddingKG/dl/table_head/train.py

@@ -2,7 +2,6 @@ import sys
 import os
 sys.path.append(os.path.abspath("../../.."))
 os.environ['KERAS_BACKEND'] = 'tensorflow'
-import keras
 from keras.metrics import categorical_accuracy
 from BiddingKG.dl.table_head.metrics import precision, recall, f1
 from keras import optimizers, Model
@@ -10,17 +9,16 @@ from BiddingKG.dl.table_head.models.model import get_model
 from BiddingKG.dl.table_head.loss import focal_loss
 from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
 from BiddingKG.dl.table_head.pre_process import get_data_from_file, get_data_from_sql, my_data_loader
-import numpy as np
 from keras import backend as K
 
 
-input_shape = (3, 10, 60)
-output_shape = (3,)
-batch_size = 1024
+input_shape = (6, 10, 60)
+output_shape = (1,)
+batch_size = 32
 epochs = 1000
 pretrained_path = "checkpoints/best.hdf5"
 checkpoint_path = "checkpoints/"
-PRETRAINED = False
+PRETRAINED = True
 CHECKPOINT = False
 
 
@@ -30,6 +28,8 @@ def train():
 
     # Data
     data_x, data_y = get_data_from_file('txt')
+    # data_x = data_x[:60000]
+    # data_y = data_y[:60000]
     print("finish read data", len(data_x))
 
     # Split -> Train, Test
@@ -54,12 +54,13 @@ def train():
     else:
         print("no checkpoint")
 
-    filepath = 'e{epoch:02d}-loss{val_loss:.2f}'
+    filepath = 'e{epoch:02d}-f1{val_f1:.2f}'
     checkpoint = ModelCheckpoint(checkpoint_path+filepath+".hdf5", monitor='val_f1',
                                  verbose=1, save_best_only=True, mode='max')
 
     model.compile(optimizer=optimizers.Adam(lr=0.005), loss=focal_loss(),
-                  metrics=[categorical_accuracy,
+    # model.compile(optimizer=optimizers.Adam(lr=0.005), loss='binary_crossentropy',
+                  metrics=['acc',
                            precision, recall, f1])
 
     rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.1, patience=5,
@@ -91,3 +92,13 @@ def print_layer_output(model, data):
 
 if __name__ == '__main__':
     model, data = train()
+
+    # place_list = get_place_list()
+    # _str1 = '中国电信'
+    # _str2 = '分公司'
+    # _list = []
+    # for place in place_list:
+    #     _list.append(_str1 + place + _str2 + "\n")
+    # # print(_list)
+    # with open("电信分公司.txt", "w") as f:
+    #     f.writelines(_list)