|
|
@@ -6,7 +6,11 @@ table line detect
|
|
|
@author: chineseocr
|
|
|
"""
|
|
|
import copy
|
|
|
+import io
|
|
|
import logging
|
|
|
+import sys
|
|
|
+import traceback
|
|
|
+
|
|
|
import tensorflow as tf
|
|
|
import tensorflow.keras.backend as K
|
|
|
from tensorflow.keras.models import Model
|
|
|
@@ -16,7 +20,8 @@ from otr.utils import letterbox_image, get_table_line, adjust_lines, line_to_lin
|
|
|
import numpy as np
|
|
|
import cv2
|
|
|
import time
|
|
|
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
+from format_convert import _global
|
|
|
+from format_convert.utils import log
|
|
|
|
|
|
|
|
|
def dice_coef(y_true, y_pred, smooth=1e-5):
|
|
|
@@ -239,7 +244,7 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
|
|
|
_last = False
|
|
|
_current = False
|
|
|
while 1:
|
|
|
- h_index += 5
|
|
|
+ h_index += 2
|
|
|
if h_index>=height:
|
|
|
break
|
|
|
w_index = -1
|
|
|
@@ -248,7 +253,7 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
|
|
|
else:
|
|
|
h_i = h_index
|
|
|
while 1:
|
|
|
- w_index += 5
|
|
|
+ w_index += 2
|
|
|
if w_index>=width:
|
|
|
break
|
|
|
if _sum[w_index]<min_len:
|
|
|
@@ -274,12 +279,12 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
|
|
|
w_index = -1
|
|
|
_sum = list(np.sum(np.array((pred[...,0]>prob)).astype(int),axis=1))
|
|
|
while 1:
|
|
|
- w_index += 5
|
|
|
+ w_index += 2
|
|
|
if w_index>=width:
|
|
|
break
|
|
|
h_index = -1
|
|
|
while 1:
|
|
|
- h_index += 5
|
|
|
+ h_index += 2
|
|
|
if h_index>=height:
|
|
|
break
|
|
|
if _sum[h_index]<min_len:
|
|
|
@@ -355,7 +360,7 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
|
|
|
for _line in final_horizontal:
|
|
|
list_line.append(_line)
|
|
|
|
|
|
- logging.info("points2lines cost %.2fs"%(time.time()-_time))
|
|
|
+ log("points2lines cost %.2fs"%(time.time()-_time))
|
|
|
|
|
|
# import matplotlib.pyplot as plt
|
|
|
# plt.figure()
|
|
|
@@ -443,14 +448,179 @@ def get_line_from_binary_image(image_np, point_value=1, axis=0):
|
|
|
return axis_lines
|
|
|
|
|
|
|
|
|
+def table_preprocess(img_data, prob=0.2):
|
|
|
+ try:
|
|
|
+ log("into table_preprocess, prob is " + str(prob))
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ # 二进制数据流转np.ndarray [np.uint8: 8位像素]
|
|
|
+ img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
|
|
|
+
|
|
|
+ # 将bgr转为rbg
|
|
|
+ image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
|
+
|
|
|
+ # 模型输入
|
|
|
+ inputs = np.array([image_np])
|
|
|
+
|
|
|
+ # # 压缩numpy
|
|
|
+ # compressed_array = io.BytesIO()
|
|
|
+ # np.savez_compressed(compressed_array, inputs)
|
|
|
+ # compressed_array.seek(0)
|
|
|
+ # inputs_compressed = compressed_array.read()
|
|
|
+
|
|
|
+ log("otr preprocess time: " + str(round(float(time.time()-start_time), 4)) + "s")
|
|
|
+ return image_np, inputs
|
|
|
+ except Exception as e:
|
|
|
+ log("table pre process failed!")
|
|
|
+ traceback.print_exc()
|
|
|
+ return [-1], [-1]
|
|
|
+
|
|
|
+
|
|
|
+def table_postprocess(img_new, pred, prob=0.2, is_test=0):
|
|
|
+ try:
|
|
|
+ # 横线预测结果
|
|
|
+ # row_pred = pred[..., 0] > hprob
|
|
|
+ # row_pred = row_pred.astype(np.uint8)
|
|
|
+ # # 竖线预测结果
|
|
|
+ # col_pred = pred[..., 1] > vprob
|
|
|
+ # col_pred = col_pred.astype(np.uint8)
|
|
|
+ # # 打印模型输出
|
|
|
+ # cv2.imshow("predict", (col_pred+row_pred)*255)
|
|
|
+ # cv2.waitKey(0)
|
|
|
+
|
|
|
+ start_time = time.time()
|
|
|
+ list_line = points2lines(pred, False, prob=prob)
|
|
|
+ mat_plot(list_line, "points2lines", is_test)
|
|
|
+ log("points2lines " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 清除短线
|
|
|
+ # print(img_new.shape)
|
|
|
+ start_time = time.time()
|
|
|
+ list_line = delete_short_lines(list_line, img_new.shape)
|
|
|
+ mat_plot(list_line, "delete_short_lines", is_test)
|
|
|
+ log("delete_short_lines " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 清除无交点线
|
|
|
+ start_time = time.time()
|
|
|
+ list_line = delete_no_cross_lines(list_line)
|
|
|
+ mat_plot(list_line, "delete_no_cross_lines", is_test)
|
|
|
+ log("delete_no_cross_lines " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 分成横竖线
|
|
|
+ start_time = time.time()
|
|
|
+ list_rows = []
|
|
|
+ list_cols = []
|
|
|
+ for line in list_line:
|
|
|
+ if line[0] == line[2]:
|
|
|
+ list_cols.append(line)
|
|
|
+ elif line[1] == line[3]:
|
|
|
+ list_rows.append(line)
|
|
|
+ log("divide rows and cols " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 合并错开线
|
|
|
+ start_time = time.time()
|
|
|
+ list_rows = merge_line(list_rows, axis=0)
|
|
|
+ list_cols = merge_line(list_cols, axis=1)
|
|
|
+ mat_plot(list_rows+list_cols, "merge_line", is_test)
|
|
|
+ log("merge_line " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 计算交点、分割线
|
|
|
+ start_time = time.time()
|
|
|
+ cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
|
|
|
+ if not cross_points:
|
|
|
+ return []
|
|
|
+ log("get_points " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 清掉外围的没用的线
|
|
|
+ # list_rows, list_cols = delete_outline(list_rows, list_cols, cross_points)
|
|
|
+ # mat_plot(list_rows+list_cols, "delete_outline", is_test)
|
|
|
+
|
|
|
+ # 多个表格分割线
|
|
|
+ start_time = time.time()
|
|
|
+ list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new)
|
|
|
+ split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
|
|
|
+ log("get_split_line " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 修复边框
|
|
|
+ start_time = time.time()
|
|
|
+ new_rows, new_cols, long_rows, long_cols = fix_outline(img_new, list_rows, list_cols, cross_points,
|
|
|
+ split_y)
|
|
|
+
|
|
|
+ # 如有补线
|
|
|
+ if new_rows or new_cols:
|
|
|
+ # 连接至补线的延长线
|
|
|
+ if long_rows:
|
|
|
+ list_rows = long_rows
|
|
|
+ if long_cols:
|
|
|
+ list_cols = long_cols
|
|
|
+ # 新的补线
|
|
|
+ if new_rows:
|
|
|
+ list_rows += new_rows
|
|
|
+ if new_cols:
|
|
|
+ list_cols += new_cols
|
|
|
+
|
|
|
+ list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new)
|
|
|
+
|
|
|
+ # 修复边框后重新计算交点、分割线
|
|
|
+ cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
|
|
|
+ cv_plot(cross_points, img_new.shape, 0, is_test)
|
|
|
+
|
|
|
+ split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
|
|
|
+ print("fix new split_y", split_y)
|
|
|
+ print("fix new split_lines", split_lines)
|
|
|
+
|
|
|
+ # 修复内部缺线
|
|
|
+ # cross_points = fix_inner(list_rows, list_cols, cross_points, split_y)
|
|
|
+ # if not cross_points:
|
|
|
+ # return []
|
|
|
+ mat_plot(list_rows+list_cols, "fix_outline", is_test)
|
|
|
+
|
|
|
+ split_lines_show = []
|
|
|
+ for _l in split_lines:
|
|
|
+ split_lines_show.append([_l[0][0], _l[0][1], _l[1][0], _l[1][1]])
|
|
|
+ mat_plot(split_lines_show+list_cols,
|
|
|
+ "split_lines", is_test)
|
|
|
+ log("fix_outline " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 修复表格4个角
|
|
|
+ start_time = time.time()
|
|
|
+ list_rows, list_cols = fix_corner(list_rows, list_cols, split_y, threshold=0)
|
|
|
+ mat_plot(list_rows+list_cols, "fix_corner", is_test)
|
|
|
+ log("fix_corner " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 修复内部缺线
|
|
|
+ start_time = time.time()
|
|
|
+ list_rows, list_cols = fix_inner(list_rows, list_cols, cross_points, split_y)
|
|
|
+ mat_plot(list_rows+list_cols, "fix_inner", is_test)
|
|
|
+ log("fix_inner " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 合并错开线
|
|
|
+ start_time = time.time()
|
|
|
+ list_rows = merge_line(list_rows, axis=0)
|
|
|
+ list_cols = merge_line(list_cols, axis=1)
|
|
|
+ mat_plot(list_rows+list_cols, "merge_line", is_test)
|
|
|
+ log("merge_line " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ list_line = list_rows + list_cols
|
|
|
+
|
|
|
+ # 打印处理后线
|
|
|
+ mat_plot(list_line, "all", is_test)
|
|
|
+ log("otr postprocess table_line " + str(time.time()-start_time))
|
|
|
+ return list_line
|
|
|
+ except Exception as e:
|
|
|
+ log("table post process failed!")
|
|
|
+ traceback.print_exc()
|
|
|
+ return [-1]
|
|
|
+
|
|
|
+
|
|
|
def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
|
|
|
- logging.info("into table_line, prob is " + str(prob))
|
|
|
+ log("into table_line, prob is " + str(prob))
|
|
|
sizew, sizeh = size
|
|
|
img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)
|
|
|
|
|
|
start_time = time.time()
|
|
|
pred = model.predict(np.array([img_new]))
|
|
|
- logging.info("otr model predict time " + str(time.time()-start_time))
|
|
|
+ log("otr model predict time " + str(time.time()-start_time))
|
|
|
pred = pred[0]
|
|
|
|
|
|
draw_pixel(pred, prob, is_test)
|
|
|
@@ -468,20 +638,20 @@ def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
|
|
|
start_time = time.time()
|
|
|
list_line = points2lines(pred, False, prob=prob)
|
|
|
mat_plot(list_line, "points2lines", is_test)
|
|
|
- logging.info("points2lines " + str(time.time()-start_time))
|
|
|
+ log("points2lines " + str(time.time()-start_time))
|
|
|
|
|
|
# 清除短线
|
|
|
# print(img_new.shape)
|
|
|
start_time = time.time()
|
|
|
list_line = delete_short_lines(list_line, img_new.shape)
|
|
|
mat_plot(list_line, "delete_short_lines", is_test)
|
|
|
- logging.info("delete_short_lines " + str(time.time()-start_time))
|
|
|
+ log("delete_short_lines " + str(time.time()-start_time))
|
|
|
|
|
|
# 清除无交点线
|
|
|
start_time = time.time()
|
|
|
list_line = delete_no_cross_lines(list_line)
|
|
|
mat_plot(list_line, "delete_no_cross_lines", is_test)
|
|
|
- logging.info("delete_no_cross_lines " + str(time.time()-start_time))
|
|
|
+ log("delete_no_cross_lines " + str(time.time()-start_time))
|
|
|
|
|
|
# 分成横竖线
|
|
|
start_time = time.time()
|
|
|
@@ -492,21 +662,21 @@ def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
|
|
|
list_cols.append(line)
|
|
|
elif line[1] == line[3]:
|
|
|
list_rows.append(line)
|
|
|
- logging.info("divide rows and cols " + str(time.time()-start_time))
|
|
|
+ log("divide rows and cols " + str(time.time()-start_time))
|
|
|
|
|
|
# 合并错开线
|
|
|
start_time = time.time()
|
|
|
list_rows = merge_line(list_rows, axis=0)
|
|
|
list_cols = merge_line(list_cols, axis=1)
|
|
|
mat_plot(list_rows+list_cols, "merge_line", is_test)
|
|
|
- logging.info("merge_line " + str(time.time()-start_time))
|
|
|
+ log("merge_line " + str(time.time()-start_time))
|
|
|
|
|
|
# 计算交点、分割线
|
|
|
start_time = time.time()
|
|
|
cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
|
|
|
if not cross_points:
|
|
|
return []
|
|
|
- logging.info("get_points " + str(time.time()-start_time))
|
|
|
+ log("get_points " + str(time.time()-start_time))
|
|
|
|
|
|
# 清掉外围的没用的线
|
|
|
# list_rows, list_cols = delete_outline(list_rows, list_cols, cross_points)
|
|
|
@@ -516,7 +686,7 @@ def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
|
|
|
start_time = time.time()
|
|
|
list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new)
|
|
|
split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
|
|
|
- logging.info("get_split_line " + str(time.time()-start_time))
|
|
|
+ log("get_split_line " + str(time.time()-start_time))
|
|
|
|
|
|
# 修复边框
|
|
|
start_time = time.time()
|
|
|
@@ -557,32 +727,32 @@ def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
|
|
|
split_lines_show.append([_l[0][0], _l[0][1], _l[1][0], _l[1][1]])
|
|
|
mat_plot(split_lines_show+list_cols,
|
|
|
"split_lines", is_test)
|
|
|
- logging.info("fix_outline " + str(time.time()-start_time))
|
|
|
+ log("fix_outline " + str(time.time()-start_time))
|
|
|
|
|
|
# 修复表格4个角
|
|
|
start_time = time.time()
|
|
|
list_rows, list_cols = fix_corner(list_rows, list_cols, split_y, threshold=0)
|
|
|
mat_plot(list_rows+list_cols, "fix_corner", is_test)
|
|
|
- logging.info("fix_corner " + str(time.time()-start_time))
|
|
|
+ log("fix_corner " + str(time.time()-start_time))
|
|
|
|
|
|
# 修复内部缺线
|
|
|
start_time = time.time()
|
|
|
list_rows, list_cols = fix_inner(list_rows, list_cols, cross_points, split_y)
|
|
|
mat_plot(list_rows+list_cols, "fix_inner", is_test)
|
|
|
- logging.info("fix_inner " + str(time.time()-start_time))
|
|
|
+ log("fix_inner " + str(time.time()-start_time))
|
|
|
|
|
|
# 合并错开线
|
|
|
start_time = time.time()
|
|
|
list_rows = merge_line(list_rows, axis=0)
|
|
|
list_cols = merge_line(list_cols, axis=1)
|
|
|
mat_plot(list_rows+list_cols, "merge_line", is_test)
|
|
|
- logging.info("merge_line " + str(time.time()-start_time))
|
|
|
+ log("merge_line " + str(time.time()-start_time))
|
|
|
|
|
|
list_line = list_rows + list_cols
|
|
|
|
|
|
# 打印处理后线
|
|
|
mat_plot(list_line, "all", is_test)
|
|
|
- logging.info("otr postprocess table_line " + str(time.time()-start_time))
|
|
|
+ log("otr postprocess table_line " + str(time.time()-start_time))
|
|
|
return list_line
|
|
|
|
|
|
|
|
|
@@ -593,16 +763,16 @@ def table_line2(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=
|
|
|
# pred = model.predict(np.array([np.array(inputBlob)]))
|
|
|
# pred = model.predict(np.array([np.array(inputBlob)/255.0]))
|
|
|
img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)
|
|
|
- # logging.info("into table_line 1")
|
|
|
+ # log("into table_line 1")
|
|
|
|
|
|
pred = model.predict(np.array([img_new]))
|
|
|
- # logging.info("into table_line 2")
|
|
|
+ # log("into table_line 2")
|
|
|
pred = pred[0]
|
|
|
|
|
|
draw_pixel(pred)
|
|
|
_time = time.time()
|
|
|
points2lines(pred)
|
|
|
- logging.info("points2lines takes %ds"%(time.time()-_time))
|
|
|
+ log("points2lines takes %ds"%(time.time()-_time))
|
|
|
|
|
|
vpred = pred[..., 1] > vprob # 横线
|
|
|
hpred = pred[..., 0] > hprob # 竖线
|
|
|
@@ -614,7 +784,7 @@ def table_line2(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=
|
|
|
|
|
|
colboxes = get_table_line(vpred, axis=1, lineW=col)
|
|
|
rowboxes = get_table_line(hpred, axis=0, lineW=row)
|
|
|
- # logging.info("into table_line 3")
|
|
|
+ # log("into table_line 3")
|
|
|
|
|
|
# if len(rowboxes) > 0:
|
|
|
# rowboxes = np.array(rowboxes)
|
|
|
@@ -633,7 +803,7 @@ def table_line2(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=
|
|
|
for j in range(ncol):
|
|
|
rowboxes[i] = line_to_line(rowboxes[i], colboxes[j], 10)
|
|
|
colboxes[j] = line_to_line(colboxes[j], rowboxes[i], 10)
|
|
|
- # logging.info("into table_line 4")
|
|
|
+ # log("into table_line 4")
|
|
|
|
|
|
# 删掉贴着边框的line
|
|
|
temp_list = []
|
|
|
@@ -1810,7 +1980,7 @@ def fix_outline2(image, row_lines, col_lines, points, split_y):
|
|
|
|
|
|
|
|
|
def fix_outline(image, row_lines, col_lines, points, split_y, scale=25):
|
|
|
- logging.info("into fix_outline")
|
|
|
+ log("into fix_outline")
|
|
|
x_min_len = max(10, int(image.shape[0] / scale))
|
|
|
y_min_len = max(10, int(image.shape[1] / scale))
|
|
|
# print("x_min_len", x_min_len, "y_min_len", y_min_len)
|