#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Sep 9 23:11:51 2020 table line detect @author: chineseocr """ import copy import io import logging import sys import traceback import tensorflow as tf import tensorflow.keras.backend as K from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, concatenate, Conv2D, MaxPooling2D, BatchNormalization, UpSampling2D from tensorflow.keras.layers import LeakyReLU from otr.utils import letterbox_image, get_table_line, adjust_lines, line_to_line, draw_boxes import numpy as np import cv2 import time from format_convert import _global from format_convert.utils import log def dice_coef(y_true, y_pred, smooth=1e-5): y_true_f = K.flatten(y_true) y_pred_f = K.flatten(y_pred) intersection = K.sum(y_true_f * y_pred_f) return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) def dice_coef_loss(): def dice_coef_loss_fixed(y_true, y_pred): return -dice_coef(y_true, y_pred) return dice_coef_loss_fixed def focal_loss(gamma=3., alpha=.5): # 3 0.85 2000e acc-0.6 p-0.99 r-0.99 val_acc-0.56 val_p-0.86 val_r-0.95 # 2 0.85 double_gpu acc- # 3 0.25 gpu 50e acc-0.5 p-0.99 r-0.99 val_acc-0.45 val_p-0.96 val_r-0.88 # 2 0.25 gpu acc- # 3 0.5 double_gpu acc-0.6 p-0.99 r-0.99 val_acc-0.60 val_p-0.93 val_r-0.93 def focal_loss_fixed(y_true, y_pred): pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred)) pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred)) return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon()+pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0 + K.epsilon())) return focal_loss_fixed def table_net_large(input_shape=(1152, 896, 3), num_classes=1): inputs = Input(shape=input_shape) # 512 use_bias = False down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(inputs) down0a = BatchNormalization()(down0a) down0a = LeakyReLU(alpha=0.1)(down0a) down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(down0a) down0a = BatchNormalization()(down0a) down0a = LeakyReLU(alpha=0.1)(down0a) down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a) # 256 down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0a_pool) down0 = BatchNormalization()(down0) down0 = LeakyReLU(alpha=0.1)(down0) down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0) down0 = BatchNormalization()(down0) down0 = LeakyReLU(alpha=0.1)(down0) down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0) # 128 down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down0_pool) down1 = BatchNormalization()(down1) down1 = LeakyReLU(alpha=0.1)(down1) down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down1) down1 = BatchNormalization()(down1) down1 = LeakyReLU(alpha=0.1)(down1) down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1) # 64 down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down1_pool) down2 = BatchNormalization()(down2) down2 = LeakyReLU(alpha=0.1)(down2) down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down2) down2 = BatchNormalization()(down2) down2 = LeakyReLU(alpha=0.1)(down2) down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2) # 32 down3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(down2_pool) down3 = BatchNormalization()(down3) down3 = LeakyReLU(alpha=0.1)(down3) down3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(down3) down3 = BatchNormalization()(down3) down3 = LeakyReLU(alpha=0.1)(down3) down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3) # 16 down4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(down3_pool) down4 = BatchNormalization()(down4) down4 = LeakyReLU(alpha=0.1)(down4) down4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(down4) down4 = BatchNormalization()(down4) down4 = LeakyReLU(alpha=0.1)(down4) down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4) # 8 center = Conv2D(1024, (3, 3), padding='same', use_bias=use_bias)(down4_pool) center = BatchNormalization()(center) center = LeakyReLU(alpha=0.1)(center) center = Conv2D(1024, (3, 3), padding='same', use_bias=use_bias)(center) center = BatchNormalization()(center) center = LeakyReLU(alpha=0.1)(center) # center up4 = UpSampling2D((2, 2))(center) up4 = concatenate([down4, up4], axis=3) up4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(up4) up4 = BatchNormalization()(up4) up4 = LeakyReLU(alpha=0.1)(up4) up4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(up4) up4 = BatchNormalization()(up4) up4 = LeakyReLU(alpha=0.1)(up4) up4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(up4) up4 = BatchNormalization()(up4) up4 = LeakyReLU(alpha=0.1)(up4) # 16 up3 = UpSampling2D((2, 2))(up4) up3 = concatenate([down3, up3], axis=3) up3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(up3) up3 = BatchNormalization()(up3) up3 = LeakyReLU(alpha=0.1)(up3) up3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(up3) up3 = BatchNormalization()(up3) up3 = LeakyReLU(alpha=0.1)(up3) up3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(up3) up3 = BatchNormalization()(up3) up3 = LeakyReLU(alpha=0.1)(up3) # 32 up2 = UpSampling2D((2, 2))(up3) up2 = concatenate([down2, up2], axis=3) up2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(up2) up2 = BatchNormalization()(up2) up2 = LeakyReLU(alpha=0.1)(up2) up2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(up2) up2 = BatchNormalization()(up2) up2 = LeakyReLU(alpha=0.1)(up2) up2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(up2) up2 = BatchNormalization()(up2) up2 = LeakyReLU(alpha=0.1)(up2) # 64 up1 = UpSampling2D((2, 2))(up2) up1 = concatenate([down1, up1], axis=3) up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1) up1 = BatchNormalization()(up1) up1 = LeakyReLU(alpha=0.1)(up1) up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1) up1 = BatchNormalization()(up1) up1 = LeakyReLU(alpha=0.1)(up1) up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1) up1 = BatchNormalization()(up1) up1 = LeakyReLU(alpha=0.1)(up1) # 128 up0 = UpSampling2D((2, 2))(up1) up0 = concatenate([down0, up0], axis=3) up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0) up0 = BatchNormalization()(up0) up0 = LeakyReLU(alpha=0.1)(up0) up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0) up0 = BatchNormalization()(up0) up0 = LeakyReLU(alpha=0.1)(up0) up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0) up0 = BatchNormalization()(up0) up0 = LeakyReLU(alpha=0.1)(up0) # 256 up0a = UpSampling2D((2, 2))(up0) up0a = concatenate([down0a, up0a], axis=3) up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a) up0a = BatchNormalization()(up0a) up0a = LeakyReLU(alpha=0.1)(up0a) up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a) up0a = BatchNormalization()(up0a) up0a = LeakyReLU(alpha=0.1)(up0a) up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a) up0a = BatchNormalization()(up0a) up0a = LeakyReLU(alpha=0.1)(up0a) # 512 classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0a) model = Model(inputs=inputs, outputs=classify) return model def table_net(input_shape=(1152, 896, 3), num_classes=1): inputs = Input(shape=input_shape) # 512 use_bias = False down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(inputs) down0a = BatchNormalization()(down0a) down0a = LeakyReLU(alpha=0.1)(down0a) down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(down0a) down0a = BatchNormalization()(down0a) down0a = LeakyReLU(alpha=0.1)(down0a) down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a) # 256 down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0a_pool) down0 = BatchNormalization()(down0) down0 = LeakyReLU(alpha=0.1)(down0) down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0) down0 = BatchNormalization()(down0) down0 = LeakyReLU(alpha=0.1)(down0) down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0) # 128 down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down0_pool) down1 = BatchNormalization()(down1) down1 = LeakyReLU(alpha=0.1)(down1) down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down1) down1 = BatchNormalization()(down1) down1 = LeakyReLU(alpha=0.1)(down1) down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1) # 64 down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down1_pool) down2 = BatchNormalization()(down2) down2 = LeakyReLU(alpha=0.1)(down2) down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down2) down2 = BatchNormalization()(down2) down2 = LeakyReLU(alpha=0.1)(down2) down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2) # 32 up1 = UpSampling2D((2, 2))(down2) up1 = concatenate([down1, up1], axis=3) up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1) up1 = BatchNormalization()(up1) up1 = LeakyReLU(alpha=0.1)(up1) up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1) up1 = BatchNormalization()(up1) up1 = LeakyReLU(alpha=0.1)(up1) up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1) up1 = BatchNormalization()(up1) up1 = LeakyReLU(alpha=0.1)(up1) # 128 up0 = UpSampling2D((2, 2))(up1) up0 = concatenate([down0, up0], axis=3) up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0) up0 = BatchNormalization()(up0) up0 = LeakyReLU(alpha=0.1)(up0) up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0) up0 = BatchNormalization()(up0) up0 = LeakyReLU(alpha=0.1)(up0) up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0) up0 = BatchNormalization()(up0) up0 = LeakyReLU(alpha=0.1)(up0) # 256 up0a = UpSampling2D((2, 2))(up0) up0a = concatenate([down0a, up0a], axis=3) up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a) up0a = BatchNormalization()(up0a) up0a = LeakyReLU(alpha=0.1)(up0a) up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a) up0a = BatchNormalization()(up0a) up0a = LeakyReLU(alpha=0.1)(up0a) up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a) up0a = BatchNormalization()(up0a) up0a = LeakyReLU(alpha=0.1)(up0a) # 512 classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0a) model = Model(inputs=inputs, outputs=classify) return model model = table_net((None, None, 3), 2) def draw_pixel(pred, prob=0.2, is_test=1): if not is_test: return else: import matplotlib.pyplot as plt _array = [] for _h in range(len(pred)): _line = [] for _w in range(len(pred[_h])): _prob = pred[_h][_w] if _prob[0] > prob: _line.append((0, 0, 255)) elif _prob[1] > prob: _line.append((255, 0, 0)) else: _line.append((255, 255, 255)) _array.append(_line) plt.axis('off') plt.imshow(np.array(_array)) plt.show() return def expansionAndShrinkage(pred,width=3): pred_array = np.array(pred) print("pred_array=====",pred_array.shape) _array = pred_array[...,0] _l = [_array] for _i in range(width): tmp_array = np.pad(_array[:-(_i+1),...],((_i+1,0),(0,0))) _l.append(tmp_array) for _i in range(width): tmp_array = np.pad(_array[_i+1:,...],((0,_i+1),(0,0))) _l.append(tmp_array) for _i in range(width): tmp_array = np.pad(_array[...,:-(_i+1)],((0,0),(_i+1,0))) _l.append(tmp_array) for _i in range(width): tmp_array = np.pad(_array[...,_i+1:],((0,0),(0,_i+1))) _l.append(tmp_array) for _a in _l: print(_a.shape) h_array = np.stack(_l,axis=0) h_array = np.max(h_array,axis=0,keepdims=False) _array = pred_array[...,1] _l = [_array] for _i in range(width): tmp_array = np.pad(_array[:-(_i+1),...],((_i+1,0),(0,0))) _l.append(tmp_array) for _i in range(width): tmp_array = np.pad(_array[_i+1:,...],((0,_i+1),(0,0))) _l.append(tmp_array) for _i in range(width): tmp_array = np.pad(_array[...,:-(_i+1)],((0,0),(_i+1,0))) _l.append(tmp_array) for _i in range(width): tmp_array = np.pad(_array[...,_i+1:],((0,0),(0,_i+1))) _l.append(tmp_array) v_array = np.stack(_l,axis=0) print("v_array=====",v_array.shape) v_array = np.max(v_array,axis=0,keepdims=False) print("h_array=====",h_array.shape) print("v_array=====",v_array.shape) last_array = np.stack([h_array,v_array],axis=-1) print("pred_array=====",last_array.shape) return last_array def getIOU(bbox0, bbox1): width = abs(max(bbox0[2],bbox1[2])-min(bbox0[0],bbox1[0]))-(abs(bbox0[2]-bbox0[0])+abs(bbox1[2]-bbox1[0])) height = abs(max(bbox0[3],bbox1[3])-min(bbox0[1],bbox1[1]))-(abs(bbox0[3]-bbox0[1])+abs(bbox1[3]-bbox1[1])) if width < 0 and height < 0: iou = abs(width*height/min(abs((bbox0[2]-bbox0[0])*(bbox0[3]-bbox0[1])), abs((bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1])))) # print("getIOU", iou) return iou return 0 def lines_cluster(list_lines,line_width): after_len = 0 log("len lines %d"%len(list_lines)) append_width = line_width//2 while 1: c_lines = [] first_len = after_len for _line in list_lines: bbox = _line["bbox"] _find = False for c_l_i in range(len(c_lines)): c_l = c_lines[len(c_lines)-c_l_i-1] bbox1 = c_l["bbox"] bboxa = [max(0,bbox[0]-append_width),max(0,bbox[1]-append_width),bbox[2]+append_width,bbox[3]+append_width] bboxb = [max(0,bbox1[0]-append_width),max(0,bbox1[1]-append_width),bbox1[2]+append_width,bbox1[3]+append_width] _iou = getIOU(bboxa,bboxb) if _iou>0: new_bbox = [min(bbox[0],bbox[2],bbox1[0],bbox1[2]),min(bbox[1],bbox[3],bbox1[1],bbox1[3]),max(bbox[0],bbox[2],bbox1[0],bbox1[2]),max(bbox[1],bbox[3],bbox1[1],bbox1[3])] _find = True c_l["bbox"] = new_bbox break if not _find: c_lines.append(_line) after_len = len(c_lines) if first_len==after_len: break list_lines = c_lines return c_lines def points2lines(pred,sourceP_LB=True, prob=0.2, line_width=8, padding=3, min_len=10, cell_width=13): _time = time.time() log("starting points2lines") height = len(pred) width = len(pred[0]) _sum = list(np.sum(np.array((pred[...,0]>prob)).astype(int),axis=1)) h_index = -1 h_lines = [] v_lines = [] _step = line_width while 1: h_index += 1 if h_index>=height: break w_index = -1 if sourceP_LB: h_i = height-1-h_index else: h_i = h_index _start = None if _sum[h_index]=width: break _h,_v = pred[h_i][w_index] if _h>prob: if _start is None: _start = w_index else: if _start is not None: _end = w_index-1 _bbox = [_start,h_i,_end,h_i] _dict = {"bbox":_bbox} h_lines.append(_dict) _start = None w_index -= _step//2 log("starting points2lines 1") w_index = -1 _sum = list(np.sum(np.array((pred[...,1]>prob)).astype(int),axis=0)) _step = line_width while 1: w_index += 1 if w_index>=width: break if _sum[w_index]=height: break if sourceP_LB: h_i = height-1-h_index else: h_i = h_index _h,_v = pred[h_index][w_index] if _v>prob: if _start is None: _start = h_i else: if _start is not None: _end = last_h _bbox = [w_index,_start,w_index,_end] _dict = {"bbox":_bbox} v_lines.append(_dict) _start = None h_index -= _step//2 last_h = h_i log("starting points2lines 2") for _line in h_lines: _bbox = _line["bbox"] _bbox = [max(_bbox[0]-2,0),(_bbox[1]+_bbox[3])/2,_bbox[2]+2,(_bbox[1]+_bbox[3])/2] _line["bbox"] = _bbox for _line in v_lines: _bbox = _line["bbox"] _bbox = [(_bbox[0]+_bbox[2])/2,max(_bbox[1]-2,0),(_bbox[0]+_bbox[2])/2,_bbox[3]+2] _line["bbox"] = _bbox h_lines = lines_cluster(h_lines,line_width=line_width) v_lines = lines_cluster(v_lines,line_width=line_width) list_line = [] for _line in h_lines: _bbox = _line["bbox"] _bbox = [max(_bbox[0]-1,0),(_bbox[1]+_bbox[3])/2,_bbox[2]+1,(_bbox[1]+_bbox[3])/2] list_line.append(_bbox) for _line in v_lines: _bbox = _line["bbox"] _bbox = [(_bbox[0]+_bbox[2])/2,max(_bbox[1]-1,0),(_bbox[0]+_bbox[2])/2,_bbox[3]+1] list_line.append(_bbox) log("points2lines cost %.2fs"%(time.time()-_time)) # import matplotlib.pyplot as plt # plt.figure() # for _line in list_line: # x0,y0,x1,y1 = _line # plt.plot([x0,x1],[y0,y1]) # for _line in list_line: # x0,y0,x1,y1 = _line.bbox # plt.plot([x0,x1],[y0,y1]) # for point in list_crosspoints: # plt.scatter(point.get("point")[0],point.get("point")[1]) # plt.show() return list_line def points2lines_bak(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_len=10, cell_width=13): def inBbox(bbox,point,line_width): x,y = point if x>=bbox[0]-line_width and x<=bbox[2]+line_width and y>=bbox[1]-line_width and y<=bbox[3]+line_width: return True,[min(x,bbox[0]),min(y,bbox[1]),max(x,bbox[2]),max(y,bbox[3])] return False,None _time = time.time() height = len(pred) width = len(pred[0]) clust_horizontal = [] clust_vertical = [] h_index = -1 _step = line_width _sum = list(np.sum(np.array((pred[...,1]>prob)).astype(int),axis=0)) _last = False _current = False while 1: h_index += 2 if h_index>=height: break w_index = -1 if sourceP_LB: h_i = height-1-h_index else: h_i = h_index while 1: w_index += 2 if w_index>=width: break if _sum[w_index]prob: _find = False _point = (w_index,h_i) for l_h_i in range(len(clust_vertical)): l_h = clust_vertical[len(clust_vertical)-l_h_i-1] bbox = l_h.get("bbox") b_in,_bbox = inBbox(bbox,_point,line_width) if b_in: _find = True l_h.get("points").append(_point) l_h["bbox"] = _bbox break if not _find: clust_vertical.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]}) w_index = -1 _sum = list(np.sum(np.array((pred[...,0]>prob)).astype(int),axis=1)) while 1: w_index += 2 if w_index>=width: break h_index = -1 while 1: h_index += 2 if h_index>=height: break if _sum[h_index]prob: _find = False _point = (w_index,h_i) for l_h_i in range(len(clust_horizontal)): l_h = clust_horizontal[len(clust_horizontal)-l_h_i-1] bbox = l_h.get("bbox") b_in,_bbox = inBbox(bbox,_point,line_width) if b_in: _find = True l_h.get("points").append(_point) l_h["bbox"] = _bbox break if not _find: clust_horizontal.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]}) tmp_vertical = [] for _dict in clust_vertical: _bbox = _dict.get("bbox") if _bbox[2]-_bbox[0]>=min_len or _bbox[3]-_bbox[1]>=min_len: tmp_vertical.append([(_bbox[0]+_bbox[2])/2,_bbox[1]-padding,(_bbox[0]+_bbox[2])/2,_bbox[3]+padding]) tmp_horizontal = [] for _dict in clust_horizontal: _bbox = _dict.get("bbox") if _bbox[2]-_bbox[0]>=min_len or _bbox[3]-_bbox[1]>=min_len: tmp_horizontal.append([_bbox[0]-padding,(_bbox[1]+_bbox[3])/2,_bbox[2]+padding,(_bbox[1]+_bbox[3])/2]) #merge lines tmp_vertical.sort(key=lambda x:x[3],reverse=True) tmp_horizontal.sort(key=lambda x:x[0]) pop_index = [] final_vertical = [] for _line in tmp_vertical: _find = False x0,y0,x1,y1 = _line for _line2 in final_vertical: x2,y2,x3,y3 = _line2 if abs(x0-x2)= point_value) points = [[xs[i], ys[i]] for i in range(len(xs))] # 提出所有相同x或相同y的点 # 提取行/列 axis_points = get_axis_points(points, axis) # 提取每行/列的横线/竖线 axis_lines = get_axis_lines(axis_points, axis) # print("axis_lines", axis_lines) return axis_lines def table_preprocess(img_data, prob=0.2): try: log("into table_preprocess, prob is " + str(prob)) start_time = time.time() # 二进制数据流转np.ndarray [np.uint8: 8位像素] img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR) # 将bgr转为rbg image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # 模型输入 inputs = np.array([image_np]) # # 压缩numpy # compressed_array = io.BytesIO() # np.savez_compressed(compressed_array, inputs) # compressed_array.seek(0) # inputs_compressed = compressed_array.read() log("otr preprocess time: " + str(round(float(time.time()-start_time), 4)) + "s") return image_np, inputs except Exception as e: log("table pre process failed!") traceback.print_exc() return [-1], [-1] def table_postprocess(img_new, pred, prob=0.2, is_test=0): try: # 横线预测结果 # row_pred = pred[..., 0] > hprob # row_pred = row_pred.astype(np.uint8) # # 竖线预测结果 # col_pred = pred[..., 1] > vprob # col_pred = col_pred.astype(np.uint8) # # 打印模型输出 # cv2.imshow("predict", (col_pred+row_pred)*255) # cv2.waitKey(0) start_time = time.time() list_line = points2lines(pred, False, prob=prob) mat_plot(list_line, "points2lines", is_test) log("points2lines " + str(time.time()-start_time)) # 清除短线 # print(img_new.shape) start_time = time.time() list_line = delete_short_lines(list_line, img_new.shape) mat_plot(list_line, "delete_short_lines", is_test) log("delete_short_lines " + str(time.time()-start_time)) # # 清除无交点线--无需清除,会影响后面的结果 # start_time = time.time() # list_line = delete_no_cross_lines(list_line) # mat_plot(list_line, "delete_no_cross_lines", is_test) # log("delete_no_cross_lines " + str(time.time()-start_time)) # 分成横竖线 start_time = time.time() list_rows = [] list_cols = [] for line in list_line: if line[0] == line[2]: list_cols.append(line) elif line[1] == line[3]: list_rows.append(line) log("divide rows and cols " + str(time.time()-start_time)) # 合并错开线 start_time = time.time() list_rows = merge_line(list_rows, axis=0) list_cols = merge_line(list_cols, axis=1) mat_plot(list_rows+list_cols, "merge_line", is_test) log("merge_line " + str(time.time()-start_time)) # 计算交点、分割线 start_time = time.time() cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1])) if not cross_points: return [] log("get_points " + str(time.time()-start_time)) # 清掉外围的没用的线 # list_rows, list_cols = delete_outline(list_rows, list_cols, cross_points) # mat_plot(list_rows+list_cols, "delete_outline", is_test) # 多个表格分割线 start_time = time.time() list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new) split_lines, split_y = get_split_line(cross_points, list_cols, img_new) log("get_split_line " + str(time.time()-start_time)) # 修复边框 start_time = time.time() new_rows, new_cols, long_rows, long_cols = fix_outline(img_new, list_rows, list_cols, cross_points, split_y) # 如有补线 if new_rows or new_cols: # 连接至补线的延长线 if long_rows: list_rows = long_rows if long_cols: list_cols = long_cols # 新的补线 if new_rows: list_rows += new_rows if new_cols: list_cols += new_cols list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new) # 修复边框后重新计算交点、分割线 cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1])) cv_plot(cross_points, img_new.shape, 0, is_test) split_lines, split_y = get_split_line(cross_points, list_cols, img_new) print("fix new split_y", split_y) print("fix new split_lines", split_lines) # 修复内部缺线 # cross_points = fix_inner(list_rows, list_cols, cross_points, split_y) # if not cross_points: # return [] mat_plot(list_rows+list_cols, "fix_outline", is_test) split_lines_show = [] for _l in split_lines: split_lines_show.append([_l[0][0], _l[0][1], _l[1][0], _l[1][1]]) mat_plot(split_lines_show+list_cols, "split_lines", is_test) log("fix_outline " + str(time.time()-start_time)) # 修复表格4个角 start_time = time.time() list_rows, list_cols = fix_corner(list_rows, list_cols, split_y, threshold=0) mat_plot(list_rows+list_cols, "fix_corner", is_test) log("fix_corner " + str(time.time()-start_time)) # 修复内部缺线 start_time = time.time() list_rows, list_cols = fix_inner(list_rows, list_cols, cross_points, split_y) mat_plot(list_rows+list_cols, "fix_inner", is_test) log("fix_inner " + str(time.time()-start_time)) # 合并错开线 start_time = time.time() list_rows = merge_line(list_rows, axis=0) list_cols = merge_line(list_cols, axis=1) mat_plot(list_rows+list_cols, "merge_line", is_test) log("merge_line " + str(time.time()-start_time)) list_line = list_rows + list_cols # 打印处理后线 mat_plot(list_line, "all", is_test) log("otr postprocess table_line " + str(time.time()-start_time)) return list_line except Exception as e: log("table post process failed!") traceback.print_exc() return [-1] def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0): log("into table_line, prob is " + str(prob)) sizew, sizeh = size img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA) start_time = time.time() pred = model.predict(np.array([img_new])) log("otr model predict time " + str(time.time()-start_time)) pred = pred[0] draw_pixel(pred, prob, is_test) # 横线预测结果 # row_pred = pred[..., 0] > hprob # row_pred = row_pred.astype(np.uint8) # # 竖线预测结果 # col_pred = pred[..., 1] > vprob # col_pred = col_pred.astype(np.uint8) # # 打印模型输出 # cv2.imshow("predict", (col_pred+row_pred)*255) # cv2.waitKey(0) start_time = time.time() list_line = points2lines(pred, False, prob=prob) mat_plot(list_line, "points2lines", is_test) log("points2lines " + str(time.time()-start_time)) # 清除短线 # print(img_new.shape) start_time = time.time() list_line = delete_short_lines(list_line, img_new.shape) mat_plot(list_line, "delete_short_lines", is_test) log("delete_short_lines " + str(time.time()-start_time)) # # 清除无交点线--无需清除,会影响后面的结果 # start_time = time.time() # list_line = delete_no_cross_lines(list_line) # mat_plot(list_line, "delete_no_cross_lines", is_test) # log("delete_no_cross_lines " + str(time.time()-start_time)) # 分成横竖线 start_time = time.time() list_rows = [] list_cols = [] for line in list_line: if line[0] == line[2]: list_cols.append(line) elif line[1] == line[3]: list_rows.append(line) log("divide rows and cols " + str(time.time()-start_time)) # 合并错开线 start_time = time.time() list_rows = merge_line(list_rows, axis=0) list_cols = merge_line(list_cols, axis=1) mat_plot(list_rows+list_cols, "merge_line", is_test) log("merge_line " + str(time.time()-start_time)) # 计算交点、分割线 start_time = time.time() cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1])) if not cross_points: return [] log("get_points " + str(time.time()-start_time)) # 清掉外围的没用的线 # list_rows, list_cols = delete_outline(list_rows, list_cols, cross_points) # mat_plot(list_rows+list_cols, "delete_outline", is_test) # 多个表格分割线 start_time = time.time() list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new) split_lines, split_y = get_split_line(cross_points, list_cols, img_new) log("get_split_line " + str(time.time()-start_time)) # 修复边框 start_time = time.time() new_rows, new_cols, long_rows, long_cols = fix_outline(img_new, list_rows, list_cols, cross_points, split_y) # 如有补线 if new_rows or new_cols: # 连接至补线的延长线 if long_rows: list_rows = long_rows if long_cols: list_cols = long_cols # 新的补线 if new_rows: list_rows += new_rows if new_cols: list_cols += new_cols list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new) # 修复边框后重新计算交点、分割线 cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1])) cv_plot(cross_points, img_new.shape, 0, is_test) split_lines, split_y = get_split_line(cross_points, list_cols, img_new) print("fix new split_y", split_y) print("fix new split_lines", split_lines) # 修复内部缺线 # cross_points = fix_inner(list_rows, list_cols, cross_points, split_y) # if not cross_points: # return [] mat_plot(list_rows+list_cols, "fix_outline", is_test) split_lines_show = [] for _l in split_lines: split_lines_show.append([_l[0][0], _l[0][1], _l[1][0], _l[1][1]]) mat_plot(split_lines_show+list_cols, "split_lines", is_test) log("fix_outline " + str(time.time()-start_time)) # 修复表格4个角 start_time = time.time() list_rows, list_cols = fix_corner(list_rows, list_cols, split_y, threshold=0) mat_plot(list_rows+list_cols, "fix_corner", is_test) log("fix_corner " + str(time.time()-start_time)) # 修复内部缺线 start_time = time.time() list_rows, list_cols = fix_inner(list_rows, list_cols, cross_points, split_y) mat_plot(list_rows+list_cols, "fix_inner", is_test) log("fix_inner " + str(time.time()-start_time)) # 合并错开线 start_time = time.time() list_rows = merge_line(list_rows, axis=0) list_cols = merge_line(list_cols, axis=1) mat_plot(list_rows+list_cols, "merge_line", is_test) log("merge_line " + str(time.time()-start_time)) list_line = list_rows + list_cols # 打印处理后线 mat_plot(list_line, "all", is_test) log("otr postprocess table_line " + str(time.time()-start_time)) return list_line def table_line2(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=30, alph=15): sizew, sizeh = size # [..., ::-1] 最后一维内部反向输出 # inputBlob, fx, fy = letterbox_image(img[..., ::-1], (sizew, sizeh)) # pred = model.predict(np.array([np.array(inputBlob)])) # pred = model.predict(np.array([np.array(inputBlob)/255.0])) img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA) # log("into table_line 1") pred = model.predict(np.array([img_new])) # log("into table_line 2") pred = pred[0] draw_pixel(pred) _time = time.time() points2lines(pred) log("points2lines takes %ds"%(time.time()-_time)) vpred = pred[..., 1] > vprob # 横线 hpred = pred[..., 0] > hprob # 竖线 vpred = vpred.astype(int) hpred = hpred.astype(int) # print("vpred shape", vpred) # print("hpred shape", hpred) colboxes = get_table_line(vpred, axis=1, lineW=col) rowboxes = get_table_line(hpred, axis=0, lineW=row) # log("into table_line 3") # if len(rowboxes) > 0: # rowboxes = np.array(rowboxes) # rowboxes[:, [0, 2]] = rowboxes[:, [0, 2]]/fx # rowboxes[:, [1, 3]] = rowboxes[:, [1, 3]]/fy # rowboxes = rowboxes.tolist() # if len(colboxes) > 0: # colboxes = np.array(colboxes) # colboxes[:, [0, 2]] = colboxes[:, [0, 2]]/fx # colboxes[:, [1, 3]] = colboxes[:, [1, 3]]/fy # colboxes = colboxes.tolist() nrow = len(rowboxes) ncol = len(colboxes) for i in range(nrow): for j in range(ncol): rowboxes[i] = line_to_line(rowboxes[i], colboxes[j], 10) colboxes[j] = line_to_line(colboxes[j], rowboxes[i], 10) # log("into table_line 4") # 删掉贴着边框的line temp_list = [] threshold = 5 for line in rowboxes: if line[1]-0 <= threshold or size[1]-line[1] <= threshold: continue # 内部排序 if line[0] > line[2]: line = [line[2], line[3], line[0], line[1]] temp_list.append(line) rowboxes = temp_list temp_list = [] for line in colboxes: if line[0]-0 <= threshold or size[0]-line[0] <= threshold: continue # 内部排序 if line[1] > line[3]: line = [line[2], line[3], line[0], line[1]] temp_list.append(line) colboxes = temp_list return rowboxes, colboxes, img_new def fix_in_split_lines(_rows, _cols, _img): # 补线贴着边缘无法得到split_y,导致无法分区 for _row in _rows: if _row[1] >= _img.shape[0] - 5: _row[1] = _img.shape[0] - 6 _row[3] = _img.shape[0] - 6 print("_row", _row) if _row[1] <= 0 + 5: _row[1] = 6 _row[3] = 6 for _col in _cols: if _col[3] >= _img.shape[0] - 5: _col[3] = _img.shape[0] - 6 if _col[1] <= 0 + 5: _col[1] = 6 return _rows, _cols def mat_plot(list_line, name="", is_test=1): if not is_test: return from matplotlib import pyplot as plt plt.figure() plt.title(name) for _line in list_line: x0, y0, x1, y1 = _line plt.plot([x0, x1], [y0, y1]) plt.show() def cv_plot(_list, img_shape, line_or_point=1, is_test=1): if is_test == 0: return img_print = np.zeros(img_shape, np.uint8) img_print.fill(255) if line_or_point: for line in _list: cv2.line(img_print, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (255, 0, 0)) cv2.imshow("cv_plot", img_print) cv2.waitKey(0) else: for point in _list: cv2.circle(img_print, (int(point[0]), int(point[1])), 1, (255, 0, 0), 2) cv2.imshow("cv_plot", img_print) cv2.waitKey(0) def delete_no_cross_lines(list_lines): def get_cross_point(l1, l2): # https://www.zhihu.com/question/381406535/answer/1095948349 flag = 0 # 相交一定是一条横线一条竖线 if (l1[0] == l1[2] and l2[1] == l2[3]) or (l1[1] == l1[3] and l2[0] == l2[2]): if l1[0] <= l2[0] <= l1[2] and l2[1] <= l1[1] <= l2[3]: flag = 1 elif l2[0] <= l1[0] <= l2[2] and l1[1] <= l2[1] <= l1[3]: flag = 1 return flag new_list_lines = [] for i in range(len(list_lines)): line1 = list_lines[i] find_flag = 0 for j in range(i+1, len(list_lines)): line2 = list_lines[j] if get_cross_point(line1, line2): # print("delete_no_cross_lines", line1, line2) find_flag = 1 if line2 not in new_list_lines: new_list_lines.append(line2) if find_flag and line1 not in new_list_lines: new_list_lines.append(line1) return new_list_lines def delete_short_lines(list_lines, image_shape, scale=100): x_min_len = max(5, int(image_shape[0] / scale)) y_min_len = max(5, int(image_shape[1] / scale)) new_list_lines = [] for line in list_lines: if line[0] == line[2]: if abs(line[3] - line[1]) >= y_min_len: # print("y_min_len", abs(line[3] - line[1]), y_min_len) new_list_lines.append(line) else: if abs(line[2] - line[0]) >= x_min_len: # print("x_min_len", abs(line[2] - line[0]), x_min_len) new_list_lines.append(line) return new_list_lines def get_outline(points, image_np): # 取出x, y的最大值最小值 x_min = points[0][0] x_max = points[-1][0] points.sort(key=lambda x: (x[1], x[0])) y_min = points[0][1] y_max = points[-1][1] # 创建空图 # outline_img = np.zeros(image_size, np.uint8) outline_img = np.copy(image_np) cv2.rectangle(outline_img, (x_min-5, y_min-5), (x_max+5, y_max+5), (0, 0, 0), 2) # cv2.imshow("outline_img", outline_img) # cv2.waitKey(0) return outline_img def get_split_line(points, col_lines, image_np, threshold=5): # print("get_split_line", image_np.shape) points.sort(key=lambda x: (x[1], x[0])) # 遍历y坐标,并判断y坐标与上一个y坐标是否存在连接线 i = 0 split_line_y = [] for point in points: # 从已分开的线下面开始判断 if split_line_y: if point[1] <= split_line_y[-1] + threshold: last_y = point[1] continue if last_y <= split_line_y[-1] + threshold: last_y = point[1] continue if i == 0: last_y = point[1] i += 1 continue current_line = (last_y, point[1]) split_flag = 1 for col in col_lines: # 只要找到一条col包含就不是分割线 if current_line[0] >= col[1]-3 and current_line[1] <= col[3]+3: split_flag = 0 # print("img", img.shape) # print("col", col) # print("current_line", current_line) break if split_flag: split_line_y.append(current_line[0]+5) split_line_y.append(current_line[1]-5) last_y = point[1] # 加上收尾分割线 points.sort(key=lambda x: (x[1], x[0])) y_min = points[0][1] y_max = points[-1][1] # print("加上收尾分割线", y_min, y_max) if y_min-threshold < 0: split_line_y.append(0) else: split_line_y.append(y_min-threshold) if y_max+threshold > image_np.shape[0]: split_line_y.append(image_np.shape[0]) else: split_line_y.append(y_max+threshold) split_line_y = list(set(split_line_y)) # 剔除两条相隔太近分割线 temp_split_line_y = [] split_line_y.sort(key=lambda x: x) last_y = -20 for y in split_line_y: # print(y) if y - last_y >= 20: # print(y, last_y) temp_split_line_y.append(y) last_y = y split_line_y = temp_split_line_y # print("split_line_y", split_line_y) # 生成分割线 split_line = [] last_y = 0 for y in split_line_y: # if y - last_y <= 15: # continue split_line.append([(0, y), (image_np.shape[1], y)]) last_y = y split_line.append([(0, 0), (image_np.shape[1], 0)]) split_line.append([(0, image_np.shape[0]), (image_np.shape[1], image_np.shape[0])]) split_line.sort(key=lambda x: x[0][1]) # print("split_line", split_line) # 画图画线 # split_line_img = np.copy(image_np) # for y in split_line_y: # cv2.line(split_line_img, (0, y), (image_np.shape[1], y), (0, 0, 0), 1) # cv2.imshow("split_line_img", split_line_img) # cv2.waitKey(0) return split_line, split_line_y def get_points(row_lines, col_lines, image_size): # 创建空图 row_img = np.zeros(image_size, np.uint8) col_img = np.zeros(image_size, np.uint8) # 画线 thresh = 3 for row in row_lines: cv2.line(row_img, (int(row[0]-thresh), int(row[1])), (int(row[2]+thresh), int(row[3])), (255, 255, 255), 1) for col in col_lines: cv2.line(col_img, (int(col[0]), int(col[1]-thresh)), (int(col[2]), int(col[3]+thresh)), (255, 255, 255), 1) # 求出交点 point_img = np.bitwise_and(row_img, col_img) # cv2.imwrite("get_points.jpg", row_img+col_img) # cv2.imshow("get_points", row_img+col_img) # cv2.waitKey(0) # 识别黑白图中的白色交叉点,将横纵坐标取出 ys, xs = np.where(point_img > 0) points = [] for i in range(len(xs)): points.append((xs[i], ys[i])) points.sort(key=lambda x: (x[0], x[1])) return points def get_minAreaRect(image): gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) gray = cv2.bitwise_not(gray) thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] coords = np.column_stack(np.where(thresh > 0)) return cv2.minAreaRect(coords) def get_contours(image): gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) ret, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) contours, hierarchy = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(image, contours, -1, (0, 0, 255), 3) cv2.imshow("get contours", image) cv2.waitKey(0) def merge_line(lines, axis, threshold=5): """ 解决模型预测一条直线错开成多条直线,合并成一条直线 :param lines: 线条列表 :param axis: 0:横线 1:竖线 :param threshold: 两条线间像素差阈值 :return: 合并后的线条列表 """ # 任意一条line获取该合并的line,横线往下找,竖线往右找 lines.sort(key=lambda x: (x[axis], x[1-axis])) merged_lines = [] used_lines = [] for line1 in lines: if line1 in used_lines: continue merged_line = [line1] used_lines.append(line1) for line2 in lines: if line2 in used_lines: continue if line1[1-axis]-threshold <= line2[1-axis] <= line1[1-axis]+threshold: # 计算基准长度 min_axis = 10000 max_axis = 0 for line3 in merged_line: if line3[axis] < min_axis: min_axis = line3[axis] if line3[axis+2] > max_axis: max_axis = line3[axis+2] # 判断两条线有无交集 if min_axis <= line2[axis] <= max_axis \ or min_axis <= line2[axis+2] <= max_axis: merged_line.append(line2) used_lines.append(line2) if merged_line: merged_lines.append(merged_line) # 合并line result_lines = [] for merged_line in merged_lines: # 获取line宽的平均值 axis_average = 0 for line in merged_line: axis_average += line[1-axis] axis_average = int(axis_average/len(merged_line)) # 获取最长line两端 merged_line.sort(key=lambda x: (x[axis])) axis_start = merged_line[0][axis] merged_line.sort(key=lambda x: (x[axis+2])) axis_end = merged_line[-1][axis+2] if axis: result_lines.append([axis_average, axis_start, axis_average, axis_end]) else: result_lines.append([axis_start, axis_average, axis_end, axis_average]) return result_lines def fix_inner2(row_points, col_points, row_lines, col_lines, threshold=3): for i in range(len(row_points)): row = row_points[i] row.sort(key=lambda x: (x[1], x[0])) for j in range(len(row)): # 当前点 point = row[j] # 获取当前点在所在行的下个点 if j >= len(row) - 1: next_row_point = [] else: next_row_point = row[j+1] if next_row_point: for k in range(len(row_lines)): line = row_lines[k] if line[1] - threshold <= point[1] <= line[1] + threshold: if not line[0] <= point[0] <= next_row_point[0] <= line[2]: if point[0] <= line[2] < next_row_point[0]: if line[2] - point[0] >= 1/3 * (next_row_point[0] - point[0]): row_lines[k][2] = next_row_point[0] if point[0] < line[0] <= next_row_point[0]: if next_row_point[0] - line[0] >= 1/3 * (next_row_point[0] - point[0]): row_lines[k][0] = point[0] # 获取当前点所在列的下个点 next_col_point = [] for col in col_points: if point in col: col.sort(key=lambda x: (x[0], x[1])) if col.index(point) < len(col) - 1: next_col_point = col[col.index(point)+1] break # 获取当前点的对角线点,通过该列下个点所在行的下个点获得 next_row_next_col_point = [] if next_col_point: for row2 in row_points: if next_col_point in row2: row2.sort(key=lambda x: (x[1], x[0])) if row2.index(next_col_point) < len(row2) - 1: next_row_next_col_point = row2[row2.index(next_col_point)+1] break # 有该列下一点但没有该列下一点所在行的下个点 if not next_row_next_col_point: # 如果有该行下个点 if next_row_point: next_row_next_col_point = [next_row_point[0], next_col_point[1]] if next_col_point: for k in range(len(col_lines)): line = col_lines[k] if line[0] - threshold <= point[0] <= line[0] + threshold: if not line[1] <= point[1] <= next_col_point[1] <= line[3]: if point[1] <= line[3] < next_col_point[1]: if line[3] - point[1] >= 1/3 * (next_col_point[1] - point[1]): col_lines[k][3] = next_col_point[1] if point[1] < line[1] <= next_col_point[1]: if next_col_point[1] - line[1] >= 1/3 * (next_col_point[1] - point[1]): col_lines[k][1] = point[1] if next_row_next_col_point: for k in range(len(col_lines)): line = col_lines[k] if line[0] - threshold <= next_row_next_col_point[0] <= line[0] + threshold: if not line[1] <= point[1] <= next_row_next_col_point[1] <= line[3]: if point[1] < line[1] <= next_row_next_col_point[1]: if next_row_next_col_point[1] - line[1] >= 1/3 * (next_row_next_col_point[1] - point[1]): col_lines[k][1] = point[1] return row_lines, col_lines def fix_inner1(row_lines, col_lines, points, split_y): def fix(fix_lines, assist_lines, split_points, axis): new_points = [] for line1 in fix_lines: min_assist_line = [[], []] min_distance = [1000, 1000] if_find = [0, 0] # 获取fix_line中的所有col point,里面可能不包括两个顶点,col point是交点,顶点可能不是交点 fix_line_points = [] for point in split_points: if abs(point[1-axis] - line1[1-axis]) <= 2: if line1[axis] <= point[axis] <= line1[axis+2]: fix_line_points.append(point) # 找出离两个顶点最近的assist_line, 并且assist_line与fix_line不相交 line1_point = [line1[:2], line1[2:]] for i in range(2): point = line1_point[i] for line2 in assist_lines: if not if_find[i] and abs(point[axis] - line2[axis]) <= 2: if line1[1-axis] <= point[1-axis] <= line2[1-axis+2]: # print("line1, match line2", line1, line2) if_find[i] = 1 break else: if abs(point[axis] - line2[axis]) < min_distance[i] and line2[1-axis] <= point[1-axis] <= line2[1-axis+2]: if line1[axis] <= line2[axis] <= line1[axis+2]: continue min_distance[i] = abs(line1[axis] - line2[axis]) min_assist_line[i] = line2 # 找出离assist_line最近的交点 # 顶点到交点的距离(多出来的线)需大于assist_line到交点的距离(bbox的边)的1/3 min_distance = [1000, 1000] min_col_point = [[], []] for i in range(2): # print("顶点", i, line1_point[i]) if min_assist_line[i]: for point in fix_line_points: if abs(point[axis] - min_assist_line[i][axis]) < min_distance[i]: min_distance[i] = abs(point[axis] - min_assist_line[i][axis]) min_col_point[i] = point if min_col_point[i]: bbox_len = abs(min_col_point[i][axis] - min_assist_line[i][axis]) line_distance = abs(min_col_point[i][axis] - line1_point[i][axis]) if bbox_len/3 <= line_distance <= bbox_len: add_point = (line1_point[i][1-axis], min_assist_line[i][axis]) print("============================table line==") print("fix_inner add point", add_point) print(min_col_point[i][axis], line1_point[i][axis], min_col_point[i][axis], min_assist_line[i][axis]) print(abs(min_col_point[i][axis] - line1_point[i][axis]), abs(min_col_point[i][axis] - min_assist_line[i][axis])/3) print("line1, line2", line1, min_assist_line[i]) new_points.append(add_point) return new_points new_points = [] for i in range(1, len(split_y)): last_y = split_y[i-1] y = split_y[i] # 先对点线进行分区 split_row_lines = [] split_col_lines = [] split_points = [] for row in row_lines: if last_y <= row[1] <= y: split_row_lines.append(row) for col in col_lines: if last_y <= col[1] <= y: split_col_lines.append(col) for point in points: if last_y <= point[1] <= y: split_points.append(point) new_points += fix(split_col_lines, split_row_lines, split_points, axis=1) new_points += fix(split_row_lines, split_col_lines, split_points, axis=0) # 找出所有col的顶点不在row上的、row的顶点不在col上的 # for col in split_col_lines: # print("*"*30) # # # 获取该line中的所有point # col_points = [] # for point in split_points: # if abs(point[0] - col[0]) <= 2: # if col[1] <= point[1] <= col[3]: # col_points.append(point) # # # 比较顶点 # min_row_1 = [] # min_row_2 = [] # min_distance_1 = 1000 # min_distance_2 = 1000 # if_find_1 = 0 # if_find_2 = 0 # for row in split_row_lines: # # 第一个顶点 # if not if_find_1 and abs(col[1] - row[1]) <= 2: # if row[0] <= col[0] <= row[2]: # print("col, match row", col, row) # if_find_1 = 1 # break # else: # if abs(col[1] - row[1]) < min_distance_1 and row[0] <= col[0] <= row[2]: # if col[1] <= row[1] <= col[3]: # continue # min_distance_1 = abs(col[1] - row[1]) # min_row_1 = row # # # 第二个顶点 # if not if_find_2 and abs(col[3] - row[1]) <= 2: # if row[0] <= col[2] <= row[2]: # if_find_2 = 1 # break # else: # if abs(col[3] - row[1]) < min_distance_2 and row[0] <= col[2] <= row[2]: # min_distance_2 = abs(col[3] - row[1]) # min_row_2 = row # # if not if_find_1: # print("col", col) # print("min_row_1", min_row_1) # if min_row_1: # min_distance_1 = 1000 # min_col_point = [] # for point in col_points: # if abs(point[1] - min_row_1[1]) < min_distance_1: # min_distance_1 = abs(point[1] - min_row_1[1]) # min_col_point = point # # if abs(min_col_point[1] - col[1]) >= abs(min_col_point[1] - min_row_1[1])/3: # # add_point = (col[0], min_row_1[1]) # print("fix_inner add point", add_point) # new_points.append(add_point) # else: # print("distance too long", min_col_point, min_row_1) # print(abs(min_col_point[1] - col[1]), abs(min_col_point[1] - min_row_1[1])/3) return points+new_points def fix_inner(row_lines, col_lines, points, split_y): def fix(fix_lines, assist_lines, split_points, axis): new_points = [] for line1 in fix_lines: min_assist_line = [[], []] min_distance = [1000, 1000] if_find = [0, 0] # 获取fix_line中的所有col point,里面可能不包括两个顶点,col point是交点,顶点可能不是交点 fix_line_points = [] for point in split_points: if abs(point[1-axis] - line1[1-axis]) <= 2: if line1[axis] <= point[axis] <= line1[axis+2]: fix_line_points.append(point) # 找出离两个顶点最近的assist_line, 并且assist_line与fix_line不相交 line1_point = [line1[:2], line1[2:]] for i in range(2): point = line1_point[i] for line2 in assist_lines: if not if_find[i] and abs(point[axis] - line2[axis]) <= 2: if line1[1-axis] <= point[1-axis] <= line2[1-axis+2]: # print("line1, match line2", line1, line2) if_find[i] = 1 break else: if abs(point[axis] - line2[axis]) < min_distance[i] and line2[1-axis] <= point[1-axis] <= line2[1-axis+2]: if line1[axis] <= line2[axis] <= line1[axis+2]: continue min_distance[i] = abs(line1[axis] - line2[axis]) min_assist_line[i] = line2 # 找出离assist_line最近的交点 # 顶点到交点的距离(多出来的线)需大于assist_line到交点的距离(bbox的边)的1/3 min_distance = [1000, 1000] min_col_point = [[], []] for i in range(2): # print("顶点", i, line1_point[i]) if min_assist_line[i]: for point in fix_line_points: if abs(point[axis] - min_assist_line[i][axis]) < min_distance[i]: min_distance[i] = abs(point[axis] - min_assist_line[i][axis]) min_col_point[i] = point # print("min_col_point", min_col_point) # print("min_assist_line", min_assist_line) # print("line1_point", line1_point) if min_assist_line[0] and min_assist_line[0] == min_assist_line[1]: if min_assist_line[0][axis] < line1_point[0][axis]: bbox_len = abs(min_col_point[0][axis] - min_assist_line[0][axis]) line_distance = abs(min_col_point[0][axis] - line1_point[0][axis]) if bbox_len/3 <= line_distance <= bbox_len: if axis == 1: add_point = (line1_point[0][1-axis], min_assist_line[0][axis]) else: add_point = (min_assist_line[0][axis], line1_point[0][1-axis]) new_points.append([line1, add_point]) elif min_assist_line[1][axis] > line1_point[1][axis]: bbox_len = abs(min_col_point[1][axis] - min_assist_line[1][axis]) line_distance = abs(min_col_point[1][axis] - line1_point[1][axis]) if bbox_len/3 <= line_distance <= bbox_len: if axis == 1: add_point = (line1_point[1][1-axis], min_assist_line[1][axis]) else: add_point = (min_assist_line[1][axis], line1_point[1][1-axis]) new_points.append([line1, add_point]) else: for i in range(2): if min_col_point[i]: bbox_len = abs(min_col_point[i][axis] - min_assist_line[i][axis]) line_distance = abs(min_col_point[i][axis] - line1_point[i][axis]) # print("bbox_len, line_distance", bbox_len, line_distance) if bbox_len/3 <= line_distance <= bbox_len: if axis == 1: add_point = (line1_point[i][1-axis], min_assist_line[i][axis]) else: add_point = (min_assist_line[i][axis], line1_point[i][1-axis]) # print("============================table line==") # print("fix_inner add point", add_point) # print(min_col_point[i][axis], line1_point[i][axis], min_col_point[i][axis], min_assist_line[i][axis]) # print(abs(min_col_point[i][axis] - line1_point[i][axis]), abs(min_col_point[i][axis] - min_assist_line[i][axis])/3) # print("line1, line2", line1, min_assist_line[i]) # print("line1, add_point", [line1, add_point]) new_points.append([line1, add_point]) return new_points row_lines_copy = copy.deepcopy(row_lines) col_lines_copy = copy.deepcopy(col_lines) try: new_points = [] for i in range(1, len(split_y)): last_y = split_y[i-1] y = split_y[i] # 先对点线进行分区 split_row_lines = [] split_col_lines = [] split_points = [] for row in row_lines: if last_y <= row[1] <= y: split_row_lines.append(row) for col in col_lines: if last_y <= col[1] <= y: split_col_lines.append(col) for point in points: if last_y <= point[1] <= y: split_points.append(point) new_point_list = fix(split_col_lines, split_row_lines, split_points, axis=1) for line, new_point in new_point_list: if line in col_lines: index = col_lines.index(line) point1 = line[:2] point2 = line[2:] if new_point[1] >= point2[1]: col_lines[index] = [point1[0], point1[1], new_point[0], new_point[1]] elif new_point[1] <= point1[1]: col_lines[index] = [new_point[0], new_point[1], point2[0], point2[1]] new_point_list = fix(split_row_lines, split_col_lines, split_points, axis=0) for line, new_point in new_point_list: if line in row_lines: index = row_lines.index(line) point1 = line[:2] point2 = line[2:] if new_point[0] >= point2[0]: row_lines[index] = [point1[0], point1[1], new_point[0], new_point[1]] elif new_point[0] <= point1[0]: row_lines[index] = [new_point[0], new_point[1], point2[0], point2[1]] return row_lines, col_lines except: traceback.print_exc() return row_lines_copy, col_lines_copy def fix_corner(row_lines, col_lines, split_y, threshold=0): new_row_lines = [] new_col_lines = [] last_y = split_y[0] for y in split_y: if y == last_y: continue split_row_lines = [] split_col_lines = [] for row in row_lines: if last_y-threshold <= row[1] <= y+threshold or last_y-threshold <= row[3] <= y+threshold: split_row_lines.append(row) for col in col_lines: # fix corner 容易因split line 漏掉线 if last_y-threshold <= col[1] <= y+threshold or last_y-threshold <= col[3] <= y+threshold: split_col_lines.append(col) if not split_row_lines or not split_col_lines: last_y = y continue split_row_lines.sort(key=lambda x: (x[1], x[0])) split_col_lines.sort(key=lambda x: (x[0], x[1])) up_line = split_row_lines[0] bottom_line = split_row_lines[-1] left_line = split_col_lines[0] right_line = split_col_lines[-1] # 左上角 if up_line[0:2] != left_line[0:2]: # print("up_line, left_line", up_line, left_line) add_corner = [left_line[0], up_line[1]] split_row_lines[0][0] = add_corner[0] split_col_lines[0][1] = add_corner[1] # 右上角 if up_line[2:] != right_line[:2]: # print("up_line, right_line", up_line, right_line) add_corner = [right_line[0], up_line[1]] split_row_lines[0][2] = add_corner[0] split_col_lines[-1][1] = add_corner[1] new_row_lines = new_row_lines + split_row_lines new_col_lines = new_col_lines + split_col_lines last_y = y return new_row_lines, new_col_lines def delete_outline(row_lines, col_lines, points): row_lines.sort(key=lambda x: (x[1], x[0])) col_lines.sort(key=lambda x: (x[0], x[1])) line = [row_lines[0], row_lines[-1], col_lines[0], col_lines[-1]] threshold = 2 point_cnt = [0, 0, 0, 0] for point in points: for i in range(4): if i < 2: if line[i][1]-threshold <= point[1] <= line[i][1]+threshold: if line[i][0] <= point[0] <= line[i][2]: point_cnt[i] += 1 else: if line[i][0]-threshold <= point[0] <= line[i][0]+threshold: if line[i][1] <= point[1] <= line[i][3]: point_cnt[i] += 1 # if line[0][1]-threshold <= point[1] <= line[0][1]+threshold: # if line[0][0] <= point[0] <= line[0][2]: # point_cnt[0] += 1 # elif line[1][1]-threshold <= point[1] <= line[1][1]+threshold: # if line[1][0] <= point[0] <= line[1][2]: # point_cnt[1] += 1 # elif line[2][0]-threshold <= point[0] <= line[2][0]+threshold: # if line[2][1] <= point[1] <= line[2][3]: # point_cnt[2] += 1 # elif line[3][0]-threshold <= point[0] <= line[3][0]+threshold: # if line[3][1] <= point[1] <= line[3][3]: # point_cnt[3] += 1 # 轮廓line至少包含3个交点 for i in range(4): if point_cnt[i] < 3: if i < 2: if line[i] in row_lines: row_lines.remove(line[i]) else: if line[i] in col_lines: col_lines.remove(line[i]) return row_lines, col_lines def fix_outline2(image, row_lines, col_lines, points, split_y): print("split_y", split_y) # 分割线纵坐标 if len(split_y) < 2: return [], [], [], [] # elif len(split_y) == 2: # split_y = [2000., 2000., 2000., 2000.] split_y.sort(key=lambda x: x) new_split_y = [] for i in range(1, len(split_y), 2): new_split_y.append(int((split_y[i]+split_y[i-1])/2)) # # 查看是否正确输出区域分割线 # for line in split_y: # cv2.line(image, (0, int(line)), (int(image.shape[1]), int(line)), (0, 0, 255), 2) # cv2.imshow("split_y", image) # cv2.waitKey(0) # 预测线根据分割线纵坐标分为多个分割区域 # row_lines.sort(key=lambda x: (x[3], x[2], x[1], x[0])) # col_lines.sort(key=lambda x: (x[3], x[2], x[1], x[0])) # points.sort(key=lambda x: (x[1], x[0])) # row_count = 0 # col_count = 0 # point_count = 0 split_row_list = [] split_col_list = [] split_point_list = [] # for i in range(1, len(split_y)): # y = split_y[i] # last_y = split_y[i-1] # row_lines = row_lines[row_count:] # col_lines = col_lines[col_count:] # points = points[point_count:] # row_count = 0 # col_count = 0 # point_count = 0 # # if not row_lines: # split_row_list.append([]) # for row in row_lines: # if last_y <= row[3] <= y: # row_count += 1 # else: # split_row_list.append(row_lines[:row_count]) # break # if row_count == len(row_lines): # split_row_list.append(row_lines[:row_count]) # break # # if not col_lines: # split_col_list.append([]) # # for col in col_lines: # # if last_y <= col[3] <= y: # if col[1] <= last_y <= y <= col[3] or last_y <= col[3] <= y: # # if last_y <= col[1] <= y or last_y <= col[3] <= y: # col_count += 1 # else: # split_col_list.append(col_lines[:col_count]) # break # if col_count == len(col_lines): # split_col_list.append(col_lines[:col_count]) # break # # if not points: # split_point_list.append([]) # for point in points: # if last_y <= point[1] <= y: # point_count += 1 # else: # split_point_list.append(points[:point_count]) # break # if point_count == len(points): # split_point_list.append(points[:point_count]) # break # # # print("len(split_row_list)", len(split_row_list)) # # print("len(split_col_list)", len(split_col_list)) # if row_count < len(row_lines) - 1 and col_count < len(col_lines) - 1: # row_lines = row_lines[row_count:] # split_row_list.append(row_lines) # col_lines = col_lines[col_count:] # split_col_list.append(col_lines) # # if point_count < len(points) - 1: # points = points[point_count:len(points)] # split_point_list.append(points) for i in range(1, len(split_y)): y = split_y[i] last_y = split_y[i-1] split_row = [] for row in row_lines: if last_y <= row[3] <= y: split_row.append(row) split_row_list.append(split_row) split_col = [] for col in col_lines: if last_y <= col[1] <= y or last_y <= col[3] <= y or col[1] < last_y < y < col[3]: split_col.append(col) split_col_list.append(split_col) split_point = [] for point in points: if last_y <= point[1] <= y: split_point.append(point) split_point_list.append(split_point) # 预测线取上下左右4个边(会有超出表格部分) [(), ()] area_row_line = [] area_col_line = [] for area in split_row_list: if not area: area_row_line.append([]) continue area.sort(key=lambda x: (x[1], x[0])) up_line = area[0] bottom_line = area[-1] area_row_line.append([up_line, bottom_line]) for area in split_col_list: if not area: area_col_line.append([]) continue area.sort(key=lambda x: x[0]) left_line = area[0] right_line = area[-1] area_col_line.append([left_line, right_line]) # 线交点根据分割线纵坐标分为多个分割区域 # points.sort(key=lambda x: (x[1], x[0])) # point_count = 0 # split_point_list = [] # for y in new_split_y: # points = points[point_count:len(points)] # point_count = 0 # for point in points: # if point[1] <= y: # point_count += 1 # else: # split_point_list.append(points[:point_count]) # break # if point_count == len(points): # split_point_list.append(points[:point_count]) # break # if point_count < len(points) - 1: # points = points[point_count:len(points)] # split_point_list.append(points) # print("len(split_point_list)", len(split_point_list)) # 取每个分割区域的4条线(无超出表格部分) area_row_line2 = [] area_col_line2 = [] for area in split_point_list: if not area: area_row_line2.append([]) area_col_line2.append([]) continue area.sort(key=lambda x: (x[0], x[1])) left_up = area[0] right_bottom = area[-1] up_line = [left_up[0], left_up[1], right_bottom[0], left_up[1]] bottom_line = [left_up[0], right_bottom[1], right_bottom[0], right_bottom[1]] left_line = [left_up[0], left_up[1], left_up[0], right_bottom[1]] right_line = [right_bottom[0], left_up[1], right_bottom[0], right_bottom[1]] area_row_line2.append([up_line, bottom_line]) area_col_line2.append([left_line, right_line]) # 判断超出部分的长度,超出一定长度就补线 new_row_lines = [] new_col_lines = [] longer_row_lines = [] longer_col_lines = [] all_longer_row_lines = [] all_longer_col_lines = [] # print("split_y", split_y) # print("split_row_list", split_row_list, len(split_row_list)) # print("split_row_list", split_col_list, len(split_col_list)) # print("area_row_line", area_row_line, len(area_row_line)) # print("area_col_line", area_col_line, len(area_col_line)) for i in range(len(area_row_line)): if not area_row_line[i] or not area_col_line[i]: continue up_line = area_row_line[i][0] up_line2 = area_row_line2[i][0] bottom_line = area_row_line[i][1] bottom_line2 = area_row_line2[i][1] left_line = area_col_line[i][0] left_line2 = area_col_line2[i][0] right_line = area_col_line[i][1] right_line2 = area_col_line2[i][1] # 计算单格高度宽度 if len(split_row_list[i]) > 1: height_dict = {} for j in range(len(split_row_list[i])): if j + 1 > len(split_row_list[i]) - 1: break height = abs(int(split_row_list[i][j][3] - split_row_list[i][j+1][3])) if height in height_dict.keys(): height_dict[height] = height_dict[height] + 1 else: height_dict[height] = 1 height_list = [[x, height_dict[x]] for x in height_dict.keys()] height_list.sort(key=lambda x: (x[1], -x[0]), reverse=True) # print("height_list", height_list) box_height = height_list[0][0] else: box_height = 10 if len(split_col_list[i]) > 1: box_width = abs(split_col_list[i][1][2] - split_col_list[i][0][2]) else: box_width = 10 print("box_height", box_height, "box_width", box_width) # cv2.line(image, (int(up_line[0]), int(up_line[1])), # (int(up_line[2]), int(up_line[3])), # (255, 255, 0), 2) # cv2.line(image, (int(right_line[0]), int(right_line[1])), # (int(right_line[2]), int(right_line[3])), # (0, 255, 255), 2) # cv2.imshow("right_line", image) # cv2.waitKey(0) # 补左右两条竖线超出来的线的row if (up_line[1] - left_line[1] >= 10 and up_line[1] - right_line[1] >= 2) or \ (up_line[1] - left_line[1] >= 2 and up_line[1] - right_line[1] >= 10): if up_line[1] - left_line[1] >= up_line[1] - right_line[1]: new_row_lines.append([left_line[0], left_line[1], right_line[0], left_line[1]]) new_col_y = left_line[1] # 补了row,要将其他短的col连到row上 for j in range(len(split_col_list[i])): col = split_col_list[i][j] # 且距离不能相差大于一格 # print("abs(new_col_y - col[1])", abs(new_col_y - col[1])) if abs(new_col_y - col[1]) <= box_height: split_col_list[i][j][1] = min([new_col_y, col[1]]) longer_col_lines.append([col[0], min([new_col_y, col[1]]), col[2], col[3]]) else: new_row_lines.append([left_line[0], right_line[1], right_line[0], right_line[1]]) new_col_y = right_line[1] # 补了row,要将其他短的col连到row上 for j in range(len(split_col_list[i])): # 需判断该线在这个区域中 # if up_line2[1]-3 <= col[1] <= col[3] <= bottom_line2[1]+3: col = split_col_list[i][j] # 且距离不能相差太大 # print("abs(new_col_y - col[1])", abs(new_col_y - col[1])) if abs(new_col_y - col[1]) <= box_height: split_col_list[i][j][1] = min([new_col_y, col[1]]) if (left_line[3] - bottom_line[3] >= 10 and right_line[3] - bottom_line[3] >= 2) or \ (left_line[3] - bottom_line[3] >= 2 and right_line[3] - bottom_line[3] >= 10): if left_line[3] - bottom_line[3] >= right_line[3] - bottom_line[3]: new_row_lines.append([left_line[2], left_line[3], right_line[2], left_line[3]]) new_col_y = left_line[3] # 补了row,要将其他短的col连到row上 for j in range(len(split_col_list[i])): col = split_col_list[i][j] # 且距离不能相差太大 if abs(new_col_y - col[3]) <= box_height: split_col_list[i][j][3] = max([new_col_y, col[3]]) else: new_row_lines.append([left_line[2], right_line[3], right_line[2], right_line[3]]) new_col_y = right_line[3] # 补了row,要将其他短的col连到row上 for j in range(len(split_col_list[i])): col = split_col_list[i][j] # 且距离不能相差太大 if abs(new_col_y - col[3]) <= box_height: split_col_list[i][j][3] = max([new_col_y, col[3]]) # 补上下两条横线超出来的线的col if (left_line[0] - up_line[0] >= 10 and left_line[0] - bottom_line[0] >= 2) or \ (left_line[0] - up_line[0] >= 2 and left_line[0] - bottom_line[0] >= 10): if left_line[0] - up_line[0] >= left_line[0] - bottom_line[0]: new_col_lines.append([up_line[0], up_line[1], up_line[0], bottom_line[1]]) new_row_x = up_line[0] # 补了col,要将其他短的row连到col上 for j in range(len(split_row_list[i])): row = split_row_list[i][j] # 且距离不能相差太大 if abs(new_row_x - row[0]) <= box_width: split_row_list[i][j][0] = min([new_row_x, row[0]]) else: new_col_lines.append([bottom_line[0], up_line[1], bottom_line[0], bottom_line[1]]) new_row_x = bottom_line[0] # 补了col,要将其他短的row连到col上 for j in range(len(split_row_list[i])): row = split_row_list[i][j] # 且距离不能相差太大 if abs(new_row_x - row[0]) <= box_width: split_row_list[i][j][0] = min([new_row_x, row[0]]) if (up_line[2] - right_line[2] >= 10 and bottom_line[2] - right_line[2] >= 2) or \ (up_line[2] - right_line[2] >= 2 and bottom_line[2] - right_line[2] >= 10): if up_line[2] - right_line[2] >= bottom_line[2] - right_line[2]: new_col_lines.append([up_line[2], up_line[3], up_line[2], bottom_line[3]]) new_row_x = up_line[2] # 补了col,要将其他短的row连到col上 for j in range(len(split_row_list[i])): row = split_row_list[i][j] # 且距离不能相差太大 if abs(new_row_x - row[2]) <= box_width: split_row_list[i][j][2] = max([new_row_x, row[2]]) else: new_col_lines.append([bottom_line[2], up_line[3], bottom_line[2], bottom_line[3]]) new_row_x = bottom_line[2] # 补了col,要将其他短的row连到col上 for j in range(len(split_row_list[i])): # 需判断该线在这个区域中 # if up_line2[1]-3 <= row[1] <= bottom_line2[1]+3: row = split_row_list[i][j] # 且距离不能相差太大 if abs(new_row_x - row[2]) <= box_width: split_row_list[i][j][2] = max([new_row_x, row[2]]) all_longer_row_lines += split_row_list[i] all_longer_col_lines += split_col_list[i] # print("all_longer_row_lines", len(all_longer_row_lines), i) # print("all_longer_col_lines", len(all_longer_col_lines), i) # print("new_row_lines", len(new_row_lines), i) # print("new_col_lines", len(new_col_lines), i) # 删除表格内部的补线 # temp_list = [] # for row in new_row_lines: # if up_line[1]-5 <= row[1] <= bottom_line[1]+5: # continue # temp_list.append(row) # print("fix_outline", new_row_lines) # new_row_lines = temp_list # print("fix_outline", new_row_lines) # temp_list = [] # for col in new_col_lines: # if left_line[0]-5 <= col[0] <= right_line[0]+5: # continue # temp_list.append(col) # # new_col_lines = temp_list # print("fix_outline", new_col_lines) # print("fix_outline", new_row_lines) # 删除重复包含的补线 # temp_list = [] # for row in new_row_lines: # if up_line[1]-5 <= row[1] <= bottom_line[1]+5: # continue # temp_list.append(row) # new_row_lines = temp_list # 展示上下左右边框线 # for i in range(len(area_row_line)): # print("row1", area_row_line[i]) # print("row2", area_row_line2[i]) # print("col1", area_col_line[i]) # print("col2", area_col_line2[i]) # cv2.line(image, (int(area_row_line[i][0][0]), int(area_row_line[i][0][1])), # (int(area_row_line[i][0][2]), int(area_row_line[i][0][3])), (0, 255, 0), 2) # cv2.line(image, (int(area_row_line2[i][1][0]), int(area_row_line2[i][1][1])), # (int(area_row_line2[i][1][2]), int(area_row_line2[i][1][3])), (0, 0, 255), 2) # cv2.imshow("fix_outline", image) # cv2.waitKey(0) # 展示所有线 # for line in all_longer_col_lines: # cv2.line(image, (int(line[0]), int(line[1])), # (int(line[2]), int(line[3])), # (0, 255, 0), 2) # cv2.imshow("fix_outline", image) # cv2.waitKey(0) # for line in all_longer_row_lines: # cv2.line(image, (int(line[0]), int(line[1])), # (int(line[2]), int(line[3])), # (0, 0, 255), 2) # cv2.imshow("fix_outline", image) # cv2.waitKey(0) return new_row_lines, new_col_lines, all_longer_row_lines, all_longer_col_lines def fix_outline(image, row_lines, col_lines, points, split_y, scale=25): log("into fix_outline") x_min_len = max(10, int(image.shape[0] / scale)) y_min_len = max(10, int(image.shape[1] / scale)) # print("x_min_len", x_min_len, "y_min_len", y_min_len) # print("split_y", split_y) # 分割线纵坐标 if len(split_y) < 2: return [], [], [], [] split_y.sort(key=lambda x: x) new_split_y = [] for i in range(1, len(split_y), 2): new_split_y.append(int((split_y[i]+split_y[i-1])/2)) split_row_list = [] split_col_list = [] split_point_list = [] for i in range(1, len(split_y)): y = split_y[i] last_y = split_y[i-1] split_row = [] for row in row_lines: if last_y <= row[3] <= y: split_row.append(row) split_row_list.append(split_row) split_col = [] for col in col_lines: if last_y <= col[1] <= y or last_y <= col[3] <= y or col[1] < last_y < y < col[3]: split_col.append(col) split_col_list.append(split_col) split_point = [] for point in points: if last_y <= point[1] <= y: split_point.append(point) split_point_list.append(split_point) # 预测线取上下左右4个边(会有超出表格部分) [(), ()] area_row_line = [] area_col_line = [] for area in split_row_list: if not area: area_row_line.append([]) continue area.sort(key=lambda x: (x[1], x[0])) up_line = area[0] bottom_line = area[-1] area_row_line.append([up_line, bottom_line]) for area in split_col_list: if not area: area_col_line.append([]) continue area.sort(key=lambda x: x[0]) left_line = area[0] right_line = area[-1] area_col_line.append([left_line, right_line]) # 取每个分割区域的4条线(无超出表格部分) area_row_line2 = [] area_col_line2 = [] for area in split_point_list: if not area: area_row_line2.append([]) area_col_line2.append([]) continue area.sort(key=lambda x: (x[0], x[1])) left_up = area[0] right_bottom = area[-1] up_line = [left_up[0], left_up[1], right_bottom[0], left_up[1]] bottom_line = [left_up[0], right_bottom[1], right_bottom[0], right_bottom[1]] left_line = [left_up[0], left_up[1], left_up[0], right_bottom[1]] right_line = [right_bottom[0], left_up[1], right_bottom[0], right_bottom[1]] area_row_line2.append([up_line, bottom_line]) area_col_line2.append([left_line, right_line]) # 判断超出部分的长度,超出一定长度就补线 new_row_lines = [] new_col_lines = [] longer_row_lines = [] longer_col_lines = [] all_longer_row_lines = [] all_longer_col_lines = [] for i in range(len(area_row_line)): if not area_row_line[i] or not area_col_line[i]: continue up_line = area_row_line[i][0] up_line2 = area_row_line2[i][0] bottom_line = area_row_line[i][1] bottom_line2 = area_row_line2[i][1] left_line = area_col_line[i][0] left_line2 = area_col_line2[i][0] right_line = area_col_line[i][1] right_line2 = area_col_line2[i][1] # 计算单格高度宽度 if len(split_row_list[i]) > 1: height_dict = {} for j in range(len(split_row_list[i])): if j + 1 > len(split_row_list[i]) - 1: break # print("height_dict", split_row_list[i][j], split_row_list[i][j+1]) height = abs(int(split_row_list[i][j][3] - split_row_list[i][j+1][3])) if height >= 10: if height in height_dict.keys(): height_dict[height] = height_dict[height] + 1 else: height_dict[height] = 1 height_list = [[x, height_dict[x]] for x in height_dict.keys()] height_list.sort(key=lambda x: (x[1], -x[0]), reverse=True) # print("box_height", height_list) box_height = height_list[0][0] else: box_height = y_min_len if len(split_col_list[i]) > 1: box_width = abs(split_col_list[i][1][2] - split_col_list[i][0][2]) else: box_width = x_min_len # print("box_height", box_height, "box_width", box_width) # 设置轮廓线需超出阈值 if box_height >= 2*y_min_len: fix_h_len = y_min_len else: fix_h_len = box_height * 2/3 if box_width >= 2*x_min_len: fix_w_len = x_min_len else: fix_w_len = box_width * 2/3 # 补左右两条竖线超出来的线的row if up_line[1] - left_line[1] >= fix_h_len and up_line[1] - right_line[1] >= fix_h_len: if up_line[1] - left_line[1] >= up_line[1] - right_line[1]: new_row_lines.append([left_line[0], left_line[1], right_line[0], left_line[1]]) new_col_y = left_line[1] # 补了row,要将其他短的col连到row上 for j in range(len(split_col_list[i])): col = split_col_list[i][j] # 且距离不能相差大于一格 # print("abs(new_col_y - col[1])", abs(new_col_y - col[1])) if abs(new_col_y - col[1]) <= box_height: split_col_list[i][j][1] = min([new_col_y, col[1]]) longer_col_lines.append([col[0], min([new_col_y, col[1]]), col[2], col[3]]) else: new_row_lines.append([left_line[0], right_line[1], right_line[0], right_line[1]]) new_col_y = right_line[1] # 补了row,要将其他短的col连到row上 for j in range(len(split_col_list[i])): # 需判断该线在这个区域中 # if up_line2[1]-3 <= col[1] <= col[3] <= bottom_line2[1]+3: col = split_col_list[i][j] # 且距离不能相差太大 # print("abs(new_col_y - col[1])", abs(new_col_y - col[1])) if abs(new_col_y - col[1]) <= box_height: split_col_list[i][j][1] = min([new_col_y, col[1]]) if left_line[3] - bottom_line[3] >= fix_h_len and right_line[3] - bottom_line[3] >= fix_h_len: if left_line[3] - bottom_line[3] >= right_line[3] - bottom_line[3]: new_row_lines.append([left_line[2], left_line[3], right_line[2], left_line[3]]) new_col_y = left_line[3] # 补了row,要将其他短的col连到row上 for j in range(len(split_col_list[i])): col = split_col_list[i][j] # 且距离不能相差太大 if abs(new_col_y - col[3]) <= box_height: split_col_list[i][j][3] = max([new_col_y, col[3]]) else: new_row_lines.append([left_line[2], right_line[3], right_line[2], right_line[3]]) new_col_y = right_line[3] # 补了row,要将其他短的col连到row上 for j in range(len(split_col_list[i])): col = split_col_list[i][j] # 且距离不能相差太大 if abs(new_col_y - col[3]) <= box_height: split_col_list[i][j][3] = max([new_col_y, col[3]]) # 补上下两条横线超出来的线的col if left_line[0] - up_line[0] >= fix_w_len and left_line[0] - bottom_line[0] >= fix_w_len: if left_line[0] - up_line[0] >= left_line[0] - bottom_line[0]: new_col_lines.append([up_line[0], up_line[1], up_line[0], bottom_line[1]]) new_row_x = up_line[0] # 补了col,要将其他短的row连到col上 for j in range(len(split_row_list[i])): row = split_row_list[i][j] # 且距离不能相差太大 if abs(new_row_x - row[0]) <= box_width: split_row_list[i][j][0] = min([new_row_x, row[0]]) else: new_col_lines.append([bottom_line[0], up_line[1], bottom_line[0], bottom_line[1]]) new_row_x = bottom_line[0] # 补了col,要将其他短的row连到col上 for j in range(len(split_row_list[i])): row = split_row_list[i][j] # 且距离不能相差太大 if abs(new_row_x - row[0]) <= box_width: split_row_list[i][j][0] = min([new_row_x, row[0]]) if up_line[2] - right_line[2] >= fix_w_len and bottom_line[2] - right_line[2] >= fix_w_len: if up_line[2] - right_line[2] >= bottom_line[2] - right_line[2]: new_col_lines.append([up_line[2], up_line[3], up_line[2], bottom_line[3]]) new_row_x = up_line[2] # 补了col,要将其他短的row连到col上 for j in range(len(split_row_list[i])): row = split_row_list[i][j] # 且距离不能相差太大 if abs(new_row_x - row[2]) <= box_width: split_row_list[i][j][2] = max([new_row_x, row[2]]) else: new_col_lines.append([bottom_line[2], up_line[3], bottom_line[2], bottom_line[3]]) new_row_x = bottom_line[2] # 补了col,要将其他短的row连到col上 for j in range(len(split_row_list[i])): # 需判断该线在这个区域中 # if up_line2[1]-3 <= row[1] <= bottom_line2[1]+3: row = split_row_list[i][j] # 且距离不能相差太大 if abs(new_row_x - row[2]) <= box_width: split_row_list[i][j][2] = max([new_row_x, row[2]]) all_longer_row_lines += split_row_list[i] all_longer_col_lines += split_col_list[i] return new_row_lines, new_col_lines, all_longer_row_lines, all_longer_col_lines def fix_table(row_point_list, col_point_list, split_y, row_lines, col_lines): # 分割线纵坐标 if len(split_y) < 2: return [] # 获取bbox bbox = [] # 每个点获取与其x最相近和y最相近的点 for i in range(1, len(split_y)): # 循环每行 for row in row_point_list: row.sort(key=lambda x: (x[0], x[1])) # 行不在该区域跳过 if row[0][1] <= split_y[i-1] or row[0][1] >= split_y[i]: continue # print("len(row)", len(row)) # print("row", row) # 循环行中的点 for j in range(len(row)): if j == len(row) - 1: break current_point = row[j] next_point_in_row_list = row[j+1:] # 循环这一行的下一个点 for next_point_in_row in next_point_in_row_list: # 是否在这一行点找到,找不到就这一行的下个点 not_found = 1 # 查询下个点所在列 next_col = [] for col in col_point_list: col.sort(key=lambda x: (x[1], x[0])) # 列不在该区域跳过 if col[0][1] <= split_y[i-1] or col[-1][1] >= split_y[i]: continue if col[0][0]-3 <= next_point_in_row[0] <= col[0][0]+3: next_col = col break # 循环匹配当前点和下一列点 next_col.sort(key=lambda x: (x[1], x[0])) for point1 in next_col: # 同一行的就跳过 if current_point[1]-3 <= point1[1] <= current_point[1]+3: continue if point1[1] <= current_point[1]-3: continue # 候选bbox candidate_bbox = [current_point[0], current_point[1], point1[0], point1[1]] # print("candidate_bbox", candidate_bbox) # 判断该bbox是否存在,判断bbox的上下两条边是否有包含在row中 contain_flag1 = 0 contain_flag2 = 0 for row1 in row_lines: # 行不在该区域跳过 if row1[1] <= split_y[i-1] or row1[1] >= split_y[i]: continue # bbox上边框 y一样 if not contain_flag1: if row1[1]-3 <= candidate_bbox[1] <= row1[1]+3: # 格子里的断开线段 row1_break = (max([row1[0], candidate_bbox[0]]), row1[1], min([row1[2], candidate_bbox[2]]), row1[3]) if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3: contain_flag1 = 1 # bbox下边框 y一样 if not contain_flag2: if row1[1]-3 <= candidate_bbox[3] <= row1[1]+3: # 格子里的断开线段 row1_break = (max([row1[0], candidate_bbox[0]]), row1[1], min([row1[2], candidate_bbox[2]]), row1[3]) if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3: contain_flag2 = 1 # 判断该bbox是否存在,判断bbox的左右两条边是否有包含在col中 contain_flag3 = 0 contain_flag4 = 0 for col1 in col_lines: # 列不在该区域跳过 if col1[1] <= split_y[i-1] or col1[3] >= split_y[i]: continue # bbox左边线 x一样 if not contain_flag3: if col1[0]-3 <= candidate_bbox[0] <= col1[0]+3: # 格子里的断开线段 col1_break = (col1[0], max([col1[1], candidate_bbox[1]]), col1[2], min([col1[3], candidate_bbox[3]])) if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3: contain_flag3 = 1 # bbox右边框 x一样 if not contain_flag4: if col1[0]-3 <= candidate_bbox[2] <= col1[0]+3: # 格子里的断开线段 col1_break = (col1[0], max([col1[1], candidate_bbox[1]]), col1[2], min([col1[3], candidate_bbox[3]])) if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3: contain_flag4 = 1 # 找到了该bbox,并且是存在的 if contain_flag1 and contain_flag2 and contain_flag3 and contain_flag4: bbox.append([(candidate_bbox[0], candidate_bbox[1]), (candidate_bbox[2], candidate_bbox[3])]) not_found = 0 break if not not_found: break return bbox def delete_close_points(point_list, row_point_list, col_point_list, threshold=5): new_point_list = [] delete_point_list = [] point_list.sort(key=lambda x: (x[1], x[0])) for i in range(len(point_list)): point1 = point_list[i] if point1 in delete_point_list: continue if i == len(point_list) - 1: new_point_list.append(point1) break point2 = point_list[i+1] # 判断坐标 if abs(point1[0] - point2[0]) > threshold or abs(point1[1] - point2[1]) > threshold: new_point_list.append(point1) else: # 看两个点上的相同坐标点哪个多,就保留哪个 count1 = 0 count2 = 0 for col in col_point_list: if point1[0] == col[0][0]: count1 += len(col) elif point2[0] == col[0][0]: count2 += len(col) if count1 >= count2: new_point_list.append(point1) delete_point_list.append(point2) else: new_point_list.append(point2) delete_point_list.append(point1) point_list = new_point_list new_point_list = [] delete_point_list = [] point_list.sort(key=lambda x: (x[0], x[1])) for i in range(len(point_list)): point1 = point_list[i] if point1 in delete_point_list: continue if i == len(point_list) - 1: new_point_list.append(point1) break point2 = point_list[i+1] # 判断坐标 if abs(point1[0] - point2[0]) > threshold or abs(point1[1] - point2[1]) > threshold: new_point_list.append(point1) else: count1 = 0 count2 = 0 for row in row_point_list: if point1[0] == row[0][0]: count1 += len(row) elif point2[0] == row[0][0]: count2 += len(row) if count1 >= count2: new_point_list.append(point1) delete_point_list.append(point2) else: new_point_list.append(point2) delete_point_list.append(point1) return new_point_list def get_bbox2(image_np, points): # # 坐标点按行分 # row_point_list = [] # row_point = [] # points.sort(key=lambda x: (x[0], x[1])) # for p in points: # if len(row_point) == 0: # x = p[0] # if x-5 <= p[0] <= x+5: # row_point.append(p) # else: # row_point_list.append(row_point) # row_point = [] # # 坐标点按列分 # col_point_list = [] # col_point = [] # points.sort(key=lambda x: (x[1], x[0])) # for p in points: # if len(col_point) == 0: # y = p[1] # if y-5 <= p[1] <= y+5: # col_point.append(p) # else: # col_point_list.append(col_point) # col_point = [] row_point_list = get_points_row(points) col_point_list = get_points_col(points) print("len(points)", len(points)) for point in points: cv2.circle(image_np, point, 1, (0, 255, 0), 1) cv2.imshow("points_deleted", image_np) points = delete_close_points(points, row_point_list, col_point_list) print("len(points)", len(points)) for point in points: cv2.circle(image_np, point, 1, (255, 0, 0), 3) cv2.imshow("points_deleted", image_np) cv2.waitKey(0) row_point_list = get_points_row(points, 5) col_point_list = get_points_col(points, 5) print("len(row_point_list)", len(row_point_list)) for row in row_point_list: print("row", len(row)) print("col_point_list", len(col_point_list)) for col in col_point_list: print("col", len(col)) bbox = [] for i in range(len(row_point_list)): if i == len(row_point_list) - 1: break # 遍历每个row的point,找到其所在列的下一个点和所在行的下一个点 current_row = row_point_list[i] for j in range(len(current_row)): current_point = current_row[j] if j == len(current_row) - 1: break next_row_point = current_row[j+1] # 找出当前点所在的col,得到该列下一个point current_col = col_point_list[j] for k in range(len(current_col)): if current_col[k][1] > current_point[1] + 10: next_col_point = current_col[k] break next_row = row_point_list[k] for k in range(len(next_row)): if next_row[k][0] >= next_row_point[0] + 5: next_point = next_row[k] break # 得到bbox bbox.append([(current_point[0], current_point[1]), (next_point[0], next_point[1])]) # bbox = [] # for p in points: # # print("p", p) # p_row = [] # p_col = [] # for row in row_point_list: # if p[0] == row[0][0]: # for p1 in row: # if abs(p[1]-p1[1]) <= 5: # continue # p_row.append([p1, abs(p[1]-p1[1])]) # p_row.sort(key=lambda x: x[1]) # for col in col_point_list: # if p[1] == col[0][1]: # for p2 in col: # if abs(p[0]-p2[0]) <= 5: # continue # p_col.append([p2, abs(p[0]-p2[0])]) # p_col.sort(key=lambda x: x[1]) # if len(p_row) == 0 or len(p_col) == 0: # continue # break_flag = 0 # for i in range(len(p_row)): # for j in range(len(p_col)): # # print(p_row[i][0]) # # print(p_col[j][0]) # another_point = (p_col[j][0][0], p_row[i][0][1]) # # print("another_point", another_point) # if abs(p[0]-another_point[0]) <= 5 or abs(p[1]-another_point[1]) <= 5: # continue # if p[0] >= another_point[0] or p[1] >= another_point[1]: # continue # if another_point in points: # box = [p, another_point] # box.sort(key=lambda x: x[0]) # if box not in bbox: # bbox.append(box) # break_flag = 1 # break # if break_flag: # break # # # delete duplicate # delete_bbox = [] # for i in range(len(bbox)): # for j in range(i+1, len(bbox)): # if bbox[i][0] == bbox[j][0]: # if bbox[i][1][0] - bbox[j][1][0] <= 3 \ # and bbox[i][1][1] - bbox[j][1][1] <= 3: # delete_bbox.append(bbox[j]) # if bbox[i][1] == bbox[j][1]: # if bbox[i][0][0] - bbox[j][0][0] <= 3 \ # and bbox[i][0][1] - bbox[j][0][1] <= 3: # delete_bbox.append(bbox[j]) # # delete too small area # # for box in bbox: # # if box[1][0] - box[0][0] <= # for d_box in delete_bbox: # if d_box in bbox: # bbox.remove(d_box) # print bbox bbox.sort(key=lambda x: (x[0][0], x[0][1], x[1][0], x[1][1])) # origin bbox # origin_bbox = [] # for box in bbox: # origin_bbox.append([(box[0][0], box[0][1] - 40), (box[1][0], box[1][1] - 40)]) # for box in origin_bbox: # cv2.rectangle(origin_image, box[0], box[1], (0, 0, 255), 2, 8) # cv2.imshow('AlanWang', origin_image) # cv2.waitKey(0) for box in bbox: cv2.rectangle(image_np, box[0], box[1], (0, 0, 255), 2, 8) cv2.imshow('bboxes', image_np) cv2.waitKey(0) # for point in points: # print(point) # cv2.circle(image_np, point, 1, (0, 0, 255), 3) # cv2.imshow('points', image_np) # cv2.waitKey(0) return bbox def get_bbox1(image_np, points, split_y): # 分割线纵坐标 # print("split_y", split_y) if len(split_y) < 2: return [] # 计算行列,剔除相近交点 row_point_list = get_points_row(points) col_point_list = get_points_col(points) print("len(row_point_list)", row_point_list) print("len(col_point_list)", len(col_point_list)) # for point in points: # cv2.circle(image_np, point, 1, (0, 255, 0), 1) # cv2.imshow("points", image_np) points = delete_close_points(points, row_point_list, col_point_list) # print("len(points)", len(points)) # for point in points: # cv2.circle(image_np, point, 1, (255, 0, 0), 3) # cv2.imshow("points_deleted", image_np) # cv2.waitKey(0) # 获取bbox bbox = [] # 每个点获取与其x最相近和y最相近的点 for i in range(1, len(split_y)): for point1 in points: if point1[1] <= split_y[i-1] or point1[1] >= split_y[i]: continue distance_x = 10000 distance_y = 10000 x = 0 y = 0 threshold = 10 for point2 in points: if point2[1] <= split_y[i-1] or point2[1] >= split_y[i]: continue # 最近 x y if 2 < point2[0] - point1[0] < distance_x and point2[1] - point1[1] <= threshold: distance_x = point2[0] - point1[0] x = point2[0] if 2 < point2[1] - point1[1] < distance_y and point2[0] - point1[0] <= threshold: distance_y = point2[1] - point1[1] y = point2[1] if not x or not y: continue bbox.append([(point1[0], point1[1]), (x, y)]) # 删除包含关系bbox temp_list = [] for i in range(len(bbox)): box1 = bbox[i] for j in range(len(bbox)): if i == j: continue box2 = bbox[j] contain_flag = 0 if box2[0][0] <= box1[0][0] <= box1[1][0] <= box2[1][0] and \ box2[0][1] <= box1[0][1] <= box1[1][1] <= box2[1][1]: contain_flag = 1 break temp_list.append(box1) bbox = temp_list # 展示 for box in bbox: # print(box[0], box[1]) # if abs(box[0][1] - box[1][1]) > abs(box[0][0] - box[1][0]): # continue cv2.rectangle(image_np, box[0], box[1], (0, 0, 255), 2, 8) cv2.imshow('bboxes', image_np) cv2.waitKey(0) return bbox def get_bbox0(image_np, row_point_list, col_point_list, split_y, row_lines, col_lines): # 分割线纵坐标 if len(split_y) < 2: return [] # 计算行列,剔除相近交点 # row_point_list = get_points_row(points) # col_point_list = get_points_col(points) # points = delete_close_points(points, row_point_list, col_point_list) # row_point_list = get_points_row(points) # col_point_list = get_points_col(points) # 获取bbox bbox = [] # print("get_bbox split_y", split_y) # 每个点获取与其x最相近和y最相近的点 for i in range(1, len(split_y)): # 循环每行 for row in row_point_list: row.sort(key=lambda x: (x[0], x[1])) # 行不在该区域跳过 if row[0][1] <= split_y[i-1] or row[0][1] >= split_y[i]: continue # 循环行中的点 for j in range(len(row)): if j == len(row) - 1: break current_point = row[j] next_point_in_row = row[j+1] # 查询下个点所在列 next_col = [] for col in col_point_list: col.sort(key=lambda x: (x[1], x[0])) # 列不在该区域跳过 if col[0][1] <= split_y[i-1] or col[-1][1] >= split_y[i]: continue if col[0][0]-3 <= next_point_in_row[0] <= col[0][0]+3: next_col = col break # 循环匹配当前点和下一列点 for point1 in next_col: # 同一行的就跳过 if current_point[1]-3 <= point1[1] <= current_point[1]+3: continue if point1[1] <= current_point[1]-3: continue # 候选bbox candidate_bbox = [current_point[0], current_point[1], point1[0], point1[1]] # 判断该bbox是否存在,线条包含关系 contain_flag1 = 0 contain_flag2 = 0 for row1 in row_lines: # 行不在该区域跳过 if row1[1] <= split_y[i-1] or row1[1] >= split_y[i]: continue # bbox上边框 y一样 if not contain_flag1: if row1[1]-3 <= candidate_bbox[1] <= row1[1]+3: # candidate的x1,x2需被包含在row线中 if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3: contain_flag1 = 1 # bbox下边框 y一样 if not contain_flag2: if row1[1]-3 <= candidate_bbox[3] <= row1[1]+3: # candidate的x1,x2需被包含在row线中 if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3: contain_flag2 = 1 # 找到了该bbox,并且是存在的 if contain_flag1 and contain_flag2: bbox.append([(candidate_bbox[0], candidate_bbox[1]), (candidate_bbox[2], candidate_bbox[3])]) break return bbox def get_bbox3(image_np, row_point_list, col_point_list, split_y, row_lines, col_lines): # 分割线纵坐标 if len(split_y) < 2: return [] # 获取bbox bbox = [] # 每个点获取与其x最相近和y最相近的点 for i in range(1, len(split_y)): # 循环每行 for row in row_point_list: row.sort(key=lambda x: (x[0], x[1])) # 行不在该区域跳过 if row[0][1] <= split_y[i-1] or row[0][1] >= split_y[i]: continue # print("len(row)", len(row)) # print("row", row) # 循环行中的点 for j in range(len(row)): if j == len(row) - 1: break current_point = row[j] # print("current_point", current_point) next_point_in_row_list = row[j+1:] # 循环这一行的下一个点 for next_point_in_row in next_point_in_row_list: # 是否在这一行点找到,找不到就这一行的下个点 not_found = 1 # 查询下个点所在列 next_col = [] for col in col_point_list: col.sort(key=lambda x: (x[1], x[0])) # 列不在该区域跳过 if col[0][1] <= split_y[i-1] or col[-1][1] >= split_y[i]: continue if col[0][0]-3 <= next_point_in_row[0] <= col[0][0]+3: next_col = col break # 循环匹配当前点和下一列点 next_col.sort(key=lambda x: (x[1], x[0])) for point1 in next_col: # 同一行的就跳过 if current_point[1]-3 <= point1[1] <= current_point[1]+3: continue if point1[1] <= current_point[1]-3: continue # 候选bbox candidate_bbox = [current_point[0], current_point[1], point1[0], point1[1]] # print("candidate_bbox", candidate_bbox) # 判断该bbox是否存在,判断bbox的上下两条边是否有包含在row中 contain_flag1 = 0 contain_flag2 = 0 for row1 in row_lines: # 行不在该区域跳过 if row1[1] <= split_y[i-1] or row1[1] >= split_y[i]: continue # bbox上边框 y一样 if not contain_flag1: if row1[1]-3 <= candidate_bbox[1] <= row1[1]+3: # 格子里的断开线段 row1_break = (max([row1[0], candidate_bbox[0]]), row1[1], min([row1[2], candidate_bbox[2]]), row1[3]) if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3: contain_flag1 = 1 # # candidate的x1,x2需被包含在row线中 # if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3: # contain_flag1 = 1 # # # 判断线条有无端点在格子中 # elif candidate_bbox[0] < row1[0] < candidate_bbox[2] \ # or candidate_bbox[0] < row1[2] < candidate_bbox[2]: # # 线条会有缺一点情况,判断长度超过格子一半 # if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3: # contain_flag1 = 1 # bbox下边框 y一样 if not contain_flag2: if row1[1]-3 <= candidate_bbox[3] <= row1[1]+3: # 格子里的断开线段 row1_break = (max([row1[0], candidate_bbox[0]]), row1[1], min([row1[2], candidate_bbox[2]]), row1[3]) if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3: contain_flag2 = 1 # # candidate的x1,x2需被包含在row线中 # if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3: # contain_flag2 = 1 # # # 判断线条有无端点在格子中 # elif candidate_bbox[0] < row1[0] < candidate_bbox[2] \ # or candidate_bbox[0] < row1[2] < candidate_bbox[2]: # # 线条会有缺一点情况,判断长度超过格子一半 # if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3: # contain_flag2 = 1 # 判断该bbox是否存在,判断bbox的左右两条边是否有包含在col中 contain_flag3 = 0 contain_flag4 = 0 for col1 in col_lines: # 列不在该区域跳过 if col1[1] <= split_y[i-1] or col1[3] >= split_y[i]: continue # bbox左边线 x一样 if not contain_flag3: if col1[0]-3 <= candidate_bbox[0] <= col1[0]+3: # 格子里的断开线段 col1_break = (col1[0], max([col1[1], candidate_bbox[1]]), col1[2], min([col1[3], candidate_bbox[3]])) if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3: contain_flag3 = 1 # # candidate的y1,y2需被包含在col线中 # if col1[1]-3 <= candidate_bbox[1] <= candidate_bbox[3] <= col1[3]+3: # contain_flag3 = 1 # # # 判断线条有无端点在格子中 # elif candidate_bbox[1] < col1[1] < candidate_bbox[3] \ # or candidate_bbox[1] < col1[3] < candidate_bbox[3]: # # 线条会有缺一点情况,判断长度超过格子一半 # if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3: # contain_flag3 = 1 # bbox右边框 x一样 if not contain_flag4: if col1[0]-3 <= candidate_bbox[2] <= col1[0]+3: # 格子里的断开线段 # col1_break = (col1[0], # max([col1[1], candidate_bbox[1]]), # col1[2], # min([col1[3], candidate_bbox[3]])) # if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3: # contain_flag4 = 1 # 如果候选bbox的边的上1/3或下1/3包含在col中 candidate_bbox_line1 = [candidate_bbox[1], candidate_bbox[1] + (candidate_bbox[3]-candidate_bbox[1])/3] candidate_bbox_line2 = [candidate_bbox[3] - (candidate_bbox[3]-candidate_bbox[1])/3, candidate_bbox[3]] if col1[1] <= candidate_bbox_line1[0] <= candidate_bbox_line1[1] <= col1[3] \ or col1[1] <= candidate_bbox_line2[0] <= candidate_bbox_line2[1] <= col1[3]: # print("candidate_bbox", candidate_bbox) # print("col1", col1) contain_flag4 = 1 # # candidate的y1,y2需被包含在col线中 # if col1[1]-3 <= candidate_bbox[1] <= candidate_bbox[3] <= col1[3]+3: # contain_flag4 = 1 # # # 判断线条有无端点在格子中 # elif candidate_bbox[1] < col1[1] < candidate_bbox[3] \ # or candidate_bbox[1] < col1[3] < candidate_bbox[3]: # # 线条会有缺一点情况,判断长度超过格子一半 # if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3: # contain_flag4 = 1 # 找到了该bbox,并且是存在的 if contain_flag1 and contain_flag2 and contain_flag3 and contain_flag4: bbox.append([(candidate_bbox[0], candidate_bbox[1]), (candidate_bbox[2], candidate_bbox[3])]) not_found = 0 # print("exist candidate_bbox", candidate_bbox) # print(contain_flag1, contain_flag2, contain_flag3, contain_flag4) break # else: # print("candidate_bbox", candidate_bbox) # print(contain_flag1, contain_flag2, contain_flag3, contain_flag4) if not not_found: break return bbox def get_bbox(image_np, row_point_list, col_point_list, split_y, row_lines, col_lines): # 分割线纵坐标 if len(split_y) < 2: return [] # 获取bbox bbox_list = [] for i in range(1, len(split_y)): last_y = split_y[i-1] y = split_y[i] # 先对点线进行分区 split_row_point_list = [] split_col_point_list = [] split_row_lines = [] split_col_lines = [] for row in row_point_list: if last_y <= row[0][1] <= y: row.sort(key=lambda x: (x[1], x[0])) split_row_point_list.append(row) for col in col_point_list: if last_y <= col[0][1] <= y: split_col_point_list.append(col) for row in row_lines: if last_y <= row[1] <= y: split_row_lines.append(row) for col in col_lines: if last_y <= col[1] <= y: split_col_lines.append(col) # 每个点获取其对角线点,以便形成bbox,按行循环 for i in range(len(split_row_point_list)-1): row = split_row_point_list[i] # 循环该行的点 for k in range(len(row)-1): point1 = row[k] next_point1 = row[k+1] # print("*"*30) # print("point1", point1) # 有三种对角线点 # 1. 该点下一行的下一列的点 # 2. 该点下一列的下一行的点 # 3. 上述两个点是同一个点 # 下一行没找到就循环后面的行 if_find = 0 for j in range(i+1, len(split_row_point_list)): if if_find: break next_row = split_row_point_list[j] # print("next_row", next_row) # 循环下一行的点 for point2 in next_row: if abs(point1[0] - point2[0]) <= 2: continue if point2[0] < point1[0]: continue bbox = [point1[0], point1[1], point2[0], point2[1]] if abs(bbox[0] - bbox[2]) <= 10: continue if abs(bbox[1] - bbox[3]) <= 10: continue # bbox的四条边都需要验证是否在line上 if check_bbox(bbox, split_row_lines, split_col_lines): bbox_list.append([(bbox[0], bbox[1]), (bbox[2], bbox[3])]) if_find = 1 # print("check bbox", bbox) break return bbox_list def check_bbox(bbox, rows, cols, threshold=5): def check(check_line, lines, limit_axis, axis): # 需检查的线的1/2段,1/3段,2/3段,1/4段,3/4段 line_1_2 = [check_line[0], (check_line[0]+check_line[1])/2] line_2_2 = [(check_line[0]+check_line[1])/2, check_line[1]] line_1_3 = [check_line[0], check_line[0]+(check_line[1]-check_line[0])/3] line_2_3 = [check_line[1]-(check_line[1]-check_line[0])/3, check_line[1]] line_1_4 = [check_line[0], check_line[0]+(check_line[1]-check_line[0])/4] line_3_4 = [check_line[1]-(check_line[1]-check_line[0])/4, check_line[1]] # 限制row相同y,col相同x if_line = 0 for line1 in lines: if not if_line and abs(line1[1-axis] - limit_axis) <= threshold: # check_line完全包含在line中 if line1[axis] <= check_line[0] <= check_line[1] <= line1[axis+2]: if_line = 1 # check_line的1/2包含在line elif line1[axis] <= line_1_2[0] <= line_1_2[1] <= line1[axis+2] \ or line1[axis] <= line_2_2[0] <= line_2_2[1] <= line1[axis+2]: if_line = 1 # check_line两个1/3段被包含在不同line中 elif line1[axis] <= line_1_3[0] <= line_1_3[1] <= line1[axis+2]: # check_line另一边的1/4被包含 for line2 in lines: if abs(line1[1-axis] - limit_axis) <= threshold: if line2[axis] <= line_3_4[0] <= line_3_4[1] <= line2[axis+2]: if_line = 1 break elif line1[axis] <= line_2_3[0] <= line_2_3[1] <= line1[axis+2]: # check_line另一边的1/4被包含 for line2 in lines: if abs(line1[1-axis] - limit_axis) <= threshold: if line2[axis] <= line_1_4[0] <= line_1_4[1] <= line2[axis+2]: if_line = 1 break return if_line up_down_line = [bbox[0], bbox[2]] up_y, down_y = bbox[1], bbox[3] left_right_line = [bbox[1], bbox[3]] left_x, right_x = bbox[0], bbox[2] # 检查bbox四条边是否存在 if_up = check(up_down_line, rows, up_y, 0) if_down = check(up_down_line, rows, down_y, 0) if_left = check(left_right_line, cols, left_x, 1) if_right = check(left_right_line, cols, right_x, 1) # 检查bbox内部除了四条边,是否有其它line在bbox内部 if_col = 0 if_row = 0 if if_up and if_down and if_left and if_right: for col in cols: if not if_col and left_x+threshold <= col[0] <= right_x-threshold: if col[1] <= left_right_line[0] <= left_right_line[1] <= col[3]: if_col = 1 elif left_right_line[0] <= col[1] <= left_right_line[1]: if left_right_line[1] - col[1] >= (left_right_line[1] + left_right_line[0])/2: if_col = 1 elif left_right_line[0] <= col[3] <= left_right_line[1]: if col[3] - left_right_line[0] >= (left_right_line[1] + left_right_line[0])/2: if_col = 1 for row in rows: if not if_row and up_y+threshold <= row[1] <= down_y-threshold: if row[0] <= up_down_line[0] <= up_down_line[1] <= row[2]: if_row = 1 elif up_down_line[0] <= row[0] <= up_down_line[1]: if up_down_line[1] - row[0] >= (up_down_line[1] + up_down_line[0])/2: if_row = 1 elif up_down_line[0] <= row[2] <= up_down_line[1]: if row[2] - up_down_line[0] >= (up_down_line[1] + up_down_line[0])/2: if_row = 1 if if_up and if_down and if_left and if_right and not if_col and not if_row: return True else: return False def add_continue_bbox(bboxes): add_bbox_list = [] bboxes.sort(key=lambda x: (x[0][0], x[0][1])) last_bbox = bboxes[0] # 先对bbox分区 for i in range(1, len(split_y)): y = split_y[i] last_y = split_y[i-1] split_bbox = [] for bbox in bboxes: if last_y <= bbox[1][1] <= y: split_bbox.append(bbox) split_bbox.sort for i in range(1, len(bboxes)): bbox = bboxes[i] if last_y <= bbox[1][1] <= y and last_y <= last_bbox[1][1] <= y: if abs(last_bbox[1][1] - bbox[0][1]) <= 2: last_bbox = bbox else: if last_bbox[1][1] > bbox[0][1]: last_bbox = bbox else: add_bbox = [(last_bbox[0][0], last_bbox[1][1]), (last_bbox[1][0], bbox[0][1])] add_bbox_list.append(add_bbox) last_y = y print("add_bbox_list", add_bbox_list) if add_bbox_list: bboxes = [str(x) for x in bboxes + add_bbox_list] bboxes = list(set(bboxes)) bboxes = [eval(x) for x in bboxes] bboxes.sort(key=lambda x: (x[0][1], x[0][0])) return bboxes def points_to_line(points_lines, axis): new_line_list = [] for line in points_lines: average = 0 _min = _min = line[0][axis] _max = line[-1][axis] for point in line: average += point[1-axis] if point[axis] < _min: _min = point[axis] if point[axis] > _max: _max = point[axis] average = int(average / len(line)) if axis: new_line = [average, _min, average, _max] else: new_line = [_min, average, _max, average] new_line_list.append(new_line) return new_line_list def get_bbox_by_contours(image_np): img_gray = cv2.cvtColor(image_np, cv2.COLOR_BGR2GRAY) ret, img_bin = cv2.threshold(img_gray, 127, 255, cv2.THRESH_BINARY) # 3.连通域分析 img_bin, contours, hierarchy = cv2.findContours(img_bin, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) # 4.获取最小外接圆 圆心 半径 center, radius = cv2.minEnclosingTriangle(contours[0]) center = np.int0(center) # 5.绘制最小外接圆 img_result = image_np.copy() cv2.circle(img_result, tuple(center), int(radius), (255, 255, 255), 2) # # 读入图片 # img = image_np # cv2.imshow("get_bbox_by_contours ", image_np) # # 中值滤波,去噪 # img = cv2.medianBlur(img, 3) # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # cv2.namedWindow('original', cv2.WINDOW_AUTOSIZE) # cv2.imshow('original', gray) # # # 阈值分割得到二值化图片 # ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) # # # 膨胀操作 # kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) # bin_clo = cv2.dilate(binary, kernel2, iterations=2) # # # 连通域分析 # num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(bin_clo, connectivity=8) # # # 查看各个返回值 # # 连通域数量 # print('num_labels = ',num_labels) # # 连通域的信息:对应各个轮廓的x、y、width、height和面积 # print('stats = ',stats) # # 连通域的中心点 # print('centroids = ',centroids) # # 每一个像素的标签1、2、3.。。,同一个连通域的标签是一致的 # print('labels = ',labels) # # # 不同的连通域赋予不同的颜色 # output = np.zeros((img.shape[0], img.shape[1], 3), np.uint8) # for i in range(1, num_labels): # # mask = labels == i # output[:, :, 0][mask] = np.random.randint(0, 255) # output[:, :, 1][mask] = np.random.randint(0, 255) # output[:, :, 2][mask] = np.random.randint(0, 255) # cv2.imshow('oginal', output) # cv2.waitKey() # cv2.destroyAllWindows() def get_points_col(points, split_y, threshold=5): # 坐标点按行分 row_point_list = [] row_point = [] points.sort(key=lambda x: (x[0], x[1])) # print("get_points_col points sort", points) x = points[0][0] for i in range(1, len(split_y)): for p in points: if p[1] <= split_y[i-1] or p[1] >= split_y[i]: continue if x-threshold <= p[0] <= x+threshold: row_point.append(p) else: # print("row_point", row_point) row_point.sort(key=lambda x: (x[1], x[0])) if row_point: row_point_list.append(row_point) row_point = [] x = p[0] row_point.append(p) if row_point: row_point_list.append(row_point) return row_point_list def get_points_row(points, split_y, threshold=5): # 坐标点按列分 col_point_list = [] col_point = [] points.sort(key=lambda x: (x[1], x[0])) y = points[0][1] for i in range(len(split_y)): for p in points: if p[1] <= split_y[i-1] or p[1] >= split_y[i]: continue if y-threshold <= p[1] <= y+threshold: col_point.append(p) else: col_point.sort(key=lambda x: (x[0], x[1])) if col_point: col_point_list.append(col_point) col_point = [] y = p[1] col_point.append(p) if col_point: col_point_list.append(col_point) return col_point_list def get_outline_point(points, split_y): # 分割线纵坐标 # print("get_outline_point split_y", split_y) if len(split_y) < 2: return [] outline_2point = [] points.sort(key=lambda x: (x[1], x[0])) for i in range(1, len(split_y)): area_points = [] for point in points: if point[1] <= split_y[i-1] or point[1] >= split_y[i]: continue area_points.append(point) if area_points: area_points.sort(key=lambda x: (x[1], x[0])) outline_2point.append([area_points[0], area_points[-1]]) return outline_2point # def merge_row(row_lines): # for row in row_lines: # for row1 in row_lines: def get_best_predict_size(image_np): sizes = [1280, 1152, 1024, 896, 768, 640, 512, 384, 256, 128] min_len = 10000 best_height = sizes[0] for height in sizes: if abs(image_np.shape[0] - height) < min_len: min_len = abs(image_np.shape[0] - height) best_height = height min_len = 10000 best_width = sizes[0] for width in sizes: if abs(image_np.shape[1] - width) < min_len: min_len = abs(image_np.shape[1] - width) best_width = width return best_height, best_width def choose_longer_row(lines): new_row = [] jump_row = [] for i in range(len(lines)): row1 = lines[i] jump_flag = 0 if row1 in jump_row: continue for j in range(i+1, len(lines)): row2 = lines[j] if row2 in jump_row: continue if row2[1]-5 <= row1[1] <= row2[1]+5: if row1[0] <= row2[0] and row1[2] >= row2[2]: new_row.append(row1) jump_row.append(row1) jump_row.append(row2) jump_flag = 1 break elif row2[0] <= row1[0] and row2[2] >= row1[2]: new_row.append(row2) jump_row.append(row1) jump_row.append(row2) jump_flag = 1 break if not jump_flag: new_row.append(row1) jump_row.append(row1) return new_row def choose_longer_col(lines): new_col = [] jump_col = [] for i in range(len(lines)): col1 = lines[i] jump_flag = 0 if col1 in jump_col: continue for j in range(i+1, len(lines)): col2 = lines[j] if col2 in jump_col: continue if col2[0]-5 <= col1[0] <= col2[0]+5: if col1[1] <= col2[1] and col1[3] >= col2[3]: new_col.append(col1) jump_col.append(col1) jump_col.append(col2) jump_flag = 1 break elif col2[1] <= col1[1] and col2[3] >= col1[3]: new_col.append(col2) jump_col.append(col1) jump_col.append(col2) jump_flag = 1 break if not jump_flag: new_col.append(col1) jump_col.append(col1) return new_col def delete_contain_bbox(bboxes): # bbox互相包含,取小的bbox delete_bbox = [] for i in range(len(bboxes)): for j in range(i+1, len(bboxes)): bbox1 = bboxes[i] bbox2 = bboxes[j] # 横坐标相等情况 if bbox1[0][0] == bbox2[0][0] and bbox1[1][0] == bbox2[1][0]: if bbox1[0][1] <= bbox2[0][1] <= bbox2[1][1] <= bbox1[1][1]: # print("1", bbox1, bbox2) delete_bbox.append(bbox1) elif bbox2[0][1] <= bbox1[0][1] <= bbox1[1][1] <= bbox2[1][1]: # print("2", bbox1, bbox2) delete_bbox.append(bbox2) # 纵坐标相等情况 elif bbox1[0][1] == bbox2[0][1] and bbox1[1][1] == bbox2[1][1]: if bbox1[0][0] <= bbox2[0][0] <= bbox2[1][0] <= bbox1[1][0]: print("3", bbox1, bbox2) delete_bbox.append(bbox1) elif bbox2[0][0] <= bbox1[0][0] <= bbox1[1][0] <= bbox2[1][0]: print("4", bbox1, bbox2) delete_bbox.append(bbox2) print("delete_contain_bbox len(bboxes)", len(bboxes)) print("delete_contain_bbox len(delete_bbox)", len(delete_bbox)) for bbox in delete_bbox: if bbox in bboxes: bboxes.remove(bbox) print("delete_contain_bbox len(bboxes)", len(bboxes)) return bboxes if __name__ == '__main__': # p = "开标记录表3_page_0.png" # p = "train_data/label_1.jpg" # p = "test_files/train_463.jpg" p = "test_files/8.png" # p = "test_files/无边框3.jpg" # p = "test_files/part1.png" # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00e959a0bc9011ebaf5a00163e0ae709" + \ # "\\00e95f7cbc9011ebaf5a00163e0ae709_pdf_page0.png" # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00fb3e52bc7e11eb836000163e0ae709" + \ # "\\00fb43acbc7e11eb836000163e0ae709.png" # p = "test_files/table.jpg" # p = "data_process/create_data/0.jpg" # p = "../format_conversion/temp/f1fe9c4ac8e511eb81d700163e0857b6/f1fea1e0c8e511eb81d700163e0857b6.png" # p = "../format_conversion/1.png" img = cv2.imread(p) t = time.time() model.load_weights("") best_h, best_w = get_best_predict_size(img) print(img.shape) print((best_h, best_w)) # row_boxes, col_boxes = table_line(img[..., ::-1], model, size=(512, 1024), hprob=0.5, vprob=0.5) # row_boxes, col_boxes, img = table_line(img[..., ::-1], model, size=(best_w, best_h), hprob=0.5, vprob=0.5) row_boxes, col_boxes, img = table_line(img, model, size=(best_w, best_h), hprob=0.5, vprob=0.5) print("len(row_boxes)", len(row_boxes)) print("len(col_boxes)", col_boxes) # 创建空图 test_img = np.zeros((img.shape), np.uint8) test_img.fill(255) for box in row_boxes+col_boxes: cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1) cv2.imshow("test_image", test_img) cv2.waitKey(0) cv2.imwrite("temp.jpg", test_img) # 计算交点、分割线 crossover_points = get_points(row_boxes, col_boxes, (img.shape[0], img.shape[1])) print("len(col_boxes)", len(col_boxes)) split_lines, split_y = get_split_line(crossover_points, col_boxes, img) print("split_y", split_y) # for point in crossover_points: # cv2.circle(test_img, point, 1, (0, 255, 0), 3) # cv2.imshow("point image1", test_img) # cv2.waitKey(0) # 计算行列,剔除相近交点 row_point_list = get_points_row(crossover_points, split_y, 0) col_point_list = get_points_col(crossover_points, split_y, 0) crossover_points = delete_close_points(crossover_points, row_point_list, col_point_list) row_point_list = get_points_row(crossover_points, split_y) col_point_list = get_points_col(crossover_points, split_y) for point in crossover_points: cv2.circle(test_img, point, 1, (0, 0, 255), 3) cv2.imshow("point image1", test_img) cv2.waitKey(0) print("len(row_boxes)", len(row_boxes)) print("len(col_boxes)", len(col_boxes)) # 修复边框 new_row_boxes, new_col_boxes, long_row_boxes, long_col_boxes = \ fix_outline(img, row_boxes, col_boxes, crossover_points, split_y) if new_row_boxes or new_col_boxes: if long_row_boxes: print("long_row_boxes", long_row_boxes) row_boxes = long_row_boxes if long_col_boxes: print("long_col_boxes", long_col_boxes) col_boxes = long_col_boxes if new_row_boxes: row_boxes += new_row_boxes print("new_row_boxes", new_row_boxes) if new_col_boxes: print("new_col_boxes", new_col_boxes) col_boxes += new_col_boxes # print("len(row_boxes)", len(row_boxes)) # print("len(col_boxes)", len(col_boxes)) # row_boxes += new_row_boxes # col_boxes += new_col_boxes # row_boxes = choose_longer_row(row_boxes) # col_boxes = choose_longer_col(col_boxes) # 创建空图 test_img = np.zeros((img.shape), np.uint8) test_img.fill(255) for box in row_boxes+col_boxes: cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1) cv2.imshow("test_image2", test_img) cv2.waitKey(0) # 展示补线 for row in new_row_boxes: cv2.line(test_img, (int(row[0]), int(row[1])), (int(row[2]), int(row[3])), (0, 0, 255), 1) for col in new_col_boxes: cv2.line(test_img, (int(col[0]), int(col[1])), (int(col[2]), int(col[3])), (0, 0, 255), 1) cv2.imshow("fix_outline", test_img) cv2.waitKey(0) cv2.imwrite("temp.jpg", test_img) # 修复边框后重新计算交点、分割线 print("crossover_points", len(crossover_points)) crossover_points = get_points(row_boxes, col_boxes, (img.shape[0], img.shape[1])) print("crossover_points new", len(crossover_points)) split_lines, split_y = get_split_line(crossover_points, col_boxes, img) # 计算行列,剔除相近交点 row_point_list = get_points_row(crossover_points, split_y, 0) col_point_list = get_points_col(crossover_points, split_y, 0) print(len(crossover_points), len(row_point_list), len(col_point_list)) crossover_points = delete_close_points(crossover_points, row_point_list, col_point_list) print(len(crossover_points), len(row_point_list), len(col_point_list)) row_point_list = get_points_row(crossover_points, split_y) col_point_list = get_points_col(crossover_points, split_y) for point in crossover_points: cv2.circle(test_img, point, 1, (0, 255, 0), 3) cv2.imshow("point image2", test_img) cv2.waitKey(0) # 获取每个表格的左上右下两个点 outline_point = get_outline_point(crossover_points, split_y) # print(outline_point) for outline in outline_point: cv2.circle(test_img, outline[0], 1, (255, 0, 0), 5) cv2.circle(test_img, outline[1], 1, (255, 0, 0), 5) cv2.imshow("outline point", test_img) cv2.waitKey(0) # 获取bbox # get_bbox(img, crossover_points, split_y) # for point in crossover_points: # cv2.circle(test_img, point, 1, (0, 255, 0), 3) # cv2.imshow("point image3", test_img) # cv2.waitKey(0) # split_y = [] # for outline in outline_point: # split_y.extend([outline[0][1]-5, outline[1][1]+5]) print("len(row_boxes)", len(row_boxes)) print("len(col_boxes)", len(col_boxes)) bboxes = get_bbox(img, row_point_list, col_point_list, split_y, row_boxes, col_boxes) # 展示 for box in bboxes: # print(box[0], box[1]) # if abs(box[0][1] - box[1][1]) > abs(box[0][0] - box[1][0]): # continue cv2.rectangle(test_img, box[0], box[1], (0, 0, 255), 2, 8) cv2.imshow('bboxes', test_img) cv2.waitKey(0) # img = draw_lines(img, row_boxes+col_boxes, color=(255, 0, 0), lineW=2) # img = draw_boxes(img, rowboxes+colboxes, color=(0, 0, 255)) print(time.time()-t, len(row_boxes), len(col_boxes)) cv2.imwrite('temp.jpg', test_img) # cv2.imshow('main', img) # cv2.waitKey(0)