fangjiasheng 3 жил өмнө
parent
commit
9fc86bd2b0

+ 63 - 12
data_process/create_labelme_data.py

@@ -14,12 +14,15 @@ def create_table_image(if_add_seal=False):
 
     # 空图
     # weight = random.randint(512, 896)
-    height = random.randint(512, 640)
-    weight = random.randint(512, 768)
+    height = random.randint(512, 2500)
+    weight = random.randint(512, 1500)
     weight_origion = weight
     height_origion = height
     img = np.zeros((weight, height), np.uint8)
     img.fill(255)
+    # 计算交点用,黑图
+    row_img = np.zeros((weight, height), np.uint8)
+    col_img = np.zeros((weight, height), np.uint8)
 
     # 画矩形表格轮廓
     num = random.choice([1, 1, 2])
@@ -33,6 +36,10 @@ def create_table_image(if_add_seal=False):
         x2 = scale_x + (weight - 2 * scale_x)
         y2 = scale_y + (height - 2 * scale_y)
         img = cv2.rectangle(img, (y1, x1), (y2, x2), (0, 0, 0), 1)
+        row_img = cv2.line(row_img, (y1, x1), (y2, x1), (255, 255, 255), 1)
+        row_img = cv2.line(row_img, (y2, x2), (y1, x2), (255, 255, 255), 1)
+        col_img = cv2.line(col_img, (y2, x1), (y2, x2), (255, 255, 255), 1)
+        col_img = cv2.line(col_img, (y1, x2), (y1, x1), (255, 255, 255), 1)
         lines_of_table.extend([[(y1, x1), (y2, x1)], [(y2, x1), (y2, x2)],
                                [(y2, x2), (y1, x2)], [(y1, x2), (y1, x1)]])
     if num == 2:
@@ -43,6 +50,10 @@ def create_table_image(if_add_seal=False):
         y2 = (y1 + (height - 2 * scale_y))
         # print(y1, x1, y2, x2)
         img = cv2.rectangle(img, (y1, x1), (y2, x2), (0, 0, 0), 1)
+        row_img = cv2.line(row_img, (y1, x1), (y2, x1), (255, 255, 255), 1)
+        row_img = cv2.line(row_img, (y2, x2), (y1, x2), (255, 255, 255), 1)
+        col_img = cv2.line(col_img, (y2, x1), (y2, x2), (255, 255, 255), 1)
+        col_img = cv2.line(col_img, (y1, x2), (y1, x1), (255, 255, 255), 1)
         lines_of_table.extend([[(y1, x1), (y2, x1)], [(y2, x1), (y2, x2)],
                                [(y2, x2), (y1, x2)], [(y1, x2), (y1, x1)]])
 
@@ -58,15 +69,24 @@ def create_table_image(if_add_seal=False):
         # print(y3, x3, y4, x4)
         if x2 + 10 <= x3 or y4 - y3 >= 20 or x4 - 10 <= weight_origin:
             img = cv2.rectangle(img, (y3, x3), (y4, x4), (0, 0, 0), 1)
+            row_img = cv2.line(row_img, (y3, x3), (y4, x3), (255, 255, 255), 1)
+            row_img = cv2.line(row_img, (y4, x4), (y3, x4), (255, 255, 255), 1)
+            col_img = cv2.line(col_img, (y4, x3), (y4, x4), (255, 255, 255), 1)
+            col_img = cv2.line(col_img, (y3, x4), (y3, x3), (255, 255, 255), 1)
             lines_of_table.extend([[(y3, x3), (y4, x3)], [(y4, x3), (y4, x4)],
                                    [(y4, x4), (y3, x4)], [(y3, x4), (y3, x3)]])
 
     # 画表格内单元格线
     # 第一个表格
     row_num = random.randint(23, 25)
-    col_num = random.randint(3, 4)
+    col_num = random.randint(3, 6)
     margin_x = int((x2 - x1)/row_num)
     margin_y = int((y2 - y1)/col_num)
+    print("margin_x", margin_x)
+    print("margin_y", margin_y)
+    if margin_x < 20:
+        row_num = random.randint(11, 12)
+        margin_x = int((x2 - x1)/row_num)
     col_points = []
     row_points = []
     # for rn in range(0, row_num):
@@ -99,17 +119,21 @@ def create_table_image(if_add_seal=False):
         # 不跨行
         if random.choice([1, 1, 1, 1]):
             img = cv2.line(img, (col, x1), (col, x2), (0, 0, 0), 1)
+            col_img = cv2.line(col_img, (col, x1), (col, x2), (255, 255, 255), 1)
             lines_of_table.append([(col, x1), (col, x2)])
         else:
             img = cv2.line(img, (col, x1+margin_x), (col, x2), (0, 0, 0), 1)
+            col_img = cv2.line(col_img, (col, x1+margin_x), (col, x2), (255, 255, 255), 1)
             lines_of_table.append([(col, x1+margin_x), (col, x2)])
     for row in row_points:
         # if random.choice([0, 1, 1, 1]):
         if random.choice([1, 1, 1, 1]):
             img = cv2.line(img, (y1, row), (y2, row), (0, 0, 0), 1)
+            row_img = cv2.line(row_img, (y1, row), (y2, row), (255, 255, 255), 1)
             lines_of_table.append([(y1, row), (y2, row)])
         else:
             img = cv2.line(img, (y1+margin_y, row), (y2, row), (0, 0, 0), 1)
+            row_img = cv2.line(row_img, (y1+margin_y, row), (y2, row), (255, 255, 255), 1)
             lines_of_table.append([(y1+margin_y, row), (y2, row)])
 
     # 第二个表格
@@ -131,16 +155,20 @@ def create_table_image(if_add_seal=False):
         for col in col_points2:
             if random.choice([0, 1, 1, 1]):
                 img = cv2.line(img, (col, x3), (col, x4), (0, 0, 0), 1)
+                col_img = cv2.line(col_img, (col, x3), (col, x4), (255, 255, 255), 1)
                 lines_of_table.append([(col, x3), (col, x4)])
             else:
                 img = cv2.line(img, (col, x3+margin_x2), (col, x4), (0, 0, 0), 1)
+                col_img = cv2.line(col_img, (col, x3+margin_x2), (col, x4), (255, 255, 255), 1)
                 lines_of_table.append([(col, x3+margin_x2), (col, x4)])
         for row in row_points2:
             if random.choice([0, 1, 1, 1]):
                 img = cv2.line(img, (y3, row), (y4, row), (0, 0, 0), 1)
+                row_img = cv2.line(row_img, (y3, row), (y4, row), (255, 255, 255), 1)
                 lines_of_table.append([(y3, row), (y4, row)])
             else:
                 img = cv2.line(img, (y3+margin_y2, row), (y4, row), (0, 0, 0), 1)
+                row_img = cv2.line(row_img, (y3+margin_y2, row), (y4, row), (255, 255, 255), 1)
                 lines_of_table.append([(y3+margin_y2, row), (y4, row)])
 
     # 画轮廓
@@ -155,6 +183,21 @@ def create_table_image(if_add_seal=False):
         # img = cv2.rectangle(img, point3, point4, (0, 0, 255), 2)
         outline_of_table.append([point3, point4])
 
+    # 计算交点
+    point_img = np.bitwise_and(row_img, col_img)
+    ys, xs = np.where(point_img > 0)
+    points = []
+    for i in range(len(xs)):
+        points.append((xs[i], ys[i]))
+    points.sort(key=lambda x: (x[0], x[1]))
+    print("len(points)", len(points))
+
+    # 查看交点
+    # for p in points:
+    #     cv2.circle(img, p, 2, (0, 0, 255))
+    # cv2.imshow("points", img)
+    # cv2.waitKey(0)
+
     # 表格中填字
     font_list = ["卡", "啊", "的", "我", "你", "吧", "为", "看", "他", "个", "来", "哦", "啊"]
     font_num = random.randint(7, 10)
@@ -176,8 +219,9 @@ def create_table_image(if_add_seal=False):
             font = ImageFont.truetype("msyh.ttc", int(font_size))
             dr = ImageDraw.Draw(image_pil)
 
-            bias_x = random.choice([0, 0, 0, 15, 30])
-            dr.text((text_x+bias_x, text_y), text, font=font, fill="#000000")
+            bias_x = random.choice([0, 0, 0, 15, 30, 50, 70])
+            bias_y = random.choice([0, 10, 10, 15, 20, 30])
+            dr.text((text_x+bias_x, text_y+bias_y), text, font=font, fill="#000000")
 
     img = cv2.cvtColor(np.asarray(image_pil), cv2.COLOR_RGB2BGR)
 
@@ -199,6 +243,12 @@ def create_table_image(if_add_seal=False):
 
     img = cv2.cvtColor(np.asarray(image_pil), cv2.COLOR_RGB2BGR)
 
+    # 高斯模糊
+    if random.choice([0, 1]):
+        sigmaX = random.randint(1, 10)
+        sigmaY = random.randint(1, 10)
+        img = cv2.GaussianBlur(img, (5, 5), sigmaX, sigmaY)
+
     # 加印章
     if if_add_seal:
         img = create_official_seal(img)
@@ -213,7 +263,7 @@ def create_table_image(if_add_seal=False):
     # 显示
     # cv2.imshow("image", img)
     # cv2.waitKey(0)
-    return lines_of_table, outline_of_table, image_bytes, weight_origion, height_origion
+    return lines_of_table, outline_of_table, image_bytes, weight_origion, height_origion, points
 
 
 def create_outline_labelme(outline_list, image_bytes):
@@ -231,7 +281,7 @@ def create_outline_labelme(outline_list, image_bytes):
     return labelme_data
 
 
-def create_lines_labelme(line_list, image_bytes, weight, height):
+def create_lines_labelme(line_list, image_bytes, weight, height, cross_points):
     labelme_data = {}
     shapes_list = []
     for line in line_list:
@@ -253,6 +303,7 @@ def create_lines_labelme(line_list, image_bytes, weight, height):
     labelme_data["imageData"] = image_base64_string
     labelme_data["imageHeight"] = height
     labelme_data["imageWidth"] = weight
+    labelme_data["cross_points"] = str(cross_points)
     return labelme_data
 
 
@@ -324,12 +375,12 @@ def create_official_seal(main_image_np):
 
 if __name__ == '__main__':
     # 生成 单元格线 数据
-    for i in range(2601, 2801):
+    for i in range(2500, 5000):
         if i % 100 == 0:
-            print(i)
-        lines, outlines, image_data, weight, height = create_table_image(if_add_seal=True)
-        labelme = create_lines_labelme(lines, image_data, weight, height)
-        with open('../train/dataset-line/6/train_' + str(i) + '.json', 'w') as f:
+            print("Loop", i)
+        lines, outlines, image_data, weight, height, cross_points = create_table_image(if_add_seal=False)
+        labelme = create_lines_labelme(lines, image_data, weight, height, cross_points)
+        with open('../train/dataset-line/7/train_' + str(i) + '.json', 'w') as f:
             json.dump(labelme, f)
 
     # main_image = cv2.imread("../13.png")

+ 24 - 9
image.py

@@ -433,6 +433,7 @@ def gen2(paths, batchsize=2, linetype=2):
         yield X, Y
 
 
+
 def gen(paths, batchsize=2, linetype=2):
     num = len(paths)
 
@@ -443,8 +444,15 @@ def gen(paths, batchsize=2, linetype=2):
         # size = np.random.choice(sizes, 1)[0]
         # height = np.random.choice(sizes, 1)[0]
         # width = int(height/1.4)
-        heights = [1024, 896, 768, 640, 512, 384, 256, 128]
-        widths = [1024, 896, 768, 640, 512, 384, 256, 128]
+        # heights = [1024, 896, 768, 640, 512, 384, 256, 128]
+        # widths = [1024, 896, 768, 640, 512, 384, 256, 128]
+        heights = [3008, 2944, 2880, 2816, 2752, 2688, 2624, 2560, 2496, 2432, 2368,
+                   2304, 2240, 2176, 2112, 2048, 1984, 1920, 1856, 1792, 1728, 1664,
+                   1600, 1536, 1472, 1408, 1344, 1280, 1216, 1152, 1088, 1024, 960,
+                   896, 832, 768, 704, 640, 576, 512]
+        widths = [2048, 1984, 1920, 1856, 1792, 1728, 1664,
+                  1600, 1536, 1472, 1408, 1344, 1280, 1216, 1152, 1088, 1024, 960,
+                  896, 832, 768, 704, 640, 576, 512]
         height = np.random.choice(heights, 1)[0]
         width = np.random.choice(widths, 1)[0]
         # height = 1024
@@ -464,7 +472,7 @@ def gen(paths, batchsize=2, linetype=2):
             # linetype=2
             # print("gen input size", (height, width))
             img, lines, labelImg, size = get_img_label(p, size=(height, width),
-                                                 linetype=linetype)
+                                                       linetype=linetype)
 
             # if not img.any():
             #     print("image too large, jump")
@@ -474,12 +482,12 @@ def gen(paths, batchsize=2, linetype=2):
             # X = np.zeros((batchsize, height, width, 3))
             # Y = np.zeros((batchsize, height, width, 2))
 
-            if_blur = np.random.choice([0, 1], 1)[0]
-            if if_blur:
-                # 高斯模糊
-                sigmaX = random.randint(1, 3)
-                sigmaY = random.randint(1, 3)
-                img = cv2.GaussianBlur(img, (5, 5), sigmaX, sigmaY)
+            # if_blur = np.random.choice([0, 1], 1)[0]
+            # if if_blur:
+            #     # 高斯模糊
+            #     sigmaX = random.randint(1, 3)
+            #     sigmaY = random.randint(1, 3)
+            #     img = cv2.GaussianBlur(img, (5, 5), sigmaX, sigmaY)
 
             # cv2.imshow("gen", img)
             # cv2.waitKey(0)
@@ -662,3 +670,10 @@ def img_resize_by_padding_crop(im, lines, target_size):
         lines[i] = [p1, p2]
 
     return im, lines
+
+
+if __name__ == '__main__':
+    _list = []
+    for i in range(100, 1, -1):
+        _list.append(i*64)
+    print(_list)

+ 85 - 0
metrics.py

@@ -0,0 +1,85 @@
+from keras import backend as K
+
+
+def mcor(y_true, y_pred):
+    # matthews_correlation
+    y_pred_pos = K.round(K.clip(y_pred, 0, 1))
+    y_pred_neg = 1 - y_pred_pos
+
+    y_pos = K.round(K.clip(y_true, 0, 1))
+    y_neg = 1 - y_pos
+
+    tp = K.sum(y_pos * y_pred_pos)
+    tn = K.sum(y_neg * y_pred_neg)
+
+    fp = K.sum(y_neg * y_pred_pos)
+    fn = K.sum(y_pos * y_pred_neg)
+
+    numerator = (tp * tn - fp * fn)
+    denominator = K.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
+    return numerator / (denominator + K.epsilon())
+
+
+def precision(y_true, y_pred):
+    """Precision metric.
+
+    Only computes a batch-wise average of precision.
+
+    Computes the precision, a metric for multi-label classification of
+    how many selected items are relevant.
+    """
+    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
+    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
+    _precision = true_positives / (predicted_positives + K.epsilon())
+    return _precision
+
+
+def recall(y_true, y_pred):
+    """Recall metric.
+
+    Only computes a batch-wise average of recall.
+
+    Computes the recall, a metric for multi-label classification of
+    how many relevant items are selected.
+    """
+    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
+    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
+    _recall = true_positives / (possible_positives + K.epsilon())
+    return _recall
+
+
+def f1(y_true, y_pred):
+    def recall(y_true, y_pred):
+        """Recall metric.
+
+        Only computes a batch-wise average of recall.
+
+        Computes the recall, a metric for multi-label classification of
+        how many relevant items are selected.
+        """
+        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
+        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
+        _recall = true_positives / (possible_positives + K.epsilon())
+        return _recall
+
+    def precision(y_true, y_pred):
+        """Precision metric.
+
+        Only computes a batch-wise average of precision.
+
+        Computes the precision, a metric for multi-label classification of
+        how many selected items are relevant.
+        """
+        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
+        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
+        _precision = true_positives / (predicted_positives + K.epsilon())
+        return _precision
+    _precision = precision(y_true, y_pred)
+    _recall = recall(y_true, y_pred)
+    return 2*((_precision*_recall)/(_precision+_recall+K.epsilon()))
+
+
+#you can use it like this
+# model.compile(loss='binary_crossentropy',
+#               optimizer= "adam",
+#               metrics=[mcor,recall, f1])

BIN
models/table-line-471-1755.32.h5


BIN
test_files.rar


BIN
test_files/1.png


BIN
test_files/QQ截图20210907102912.jpg


BIN
test_files/error1.jpg


+ 11 - 10
train.py

@@ -8,7 +8,8 @@ Created on Thu Sep 9 23:11:51 2020
 import json
 import os
 import sys
-sys.path.append('train')
+sys.path.append(os.path.dirname(__file__))
+print(os.path.dirname(__file__))
 from table_line import model, focal_loss, dice_coef, dice_coef_loss
 from tensorflow.keras.optimizers import Adam
 from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
@@ -17,6 +18,7 @@ from glob import glob
 from image import gen
 from config import tableModeLinePath
 import tensorflow.keras as keras
+from metrics import f1
 
 PRETRAINED = True
 CHECKPOINT = False
@@ -36,22 +38,21 @@ if __name__ == '__main__':
         print("no checkpoint")
 
     # 模型权重存放位置
-    filepath = 'models/table-line.h5'
-    # filepath = 'table-line-{epoch:02d}-{val_loss:.2f}.h5'
-
-    checkpointer = ModelCheckpoint(filepath=filepath, monitor='loss', verbose=0,
+    # filepath = 'models/table-line.h5'
+    filepath = 'e{epoch:02d}-f1{val_f1:.2f}.h5'
+    checkpointer = ModelCheckpoint(filepath=filepath, monitor='val_f1', verbose=0,
                                    save_weights_only=True, save_best_only=True)
-    rlu = ReduceLROnPlateau(monitor='loss', factor=0.1, patience=5,
-                            verbose=0, mode='auto', cooldown=0, min_lr=0)
+    rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.1, patience=10,
+                            verbose=0, mode='max', cooldown=0, min_lr=0)
     model.compile(optimizer=Adam(lr=0.0003), loss=focal_loss(),
-                  metrics=['acc', keras.metrics.Precision(), keras.metrics.Recall()])
+                  metrics=['acc', keras.metrics.Precision(), keras.metrics.Recall(), f1])
 
     # table line dataset label with labelme
-    paths = glob('train/dataset-line/6/*.json')
+    paths = glob('train/dataset-line/7/*.json')
     # print("paths", paths)
     print("len(paths)", len(paths))
 
-    paths = paths[:3000]
+    # paths = paths[:3000]
 
     # 限制长度
     # print('len(paths)', len(paths))