123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288 |
- import base64
- import multiprocessing as mp
- import os
- # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
- import sys
- import time
- import traceback
- from multiprocessing.context import Process
- import multiprocessing
- import logging
- # from table_line import *
- import cv2
- import numpy as np
- import tensorflow as tf
- from otr.table_line import get_best_predict_size, table_line, get_points, get_split_line, get_points_row, \
- get_points_col, \
- delete_close_points, fix_outline, get_bbox, get_outline_point, table_net, delete_contain_bbox, points_to_line, \
- fix_inner, merge_line, fix_corner, add_continue_bbox, delete_outline
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- logger = logging.getLogger(__name__)
- def log(msg):
- """
- @summary:打印信息
- """
- logger.info(msg)
- def otr(data, otr_model):
- try:
- img_data = base64.b64decode(data)
- # points_and_lines = pool.apply(table_detect, (img_data,))
- points_and_lines = table_detect(img_data, otr_model)
- return points_and_lines
- except TimeoutError:
- raise TimeoutError
- flag = 0
- # model_path = "models/table-line.h5"
- def table_detect(img_data, otr_model):
- logging.info("into otr_interface table_detect")
- start_time = time.time()
- try:
- start_time1 = time.time()
- # 二进制数据流转np.ndarray [np.uint8: 8位像素]
- img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
- # logging.info("into otr_interface table_detect 1")
- # cv2.imwrite("111111.jpg", img)
- # 将bgr转为rbg
- image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- # logging.info("into otr_interface table_detect 2")
- # 选择与图片最接近分辨率,以防失真
- # best_h, best_w = get_best_predict_size(img)
- print("image_np.shape", image_np.shape)
- best_h, best_w, _ = image_np.shape
- logging.info("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
- # 调用模型
- # rows, cols = table_line(image_np, otr_model)
- start_time1 = time.time()
- rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
- logging.info("otr model predict time: " + str(round(float(time.time()-start_time1), 4)) + "s")
- start_time1 = time.time()
- if not rows or not cols:
- print("points", 0, "split_lines", 0, "bboxes", 0)
- return {"points": str([]), "split_lines": str([]),
- "bboxes": str([]), "outline_points": str([]),
- "lines": str([])}
- # 查看是否正确输出rows,cols
- # for line in rows+cols:
- # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
- # (255, 0, 0), 2)
- # cv2.imshow("rows-cols1", img)
- # cv2.waitKey(0)
- # 处理结果
- # 合并错开线
- rows = merge_line(rows, axis=0)
- cols = merge_line(cols, axis=1)
- # 计算交点、分割线
- points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
- # logging.info("into otr_interface table_detect 5")
- if not points:
- print("points", 0, "split_lines", 0, "bboxes", 0)
- return {"points": str([]), "split_lines": str([]),
- "bboxes": str([]), "outline_points": str([]),
- "lines": str([])}
- # 清掉外围的没用的线
- rows, cols = delete_outline(rows, cols, points)
- split_lines, split_y = get_split_line(points, cols, image_np)
- # logging.info("into otr_interface table_detect 6")
- # 计算交点所在行列,剔除相近交点
- row_point_list = get_points_row(points, split_y, 5)
- col_point_list = get_points_col(points, split_y, 5)
- # logging.info("into otr_interface table_detect 7")
- points = delete_close_points(points, row_point_list, col_point_list)
- # logging.info("into otr_interface table_detect 8")
- # 查看是否正确输出点
- # for p in points:
- # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
- # cv2.imshow("points", img)
- # cv2.waitKey(0)
- # 查看是否正确输出rows,cols
- # for line in rows+cols:
- # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
- # (0, 255, 0), 2)
- # cv2.imshow("rows-cols0", img)
- # cv2.waitKey(0)
- # 修复边框
- new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points,
- split_y)
- # print(new_cols, new_rows)
- if new_rows or new_cols:
- # 连接至补线的延长线
- if long_rows:
- rows = long_rows
- if long_cols:
- cols = long_cols
- # 新的补线
- if new_rows:
- rows += new_rows
- if new_cols:
- cols += new_cols
- # 修复边框后重新计算交点、分割线
- points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
- # logging.info("into otr_interface table_detect 10")
- split_lines, split_y = get_split_line(points, cols, image_np)
- # 计算交点所在行列,剔除相近交点
- row_point_list = get_points_row(points, split_y, 0)
- col_point_list = get_points_col(points, split_y, 0)
- # logging.info("into otr_interface table_detect 11")
- points = delete_close_points(points, row_point_list, col_point_list)
- # row_point_list = get_points_row(points, split_y)
- # col_point_list = get_points_col(points, split_y)
- # logging.info("into otr_interface table_detect 12")
- # 查看是否正确输出rows,cols
- # for line in rows+cols:
- # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
- # (255, 0, 0), 2)
- # cv2.imshow("rows-cols1", img)
- # cv2.waitKey(0)
- # 修复表格4个角
- rows, cols = fix_corner(rows, cols, split_y)
- points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
- # row_point_list = get_points_row(points, split_y, 5)
- # col_point_list = get_points_col(points, split_y, 5)
- # print("row_point_list", row_point_list)
- # print("col_point_list", col_point_list)
- # 修复内部缺线
- points = fix_inner(rows, cols, points, split_y)
- if not points:
- print("points", 0, "split_lines", 0, "bboxes", 0)
- return {"points": str([]), "split_lines": str([]),
- "bboxes": str([]), "outline_points": str([]),
- "lines": str([])}
- row_point_list = get_points_row(points, split_y, 5)
- col_point_list = get_points_col(points, split_y, 5)
- # 查看是否正确输出点
- # for p in points:
- # cv2.circle(img, (p[0], p[1]), 1, (0, 255, 0), 3)
- # cv2.imshow("points fix", img)
- # cv2.waitKey(0)
- # 查看是否正确输出rows,cols
- # for line in rows+cols:
- # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
- # (255, 0, 0), 2)
- # cv2.imshow("rows-cols2", img)
- # cv2.waitKey(0)
- # 根据分行分列重新得到rows、cols,避免线延长导致后续bbox生成失败
- # rows = points_to_line(row_point_list, axis=0)
- # cols = points_to_line(col_point_list, axis=1)
- # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
- # row_point_list = get_points_row(points, split_y, 0)
- # col_point_list = get_points_col(points, split_y, 0)
- # 获取bbox 单元格
- bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols)
- # logging.info("into otr_interface table_detect 13")
- # 删除包含bbox
- if bboxes:
- bboxes = delete_contain_bbox(bboxes)
- # 查看是否能输出正确框
- # for box in bboxes:
- # cv2.rectangle(img, box[0], box[1], (0, 0, 255), 3)
- # cv2.imshow("bbox", img)
- # cv2.waitKey(0)
- # 补充连续框
- # if bboxes:
- # bboxes = add_continue_bbox(bboxes)
- #
- # # 删除包含bbox
- # bboxes = delete_contain_bbox(bboxes)
- # 查看是否能输出正确框
- # cv2.namedWindow('bbox', 0)
- # for box in bboxes:
- # cv2.rectangle(img, box[0], box[1], (0, 255, 0), 3)
- # cv2.imshow("bbox", img)
- # cv2.waitKey(0)
- # 查看是否正确输出点
- # cv2.namedWindow('points', 0)
- # for p in points:
- # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
- # cv2.imshow("points", img)
- # cv2.waitKey(0)
- # 查看是否正确输出区域分割线
- # cv2.namedWindow('split_lines', 0)
- # for line in split_lines:
- # cv2.line(img, line[0], line[1], (0, 0, 255), 2)
- # cv2.imshow("split_lines", img)
- # cv2.waitKey(0)
- # 获取每个表格的左上右下两个点
- outline_points = get_outline_point(points, split_y)
- # logging.info("into otr_interface table_detect 14")
- if bboxes:
- print("bboxes number", len(bboxes))
- # print("bboxes", bboxes)
- else:
- print("bboxes number", "None")
- logging.info("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
- logging.info("use time: " + str(time.time()-start_time))
- return {"points": str(points), "split_lines": str(split_lines),
- "bboxes": str(bboxes), "outline_points": str(outline_points),
- "lines": str(rows+cols)}
- except TimeoutError:
- raise TimeoutError
- except Exception as e:
- logging.info("otr_interface cannot detected table!")
- print("otr_interface cannot detected table!", traceback.print_exc())
- print("points", 0, "split_lines", 0, "bboxes", 0)
- logging.info("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
- return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
- "outline_points": str([]), "lines": str([])}
- class OtrModels:
- def __init__(self):
- # python文件所在目录
- _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
- model_path = _dir + "/models/table-line.h5"
- self.otr_model = table_net((None, None, 3), 2)
- self.otr_model.load_weights(model_path)
- def get_model(self):
- return self.otr_model
- # otr_model = table_net((None, None, 3), 2)
- # otr_model.load_weights(model_path)
- if __name__ == '__main__':
- # if len(sys.argv) == 2:
- # port = int(sys.argv[1])
- # else:
- # port = 15017
- # app.run(host='0.0.0.0', port=port, threaded=False, debug=False)
- # log("OTR running "+str(port))
- otr_model = OtrModels().get_model()
- otr("11", otr_model)
|