Explorar o código

字典替换部分修改,新增表头后处理

fangjiasheng %!s(int64=3) %!d(string=hai) anos
pai
achega
aa735fb153

+ 39 - 34
BiddingKG/dl/entityLink/entityLink.py

@@ -1,3 +1,4 @@
+#coding:UTF8
 '''
 Created on 2019年5月21日
 
@@ -35,6 +36,43 @@ def jaccard_score(source,target):
         return 0
     return max(len(source_set&target_set)/len(source_set),len(source_set&target_set)/len(target_set))
 
+
+def get_place_list():
+    with open(os.path.abspath(__file__) + '/../../place_info.csv', 'r') as f:
+        reader = csv.reader(f)
+        place_list = []
+        for r in reader:
+            place_list.append(r[1:3])
+        place_list = place_list[1:]
+        place_list.append(['台湾', '台湾'])
+        place_list.append(['澳门', '澳门'])
+        place_list.append(['香港', '香港'])
+        place_list.append(['東莞', '東莞'])
+        place_list.append(['廣州', '廣州'])
+        place_list.append(['韩国', '韩国'])
+        place_list.append(['德国', '德国'])
+        place_list.append(['英国', '英国'])
+        place_list.append(['日本', '日本'])
+        place_list.append(['意大利', '意大利'])
+        place_list.append(['新加坡', '新加坡'])
+        place_list.append(['加拿大', '加拿大'])
+        place_list.append(['西班牙', '西班牙'])
+        place_list.append(['澳大利亚', '澳大利亚'])
+        place_list.append(['美国', '美国'])
+    # 去重
+    place_list_str = []
+    for place in place_list:
+        place_list_str.append(str(place))
+    place_list_str = list(set(place_list_str))
+    place_list = []
+    for place in place_list_str:
+        place_list.append(eval(place))
+    return place_list
+
+
+place_list = get_place_list()
+
+
 def link_entitys(list_entitys,on_value=0.8):
     for list_entity in list_entitys:
         range_entity = []
@@ -62,7 +100,6 @@ def link_entitys(list_entitys,on_value=0.8):
                             _entity.entity_text = _ent.entity_text
 
         # 2021/12/21 替换通过字典识别到的取长度最大的相似实体
-        place_list = get_place_list()
         for _entity in range_entity:
             used_linked_entitys = []
             if not _entity.linked_entitys:
@@ -87,41 +124,9 @@ def link_entitys(list_entitys,on_value=0.8):
                             _entity.origin_entity_text = _entity.entity_text
                             _entity.entity_text = _ent.entity_text
                             used_linked_entitys.append(_ent)
-            # print(_entity.origin_entity_text, _entity.entity_text)
-
+                # print(_entity.entity_text, _entity.if_dict_match, _ent.entity_text, _ent.if_dict_match)
 
-def get_place_list():
-    with open('../place_info.csv', 'r') as f:
-        reader = csv.reader(f)
 
-        place_list = []
-        for r in reader:
-            place_list.append(r[1:3])
-        place_list = place_list[1:]
-        place_list.append(['台湾', '台湾'])
-        place_list.append(['澳门', '澳门'])
-        place_list.append(['香港', '香港'])
-        place_list.append(['東莞', '東莞'])
-        place_list.append(['廣州', '廣州'])
-        place_list.append(['韩国', '韩国'])
-        place_list.append(['德国', '德国'])
-        place_list.append(['英国', '英国'])
-        place_list.append(['日本', '日本'])
-        place_list.append(['意大利', '意大利'])
-        place_list.append(['新加坡', '新加坡'])
-        place_list.append(['加拿大', '加拿大'])
-        place_list.append(['西班牙', '西班牙'])
-        place_list.append(['澳大利亚', '澳大利亚'])
-        place_list.append(['美国', '美国'])
-    # 去重
-    place_list_str = []
-    for place in place_list:
-        place_list_str.append(str(place))
-    place_list_str = list(set(place_list_str))
-    place_list = []
-    for place in place_list_str:
-        place_list.append(eval(place))
-    return place_list
 
 
 def getEnterprisePath():

+ 91 - 0
BiddingKG/dl/table_head/check_user_label_accuracy.py

@@ -0,0 +1,91 @@
+import difflib
+import time
+
+from BiddingKG.dl.table_head.pre_process import postgresql_util
+from BiddingKG.dl.table_head.predict import predict
+
+
+def user_label_accuracy(update_user):
+    sql = """
+    select table_text, pre_label, post_label, id
+    from label_table_head_info 
+    where update_user='""" + update_user + "' "
+
+    result_list = postgresql_util(sql, limit=1000000)
+    right_cnt = 0
+    error_id_list = []
+    i = 0
+    start_time = time.time()
+    for table in result_list:
+        i += 1
+        if i % 1000 == 0:
+            print("Loop", i, right_cnt, time.time()-start_time)
+            start_time = time.time()
+
+        pre_label = eval(table[1])
+        post_label = eval(table[2])
+        _id = table[3]
+
+        # table_text需要特殊处理
+        try:
+            table_text = table[0]
+            if table_text[0] == '"':
+                table_text = eval(table_text)
+            else:
+                table_text = table_text
+            table_text = table_text.replace('\\', '/')
+            table_text = eval(table_text)
+        except:
+            print("无法识别table_text", _id)
+            continue
+
+        if post_label:
+            label_list = post_label
+        else:
+            label_list = pre_label
+
+        predict_label_list = predict(table_text)
+        if predict_label_list:
+            if str(label_list) == str(predict_label_list):
+                right_cnt += 1
+            else:
+                # cnt = 0
+                # for j in range(len(label_list)):
+                #     row1 = label_list[j]
+                #     row2 = predict_label_list[j]
+                #     if str(row1) != str(row2):
+                #         cnt += 1
+                #     if cnt >= 2:
+                #         error_id_list.append(str(_id)+"\n")
+                #         break
+                error_id_list.append(str(_id)+"\n")
+
+    print(update_user + " accuracy:", right_cnt / len(result_list), 'total:', len(result_list))
+    print("error_id_list", len(error_id_list))
+
+    save_path = "check_user_result/"+update_user+".txt"
+    with open(save_path, 'w') as f:
+        f.writelines(error_id_list)
+
+
+def get_single_result(_id):
+    sql = """
+    select table_text, pre_label, post_label, id
+    from label_table_head_info 
+    where id=""" + str(_id)
+
+    result_list = postgresql_util(sql, limit=1000000)
+    table_text = result_list[0][0]
+    if table_text[0] == '"':
+        table_text = eval(table_text)
+    else:
+        table_text = table_text
+    table_text = table_text.replace('\\', '/')
+    table_text = eval(table_text)
+    label_list = predict(table_text)
+    print(label_list)
+
+
+if __name__ == '__main__':
+    user_label_accuracy("test9")
+    # get_single_result(171520)

BIN=BIN
BiddingKG/dl/table_head/checkpoints/best.hdf5


BIN=BIN
BiddingKG/dl/table_head/checkpoints/binary_loss/best.hdf5


BIN=BIN
BiddingKG/dl/table_head/checkpoints/focal_loss/best.hdf5


+ 27 - 0
BiddingKG/dl/table_head/post_process.py

@@ -0,0 +1,27 @@
+
+
+def table_post_process(table_text_list, predict_result, threshold=0.5):
+    predict_result = predict_result.tolist()
+    predict_list = []
+    for i in range(0, len(predict_result), 2):
+        row_predict = predict_result[i][1]
+        col_predict = predict_result[i+1][1]
+        if row_predict + col_predict >= 2 * threshold:
+            predict = "1"
+        else:
+            predict = "0"
+        predict_list.append(predict)
+
+    table_label_list = []
+    i = 0
+    for row in table_text_list:
+        row_label = []
+        for col in row:
+            row_label.append(predict_list[i])
+            i += 1
+        table_label_list.append(row_label)
+
+    if len(table_label_list) != len(table_text_list):
+        print("table_post_process 输出label维度与text不一致!")
+        table_label_list = []
+    return table_label_list

+ 100 - 53
BiddingKG/dl/table_head/pre_process.py

@@ -1,6 +1,5 @@
 import psycopg2
 import numpy as np
-
 from BiddingKG.dl.common.Utils import embedding_word
 
 
@@ -83,9 +82,9 @@ def get_data_from_sql(dim=10):
 
         # 只有一行的也不要
         if len(post_label) >= 2:
-            data_list, data_label_list = table_process(table_text, post_label, _id)
+            data_list, data_label_list = table_pre_process(table_text, post_label, _id)
         elif len(pre_label) >= 2:
-            data_list, data_label_list = table_process(table_text, pre_label, _id)
+            data_list, data_label_list = table_pre_process(table_text, pre_label, _id)
         else:
             data_list, data_label_list = [], []
 
@@ -140,77 +139,105 @@ def get_data_from_sql(dim=10):
     return all_data_list, all_data_label_list
 
 
-def table_process(text_list, label_list, _id):
-    if len(text_list) != len(label_list):
-        print("文字单元格与标注单元格数量不匹配!", _id)
-        print("len(text_list)", len(text_list), "len(label_list)", len(label_list))
-        return [], []
+def table_pre_process(text_list, label_list, _id, is_train=True):
+    """
+    表格处理,每个单元格生成2条数据,横竖各1条
+
+    :param text_list:
+    :param label_list:
+    :param _id:
+    :param is_train:
+    :return:
+    """
+    if is_train:
+        if len(text_list) != len(label_list):
+            print("文字单元格与标注单元格数量不匹配!", _id)
+            print("len(text_list)", len(text_list), "len(label_list)", len(label_list))
+            return [], []
 
     data_list = []
     data_label_list = []
     for i in range(len(text_list)):
         row = text_list[i]
-        row_label = label_list[i]
+        if is_train:
+            row_label = label_list[i]
 
         if i > 0:
             last_row = text_list[i-1]
-            last_row_label = label_list[i-1]
+            if is_train:
+                last_row_label = label_list[i-1]
         else:
             last_row = []
-            last_row_label = []
+            if is_train:
+                last_row_label = []
 
         if i < len(text_list) - 1:
             next_row = text_list[i+1]
-            next_row_label = label_list[i+1]
+            if is_train:
+                next_row_label = label_list[i+1]
         else:
             next_row = []
-            next_row_label = []
+            if is_train:
+                next_row_label = []
 
         for j in range(len(row)):
             col = row[j]
-            col_label = row_label[j]
+            if is_train:
+                col_label = row_label[j]
 
             # 超出表格置为None, 0
             if j > 0:
                 last_col = row[j-1]
-                last_col_label = row_label[j-1]
+                if is_train:
+                    last_col_label = row_label[j-1]
             else:
                 last_col = None
-                last_col_label = 0
+                if is_train:
+                    last_col_label = 0
 
             if j < len(row) - 1:
                 next_col = row[j+1]
-                next_col_label = row_label[j+1]
+                if is_train:
+                    next_col_label = row_label[j+1]
             else:
                 next_col = None
-                next_col_label = 0
+                if is_train:
+                    next_col_label = 0
 
             if last_row:
                 last_row_col = last_row[j]
-                last_row_col_label = last_row_label[j]
+                if is_train:
+                    last_row_col_label = last_row_label[j]
             else:
                 last_row_col = None
-                last_row_col_label = 0
+                if is_train:
+                    last_row_col_label = 0
 
             if next_row:
                 next_row_col = next_row[j]
-                next_row_col_label = next_row_label[j]
+                if is_train:
+                    next_row_col_label = next_row_label[j]
             else:
                 next_row_col = None
-                next_row_col_label = 0
+                if is_train:
+                    next_row_col_label = 0
 
             # 三元组有一对不相等就作为数据
             # if col != next_col or col != last_col:
             data_list.append([last_col, col, next_col])
-            data_label_list.append([int(last_col_label), int(col_label),
-                                    int(next_col_label)])
+            if is_train:
+                data_label_list.append([int(last_col_label), int(col_label),
+                                        int(next_col_label)])
 
             # if col != next_row_col or col != last_row_col:
             data_list.append([last_row_col, col, next_row_col])
-            data_label_list.append([int(last_row_col_label), int(col_label),
-                                    int(next_row_col_label)])
-
-    return data_list, data_label_list
+            if is_train:
+                data_label_list.append([int(last_row_col_label), int(col_label),
+                                        int(next_row_col_label)])
+    if is_train:
+        return data_list, data_label_list
+    else:
+        return data_list
 
 
 def get_data_from_file(file_type):
@@ -285,7 +312,7 @@ def test_embedding():
     print(result)
 
 
-def my_data_loader(data_list, data_label_list, batch_size):
+def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
     data_num = len(data_list)
 
     # 定义Embedding输出
@@ -293,30 +320,50 @@ def my_data_loader(data_list, data_label_list, batch_size):
 
     # batch循环取数据
     i = 0
-    while True:
-        new_data_list = []
-        new_data_label_list = []
-        for j in range(batch_size):
-            if i >= data_num:
-                i = 0
-
-            # 中文字符映射为Embedding
-            data = eval(data_list[i][:-1])
-            data_label = eval(data_label_list[i][:-1])
-            data = embedding_word(data, output_shape)
-            if data.shape == output_shape:
-                new_data_list.append(data)
-                new_data_label_list.append(data_label)
-            i += 1
-
-        new_data_list = np.array(new_data_list)
-        new_data_label_list = np.array(new_data_label_list)
-        X = new_data_list
-        Y = new_data_label_list
-        # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding)
-        X = np.transpose(X, (1, 0, 2, 3))
-        # print("input_1", X[0].shape, "input_2", X[1].shape, "input_3", X[2].shape, "Y", Y.shape)
-        yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}, {'output': Y}
+    if is_train:
+        while True:
+            new_data_list = []
+            new_data_label_list = []
+            for j in range(batch_size):
+                if i >= data_num:
+                    i = 0
+
+                # 中文字符映射为Embedding
+                data = eval(data_list[i][:-1])
+                data_label = eval(data_label_list[i][:-1])
+                data = embedding_word(data, output_shape)
+                if data.shape == output_shape:
+                    new_data_list.append(data)
+                    new_data_label_list.append(data_label)
+                i += 1
+
+            new_data_list = np.array(new_data_list)
+            new_data_label_list = np.array(new_data_label_list)
+            X = new_data_list
+            Y = new_data_label_list
+            # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding)
+            X = np.transpose(X, (1, 0, 2, 3))
+            # print("input_1", X[0].shape, "input_2", X[1].shape, "input_3", X[2].shape, "Y", Y.shape)
+            yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}, {'output': Y}
+
+    else:
+        while True:
+            new_data_list = []
+            for j in range(batch_size):
+                if i >= data_num:
+                    i = 0
+
+                # 中文字符映射为Embedding
+                data = data_list[i]
+                data = embedding_word(data, output_shape)
+                if data.shape == output_shape:
+                    new_data_list.append(data)
+                i += 1
+
+            new_data_list = np.array(new_data_list)
+            X = new_data_list
+            X = np.transpose(X, (1, 0, 2, 3))
+            yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}
 
 
 if __name__ == '__main__':

+ 42 - 12
BiddingKG/dl/table_head/predict.py

@@ -1,12 +1,42 @@
-def predict():
-    """
-    @summary: 预测数据
-    """
-    test_x,doc_id,ent_id,sen,ent_text,dianhua = get_data()
-    model = models.load_model("model_person.model",
-                              custom_objects={'precision': precision,
-                                              'recall': recall,
-                                              'f1_score': f1_score}
-                              )
-    predict_y = model.predict([test_x[0], test_x[1]])
-    label = [np.argmax(y) for y in predict_y]
+from BiddingKG.dl.table_head.models.model import get_model
+from BiddingKG.dl.table_head.post_process import table_post_process
+from BiddingKG.dl.table_head.pre_process import my_data_loader, table_pre_process
+
+# init model
+input_shape = (3, 10, 60)
+output_shape = (3,)
+model = get_model(input_shape, output_shape)
+
+# load weights
+model_path = "checkpoints/best.hdf5"
+model.load_weights(model_path)
+
+
+def predict(table_text_list):
+    # 表格单元格数*2 即为单次预测batch_size
+    data_list = table_pre_process(table_text_list, [], 0, is_train=False)
+    batch_size = len(data_list)
+    # print("batch_size", batch_size)
+
+    # 数据预处理
+    predict_x = my_data_loader(data_list, [], batch_size, is_train=False)
+
+    # 预测
+    predict_result = model.predict_generator(predict_x, steps=1)
+    # print("predict_result", predict_result.shape)
+
+    # 数据后处理
+    table_label_list = table_post_process(table_text_list, predict_result)
+    return table_label_list
+
+
+if __name__ == '__main__':
+    _str = "[['序号', '投标人名称', '价格得分', '技术得分', '商务得分', '综合得分', '排名'], " \
+           "['序号', '投标人名称', '比例(20%),', '比例(45%),', '比例(35%),', '100%', '排名'], " \
+           "['1', '广州中科雅图信息技术有限公司', '19.71', '11.50', '11.00', '42.21', '3'], " \
+           "['2', '核工业赣州工程勘察院', '19.64', '15.00', '11.00', '45.64', '2'], " \
+           "['3', '广东晟腾地信科技有限公司', '20.00', '16.17', '14.00', '50.17', '1']]"
+
+    data_list = eval(_str)
+    print("len(data_list)", len(data_list))
+    predict(data_list)

+ 4 - 3
BiddingKG/dl/table_head/train.py

@@ -3,6 +3,7 @@ import os
 sys.path.append(os.path.abspath("../../.."))
 os.environ['KERAS_BACKEND'] = 'tensorflow'
 import keras
+from keras.metrics import categorical_accuracy
 from BiddingKG.dl.table_head.metrics import precision, recall, f1
 from keras import optimizers, Model
 from BiddingKG.dl.table_head.models.model import get_model
@@ -53,12 +54,12 @@ def train():
     else:
         print("no checkpoint")
 
-    filepath = 'e-{epoch:02d}-loss-{val_loss:.2f}'
+    filepath = 'e{epoch:02d}-loss{val_loss:.2f}'
     checkpoint = ModelCheckpoint(checkpoint_path+filepath+".hdf5", monitor='val_f1',
                                  verbose=1, save_best_only=True, mode='max')
 
-    model.compile(optimizer=optimizers.Adam(lr=0.0005), loss='binary_crossentropy',
-                  metrics=['binary_crossentropy', 'acc',
+    model.compile(optimizer=optimizers.Adam(lr=0.005), loss=focal_loss(),
+                  metrics=[categorical_accuracy,
                            precision, recall, f1])
 
     rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.1, patience=5,