#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Sep 9 23:11:51 2020 table line detect @author: chineseocr """ import copy import logging import tensorflow as tf import tensorflow.keras.backend as K from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, concatenate, Conv2D, MaxPooling2D, BatchNormalization, UpSampling2D from tensorflow.keras.layers import LeakyReLU from otr.utils import letterbox_image, get_table_line, adjust_lines, line_to_line, draw_boxes import numpy as np import cv2 import time logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') def dice_coef(y_true, y_pred, smooth=1e-5): y_true_f = K.flatten(y_true) y_pred_f = K.flatten(y_pred) intersection = K.sum(y_true_f * y_pred_f) return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) def dice_coef_loss(): def dice_coef_loss_fixed(y_true, y_pred): return -dice_coef(y_true, y_pred) return dice_coef_loss_fixed def focal_loss(gamma=3., alpha=.5): # 3 0.85 2000e acc-0.6 p-0.99 r-0.99 val_acc-0.56 val_p-0.86 val_r-0.95 # 2 0.85 double_gpu acc- # 3 0.25 gpu 50e acc-0.5 p-0.99 r-0.99 val_acc-0.45 val_p-0.96 val_r-0.88 # 2 0.25 gpu acc- # 3 0.5 double_gpu acc-0.6 p-0.99 r-0.99 val_acc-0.60 val_p-0.93 val_r-0.93 def focal_loss_fixed(y_true, y_pred): pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred)) pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred)) return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon()+pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0 + K.epsilon())) return focal_loss_fixed def table_net(input_shape=(1152, 896, 3), num_classes=1): inputs = Input(shape=input_shape) # 512 use_bias = False down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(inputs) down0a = BatchNormalization()(down0a) down0a = LeakyReLU(alpha=0.1)(down0a) down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(down0a) down0a = BatchNormalization()(down0a) down0a = LeakyReLU(alpha=0.1)(down0a) down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a) # 256 down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0a_pool) down0 = BatchNormalization()(down0) down0 = LeakyReLU(alpha=0.1)(down0) down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0) down0 = BatchNormalization()(down0) down0 = LeakyReLU(alpha=0.1)(down0) down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0) # 128 down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down0_pool) down1 = BatchNormalization()(down1) down1 = LeakyReLU(alpha=0.1)(down1) down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down1) down1 = BatchNormalization()(down1) down1 = LeakyReLU(alpha=0.1)(down1) down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1) # 64 down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down1_pool) down2 = BatchNormalization()(down2) down2 = LeakyReLU(alpha=0.1)(down2) down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down2) down2 = BatchNormalization()(down2) down2 = LeakyReLU(alpha=0.1)(down2) down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2) # 32 down3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(down2_pool) down3 = BatchNormalization()(down3) down3 = LeakyReLU(alpha=0.1)(down3) down3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(down3) down3 = BatchNormalization()(down3) down3 = LeakyReLU(alpha=0.1)(down3) down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3) # 16 down4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(down3_pool) down4 = BatchNormalization()(down4) down4 = LeakyReLU(alpha=0.1)(down4) down4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(down4) down4 = BatchNormalization()(down4) down4 = LeakyReLU(alpha=0.1)(down4) down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4) # 8 center = Conv2D(1024, (3, 3), padding='same', use_bias=use_bias)(down4_pool) center = BatchNormalization()(center) center = LeakyReLU(alpha=0.1)(center) center = Conv2D(1024, (3, 3), padding='same', use_bias=use_bias)(center) center = BatchNormalization()(center) center = LeakyReLU(alpha=0.1)(center) # center up4 = UpSampling2D((2, 2))(center) up4 = concatenate([down4, up4], axis=3) up4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(up4) up4 = BatchNormalization()(up4) up4 = LeakyReLU(alpha=0.1)(up4) up4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(up4) up4 = BatchNormalization()(up4) up4 = LeakyReLU(alpha=0.1)(up4) up4 = Conv2D(512, (3, 3), padding='same', use_bias=use_bias)(up4) up4 = BatchNormalization()(up4) up4 = LeakyReLU(alpha=0.1)(up4) # 16 up3 = UpSampling2D((2, 2))(up4) up3 = concatenate([down3, up3], axis=3) up3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(up3) up3 = BatchNormalization()(up3) up3 = LeakyReLU(alpha=0.1)(up3) up3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(up3) up3 = BatchNormalization()(up3) up3 = LeakyReLU(alpha=0.1)(up3) up3 = Conv2D(256, (3, 3), padding='same', use_bias=use_bias)(up3) up3 = BatchNormalization()(up3) up3 = LeakyReLU(alpha=0.1)(up3) # 32 up2 = UpSampling2D((2, 2))(up3) up2 = concatenate([down2, up2], axis=3) up2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(up2) up2 = BatchNormalization()(up2) up2 = LeakyReLU(alpha=0.1)(up2) up2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(up2) up2 = BatchNormalization()(up2) up2 = LeakyReLU(alpha=0.1)(up2) up2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(up2) up2 = BatchNormalization()(up2) up2 = LeakyReLU(alpha=0.1)(up2) # 64 up1 = UpSampling2D((2, 2))(up2) up1 = concatenate([down1, up1], axis=3) up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1) up1 = BatchNormalization()(up1) up1 = LeakyReLU(alpha=0.1)(up1) up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1) up1 = BatchNormalization()(up1) up1 = LeakyReLU(alpha=0.1)(up1) up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1) up1 = BatchNormalization()(up1) up1 = LeakyReLU(alpha=0.1)(up1) # 128 up0 = UpSampling2D((2, 2))(up1) up0 = concatenate([down0, up0], axis=3) up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0) up0 = BatchNormalization()(up0) up0 = LeakyReLU(alpha=0.1)(up0) up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0) up0 = BatchNormalization()(up0) up0 = LeakyReLU(alpha=0.1)(up0) up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0) up0 = BatchNormalization()(up0) up0 = LeakyReLU(alpha=0.1)(up0) # 256 up0a = UpSampling2D((2, 2))(up0) up0a = concatenate([down0a, up0a], axis=3) up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a) up0a = BatchNormalization()(up0a) up0a = LeakyReLU(alpha=0.1)(up0a) up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a) up0a = BatchNormalization()(up0a) up0a = LeakyReLU(alpha=0.1)(up0a) up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a) up0a = BatchNormalization()(up0a) up0a = LeakyReLU(alpha=0.1)(up0a) # 512 classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0a) model = Model(inputs=inputs, outputs=classify) return model model = table_net((None, None, 3), 2) def drawpixel(pred): import matplotlib.pyplot as plt _array = [] for _h in range(len(pred)): _line = [] for _w in range(len(pred[_h])): _prob = pred[_h][_w] if _prob[0]>0.5: _line.append((0,255,255)) elif _prob[1]>0.5: _line.append((255,255,0)) else: _line.append((255,255,255)) _array.append(_line) plt.imshow(np.array(_array)) plt.show() def points2lines(pred,sourceP_LB=True,prob=0.2,line_width=7,padding=3,min_len=10,cell_width=13): def inBbox(bbox,point,line_width): x,y = point if x>=bbox[0]-line_width and x<=bbox[2]+line_width and y>=bbox[1]-line_width and y<=bbox[3]+line_width: return True,[min(x,bbox[0]),min(y,bbox[1]),max(x,bbox[2]),max(y,bbox[3])] return False,None height = len(pred) width = len(pred[0]) clust_horizontal = [] clust_vertical = [] h_index = -1 _step = line_width _sum = list(np.sum(np.array((pred[...,1]>prob)).astype(int),axis=1)) _last = False _current = False while 1: h_index += 1 if h_index>=height: break w_index = -1 if sourceP_LB: h_i = height-1-h_index else: h_i = h_index if _sum[h_index]=width: break _v,_h = pred[h_index][w_index] if _h>prob: _find = False _point = (w_index,h_i) for l_h_i in range(len(clust_vertical)): l_h = clust_vertical[len(clust_vertical)-l_h_i-1] bbox = l_h.get("bbox") b_in,_bbox = inBbox(bbox,_point,line_width) if b_in: _find = True l_h.get("points").append(_point) l_h["bbox"] = _bbox break if not _find: clust_vertical.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]}) w_index = -1 _sum = list(np.sum(np.array((pred[...,0]>prob)).astype(int),axis=0)) while 1: w_index += 1 if w_index>=width: break h_index = -1 if _sum[w_index]=height: break if sourceP_LB: h_i = height-1-h_index else: h_i = h_index _v,_h = pred[h_index][w_index] if _v>prob: _find = False _point = (w_index,h_i) for l_h_i in range(len(clust_horizontal)): l_h = clust_horizontal[len(clust_horizontal)-l_h_i-1] bbox = l_h.get("bbox") b_in,_bbox = inBbox(bbox,_point,line_width) if b_in: _find = True l_h.get("points").append(_point) l_h["bbox"] = _bbox break if not _find: clust_horizontal.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]}) tmp_vertical = [] for _dict in clust_vertical: _bbox = _dict.get("bbox") if _bbox[2]-_bbox[0]>=min_len or _bbox[3]-_bbox[1]>=min_len: tmp_vertical.append([(_bbox[0]+_bbox[2])/2,_bbox[1]-padding,(_bbox[0]+_bbox[2])/2,_bbox[3]+padding]) tmp_horizontal = [] for _dict in clust_horizontal: _bbox = _dict.get("bbox") if _bbox[2]-_bbox[0]>=min_len or _bbox[3]-_bbox[1]>=min_len: tmp_horizontal.append([_bbox[0]-padding,(_bbox[1]+_bbox[3])/2,_bbox[2]+padding,(_bbox[1]+_bbox[3])/2]) #merge lines tmp_vertical.sort(key=lambda x:x[3],reverse=True) tmp_horizontal.sort(key=lambda x:x[0]) pop_index = [] final_vertical = [] for _line in tmp_vertical: _find = False x0,y0,x1,y1 = _line for _line2 in final_vertical: x2,y2,x3,y3 = _line2 if abs(x0-x2) vprob # 横线 hpred = pred[..., 0] > hprob # 竖线 vpred = vpred.astype(int) hpred = hpred.astype(int) # print("vpred shape", vpred) # print("hpred shape", hpred) colboxes = get_table_line(vpred, axis=1, lineW=col) rowboxes = get_table_line(hpred, axis=0, lineW=row) # logging.info("into table_line 3") # if len(rowboxes) > 0: # rowboxes = np.array(rowboxes) # rowboxes[:, [0, 2]] = rowboxes[:, [0, 2]]/fx # rowboxes[:, [1, 3]] = rowboxes[:, [1, 3]]/fy # rowboxes = rowboxes.tolist() # if len(colboxes) > 0: # colboxes = np.array(colboxes) # colboxes[:, [0, 2]] = colboxes[:, [0, 2]]/fx # colboxes[:, [1, 3]] = colboxes[:, [1, 3]]/fy # colboxes = colboxes.tolist() nrow = len(rowboxes) ncol = len(colboxes) for i in range(nrow): for j in range(ncol): rowboxes[i] = line_to_line(rowboxes[i], colboxes[j], 10) colboxes[j] = line_to_line(colboxes[j], rowboxes[i], 10) # logging.info("into table_line 4") # 删掉贴着边框的line temp_list = [] threshold = 5 for line in rowboxes: if line[1]-0 <= threshold or size[1]-line[1] <= threshold: continue # 内部排序 if line[0] > line[2]: line = [line[2], line[3], line[0], line[1]] temp_list.append(line) rowboxes = temp_list temp_list = [] for line in colboxes: if line[0]-0 <= threshold or size[0]-line[0] <= threshold: continue # 内部排序 if line[1] > line[3]: line = [line[2], line[3], line[0], line[1]] temp_list.append(line) colboxes = temp_list return rowboxes, colboxes, img_new def get_outline(points, image_np): # 取出x, y的最大值最小值 x_min = points[0][0] x_max = points[-1][0] points.sort(key=lambda x: (x[1], x[0])) y_min = points[0][1] y_max = points[-1][1] # 创建空图 # outline_img = np.zeros(image_size, np.uint8) outline_img = np.copy(image_np) cv2.rectangle(outline_img, (x_min-5, y_min-5), (x_max+5, y_max+5), (0, 0, 0), 2) # cv2.imshow("outline_img", outline_img) # cv2.waitKey(0) return outline_img def get_split_line(points, col_lines, image_np): # print("get_split_line", image_np.shape) points.sort(key=lambda x: (x[1], x[0])) # 遍历y坐标,并判断y坐标与上一个y坐标是否存在连接线 i = 0 split_line_y = [] for point in points: # 从已分开的线下面开始判断 if split_line_y: if point[1] <= split_line_y[-1] + 5: last_y = point[1] continue if last_y <= split_line_y[-1] + 5: last_y = point[1] continue if i == 0: last_y = point[1] i += 1 continue current_line = (last_y, point[1]) split_flag = 1 for col in col_lines: # 只要找到一条col包含就不是分割线 if current_line[0] >= col[1]-3 and current_line[1] <= col[3]+3: split_flag = 0 # print("img", img.shape) # print("col", col) # print("current_line", current_line) break if split_flag: split_line_y.append(current_line[0]+5) split_line_y.append(current_line[1]-5) last_y = point[1] # 加上收尾分割线 points.sort(key=lambda x: (x[1], x[0])) y_min = points[0][1] y_max = points[-1][1] # print("加上收尾分割线", y_min, y_max) if y_min-5 < 0: split_line_y.append(0) else: split_line_y.append(y_min-5) if y_max+5 > image_np.shape[0]: split_line_y.append(image_np.shape[0]) else: split_line_y.append(y_max+5) split_line_y = list(set(split_line_y)) # 剔除两条相隔太近分割线 temp_split_line_y = [] split_line_y.sort(key=lambda x: x) last_y = -20 for y in split_line_y: # print(y) if y - last_y >= 20: # print(y, last_y) temp_split_line_y.append(y) last_y = y split_line_y = temp_split_line_y # print("split_line_y", split_line_y) # 生成分割线 split_line = [] last_y = 0 for y in split_line_y: # if y - last_y <= 15: # continue split_line.append([(0, y), (image_np.shape[1], y)]) last_y = y split_line.append([(0, 0), (image_np.shape[1], 0)]) split_line.append([(0, image_np.shape[0]), (image_np.shape[1], image_np.shape[0])]) split_line.sort(key=lambda x: x[0][1]) # print("split_line", split_line) # 画图画线 # split_line_img = np.copy(image_np) # for y in split_line_y: # cv2.line(split_line_img, (0, y), (image_np.shape[1], y), (0, 0, 0), 1) # cv2.imshow("split_line_img", split_line_img) # cv2.waitKey(0) return split_line, split_line_y def get_points(row_lines, col_lines, image_size): # 创建空图 row_img = np.zeros(image_size, np.uint8) col_img = np.zeros(image_size, np.uint8) # 画线 thresh = 3 for row in row_lines: cv2.line(row_img, (int(row[0]-thresh), int(row[1])), (int(row[2]+thresh), int(row[3])), (255, 255, 255), 1) for col in col_lines: cv2.line(col_img, (int(col[0]), int(col[1]-thresh)), (int(col[2]), int(col[3]+thresh)), (255, 255, 255), 1) # 求出交点 point_img = np.bitwise_and(row_img, col_img) # cv2.imshow("point_img", np.bitwise_not(point_img)) # cv2.waitKey(0) # 识别黑白图中的白色交叉点,将横纵坐标取出 ys, xs = np.where(point_img > 0) points = [] for i in range(len(xs)): points.append((xs[i], ys[i])) points.sort(key=lambda x: (x[0], x[1])) return points def get_minAreaRect(image): gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) gray = cv2.bitwise_not(gray) thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] coords = np.column_stack(np.where(thresh > 0)) return cv2.minAreaRect(coords) def get_contours(image): gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) ret, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) contours, hierarchy = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(image, contours, -1, (0, 0, 255), 3) cv2.imshow("get contours", image) cv2.waitKey(0) def merge_line(lines, axis, threshold=5): """ 解决模型预测一条直线错开成多条直线,合并成一条直线 :param lines: 线条列表 :param axis: 0:横线 1:竖线 :param threshold: 两条线间像素差阈值 :return: 合并后的线条列表 """ # 任意一条line获取该合并的line,横线往下找,竖线往右找 lines.sort(key=lambda x: (x[axis], x[1-axis])) merged_lines = [] used_lines = [] for line1 in lines: if line1 in used_lines: continue merged_line = [line1] used_lines.append(line1) for line2 in lines: if line2 in used_lines: continue if line1[1-axis]-threshold <= line2[1-axis] <= line1[1-axis]+threshold: # 计算基准长度 min_axis = 10000 max_axis = 0 for line3 in merged_line: if line3[axis] < min_axis: min_axis = line3[axis] if line3[axis+2] > max_axis: max_axis = line3[axis+2] # 判断两条线有无交集 if min_axis <= line2[axis] <= max_axis \ or min_axis <= line2[axis+2] <= max_axis: merged_line.append(line2) used_lines.append(line2) if merged_line: merged_lines.append(merged_line) # 合并line result_lines = [] for merged_line in merged_lines: # 获取line宽的平均值 axis_average = 0 for line in merged_line: axis_average += line[1-axis] axis_average = int(axis_average/len(merged_line)) # 获取最长line两端 merged_line.sort(key=lambda x: (x[axis])) axis_start = merged_line[0][axis] merged_line.sort(key=lambda x: (x[axis+2])) axis_end = merged_line[-1][axis+2] if axis: result_lines.append([axis_average, axis_start, axis_average, axis_end]) else: result_lines.append([axis_start, axis_average, axis_end, axis_average]) return result_lines def fix_inner2(row_points, col_points, row_lines, col_lines, threshold=3): for i in range(len(row_points)): row = row_points[i] row.sort(key=lambda x: (x[1], x[0])) for j in range(len(row)): # 当前点 point = row[j] # 获取当前点在所在行的下个点 if j >= len(row) - 1: next_row_point = [] else: next_row_point = row[j+1] if next_row_point: for k in range(len(row_lines)): line = row_lines[k] if line[1] - threshold <= point[1] <= line[1] + threshold: if not line[0] <= point[0] <= next_row_point[0] <= line[2]: if point[0] <= line[2] < next_row_point[0]: if line[2] - point[0] >= 1/3 * (next_row_point[0] - point[0]): row_lines[k][2] = next_row_point[0] if point[0] < line[0] <= next_row_point[0]: if next_row_point[0] - line[0] >= 1/3 * (next_row_point[0] - point[0]): row_lines[k][0] = point[0] # 获取当前点所在列的下个点 next_col_point = [] for col in col_points: if point in col: col.sort(key=lambda x: (x[0], x[1])) if col.index(point) < len(col) - 1: next_col_point = col[col.index(point)+1] break # 获取当前点的对角线点,通过该列下个点所在行的下个点获得 next_row_next_col_point = [] if next_col_point: for row2 in row_points: if next_col_point in row2: row2.sort(key=lambda x: (x[1], x[0])) if row2.index(next_col_point) < len(row2) - 1: next_row_next_col_point = row2[row2.index(next_col_point)+1] break # 有该列下一点但没有该列下一点所在行的下个点 if not next_row_next_col_point: # 如果有该行下个点 if next_row_point: next_row_next_col_point = [next_row_point[0], next_col_point[1]] if next_col_point: for k in range(len(col_lines)): line = col_lines[k] if line[0] - threshold <= point[0] <= line[0] + threshold: if not line[1] <= point[1] <= next_col_point[1] <= line[3]: if point[1] <= line[3] < next_col_point[1]: if line[3] - point[1] >= 1/3 * (next_col_point[1] - point[1]): col_lines[k][3] = next_col_point[1] if point[1] < line[1] <= next_col_point[1]: if next_col_point[1] - line[1] >= 1/3 * (next_col_point[1] - point[1]): col_lines[k][1] = point[1] if next_row_next_col_point: for k in range(len(col_lines)): line = col_lines[k] if line[0] - threshold <= next_row_next_col_point[0] <= line[0] + threshold: if not line[1] <= point[1] <= next_row_next_col_point[1] <= line[3]: if point[1] < line[1] <= next_row_next_col_point[1]: if next_row_next_col_point[1] - line[1] >= 1/3 * (next_row_next_col_point[1] - point[1]): col_lines[k][1] = point[1] return row_lines, col_lines def fix_inner(row_lines, col_lines, points, split_y): def fix(fix_lines, assist_lines, split_points, axis): new_points = [] for line1 in fix_lines: min_assist_line = [[], []] min_distance = [1000, 1000] if_find = [0, 0] # 获取fix_line中的所有col point,里面可能不包括两个顶点,col point是交点,顶点可能不是交点 fix_line_points = [] for point in split_points: if abs(point[1-axis] - line1[1-axis]) <= 2: if line1[axis] <= point[axis] <= line1[axis+2]: fix_line_points.append(point) # 找出离两个顶点最近的assist_line, 并且assist_line与fix_line不相交 line1_point = [line1[:2], line1[2:]] for i in range(2): point = line1_point[i] for line2 in assist_lines: if not if_find[i] and abs(point[axis] - line2[axis]) <= 2: if line1[1-axis] <= point[1-axis] <= line2[1-axis+2]: # print("line1, match line2", line1, line2) if_find[i] = 1 break else: if abs(point[axis] - line2[axis]) < min_distance[i] and line2[1-axis] <= point[1-axis] <= line2[1-axis+2]: if line1[axis] <= line2[axis] <= line1[axis+2]: continue min_distance[i] = abs(line1[axis] - line2[axis]) min_assist_line[i] = line2 # 找出离assist_line最近的交点 # 顶点到交点的距离(多出来的线)需大于assist_line到交点的距离(bbox的边)的1/3 min_distance = [1000, 1000] min_col_point = [[], []] for i in range(2): # print("顶点", i, line1_point[i]) if not if_find[i]: if min_assist_line[i]: for point in fix_line_points: if abs(point[axis] - min_assist_line[i][axis]) < min_distance[i]: min_distance[i] = abs(point[axis] - min_assist_line[i][axis]) min_col_point[i] = point if min_col_point[i]: if abs(min_col_point[i][axis] - line1_point[i][axis]) >= abs(min_col_point[i][axis] - min_assist_line[i][axis])/3: add_point = (line1_point[i][1-axis], min_assist_line[i][axis]) # print("fix_inner add point", add_point) # print("line1, line2", line1, min_assist_line[i]) new_points.append(add_point) return new_points new_points = [] for i in range(1, len(split_y)): last_y = split_y[i-1] y = split_y[i] # 先对点线进行分区 split_row_lines = [] split_col_lines = [] split_points = [] for row in row_lines: if last_y <= row[1] <= y: split_row_lines.append(row) for col in col_lines: if last_y <= col[1] <= y: split_col_lines.append(col) for point in points: if last_y <= point[1] <= y: split_points.append(point) new_points += fix(split_col_lines, split_row_lines, split_points, axis=1) new_points += fix(split_row_lines, split_col_lines, split_points, axis=0) # 找出所有col的顶点不在row上的、row的顶点不在col上的 # for col in split_col_lines: # print("*"*30) # # # 获取该line中的所有point # col_points = [] # for point in split_points: # if abs(point[0] - col[0]) <= 2: # if col[1] <= point[1] <= col[3]: # col_points.append(point) # # # 比较顶点 # min_row_1 = [] # min_row_2 = [] # min_distance_1 = 1000 # min_distance_2 = 1000 # if_find_1 = 0 # if_find_2 = 0 # for row in split_row_lines: # # 第一个顶点 # if not if_find_1 and abs(col[1] - row[1]) <= 2: # if row[0] <= col[0] <= row[2]: # print("col, match row", col, row) # if_find_1 = 1 # break # else: # if abs(col[1] - row[1]) < min_distance_1 and row[0] <= col[0] <= row[2]: # if col[1] <= row[1] <= col[3]: # continue # min_distance_1 = abs(col[1] - row[1]) # min_row_1 = row # # # 第二个顶点 # if not if_find_2 and abs(col[3] - row[1]) <= 2: # if row[0] <= col[2] <= row[2]: # if_find_2 = 1 # break # else: # if abs(col[3] - row[1]) < min_distance_2 and row[0] <= col[2] <= row[2]: # min_distance_2 = abs(col[3] - row[1]) # min_row_2 = row # # if not if_find_1: # print("col", col) # print("min_row_1", min_row_1) # if min_row_1: # min_distance_1 = 1000 # min_col_point = [] # for point in col_points: # if abs(point[1] - min_row_1[1]) < min_distance_1: # min_distance_1 = abs(point[1] - min_row_1[1]) # min_col_point = point # # if abs(min_col_point[1] - col[1]) >= abs(min_col_point[1] - min_row_1[1])/3: # # add_point = (col[0], min_row_1[1]) # print("fix_inner add point", add_point) # new_points.append(add_point) # else: # print("distance too long", min_col_point, min_row_1) # print(abs(min_col_point[1] - col[1]), abs(min_col_point[1] - min_row_1[1])/3) return points+new_points def fix_corner(row_lines, col_lines, split_y): new_row_lines = [] new_col_lines = [] last_y = split_y[0] for y in split_y: if y == last_y: continue split_row_lines = [] split_col_lines = [] for row in row_lines: if last_y <= row[1] <= y or last_y <= row[3] <= y: split_row_lines.append(row) for col in col_lines: if last_y <= col[1] <= y or last_y <= col[3] <= y: split_col_lines.append(col) if not split_row_lines or not split_col_lines: last_y = y continue split_row_lines.sort(key=lambda x: (x[1], x[0])) split_col_lines.sort(key=lambda x: (x[0], x[1])) up_line = split_row_lines[0] bottom_line = split_row_lines[-1] left_line = split_col_lines[0] right_line = split_col_lines[-1] # 左上角 if up_line[0:2] != left_line[0:2]: # print("up_line, left_line", up_line, left_line) add_corner = [left_line[0], up_line[1]] split_row_lines[0][0] = add_corner[0] split_col_lines[0][1] = add_corner[1] # 右上角 if up_line[2:] != right_line[:2]: # print("up_line, right_line", up_line, right_line) add_corner = [right_line[0], up_line[1]] split_row_lines[0][2] = add_corner[0] split_col_lines[-1][1] = add_corner[1] new_row_lines = new_row_lines + split_row_lines new_col_lines = new_col_lines + split_col_lines last_y = y return new_row_lines, new_col_lines def delete_outline(row_lines, col_lines, points): row_lines.sort(key=lambda x: (x[1], x[0])) col_lines.sort(key=lambda x: (x[0], x[1])) line = [row_lines[0], row_lines[-1], col_lines[0], col_lines[-1]] threshold = 2 point_cnt = [0, 0, 0, 0] for point in points: for i in range(4): if i < 2: if line[i][1]-threshold <= point[1] <= line[i][1]+threshold: if line[i][0] <= point[0] <= line[i][2]: point_cnt[i] += 1 else: if line[i][0]-threshold <= point[0] <= line[i][0]+threshold: if line[i][1] <= point[1] <= line[i][3]: point_cnt[i] += 1 # if line[0][1]-threshold <= point[1] <= line[0][1]+threshold: # if line[0][0] <= point[0] <= line[0][2]: # point_cnt[0] += 1 # elif line[1][1]-threshold <= point[1] <= line[1][1]+threshold: # if line[1][0] <= point[0] <= line[1][2]: # point_cnt[1] += 1 # elif line[2][0]-threshold <= point[0] <= line[2][0]+threshold: # if line[2][1] <= point[1] <= line[2][3]: # point_cnt[2] += 1 # elif line[3][0]-threshold <= point[0] <= line[3][0]+threshold: # if line[3][1] <= point[1] <= line[3][3]: # point_cnt[3] += 1 # 轮廓line至少包含3个交点 for i in range(4): if point_cnt[i] < 3: if i < 2: if line[i] in row_lines: row_lines.remove(line[i]) else: if line[i] in col_lines: col_lines.remove(line[i]) return row_lines, col_lines def fix_outline(image, row_lines, col_lines, points, split_y): print("split_y", split_y) # 分割线纵坐标 if len(split_y) < 2: return [], [], [], [] # elif len(split_y) == 2: # split_y = [2000., 2000., 2000., 2000.] split_y.sort(key=lambda x: x) new_split_y = [] for i in range(1, len(split_y), 2): new_split_y.append(int((split_y[i]+split_y[i-1])/2)) # # 查看是否正确输出区域分割线 # for line in split_y: # cv2.line(image, (0, int(line)), (int(image.shape[1]), int(line)), (0, 0, 255), 2) # cv2.imshow("split_y", image) # cv2.waitKey(0) # 预测线根据分割线纵坐标分为多个分割区域 # row_lines.sort(key=lambda x: (x[3], x[2], x[1], x[0])) # col_lines.sort(key=lambda x: (x[3], x[2], x[1], x[0])) # points.sort(key=lambda x: (x[1], x[0])) # row_count = 0 # col_count = 0 # point_count = 0 split_row_list = [] split_col_list = [] split_point_list = [] # for i in range(1, len(split_y)): # y = split_y[i] # last_y = split_y[i-1] # row_lines = row_lines[row_count:] # col_lines = col_lines[col_count:] # points = points[point_count:] # row_count = 0 # col_count = 0 # point_count = 0 # # if not row_lines: # split_row_list.append([]) # for row in row_lines: # if last_y <= row[3] <= y: # row_count += 1 # else: # split_row_list.append(row_lines[:row_count]) # break # if row_count == len(row_lines): # split_row_list.append(row_lines[:row_count]) # break # # if not col_lines: # split_col_list.append([]) # # for col in col_lines: # # if last_y <= col[3] <= y: # if col[1] <= last_y <= y <= col[3] or last_y <= col[3] <= y: # # if last_y <= col[1] <= y or last_y <= col[3] <= y: # col_count += 1 # else: # split_col_list.append(col_lines[:col_count]) # break # if col_count == len(col_lines): # split_col_list.append(col_lines[:col_count]) # break # # if not points: # split_point_list.append([]) # for point in points: # if last_y <= point[1] <= y: # point_count += 1 # else: # split_point_list.append(points[:point_count]) # break # if point_count == len(points): # split_point_list.append(points[:point_count]) # break # # # print("len(split_row_list)", len(split_row_list)) # # print("len(split_col_list)", len(split_col_list)) # if row_count < len(row_lines) - 1 and col_count < len(col_lines) - 1: # row_lines = row_lines[row_count:] # split_row_list.append(row_lines) # col_lines = col_lines[col_count:] # split_col_list.append(col_lines) # # if point_count < len(points) - 1: # points = points[point_count:len(points)] # split_point_list.append(points) for i in range(1, len(split_y)): y = split_y[i] last_y = split_y[i-1] split_row = [] for row in row_lines: if last_y <= row[3] <= y: split_row.append(row) split_row_list.append(split_row) split_col = [] for col in col_lines: if last_y <= col[1] <= y or last_y <= col[3] <= y or col[1] < last_y < y < col[3]: split_col.append(col) split_col_list.append(split_col) split_point = [] for point in points: if last_y <= point[1] <= y: split_point.append(point) split_point_list.append(split_point) # 预测线取上下左右4个边(会有超出表格部分) [(), ()] area_row_line = [] area_col_line = [] for area in split_row_list: if not area: area_row_line.append([]) continue area.sort(key=lambda x: (x[1], x[0])) up_line = area[0] bottom_line = area[-1] area_row_line.append([up_line, bottom_line]) for area in split_col_list: if not area: area_col_line.append([]) continue area.sort(key=lambda x: x[0]) left_line = area[0] right_line = area[-1] area_col_line.append([left_line, right_line]) # 线交点根据分割线纵坐标分为多个分割区域 # points.sort(key=lambda x: (x[1], x[0])) # point_count = 0 # split_point_list = [] # for y in new_split_y: # points = points[point_count:len(points)] # point_count = 0 # for point in points: # if point[1] <= y: # point_count += 1 # else: # split_point_list.append(points[:point_count]) # break # if point_count == len(points): # split_point_list.append(points[:point_count]) # break # if point_count < len(points) - 1: # points = points[point_count:len(points)] # split_point_list.append(points) # print("len(split_point_list)", len(split_point_list)) # 取每个分割区域的4条线(无超出表格部分) area_row_line2 = [] area_col_line2 = [] for area in split_point_list: if not area: area_row_line2.append([]) area_col_line2.append([]) continue area.sort(key=lambda x: (x[0], x[1])) left_up = area[0] right_bottom = area[-1] up_line = [left_up[0], left_up[1], right_bottom[0], left_up[1]] bottom_line = [left_up[0], right_bottom[1], right_bottom[0], right_bottom[1]] left_line = [left_up[0], left_up[1], left_up[0], right_bottom[1]] right_line = [right_bottom[0], left_up[1], right_bottom[0], right_bottom[1]] area_row_line2.append([up_line, bottom_line]) area_col_line2.append([left_line, right_line]) # 判断超出部分的长度,超出一定长度就补线 new_row_lines = [] new_col_lines = [] longer_row_lines = [] longer_col_lines = [] all_longer_row_lines = [] all_longer_col_lines = [] # print("split_y", split_y) # print("split_row_list", split_row_list, len(split_row_list)) # print("split_row_list", split_col_list, len(split_col_list)) # print("area_row_line", area_row_line, len(area_row_line)) # print("area_col_line", area_col_line, len(area_col_line)) for i in range(len(area_row_line)): if not area_row_line[i] or not area_col_line[i]: continue up_line = area_row_line[i][0] up_line2 = area_row_line2[i][0] bottom_line = area_row_line[i][1] bottom_line2 = area_row_line2[i][1] left_line = area_col_line[i][0] left_line2 = area_col_line2[i][0] right_line = area_col_line[i][1] right_line2 = area_col_line2[i][1] # 计算单格高度宽度 if len(split_row_list[i]) > 1: height_dict = {} for j in range(len(split_row_list[i])): if j + 1 > len(split_row_list[i]) - 1: break height = abs(int(split_row_list[i][j][3] - split_row_list[i][j+1][3])) if height in height_dict.keys(): height_dict[height] = height_dict[height] + 1 else: height_dict[height] = 1 height_list = [[x, height_dict[x]] for x in height_dict.keys()] height_list.sort(key=lambda x: (x[1], -x[0]), reverse=True) # print("height_list", height_list) box_height = height_list[0][0] else: box_height = 10 if len(split_col_list[i]) > 1: box_width = abs(split_col_list[i][1][2] - split_col_list[i][0][2]) else: box_width = 10 print("box_height", box_height, "box_width", box_width) # cv2.line(image, (int(up_line[0]), int(up_line[1])), # (int(up_line[2]), int(up_line[3])), # (255, 255, 0), 2) # cv2.line(image, (int(right_line[0]), int(right_line[1])), # (int(right_line[2]), int(right_line[3])), # (0, 255, 255), 2) # cv2.imshow("right_line", image) # cv2.waitKey(0) # 补左右两条竖线超出来的线的row if (up_line[1] - left_line[1] >= 10 and up_line[1] - right_line[1] >= 2) or \ (up_line[1] - left_line[1] >= 2 and up_line[1] - right_line[1] >= 10): if up_line[1] - left_line[1] >= up_line[1] - right_line[1]: new_row_lines.append([left_line[0], left_line[1], right_line[0], left_line[1]]) new_col_y = left_line[1] # 补了row,要将其他短的col连到row上 for j in range(len(split_col_list[i])): col = split_col_list[i][j] # 且距离不能相差大于一格 # print("abs(new_col_y - col[1])", abs(new_col_y - col[1])) if abs(new_col_y - col[1]) <= box_height: split_col_list[i][j][1] = min([new_col_y, col[1]]) longer_col_lines.append([col[0], min([new_col_y, col[1]]), col[2], col[3]]) else: new_row_lines.append([left_line[0], right_line[1], right_line[0], right_line[1]]) new_col_y = right_line[1] # 补了row,要将其他短的col连到row上 for j in range(len(split_col_list[i])): # 需判断该线在这个区域中 # if up_line2[1]-3 <= col[1] <= col[3] <= bottom_line2[1]+3: col = split_col_list[i][j] # 且距离不能相差太大 # print("abs(new_col_y - col[1])", abs(new_col_y - col[1])) if abs(new_col_y - col[1]) <= box_height: split_col_list[i][j][1] = min([new_col_y, col[1]]) if (left_line[3] - bottom_line[3] >= 10 and right_line[3] - bottom_line[3] >= 2) or \ (left_line[3] - bottom_line[3] >= 2 and right_line[3] - bottom_line[3] >= 10): if left_line[3] - bottom_line[3] >= right_line[3] - bottom_line[3]: new_row_lines.append([left_line[2], left_line[3], right_line[2], left_line[3]]) new_col_y = left_line[3] # 补了row,要将其他短的col连到row上 for j in range(len(split_col_list[i])): col = split_col_list[i][j] # 且距离不能相差太大 if abs(new_col_y - col[3]) <= box_height: split_col_list[i][j][3] = max([new_col_y, col[3]]) else: new_row_lines.append([left_line[2], right_line[3], right_line[2], right_line[3]]) new_col_y = right_line[3] # 补了row,要将其他短的col连到row上 for j in range(len(split_col_list[i])): col = split_col_list[i][j] # 且距离不能相差太大 if abs(new_col_y - col[3]) <= box_height: split_col_list[i][j][3] = max([new_col_y, col[3]]) # 补上下两条横线超出来的线的col if (left_line[0] - up_line[0] >= 10 and left_line[0] - bottom_line[0] >= 2) or \ (left_line[0] - up_line[0] >= 2 and left_line[0] - bottom_line[0] >= 10): if left_line[0] - up_line[0] >= left_line[0] - bottom_line[0]: new_col_lines.append([up_line[0], up_line[1], up_line[0], bottom_line[1]]) new_row_x = up_line[0] # 补了col,要将其他短的row连到col上 for j in range(len(split_row_list[i])): row = split_row_list[i][j] # 且距离不能相差太大 if abs(new_row_x - row[0]) <= box_width: split_row_list[i][j][0] = min([new_row_x, row[0]]) else: new_col_lines.append([bottom_line[0], up_line[1], bottom_line[0], bottom_line[1]]) new_row_x = bottom_line[0] # 补了col,要将其他短的row连到col上 for j in range(len(split_row_list[i])): row = split_row_list[i][j] # 且距离不能相差太大 if abs(new_row_x - row[0]) <= box_width: split_row_list[i][j][0] = min([new_row_x, row[0]]) if (up_line[2] - right_line[2] >= 10 and bottom_line[2] - right_line[2] >= 2) or \ (up_line[2] - right_line[2] >= 2 and bottom_line[2] - right_line[2] >= 10): if up_line[2] - right_line[2] >= bottom_line[2] - right_line[2]: new_col_lines.append([up_line[2], up_line[3], up_line[2], bottom_line[3]]) new_row_x = up_line[2] # 补了col,要将其他短的row连到col上 for j in range(len(split_row_list[i])): row = split_row_list[i][j] # 且距离不能相差太大 if abs(new_row_x - row[2]) <= box_width: split_row_list[i][j][2] = max([new_row_x, row[2]]) else: new_col_lines.append([bottom_line[2], up_line[3], bottom_line[2], bottom_line[3]]) new_row_x = bottom_line[2] # 补了col,要将其他短的row连到col上 for j in range(len(split_row_list[i])): # 需判断该线在这个区域中 # if up_line2[1]-3 <= row[1] <= bottom_line2[1]+3: row = split_row_list[i][j] # 且距离不能相差太大 if abs(new_row_x - row[2]) <= box_width: split_row_list[i][j][2] = max([new_row_x, row[2]]) all_longer_row_lines += split_row_list[i] all_longer_col_lines += split_col_list[i] # print("all_longer_row_lines", len(all_longer_row_lines), i) # print("all_longer_col_lines", len(all_longer_col_lines), i) # print("new_row_lines", len(new_row_lines), i) # print("new_col_lines", len(new_col_lines), i) # 删除表格内部的补线 # temp_list = [] # for row in new_row_lines: # if up_line[1]-5 <= row[1] <= bottom_line[1]+5: # continue # temp_list.append(row) # print("fix_outline", new_row_lines) # new_row_lines = temp_list # print("fix_outline", new_row_lines) # temp_list = [] # for col in new_col_lines: # if left_line[0]-5 <= col[0] <= right_line[0]+5: # continue # temp_list.append(col) # # new_col_lines = temp_list # print("fix_outline", new_col_lines) # print("fix_outline", new_row_lines) # 删除重复包含的补线 # temp_list = [] # for row in new_row_lines: # if up_line[1]-5 <= row[1] <= bottom_line[1]+5: # continue # temp_list.append(row) # new_row_lines = temp_list # 展示上下左右边框线 # for i in range(len(area_row_line)): # print("row1", area_row_line[i]) # print("row2", area_row_line2[i]) # print("col1", area_col_line[i]) # print("col2", area_col_line2[i]) # cv2.line(image, (int(area_row_line[i][0][0]), int(area_row_line[i][0][1])), # (int(area_row_line[i][0][2]), int(area_row_line[i][0][3])), (0, 255, 0), 2) # cv2.line(image, (int(area_row_line2[i][1][0]), int(area_row_line2[i][1][1])), # (int(area_row_line2[i][1][2]), int(area_row_line2[i][1][3])), (0, 0, 255), 2) # cv2.imshow("fix_outline", image) # cv2.waitKey(0) # 展示所有线 # for line in all_longer_col_lines: # cv2.line(image, (int(line[0]), int(line[1])), # (int(line[2]), int(line[3])), # (0, 255, 0), 2) # cv2.imshow("fix_outline", image) # cv2.waitKey(0) # for line in all_longer_row_lines: # cv2.line(image, (int(line[0]), int(line[1])), # (int(line[2]), int(line[3])), # (0, 0, 255), 2) # cv2.imshow("fix_outline", image) # cv2.waitKey(0) return new_row_lines, new_col_lines, all_longer_row_lines, all_longer_col_lines def fix_table(row_point_list, col_point_list, split_y, row_lines, col_lines): # 分割线纵坐标 if len(split_y) < 2: return [] # 获取bbox bbox = [] # 每个点获取与其x最相近和y最相近的点 for i in range(1, len(split_y)): # 循环每行 for row in row_point_list: row.sort(key=lambda x: (x[0], x[1])) # 行不在该区域跳过 if row[0][1] <= split_y[i-1] or row[0][1] >= split_y[i]: continue # print("len(row)", len(row)) # print("row", row) # 循环行中的点 for j in range(len(row)): if j == len(row) - 1: break current_point = row[j] next_point_in_row_list = row[j+1:] # 循环这一行的下一个点 for next_point_in_row in next_point_in_row_list: # 是否在这一行点找到,找不到就这一行的下个点 not_found = 1 # 查询下个点所在列 next_col = [] for col in col_point_list: col.sort(key=lambda x: (x[1], x[0])) # 列不在该区域跳过 if col[0][1] <= split_y[i-1] or col[-1][1] >= split_y[i]: continue if col[0][0]-3 <= next_point_in_row[0] <= col[0][0]+3: next_col = col break # 循环匹配当前点和下一列点 next_col.sort(key=lambda x: (x[1], x[0])) for point1 in next_col: # 同一行的就跳过 if current_point[1]-3 <= point1[1] <= current_point[1]+3: continue if point1[1] <= current_point[1]-3: continue # 候选bbox candidate_bbox = [current_point[0], current_point[1], point1[0], point1[1]] # print("candidate_bbox", candidate_bbox) # 判断该bbox是否存在,判断bbox的上下两条边是否有包含在row中 contain_flag1 = 0 contain_flag2 = 0 for row1 in row_lines: # 行不在该区域跳过 if row1[1] <= split_y[i-1] or row1[1] >= split_y[i]: continue # bbox上边框 y一样 if not contain_flag1: if row1[1]-3 <= candidate_bbox[1] <= row1[1]+3: # 格子里的断开线段 row1_break = (max([row1[0], candidate_bbox[0]]), row1[1], min([row1[2], candidate_bbox[2]]), row1[3]) if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3: contain_flag1 = 1 # bbox下边框 y一样 if not contain_flag2: if row1[1]-3 <= candidate_bbox[3] <= row1[1]+3: # 格子里的断开线段 row1_break = (max([row1[0], candidate_bbox[0]]), row1[1], min([row1[2], candidate_bbox[2]]), row1[3]) if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3: contain_flag2 = 1 # 判断该bbox是否存在,判断bbox的左右两条边是否有包含在col中 contain_flag3 = 0 contain_flag4 = 0 for col1 in col_lines: # 列不在该区域跳过 if col1[1] <= split_y[i-1] or col1[3] >= split_y[i]: continue # bbox左边线 x一样 if not contain_flag3: if col1[0]-3 <= candidate_bbox[0] <= col1[0]+3: # 格子里的断开线段 col1_break = (col1[0], max([col1[1], candidate_bbox[1]]), col1[2], min([col1[3], candidate_bbox[3]])) if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3: contain_flag3 = 1 # bbox右边框 x一样 if not contain_flag4: if col1[0]-3 <= candidate_bbox[2] <= col1[0]+3: # 格子里的断开线段 col1_break = (col1[0], max([col1[1], candidate_bbox[1]]), col1[2], min([col1[3], candidate_bbox[3]])) if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3: contain_flag4 = 1 # 找到了该bbox,并且是存在的 if contain_flag1 and contain_flag2 and contain_flag3 and contain_flag4: bbox.append([(candidate_bbox[0], candidate_bbox[1]), (candidate_bbox[2], candidate_bbox[3])]) not_found = 0 break if not not_found: break return bbox def delete_close_points(point_list, row_point_list, col_point_list, threshold=5): new_point_list = [] delete_point_list = [] point_list.sort(key=lambda x: (x[1], x[0])) for i in range(len(point_list)): point1 = point_list[i] if point1 in delete_point_list: continue if i == len(point_list) - 1: new_point_list.append(point1) break point2 = point_list[i+1] # 判断坐标 if abs(point1[0] - point2[0]) > threshold or abs(point1[1] - point2[1]) > threshold: new_point_list.append(point1) else: # 看两个点上的相同坐标点哪个多,就保留哪个 count1 = 0 count2 = 0 for col in col_point_list: if point1[0] == col[0][0]: count1 += len(col) elif point2[0] == col[0][0]: count2 += len(col) if count1 >= count2: new_point_list.append(point1) delete_point_list.append(point2) else: new_point_list.append(point2) delete_point_list.append(point1) point_list = new_point_list new_point_list = [] delete_point_list = [] point_list.sort(key=lambda x: (x[0], x[1])) for i in range(len(point_list)): point1 = point_list[i] if point1 in delete_point_list: continue if i == len(point_list) - 1: new_point_list.append(point1) break point2 = point_list[i+1] # 判断坐标 if abs(point1[0] - point2[0]) > threshold or abs(point1[1] - point2[1]) > threshold: new_point_list.append(point1) else: count1 = 0 count2 = 0 for row in row_point_list: if point1[0] == row[0][0]: count1 += len(row) elif point2[0] == row[0][0]: count2 += len(row) if count1 >= count2: new_point_list.append(point1) delete_point_list.append(point2) else: new_point_list.append(point2) delete_point_list.append(point1) return new_point_list def get_bbox2(image_np, points): # # 坐标点按行分 # row_point_list = [] # row_point = [] # points.sort(key=lambda x: (x[0], x[1])) # for p in points: # if len(row_point) == 0: # x = p[0] # if x-5 <= p[0] <= x+5: # row_point.append(p) # else: # row_point_list.append(row_point) # row_point = [] # # 坐标点按列分 # col_point_list = [] # col_point = [] # points.sort(key=lambda x: (x[1], x[0])) # for p in points: # if len(col_point) == 0: # y = p[1] # if y-5 <= p[1] <= y+5: # col_point.append(p) # else: # col_point_list.append(col_point) # col_point = [] row_point_list = get_points_row(points) col_point_list = get_points_col(points) print("len(points)", len(points)) for point in points: cv2.circle(image_np, point, 1, (0, 255, 0), 1) cv2.imshow("points_deleted", image_np) points = delete_close_points(points, row_point_list, col_point_list) print("len(points)", len(points)) for point in points: cv2.circle(image_np, point, 1, (255, 0, 0), 3) cv2.imshow("points_deleted", image_np) cv2.waitKey(0) row_point_list = get_points_row(points, 5) col_point_list = get_points_col(points, 5) print("len(row_point_list)", len(row_point_list)) for row in row_point_list: print("row", len(row)) print("col_point_list", len(col_point_list)) for col in col_point_list: print("col", len(col)) bbox = [] for i in range(len(row_point_list)): if i == len(row_point_list) - 1: break # 遍历每个row的point,找到其所在列的下一个点和所在行的下一个点 current_row = row_point_list[i] for j in range(len(current_row)): current_point = current_row[j] if j == len(current_row) - 1: break next_row_point = current_row[j+1] # 找出当前点所在的col,得到该列下一个point current_col = col_point_list[j] for k in range(len(current_col)): if current_col[k][1] > current_point[1] + 10: next_col_point = current_col[k] break next_row = row_point_list[k] for k in range(len(next_row)): if next_row[k][0] >= next_row_point[0] + 5: next_point = next_row[k] break # 得到bbox bbox.append([(current_point[0], current_point[1]), (next_point[0], next_point[1])]) # bbox = [] # for p in points: # # print("p", p) # p_row = [] # p_col = [] # for row in row_point_list: # if p[0] == row[0][0]: # for p1 in row: # if abs(p[1]-p1[1]) <= 5: # continue # p_row.append([p1, abs(p[1]-p1[1])]) # p_row.sort(key=lambda x: x[1]) # for col in col_point_list: # if p[1] == col[0][1]: # for p2 in col: # if abs(p[0]-p2[0]) <= 5: # continue # p_col.append([p2, abs(p[0]-p2[0])]) # p_col.sort(key=lambda x: x[1]) # if len(p_row) == 0 or len(p_col) == 0: # continue # break_flag = 0 # for i in range(len(p_row)): # for j in range(len(p_col)): # # print(p_row[i][0]) # # print(p_col[j][0]) # another_point = (p_col[j][0][0], p_row[i][0][1]) # # print("another_point", another_point) # if abs(p[0]-another_point[0]) <= 5 or abs(p[1]-another_point[1]) <= 5: # continue # if p[0] >= another_point[0] or p[1] >= another_point[1]: # continue # if another_point in points: # box = [p, another_point] # box.sort(key=lambda x: x[0]) # if box not in bbox: # bbox.append(box) # break_flag = 1 # break # if break_flag: # break # # # delete duplicate # delete_bbox = [] # for i in range(len(bbox)): # for j in range(i+1, len(bbox)): # if bbox[i][0] == bbox[j][0]: # if bbox[i][1][0] - bbox[j][1][0] <= 3 \ # and bbox[i][1][1] - bbox[j][1][1] <= 3: # delete_bbox.append(bbox[j]) # if bbox[i][1] == bbox[j][1]: # if bbox[i][0][0] - bbox[j][0][0] <= 3 \ # and bbox[i][0][1] - bbox[j][0][1] <= 3: # delete_bbox.append(bbox[j]) # # delete too small area # # for box in bbox: # # if box[1][0] - box[0][0] <= # for d_box in delete_bbox: # if d_box in bbox: # bbox.remove(d_box) # print bbox bbox.sort(key=lambda x: (x[0][0], x[0][1], x[1][0], x[1][1])) # origin bbox # origin_bbox = [] # for box in bbox: # origin_bbox.append([(box[0][0], box[0][1] - 40), (box[1][0], box[1][1] - 40)]) # for box in origin_bbox: # cv2.rectangle(origin_image, box[0], box[1], (0, 0, 255), 2, 8) # cv2.imshow('AlanWang', origin_image) # cv2.waitKey(0) for box in bbox: cv2.rectangle(image_np, box[0], box[1], (0, 0, 255), 2, 8) cv2.imshow('bboxes', image_np) cv2.waitKey(0) # for point in points: # print(point) # cv2.circle(image_np, point, 1, (0, 0, 255), 3) # cv2.imshow('points', image_np) # cv2.waitKey(0) return bbox def get_bbox1(image_np, points, split_y): # 分割线纵坐标 # print("split_y", split_y) if len(split_y) < 2: return [] # 计算行列,剔除相近交点 row_point_list = get_points_row(points) col_point_list = get_points_col(points) print("len(row_point_list)", row_point_list) print("len(col_point_list)", len(col_point_list)) # for point in points: # cv2.circle(image_np, point, 1, (0, 255, 0), 1) # cv2.imshow("points", image_np) points = delete_close_points(points, row_point_list, col_point_list) # print("len(points)", len(points)) # for point in points: # cv2.circle(image_np, point, 1, (255, 0, 0), 3) # cv2.imshow("points_deleted", image_np) # cv2.waitKey(0) # 获取bbox bbox = [] # 每个点获取与其x最相近和y最相近的点 for i in range(1, len(split_y)): for point1 in points: if point1[1] <= split_y[i-1] or point1[1] >= split_y[i]: continue distance_x = 10000 distance_y = 10000 x = 0 y = 0 threshold = 10 for point2 in points: if point2[1] <= split_y[i-1] or point2[1] >= split_y[i]: continue # 最近 x y if 2 < point2[0] - point1[0] < distance_x and point2[1] - point1[1] <= threshold: distance_x = point2[0] - point1[0] x = point2[0] if 2 < point2[1] - point1[1] < distance_y and point2[0] - point1[0] <= threshold: distance_y = point2[1] - point1[1] y = point2[1] if not x or not y: continue bbox.append([(point1[0], point1[1]), (x, y)]) # 删除包含关系bbox temp_list = [] for i in range(len(bbox)): box1 = bbox[i] for j in range(len(bbox)): if i == j: continue box2 = bbox[j] contain_flag = 0 if box2[0][0] <= box1[0][0] <= box1[1][0] <= box2[1][0] and \ box2[0][1] <= box1[0][1] <= box1[1][1] <= box2[1][1]: contain_flag = 1 break temp_list.append(box1) bbox = temp_list # 展示 for box in bbox: # print(box[0], box[1]) # if abs(box[0][1] - box[1][1]) > abs(box[0][0] - box[1][0]): # continue cv2.rectangle(image_np, box[0], box[1], (0, 0, 255), 2, 8) cv2.imshow('bboxes', image_np) cv2.waitKey(0) return bbox def get_bbox0(image_np, row_point_list, col_point_list, split_y, row_lines, col_lines): # 分割线纵坐标 if len(split_y) < 2: return [] # 计算行列,剔除相近交点 # row_point_list = get_points_row(points) # col_point_list = get_points_col(points) # points = delete_close_points(points, row_point_list, col_point_list) # row_point_list = get_points_row(points) # col_point_list = get_points_col(points) # 获取bbox bbox = [] # print("get_bbox split_y", split_y) # 每个点获取与其x最相近和y最相近的点 for i in range(1, len(split_y)): # 循环每行 for row in row_point_list: row.sort(key=lambda x: (x[0], x[1])) # 行不在该区域跳过 if row[0][1] <= split_y[i-1] or row[0][1] >= split_y[i]: continue # 循环行中的点 for j in range(len(row)): if j == len(row) - 1: break current_point = row[j] next_point_in_row = row[j+1] # 查询下个点所在列 next_col = [] for col in col_point_list: col.sort(key=lambda x: (x[1], x[0])) # 列不在该区域跳过 if col[0][1] <= split_y[i-1] or col[-1][1] >= split_y[i]: continue if col[0][0]-3 <= next_point_in_row[0] <= col[0][0]+3: next_col = col break # 循环匹配当前点和下一列点 for point1 in next_col: # 同一行的就跳过 if current_point[1]-3 <= point1[1] <= current_point[1]+3: continue if point1[1] <= current_point[1]-3: continue # 候选bbox candidate_bbox = [current_point[0], current_point[1], point1[0], point1[1]] # 判断该bbox是否存在,线条包含关系 contain_flag1 = 0 contain_flag2 = 0 for row1 in row_lines: # 行不在该区域跳过 if row1[1] <= split_y[i-1] or row1[1] >= split_y[i]: continue # bbox上边框 y一样 if not contain_flag1: if row1[1]-3 <= candidate_bbox[1] <= row1[1]+3: # candidate的x1,x2需被包含在row线中 if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3: contain_flag1 = 1 # bbox下边框 y一样 if not contain_flag2: if row1[1]-3 <= candidate_bbox[3] <= row1[1]+3: # candidate的x1,x2需被包含在row线中 if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3: contain_flag2 = 1 # 找到了该bbox,并且是存在的 if contain_flag1 and contain_flag2: bbox.append([(candidate_bbox[0], candidate_bbox[1]), (candidate_bbox[2], candidate_bbox[3])]) break return bbox def get_bbox3(image_np, row_point_list, col_point_list, split_y, row_lines, col_lines): # 分割线纵坐标 if len(split_y) < 2: return [] # 获取bbox bbox = [] # 每个点获取与其x最相近和y最相近的点 for i in range(1, len(split_y)): # 循环每行 for row in row_point_list: row.sort(key=lambda x: (x[0], x[1])) # 行不在该区域跳过 if row[0][1] <= split_y[i-1] or row[0][1] >= split_y[i]: continue # print("len(row)", len(row)) # print("row", row) # 循环行中的点 for j in range(len(row)): if j == len(row) - 1: break current_point = row[j] # print("current_point", current_point) next_point_in_row_list = row[j+1:] # 循环这一行的下一个点 for next_point_in_row in next_point_in_row_list: # 是否在这一行点找到,找不到就这一行的下个点 not_found = 1 # 查询下个点所在列 next_col = [] for col in col_point_list: col.sort(key=lambda x: (x[1], x[0])) # 列不在该区域跳过 if col[0][1] <= split_y[i-1] or col[-1][1] >= split_y[i]: continue if col[0][0]-3 <= next_point_in_row[0] <= col[0][0]+3: next_col = col break # 循环匹配当前点和下一列点 next_col.sort(key=lambda x: (x[1], x[0])) for point1 in next_col: # 同一行的就跳过 if current_point[1]-3 <= point1[1] <= current_point[1]+3: continue if point1[1] <= current_point[1]-3: continue # 候选bbox candidate_bbox = [current_point[0], current_point[1], point1[0], point1[1]] # print("candidate_bbox", candidate_bbox) # 判断该bbox是否存在,判断bbox的上下两条边是否有包含在row中 contain_flag1 = 0 contain_flag2 = 0 for row1 in row_lines: # 行不在该区域跳过 if row1[1] <= split_y[i-1] or row1[1] >= split_y[i]: continue # bbox上边框 y一样 if not contain_flag1: if row1[1]-3 <= candidate_bbox[1] <= row1[1]+3: # 格子里的断开线段 row1_break = (max([row1[0], candidate_bbox[0]]), row1[1], min([row1[2], candidate_bbox[2]]), row1[3]) if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3: contain_flag1 = 1 # # candidate的x1,x2需被包含在row线中 # if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3: # contain_flag1 = 1 # # # 判断线条有无端点在格子中 # elif candidate_bbox[0] < row1[0] < candidate_bbox[2] \ # or candidate_bbox[0] < row1[2] < candidate_bbox[2]: # # 线条会有缺一点情况,判断长度超过格子一半 # if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3: # contain_flag1 = 1 # bbox下边框 y一样 if not contain_flag2: if row1[1]-3 <= candidate_bbox[3] <= row1[1]+3: # 格子里的断开线段 row1_break = (max([row1[0], candidate_bbox[0]]), row1[1], min([row1[2], candidate_bbox[2]]), row1[3]) if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3: contain_flag2 = 1 # # candidate的x1,x2需被包含在row线中 # if row1[0]-3 <= candidate_bbox[0] <= candidate_bbox[2] <= row1[2]+3: # contain_flag2 = 1 # # # 判断线条有无端点在格子中 # elif candidate_bbox[0] < row1[0] < candidate_bbox[2] \ # or candidate_bbox[0] < row1[2] < candidate_bbox[2]: # # 线条会有缺一点情况,判断长度超过格子一半 # if row1_break[2] - row1_break[0] >= (candidate_bbox[2] - candidate_bbox[0])/3: # contain_flag2 = 1 # 判断该bbox是否存在,判断bbox的左右两条边是否有包含在col中 contain_flag3 = 0 contain_flag4 = 0 for col1 in col_lines: # 列不在该区域跳过 if col1[1] <= split_y[i-1] or col1[3] >= split_y[i]: continue # bbox左边线 x一样 if not contain_flag3: if col1[0]-3 <= candidate_bbox[0] <= col1[0]+3: # 格子里的断开线段 col1_break = (col1[0], max([col1[1], candidate_bbox[1]]), col1[2], min([col1[3], candidate_bbox[3]])) if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3: contain_flag3 = 1 # # candidate的y1,y2需被包含在col线中 # if col1[1]-3 <= candidate_bbox[1] <= candidate_bbox[3] <= col1[3]+3: # contain_flag3 = 1 # # # 判断线条有无端点在格子中 # elif candidate_bbox[1] < col1[1] < candidate_bbox[3] \ # or candidate_bbox[1] < col1[3] < candidate_bbox[3]: # # 线条会有缺一点情况,判断长度超过格子一半 # if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3: # contain_flag3 = 1 # bbox右边框 x一样 if not contain_flag4: if col1[0]-3 <= candidate_bbox[2] <= col1[0]+3: # 格子里的断开线段 # col1_break = (col1[0], # max([col1[1], candidate_bbox[1]]), # col1[2], # min([col1[3], candidate_bbox[3]])) # if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3: # contain_flag4 = 1 # 如果候选bbox的边的上1/3或下1/3包含在col中 candidate_bbox_line1 = [candidate_bbox[1], candidate_bbox[1] + (candidate_bbox[3]-candidate_bbox[1])/3] candidate_bbox_line2 = [candidate_bbox[3] - (candidate_bbox[3]-candidate_bbox[1])/3, candidate_bbox[3]] if col1[1] <= candidate_bbox_line1[0] <= candidate_bbox_line1[1] <= col1[3] \ or col1[1] <= candidate_bbox_line2[0] <= candidate_bbox_line2[1] <= col1[3]: # print("candidate_bbox", candidate_bbox) # print("col1", col1) contain_flag4 = 1 # # candidate的y1,y2需被包含在col线中 # if col1[1]-3 <= candidate_bbox[1] <= candidate_bbox[3] <= col1[3]+3: # contain_flag4 = 1 # # # 判断线条有无端点在格子中 # elif candidate_bbox[1] < col1[1] < candidate_bbox[3] \ # or candidate_bbox[1] < col1[3] < candidate_bbox[3]: # # 线条会有缺一点情况,判断长度超过格子一半 # if col1_break[3] - col1_break[1] >= (candidate_bbox[3] - candidate_bbox[1])/3: # contain_flag4 = 1 # 找到了该bbox,并且是存在的 if contain_flag1 and contain_flag2 and contain_flag3 and contain_flag4: bbox.append([(candidate_bbox[0], candidate_bbox[1]), (candidate_bbox[2], candidate_bbox[3])]) not_found = 0 # print("exist candidate_bbox", candidate_bbox) # print(contain_flag1, contain_flag2, contain_flag3, contain_flag4) break # else: # print("candidate_bbox", candidate_bbox) # print(contain_flag1, contain_flag2, contain_flag3, contain_flag4) if not not_found: break return bbox def get_bbox(image_np, row_point_list, col_point_list, split_y, row_lines, col_lines): # 分割线纵坐标 if len(split_y) < 2: return [] # 获取bbox bbox_list = [] for i in range(1, len(split_y)): last_y = split_y[i-1] y = split_y[i] # 先对点线进行分区 split_row_point_list = [] split_col_point_list = [] split_row_lines = [] split_col_lines = [] for row in row_point_list: if last_y <= row[0][1] <= y: row.sort(key=lambda x: (x[1], x[0])) split_row_point_list.append(row) for col in col_point_list: if last_y <= col[0][1] <= y: split_col_point_list.append(col) for row in row_lines: if last_y <= row[1] <= y: split_row_lines.append(row) for col in col_lines: if last_y <= col[1] <= y: split_col_lines.append(col) # 每个点获取其对角线点,以便形成bbox,按行循环 for i in range(len(split_row_point_list)-1): row = split_row_point_list[i] # 循环该行的点 for k in range(len(row)-1): point1 = row[k] next_point1 = row[k+1] # print("*"*30) # print("point1", point1) # 有三种对角线点 # 1. 该点下一行的下一列的点 # 2. 该点下一列的下一行的点 # 3. 上述两个点是同一个点 # 下一行没找到就循环后面的行 if_find = 0 for j in range(i+1, len(split_row_point_list)): if if_find: break next_row = split_row_point_list[j] # print("next_row", next_row) # 循环下一行的点 for point2 in next_row: if abs(point1[0] - point2[0]) <= 2: continue if point2[0] < point1[0]: continue bbox = [point1[0], point1[1], point2[0], point2[1]] if abs(bbox[0] - bbox[2]) <= 10: continue if abs(bbox[1] - bbox[3]) <= 10: continue # bbox的四条边都需要验证是否在line上 if check_bbox(bbox, split_row_lines, split_col_lines): bbox_list.append([(bbox[0], bbox[1]), (bbox[2], bbox[3])]) if_find = 1 # print("check bbox", bbox) break return bbox_list def check_bbox(bbox, rows, cols, threshold=5): def check(check_line, lines, limit_axis, axis): # 需检查的线的1/2段,1/3段,2/3段,1/4段,3/4段 line_1_2 = [check_line[0], (check_line[0]+check_line[1])/2] line_2_2 = [(check_line[0]+check_line[1])/2, check_line[1]] line_1_3 = [check_line[0], check_line[0]+(check_line[1]-check_line[0])/3] line_2_3 = [check_line[1]-(check_line[1]-check_line[0])/3, check_line[1]] line_1_4 = [check_line[0], check_line[0]+(check_line[1]-check_line[0])/4] line_3_4 = [check_line[1]-(check_line[1]-check_line[0])/4, check_line[1]] # 限制row相同y,col相同x if_line = 0 for line1 in lines: if not if_line and abs(line1[1-axis] - limit_axis) <= threshold: # check_line完全包含在line中 if line1[axis] <= check_line[0] <= check_line[1] <= line1[axis+2]: if_line = 1 # check_line的1/2包含在line elif line1[axis] <= line_1_2[0] <= line_1_2[1] <= line1[axis+2] \ or line1[axis] <= line_2_2[0] <= line_2_2[1] <= line1[axis+2]: if_line = 1 # check_line两个1/3段被包含在不同line中 elif line1[axis] <= line_1_3[0] <= line_1_3[1] <= line1[axis+2]: # check_line另一边的1/4被包含 for line2 in lines: if abs(line1[1-axis] - limit_axis) <= threshold: if line2[axis] <= line_3_4[0] <= line_3_4[1] <= line2[axis+2]: if_line = 1 break elif line1[axis] <= line_2_3[0] <= line_2_3[1] <= line1[axis+2]: # check_line另一边的1/4被包含 for line2 in lines: if abs(line1[1-axis] - limit_axis) <= threshold: if line2[axis] <= line_1_4[0] <= line_1_4[1] <= line2[axis+2]: if_line = 1 break return if_line up_down_line = [bbox[0], bbox[2]] up_y, down_y = bbox[1], bbox[3] left_right_line = [bbox[1], bbox[3]] left_x, right_x = bbox[0], bbox[2] # 检查bbox四条边是否存在 if_up = check(up_down_line, rows, up_y, 0) if_down = check(up_down_line, rows, down_y, 0) if_left = check(left_right_line, cols, left_x, 1) if_right = check(left_right_line, cols, right_x, 1) # 检查bbox内部除了四条边,是否有其它line在bbox内部 if_col = 0 if_row = 0 if if_up and if_down and if_left and if_right: for col in cols: if not if_col and left_x+threshold <= col[0] <= right_x-threshold: if col[1] <= left_right_line[0] <= left_right_line[1] <= col[3]: if_col = 1 elif left_right_line[0] <= col[1] <= left_right_line[1]: if left_right_line[1] - col[1] >= (left_right_line[1] + left_right_line[0])/2: if_col = 1 elif left_right_line[0] <= col[3] <= left_right_line[1]: if col[3] - left_right_line[0] >= (left_right_line[1] + left_right_line[0])/2: if_col = 1 for row in rows: if not if_row and up_y+threshold <= row[1] <= down_y-threshold: if row[0] <= up_down_line[0] <= up_down_line[1] <= row[2]: if_row = 1 elif up_down_line[0] <= row[0] <= up_down_line[1]: if up_down_line[1] - row[0] >= (up_down_line[1] + up_down_line[0])/2: if_row = 1 elif up_down_line[0] <= row[2] <= up_down_line[1]: if row[2] - up_down_line[0] >= (up_down_line[1] + up_down_line[0])/2: if_row = 1 if if_up and if_down and if_left and if_right and not if_col and not if_row: return True else: return False def add_continue_bbox(bboxes): add_bbox_list = [] bboxes.sort(key=lambda x: (x[0][0], x[0][1])) last_bbox = bboxes[0] # 先对bbox分区 for i in range(1, len(split_y)): y = split_y[i] last_y = split_y[i-1] split_bbox = [] for bbox in bboxes: if last_y <= bbox[1][1] <= y: split_bbox.append(bbox) split_bbox.sort for i in range(1, len(bboxes)): bbox = bboxes[i] if last_y <= bbox[1][1] <= y and last_y <= last_bbox[1][1] <= y: if abs(last_bbox[1][1] - bbox[0][1]) <= 2: last_bbox = bbox else: if last_bbox[1][1] > bbox[0][1]: last_bbox = bbox else: add_bbox = [(last_bbox[0][0], last_bbox[1][1]), (last_bbox[1][0], bbox[0][1])] add_bbox_list.append(add_bbox) last_y = y print("add_bbox_list", add_bbox_list) if add_bbox_list: bboxes = [str(x) for x in bboxes + add_bbox_list] bboxes = list(set(bboxes)) bboxes = [eval(x) for x in bboxes] bboxes.sort(key=lambda x: (x[0][1], x[0][0])) return bboxes def points_to_line(points_lines, axis): new_line_list = [] for line in points_lines: average = 0 _min = _min = line[0][axis] _max = line[-1][axis] for point in line: average += point[1-axis] if point[axis] < _min: _min = point[axis] if point[axis] > _max: _max = point[axis] average = int(average / len(line)) if axis: new_line = [average, _min, average, _max] else: new_line = [_min, average, _max, average] new_line_list.append(new_line) return new_line_list def get_bbox_by_contours(image_np): img_gray = cv2.cvtColor(image_np, cv2.COLOR_BGR2GRAY) ret, img_bin = cv2.threshold(img_gray, 127, 255, cv2.THRESH_BINARY) # 3.连通域分析 img_bin, contours, hierarchy = cv2.findContours(img_bin, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) # 4.获取最小外接圆 圆心 半径 center, radius = cv2.minEnclosingTriangle(contours[0]) center = np.int0(center) # 5.绘制最小外接圆 img_result = image_np.copy() cv2.circle(img_result, tuple(center), int(radius), (255, 255, 255), 2) # # 读入图片 # img = image_np # cv2.imshow("get_bbox_by_contours ", image_np) # # 中值滤波,去噪 # img = cv2.medianBlur(img, 3) # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # cv2.namedWindow('original', cv2.WINDOW_AUTOSIZE) # cv2.imshow('original', gray) # # # 阈值分割得到二值化图片 # ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) # # # 膨胀操作 # kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) # bin_clo = cv2.dilate(binary, kernel2, iterations=2) # # # 连通域分析 # num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(bin_clo, connectivity=8) # # # 查看各个返回值 # # 连通域数量 # print('num_labels = ',num_labels) # # 连通域的信息:对应各个轮廓的x、y、width、height和面积 # print('stats = ',stats) # # 连通域的中心点 # print('centroids = ',centroids) # # 每一个像素的标签1、2、3.。。,同一个连通域的标签是一致的 # print('labels = ',labels) # # # 不同的连通域赋予不同的颜色 # output = np.zeros((img.shape[0], img.shape[1], 3), np.uint8) # for i in range(1, num_labels): # # mask = labels == i # output[:, :, 0][mask] = np.random.randint(0, 255) # output[:, :, 1][mask] = np.random.randint(0, 255) # output[:, :, 2][mask] = np.random.randint(0, 255) # cv2.imshow('oginal', output) # cv2.waitKey() # cv2.destroyAllWindows() def get_points_col(points, split_y, threshold=5): # 坐标点按行分 row_point_list = [] row_point = [] points.sort(key=lambda x: (x[0], x[1])) # print("get_points_col points sort", points) x = points[0][0] for i in range(1, len(split_y)): for p in points: if p[1] <= split_y[i-1] or p[1] >= split_y[i]: continue if x-threshold <= p[0] <= x+threshold: row_point.append(p) else: # print("row_point", row_point) row_point.sort(key=lambda x: (x[1], x[0])) if row_point: row_point_list.append(row_point) row_point = [] x = p[0] row_point.append(p) if row_point: row_point_list.append(row_point) return row_point_list def get_points_row(points, split_y, threshold=5): # 坐标点按列分 col_point_list = [] col_point = [] points.sort(key=lambda x: (x[1], x[0])) y = points[0][1] for i in range(len(split_y)): for p in points: if p[1] <= split_y[i-1] or p[1] >= split_y[i]: continue if y-threshold <= p[1] <= y+threshold: col_point.append(p) else: col_point.sort(key=lambda x: (x[0], x[1])) if col_point: col_point_list.append(col_point) col_point = [] y = p[1] col_point.append(p) if col_point: col_point_list.append(col_point) return col_point_list def get_outline_point(points, split_y): # 分割线纵坐标 # print("get_outline_point split_y", split_y) if len(split_y) < 2: return [] outline_2point = [] points.sort(key=lambda x: (x[1], x[0])) for i in range(1, len(split_y)): area_points = [] for point in points: if point[1] <= split_y[i-1] or point[1] >= split_y[i]: continue area_points.append(point) if area_points: area_points.sort(key=lambda x: (x[1], x[0])) outline_2point.append([area_points[0], area_points[-1]]) return outline_2point # def merge_row(row_lines): # for row in row_lines: # for row1 in row_lines: def get_best_predict_size(image_np): sizes = [1280, 1152, 1024, 896, 768, 640, 512, 384, 256, 128] min_len = 10000 best_height = sizes[0] for height in sizes: if abs(image_np.shape[0] - height) < min_len: min_len = abs(image_np.shape[0] - height) best_height = height min_len = 10000 best_width = sizes[0] for width in sizes: if abs(image_np.shape[1] - width) < min_len: min_len = abs(image_np.shape[1] - width) best_width = width return best_height, best_width def choose_longer_row(lines): new_row = [] jump_row = [] for i in range(len(lines)): row1 = lines[i] jump_flag = 0 if row1 in jump_row: continue for j in range(i+1, len(lines)): row2 = lines[j] if row2 in jump_row: continue if row2[1]-5 <= row1[1] <= row2[1]+5: if row1[0] <= row2[0] and row1[2] >= row2[2]: new_row.append(row1) jump_row.append(row1) jump_row.append(row2) jump_flag = 1 break elif row2[0] <= row1[0] and row2[2] >= row1[2]: new_row.append(row2) jump_row.append(row1) jump_row.append(row2) jump_flag = 1 break if not jump_flag: new_row.append(row1) jump_row.append(row1) return new_row def choose_longer_col(lines): new_col = [] jump_col = [] for i in range(len(lines)): col1 = lines[i] jump_flag = 0 if col1 in jump_col: continue for j in range(i+1, len(lines)): col2 = lines[j] if col2 in jump_col: continue if col2[0]-5 <= col1[0] <= col2[0]+5: if col1[1] <= col2[1] and col1[3] >= col2[3]: new_col.append(col1) jump_col.append(col1) jump_col.append(col2) jump_flag = 1 break elif col2[1] <= col1[1] and col2[3] >= col1[3]: new_col.append(col2) jump_col.append(col1) jump_col.append(col2) jump_flag = 1 break if not jump_flag: new_col.append(col1) jump_col.append(col1) return new_col def delete_contain_bbox(bboxes): # bbox互相包含,取小的bbox delete_bbox = [] for i in range(len(bboxes)): for j in range(i+1, len(bboxes)): bbox1 = bboxes[i] bbox2 = bboxes[j] # 横坐标相等情况 if bbox1[0][0] == bbox2[0][0] and bbox1[1][0] == bbox2[1][0]: if bbox1[0][1] <= bbox2[0][1] <= bbox2[1][1] <= bbox1[1][1]: # print("1", bbox1, bbox2) delete_bbox.append(bbox1) elif bbox2[0][1] <= bbox1[0][1] <= bbox1[1][1] <= bbox2[1][1]: # print("2", bbox1, bbox2) delete_bbox.append(bbox2) # 纵坐标相等情况 elif bbox1[0][1] == bbox2[0][1] and bbox1[1][1] == bbox2[1][1]: if bbox1[0][0] <= bbox2[0][0] <= bbox2[1][0] <= bbox1[1][0]: print("3", bbox1, bbox2) delete_bbox.append(bbox1) elif bbox2[0][0] <= bbox1[0][0] <= bbox1[1][0] <= bbox2[1][0]: print("4", bbox1, bbox2) delete_bbox.append(bbox2) print("delete_contain_bbox len(bboxes)", len(bboxes)) print("delete_contain_bbox len(delete_bbox)", len(delete_bbox)) for bbox in delete_bbox: if bbox in bboxes: bboxes.remove(bbox) print("delete_contain_bbox len(bboxes)", len(bboxes)) return bboxes if __name__ == '__main__': # p = "开标记录表3_page_0.png" # p = "train_data/label_1.jpg" # p = "test_files/train_463.jpg" p = "test_files/8.png" # p = "test_files/无边框3.jpg" # p = "test_files/part1.png" # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00e959a0bc9011ebaf5a00163e0ae709" + \ # "\\00e95f7cbc9011ebaf5a00163e0ae709_pdf_page0.png" # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00fb3e52bc7e11eb836000163e0ae709" + \ # "\\00fb43acbc7e11eb836000163e0ae709.png" # p = "test_files/table.jpg" # p = "data_process/create_data/0.jpg" # p = "../format_conversion/temp/f1fe9c4ac8e511eb81d700163e0857b6/f1fea1e0c8e511eb81d700163e0857b6.png" # p = "../format_conversion/1.png" img = cv2.imread(p) t = time.time() model.load_weights("") best_h, best_w = get_best_predict_size(img) print(img.shape) print((best_h, best_w)) # row_boxes, col_boxes = table_line(img[..., ::-1], model, size=(512, 1024), hprob=0.5, vprob=0.5) # row_boxes, col_boxes, img = table_line(img[..., ::-1], model, size=(best_w, best_h), hprob=0.5, vprob=0.5) row_boxes, col_boxes, img = table_line(img, model, size=(best_w, best_h), hprob=0.5, vprob=0.5) print("len(row_boxes)", len(row_boxes)) print("len(col_boxes)", col_boxes) # 创建空图 test_img = np.zeros((img.shape), np.uint8) test_img.fill(255) for box in row_boxes+col_boxes: cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1) cv2.imshow("test_image", test_img) cv2.waitKey(0) cv2.imwrite("temp.jpg", test_img) # 计算交点、分割线 crossover_points = get_points(row_boxes, col_boxes, (img.shape[0], img.shape[1])) print("len(col_boxes)", len(col_boxes)) split_lines, split_y = get_split_line(crossover_points, col_boxes, img) print("split_y", split_y) # for point in crossover_points: # cv2.circle(test_img, point, 1, (0, 255, 0), 3) # cv2.imshow("point image1", test_img) # cv2.waitKey(0) # 计算行列,剔除相近交点 row_point_list = get_points_row(crossover_points, split_y, 0) col_point_list = get_points_col(crossover_points, split_y, 0) crossover_points = delete_close_points(crossover_points, row_point_list, col_point_list) row_point_list = get_points_row(crossover_points, split_y) col_point_list = get_points_col(crossover_points, split_y) for point in crossover_points: cv2.circle(test_img, point, 1, (0, 0, 255), 3) cv2.imshow("point image1", test_img) cv2.waitKey(0) print("len(row_boxes)", len(row_boxes)) print("len(col_boxes)", len(col_boxes)) # 修复边框 new_row_boxes, new_col_boxes, long_row_boxes, long_col_boxes = \ fix_outline(img, row_boxes, col_boxes, crossover_points, split_y) if new_row_boxes or new_col_boxes: if long_row_boxes: print("long_row_boxes", long_row_boxes) row_boxes = long_row_boxes if long_col_boxes: print("long_col_boxes", long_col_boxes) col_boxes = long_col_boxes if new_row_boxes: row_boxes += new_row_boxes print("new_row_boxes", new_row_boxes) if new_col_boxes: print("new_col_boxes", new_col_boxes) col_boxes += new_col_boxes # print("len(row_boxes)", len(row_boxes)) # print("len(col_boxes)", len(col_boxes)) # row_boxes += new_row_boxes # col_boxes += new_col_boxes # row_boxes = choose_longer_row(row_boxes) # col_boxes = choose_longer_col(col_boxes) # 创建空图 test_img = np.zeros((img.shape), np.uint8) test_img.fill(255) for box in row_boxes+col_boxes: cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1) cv2.imshow("test_image2", test_img) cv2.waitKey(0) # 展示补线 for row in new_row_boxes: cv2.line(test_img, (int(row[0]), int(row[1])), (int(row[2]), int(row[3])), (0, 0, 255), 1) for col in new_col_boxes: cv2.line(test_img, (int(col[0]), int(col[1])), (int(col[2]), int(col[3])), (0, 0, 255), 1) cv2.imshow("fix_outline", test_img) cv2.waitKey(0) cv2.imwrite("temp.jpg", test_img) # 修复边框后重新计算交点、分割线 print("crossover_points", len(crossover_points)) crossover_points = get_points(row_boxes, col_boxes, (img.shape[0], img.shape[1])) print("crossover_points new", len(crossover_points)) split_lines, split_y = get_split_line(crossover_points, col_boxes, img) # 计算行列,剔除相近交点 row_point_list = get_points_row(crossover_points, split_y, 0) col_point_list = get_points_col(crossover_points, split_y, 0) print(len(crossover_points), len(row_point_list), len(col_point_list)) crossover_points = delete_close_points(crossover_points, row_point_list, col_point_list) print(len(crossover_points), len(row_point_list), len(col_point_list)) row_point_list = get_points_row(crossover_points, split_y) col_point_list = get_points_col(crossover_points, split_y) for point in crossover_points: cv2.circle(test_img, point, 1, (0, 255, 0), 3) cv2.imshow("point image2", test_img) cv2.waitKey(0) # 获取每个表格的左上右下两个点 outline_point = get_outline_point(crossover_points, split_y) # print(outline_point) for outline in outline_point: cv2.circle(test_img, outline[0], 1, (255, 0, 0), 5) cv2.circle(test_img, outline[1], 1, (255, 0, 0), 5) cv2.imshow("outline point", test_img) cv2.waitKey(0) # 获取bbox # get_bbox(img, crossover_points, split_y) # for point in crossover_points: # cv2.circle(test_img, point, 1, (0, 255, 0), 3) # cv2.imshow("point image3", test_img) # cv2.waitKey(0) # split_y = [] # for outline in outline_point: # split_y.extend([outline[0][1]-5, outline[1][1]+5]) print("len(row_boxes)", len(row_boxes)) print("len(col_boxes)", len(col_boxes)) bboxes = get_bbox(img, row_point_list, col_point_list, split_y, row_boxes, col_boxes) # 展示 for box in bboxes: # print(box[0], box[1]) # if abs(box[0][1] - box[1][1]) > abs(box[0][0] - box[1][0]): # continue cv2.rectangle(test_img, box[0], box[1], (0, 0, 255), 2, 8) cv2.imshow('bboxes', test_img) cv2.waitKey(0) # img = draw_lines(img, row_boxes+col_boxes, color=(255, 0, 0), lineW=2) # img = draw_boxes(img, rowboxes+colboxes, color=(0, 0, 255)) print(time.time()-t, len(row_boxes), len(col_boxes)) cv2.imwrite('temp.jpg', test_img) # cv2.imshow('main', img) # cv2.waitKey(0)