import base64 import json import multiprocessing as mp import os import sys import traceback # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit' # os.environ['CUDA_VISIBLE_DEVICES'] = "0" sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") from config.max_compute_config import MAX_COMPUTE import tensorflow as tf if not MAX_COMPUTE: # tensorflow 内存设置 try: gpus = tf.config.list_physical_devices('GPU') if len(gpus) > 0: tf.config.experimental.set_virtual_device_configuration( gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2048)]) except: traceback.print_exc() # pass # gpus = tf.config.list_physical_devices('GPU') # for gpu in gpus: # 如果使用多块GPU时 # tf.config.experimental.set_memory_growth(gpu, True) os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648) os.environ['CUDA_CACHE_DISABLE'] = str(0) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) import time import logging import cv2 import numpy as np from flask import Flask, request from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes, get_platform from otr.table_line import table_net from otr.table_line_new import table_line from format_convert import _global # 接口配置 app = Flask(__name__) @app.route('/otr', methods=['POST']) def _otr(): _global._init() _global.update({"port": globals().get("port")}) start_time = time.time() log("into otr_interface _otr") try: if not request.form: log("otr no data!") return json.dumps({"list_line": str([-9])}) otr_model = globals().get("global_otr_model") if otr_model is None: otr_model = OtrModels().get_model() globals().update({"global_otr_model": otr_model}) data = request.form.get("data") is_from_pdf = request.form.get("is_from_pdf") _md5 = request.form.get("md5") _global.update({"md5": _md5}) list_lines = otr(data, otr_model, is_from_pdf) return json.dumps(list_lines) except TimeoutError: return json.dumps({"list_line": str([-5])}) except: traceback.print_exc() return json.dumps({"list_line": str([-1])}) finally: log("otr interface finish time " + str(time.time()-start_time)) def otr(data, otr_model, is_from_pdf): log("into otr_interface otr") log('otr is_from_pdf ' + str(is_from_pdf)) try: img_data = base64.b64decode(data) if str(is_from_pdf) == 'True': # log('otr is_from_pdf ' + str(type(is_from_pdf))) list_lines = line_detect(img_data, otr_model, prob=0.2) else: # log('otr is_from_pdf ' + str(type(is_from_pdf))) list_lines = line_detect(img_data, otr_model, prob=0.5) return list_lines except TimeoutError: raise TimeoutError # def table_detect2(img_data, otr_model): # log("into otr_interface table_detect") # start_time = time.time() # try: # start_time1 = time.time() # # 二进制数据流转np.ndarray [np.uint8: 8位像素] # img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR) # # log("into otr_interface table_detect 1") # # cv2.imwrite("111111.jpg", img) # # # 将bgr转为rbg # image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # # log("into otr_interface table_detect 2") # # # 选择与图片最接近分辨率,以防失真 # # best_h, best_w = get_best_predict_size(img) # print("image_np.shape", image_np.shape) # best_h, best_w, _ = image_np.shape # log("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") # # # 调用模型 # # rows, cols = table_line(image_np, otr_model) # rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5) # # start_time1 = time.time() # if not rows or not cols: # print("points", 0, "split_lines", 0, "bboxes", 0) # return {"points": str([]), "split_lines": str([]), # "bboxes": str([]), "outline_points": str([]), # "lines": str([])} # # # 查看是否正确输出rows,cols # # for line in rows+cols: # # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), # # (255, 0, 0), 2) # # cv2.imshow("rows-cols1", img) # # cv2.waitKey(0) # # # 处理结果 # # 合并错开线 # rows = merge_line(rows, axis=0) # cols = merge_line(cols, axis=1) # # # 计算交点、分割线 # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # # log("into otr_interface table_detect 5") # if not points: # print("points", 0, "split_lines", 0, "bboxes", 0) # return {"points": str([]), "split_lines": str([]), # "bboxes": str([]), "outline_points": str([]), # "lines": str([])} # # # 清掉外围的没用的线 # rows, cols = delete_outline(rows, cols, points) # # split_lines, split_y = get_split_line(points, cols, image_np) # # log("into otr_interface table_detect 6") # # # 计算交点所在行列,剔除相近交点 # row_point_list = get_points_row(points, split_y, 5) # col_point_list = get_points_col(points, split_y, 5) # # log("into otr_interface table_detect 7") # points = delete_close_points(points, row_point_list, col_point_list) # # log("into otr_interface table_detect 8") # # # 查看是否正确输出点 # # for p in points: # # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255)) # # cv2.imshow("points", img) # # cv2.waitKey(0) # # # 查看是否正确输出rows,cols # # for line in rows+cols: # # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), # # (0, 255, 0), 2) # # cv2.imshow("rows-cols0", img) # # cv2.waitKey(0) # # # 修复边框 # new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points, # split_y) # # print(new_cols, new_rows) # if new_rows or new_cols: # # 连接至补线的延长线 # if long_rows: # rows = long_rows # if long_cols: # cols = long_cols # # 新的补线 # if new_rows: # rows += new_rows # if new_cols: # cols += new_cols # # # 修复边框后重新计算交点、分割线 # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # # # log("into otr_interface table_detect 10") # split_lines, split_y = get_split_line(points, cols, image_np) # # # 计算交点所在行列,剔除相近交点 # row_point_list = get_points_row(points, split_y, 0) # col_point_list = get_points_col(points, split_y, 0) # # log("into otr_interface table_detect 11") # points = delete_close_points(points, row_point_list, col_point_list) # # row_point_list = get_points_row(points, split_y) # # col_point_list = get_points_col(points, split_y) # # log("into otr_interface table_detect 12") # # # 查看是否正确输出rows,cols # # for line in rows+cols: # # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), # # (255, 0, 0), 2) # # cv2.imshow("rows-cols1", img) # # cv2.waitKey(0) # # # 修复表格4个角 # rows, cols = fix_corner(rows, cols, split_y) # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # # row_point_list = get_points_row(points, split_y, 5) # # col_point_list = get_points_col(points, split_y, 5) # # # print("row_point_list", row_point_list) # # print("col_point_list", col_point_list) # # # 修复内部缺线 # points = fix_inner(rows, cols, points, split_y) # if not points: # print("points", 0, "split_lines", 0, "bboxes", 0) # return {"points": str([]), "split_lines": str([]), # "bboxes": str([]), "outline_points": str([]), # "lines": str([])} # row_point_list = get_points_row(points, split_y, 5) # col_point_list = get_points_col(points, split_y, 5) # # # 查看是否正确输出点 # # for p in points: # # cv2.circle(img, (p[0], p[1]), 1, (0, 255, 0), 3) # # cv2.imshow("points fix", img) # # cv2.waitKey(0) # # # 查看是否正确输出rows,cols # # for line in rows+cols: # # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), # # (255, 0, 0), 2) # # cv2.imshow("rows-cols2", img) # # cv2.waitKey(0) # # # 根据分行分列重新得到rows、cols,避免线延长导致后续bbox生成失败 # # rows = points_to_line(row_point_list, axis=0) # # cols = points_to_line(col_point_list, axis=1) # # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # # row_point_list = get_points_row(points, split_y, 0) # # col_point_list = get_points_col(points, split_y, 0) # # # 获取bbox 单元格 # bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols) # # log("into otr_interface table_detect 13") # # # 删除包含bbox # if bboxes: # bboxes = delete_contain_bbox(bboxes) # # # 查看是否能输出正确框 # # for box in bboxes: # # cv2.rectangle(img, box[0], box[1], (0, 0, 255), 3) # # cv2.imshow("bbox", img) # # cv2.waitKey(0) # # # 补充连续框 # # if bboxes: # # bboxes = add_continue_bbox(bboxes) # # # # # 删除包含bbox # # bboxes = delete_contain_bbox(bboxes) # # # 查看是否能输出正确框 # # cv2.namedWindow('bbox', 0) # # for box in bboxes: # # cv2.rectangle(img, box[0], box[1], (0, 255, 0), 3) # # cv2.imshow("bbox", img) # # cv2.waitKey(0) # # # 查看是否正确输出点 # # cv2.namedWindow('points', 0) # # for p in points: # # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255)) # # cv2.imshow("points", img) # # cv2.waitKey(0) # # # 查看是否正确输出区域分割线 # # cv2.namedWindow('split_lines', 0) # # for line in split_lines: # # cv2.line(img, line[0], line[1], (0, 0, 255), 2) # # cv2.imshow("split_lines", img) # # cv2.waitKey(0) # # # 获取每个表格的左上右下两个点 # outline_points = get_outline_point(points, split_y) # # log("into otr_interface table_detect 14") # # if bboxes: # print("bboxes number", len(bboxes)) # # print("bboxes", bboxes) # else: # print("bboxes number", "None") # log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") # log("otr finish: " + str(round(float(time.time()-start_time1), 4)) + "s") # return {"points": str(points), "split_lines": str(split_lines), # "bboxes": str(bboxes), "outline_points": str(outline_points), # "lines": str(rows+cols)} # # except TimeoutError: # raise TimeoutError # except Exception as e: # log("otr_interface cannot detected table!") # print("otr_interface cannot detected table!", traceback.print_exc()) # print("points", 0, "split_lines", 0, "bboxes", 0) # log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") # return {"points": str([]), "split_lines": str([]), "bboxes": str([]), # "outline_points": str([]), "lines": str([])} def line_detect(img_data, otr_model, prob=0.2): log("into otr_interface table_detect") log('line_detect prob ' + str(prob)) start_time = time.time() try: start_time1 = time.time() # 二进制数据流转np.ndarray [np.uint8: 8位像素] img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR) # log("into otr_interface table_detect 1") # cv2.imwrite("111111.jpg", img) # 将bgr转为rbg image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # log("into otr_interface table_detect 2") # 选择与图片最接近分辨率,以防失真 # best_h, best_w = get_best_predict_size(img) log("image_np.shape" + str(image_np.shape)) best_h, best_w, _ = image_np.shape log("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") # 调用模型 # rows, cols = table_line(image_np, otr_model) start_time1 = time.time() list_line = table_line(image_np, otr_model, size=(best_w, best_h), prob=prob) log("otr finish " + str(round(float(time.time()-start_time1), 4)) + "s") return {"list_line": str(list_line)} except TimeoutError: raise TimeoutError except Exception as e: log("otr_interface cannot detected table!") print("otr_interface cannot detected table!", traceback.print_exc()) log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") return {"list_line": str([])} class OtrModels: def __init__(self): # python文件所在目录 _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) model_path = _dir + "/models/table-line.h5" self.otr_model = table_net((None, None, 3), 2) self.otr_model.load_weights(model_path) def get_model(self): return self.otr_model def test_otr_model(from_remote=True): _global._init() from format_convert.convert_image import get_best_predict_size, image_process if get_platform() == "Windows": file_path = "C:/Users/Administrator/Desktop/error2.png" file_path = "C:/Users/Administrator/Downloads/1652672734044.jpg" else: file_path = "1.jpg" image_np = cv2.imread(file_path) best_h, best_w = get_best_predict_size(image_np) image_resize = cv2.resize(image_np, (best_w, best_h), interpolation=cv2.INTER_AREA) cv2.imwrite(file_path, image_resize) with open(file_path, "rb") as f: file_bytes = f.read() file_base64 = base64.b64encode(file_bytes) _md5 = get_md5_from_bytes(file_bytes)[0] _global.update({"port": 15010, "md5": _md5}) if from_remote: file_json = {"data": file_base64, "is_from_pdf": False, "md5": _md5} # _url = "http://192.168.2.104:18000/otr" _url = "http://127.0.0.1:18000/otr" r = json.loads(request_post(_url, file_json)) else: # otr_model = OtrModels().get_model() # r = otr(file_base64, otr_model, is_from_pdf=False) r = image_process(image_resize, file_path) print(r) if __name__ == '__main__': if len(sys.argv) == 2: port = int(sys.argv[1]) elif len(sys.argv) == 3: port = int(sys.argv[1]) using_gpu_index = int(sys.argv[2]) else: port = 18000 using_gpu_index = 0 # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False) app.run() log("OTR running "+str(port)) # test_otr_model(False) # print(json.dumps([-2])) # otr_model = OtrModels().get_model() # otr("11", otr_model)