|
@@ -197,16 +197,16 @@ def table_net(input_shape=(1152, 896, 3), num_classes=1):
|
|
model = table_net((None, None, 3), 2)
|
|
model = table_net((None, None, 3), 2)
|
|
|
|
|
|
|
|
|
|
-def draw_pixel(pred):
|
|
|
|
|
|
+def draw_pixel(pred, prob=0.2):
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.pyplot as plt
|
|
_array = []
|
|
_array = []
|
|
for _h in range(len(pred)):
|
|
for _h in range(len(pred)):
|
|
_line = []
|
|
_line = []
|
|
for _w in range(len(pred[_h])):
|
|
for _w in range(len(pred[_h])):
|
|
_prob = pred[_h][_w]
|
|
_prob = pred[_h][_w]
|
|
- if _prob[0]>0.5:
|
|
|
|
|
|
+ if _prob[0]>prob:
|
|
_line.append((0,255,255))
|
|
_line.append((0,255,255))
|
|
- elif _prob[1]>0.5:
|
|
|
|
|
|
+ elif _prob[1]>prob:
|
|
_line.append((255,255,0))
|
|
_line.append((255,255,0))
|
|
else:
|
|
else:
|
|
_line.append((255,255,255))
|
|
_line.append((255,255,255))
|
|
@@ -367,6 +367,77 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
|
|
return list_line
|
|
return list_line
|
|
|
|
|
|
|
|
|
|
|
|
+def get_line_from_binary_image(image_np, point_value=1, axis=0):
|
|
|
|
+ """
|
|
|
|
+ 根据像素点的变化,将像素点为特定值的转化为line,即找出端点坐标。
|
|
|
|
+ 需要二值化的图。
|
|
|
|
+ 仅支持竖线横线。
|
|
|
|
+
|
|
|
|
+ :param image_np: numpy格式 image
|
|
|
|
+ :param point_value: 像素点的特定值
|
|
|
|
+ :param axis: 是否是行,否则为列
|
|
|
|
+ :return: line list
|
|
|
|
+ """
|
|
|
|
+ def get_axis_points(_list, axis=0):
|
|
|
|
+ _list.sort(key=lambda x: (x[1-axis], x[axis]))
|
|
|
|
+
|
|
|
|
+ standard_axis = points[axis][1-axis]
|
|
|
|
+ axis_points = []
|
|
|
|
+ sub_points = []
|
|
|
|
+ for p in _list:
|
|
|
|
+ if p[1-axis] == standard_axis:
|
|
|
|
+ sub_points.append(p)
|
|
|
|
+ else:
|
|
|
|
+ standard_axis = p[1-axis]
|
|
|
|
+ if sub_points:
|
|
|
|
+ axis_points.append(sub_points)
|
|
|
|
+ sub_points = []
|
|
|
|
+ # 最后一行/列
|
|
|
|
+ if sub_points:
|
|
|
|
+ axis_points.append(sub_points)
|
|
|
|
+ return axis_points
|
|
|
|
+
|
|
|
|
+ def get_axis_lines(_list, axis=0):
|
|
|
|
+ # 逐行/列判断,一行/列可能多条横线/竖线
|
|
|
|
+ points_lines = []
|
|
|
|
+ for axis_list in _list:
|
|
|
|
+ sub_line = [axis_list[0]]
|
|
|
|
+ for p in axis_list:
|
|
|
|
+ # 设置基准点
|
|
|
|
+ standard_p = sub_line[-1]
|
|
|
|
+
|
|
|
|
+ # 判断连续
|
|
|
|
+ if p[axis] - standard_p[axis] == 1:
|
|
|
|
+ sub_line.append(p)
|
|
|
|
+ else:
|
|
|
|
+ points_lines.append(sub_line)
|
|
|
|
+ sub_line = [p]
|
|
|
|
+ # 最后一行/列
|
|
|
|
+ if sub_line:
|
|
|
|
+ points_lines.append(sub_line)
|
|
|
|
+
|
|
|
|
+ # 许多点组成的line转为两点line
|
|
|
|
+ lines = []
|
|
|
|
+ for line in points_lines:
|
|
|
|
+ line.sort(key=lambda x: (x[axis], x[1-axis]))
|
|
|
|
+ lines.append([line[0][0], line[0][1], line[-1][0], line[-1][1]])
|
|
|
|
+ return lines
|
|
|
|
+
|
|
|
|
+ # 取值大于point_value的点的坐标
|
|
|
|
+ ys, xs = np.where(image_np >= point_value)
|
|
|
|
+ points = [[xs[i], ys[i]] for i in range(len(xs))]
|
|
|
|
+
|
|
|
|
+ # 提出所有相同x或相同y的点
|
|
|
|
+ # 提取行/列
|
|
|
|
+ axis_points = get_axis_points(points, axis)
|
|
|
|
+
|
|
|
|
+ # 提取每行/列的横线/竖线
|
|
|
|
+ axis_lines = get_axis_lines(axis_points, axis)
|
|
|
|
+ # print("axis_lines", axis_lines)
|
|
|
|
+
|
|
|
|
+ return axis_lines
|
|
|
|
+
|
|
|
|
+
|
|
def table_line(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=30, alph=15):
|
|
def table_line(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=30, alph=15):
|
|
sizew, sizeh = size
|
|
sizew, sizeh = size
|
|
img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)
|
|
img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)
|
|
@@ -374,8 +445,95 @@ def table_line(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=3
|
|
pred = model.predict(np.array([img_new]))
|
|
pred = model.predict(np.array([img_new]))
|
|
pred = pred[0]
|
|
pred = pred[0]
|
|
|
|
|
|
|
|
+ # 横线预测结果
|
|
|
|
+ # row_pred = pred[..., 0] > hprob
|
|
|
|
+ # row_pred = row_pred.astype(np.uint8)
|
|
|
|
+ # # 竖线预测结果
|
|
|
|
+ # col_pred = pred[..., 1] > vprob
|
|
|
|
+ # col_pred = col_pred.astype(np.uint8)
|
|
|
|
+ # # 打印模型输出
|
|
|
|
+ # cv2.imshow("predict", (col_pred+row_pred)*255)
|
|
|
|
+ # cv2.waitKey(0)
|
|
|
|
+
|
|
_time = time.time()
|
|
_time = time.time()
|
|
list_line = points2lines(pred, False)
|
|
list_line = points2lines(pred, False)
|
|
|
|
+
|
|
|
|
+ # 分成横竖线
|
|
|
|
+ list_rows = []
|
|
|
|
+ list_cols = []
|
|
|
|
+ for line in list_line:
|
|
|
|
+ if line[0] == line[2]:
|
|
|
|
+ list_cols.append(line)
|
|
|
|
+ elif line[1] == line[3]:
|
|
|
|
+ list_rows.append(line)
|
|
|
|
+
|
|
|
|
+ # 删掉贴着边框的line
|
|
|
|
+ temp_list = []
|
|
|
|
+ threshold = 5
|
|
|
|
+ for line in list_rows:
|
|
|
|
+ if line[1]-0 <= threshold or size[1]-line[1] <= threshold:
|
|
|
|
+ continue
|
|
|
|
+ # 内部排序
|
|
|
|
+ if line[0] > line[2]:
|
|
|
|
+ line = [line[2], line[3], line[0], line[1]]
|
|
|
|
+ temp_list.append(line)
|
|
|
|
+ list_rows = temp_list
|
|
|
|
+ temp_list = []
|
|
|
|
+ for line in list_cols:
|
|
|
|
+ if line[0]-0 <= threshold or size[0]-line[0] <= threshold:
|
|
|
|
+ continue
|
|
|
|
+ # 内部排序
|
|
|
|
+ if line[1] > line[3]:
|
|
|
|
+ line = [line[2], line[3], line[0], line[1]]
|
|
|
|
+ temp_list.append(line)
|
|
|
|
+ list_cols = temp_list
|
|
|
|
+ if not list_rows or not list_cols:
|
|
|
|
+ return []
|
|
|
|
+
|
|
|
|
+ # 合并错开线
|
|
|
|
+ list_rows = merge_line(list_rows, axis=0)
|
|
|
|
+ list_cols = merge_line(list_cols, axis=1)
|
|
|
|
+
|
|
|
|
+ # 计算交点、分割线
|
|
|
|
+ cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
|
|
|
|
+ if not cross_points:
|
|
|
|
+ return []
|
|
|
|
+
|
|
|
|
+ # 清掉外围的没用的线
|
|
|
|
+ list_rows, list_cols = delete_outline(list_rows, list_cols, cross_points)
|
|
|
|
+
|
|
|
|
+ # 多个表格分割线
|
|
|
|
+ split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
|
|
|
|
+
|
|
|
|
+ # 修复边框
|
|
|
|
+ new_rows, new_cols, long_rows, long_cols = fix_outline(img_new, list_rows, list_cols, cross_points,
|
|
|
|
+ split_y)
|
|
|
|
+ # 如有补线
|
|
|
|
+ if new_rows or new_cols:
|
|
|
|
+ # 连接至补线的延长线
|
|
|
|
+ if long_rows:
|
|
|
|
+ list_rows = long_rows
|
|
|
|
+ if long_cols:
|
|
|
|
+ list_cols = long_cols
|
|
|
|
+ # 新的补线
|
|
|
|
+ if new_rows:
|
|
|
|
+ list_rows += new_rows
|
|
|
|
+ if new_cols:
|
|
|
|
+ list_cols += new_cols
|
|
|
|
+
|
|
|
|
+ # 修复边框后重新计算交点、分割线
|
|
|
|
+ cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
|
|
|
|
+ split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
|
|
|
|
+
|
|
|
|
+ # 修复内部缺线
|
|
|
|
+ cross_points = fix_inner(list_rows, list_cols, cross_points, split_y)
|
|
|
|
+ if not cross_points:
|
|
|
|
+ return []
|
|
|
|
+
|
|
|
|
+ # 修复表格4个角
|
|
|
|
+ list_rows, list_cols = fix_corner(list_rows, list_cols, split_y)
|
|
|
|
+
|
|
|
|
+ list_line = list_rows + list_cols
|
|
return list_line
|
|
return list_line
|
|
|
|
|
|
|
|
|