import base64 import multiprocessing as mp import os # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit' import sys import time import traceback from multiprocessing.context import Process import multiprocessing from flask import Flask, jsonify from flask import request import logging # from table_line import * import cv2 import numpy as np import tensorflow as tf from table_line import get_best_predict_size, table_line, table_net, line_fix from table_predict import predict from table_postprocess import * app = Flask(__name__) app.config['JSON_AS_ASCII'] = False logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def log(msg): """ @summary:打印信息 """ logger.info(msg) @app.route('/otr', methods=['POST']) def otr(): start_time = time.time() if request.method == "POST": # 检测是否有数据 if not request.data: return 'no data' img_data = base64.b64decode(request.data) # points_and_lines = pool.apply(table_detect, (img_data,)) points_and_lines = table_detect(img_data) return points_and_lines flag = 0 model_path = "models/table-line.h5" def table_detect(img_data): print("child process ", os.getpid()) start_time = time.time() try: # p = "test_files/train_463.jpg" p = "test_files/8.png" # p = "test_files/无边框1.jpg" # p = "test_files/part1.png" # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00e959a0bc9011ebaf5a00163e0ae709" + \ # "\\00e95f7cbc9011ebaf5a00163e0ae709_pdf_page0.png" # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00fb3e52bc7e11eb836000163e0ae709" + \ # "\\00fb43acbc7e11eb836000163e0ae709.png" # p = "test_files/table.jpg" # p = "C:\\Users\\Administrator\\Desktop\\table2-1.jpg" # p = r"D:\Project\format_conversion_maxcompute\format_convert\temp" \ # r"\3cf0bfae0a1d11ec9622f02f748bb81b\3cf0e6a40a1d11ecb866f02f748bb81b_pdf_page1.png" # p = "data_process/create_data/0.jpg" # p = "../format_conversion/temp/f1fe9c4ac8e511eb81d700163e0857b6/f1fea1e0c8e511eb81d700163e0857b6.png" # p = "../format_conversion/1.png" image_np = cv2.imread(p) # 选择与图片最接近分辨率,以防失真 best_h, best_w = get_best_predict_size(image_np) # 调用模型 rows, cols, image_np = predict(image_np, otr_model, size=(best_w, best_h)) print("len(cols), len(rows)", len(cols), len(rows)) cols = merge_line(cols, 1) rows = merge_line(rows, 0) rows, cols = fix_gap(rows, cols) print("len(cols), len(rows)", len(cols), len(rows)) print("rows", rows) print("cols", cols) # 创建空图 test_img = np.zeros((image_np.shape), np.uint8) test_img.fill(255) for box in rows+cols: cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1) cv2.imwrite("predict.jpg", test_img) cv2.imshow("my post process", test_img) cv2.waitKey(0) # test_img = line_fix(test_img) # cv2.imshow("line fix", test_img) # cv2.waitKey(0) # 调用模型 # rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5) # # if not rows or not cols: # print("points", 0, "split_lines", 0, "bboxes", 0) # return {"points": str([]), "split_lines": str([]), # "bboxes": str([]), "outline_points": str([])} # # print("len(rows)", len(rows)) # print("len(cols)", len(cols)) # 创建空图 # test_img = np.zeros((image_np.shape), np.uint8) # test_img.fill(255) # for box in rows+cols: # cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1) # cv2.imshow("his post process", test_img) # cv2.waitKey(0) # cv2.imwrite("temp.jpg", test_img) # 处理结果 # 合并直线 # 计算交点 points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # 计算分割线 split_y = get_split_line(cols, (image_np.shape[0], image_np.shape[1])) print("split_y", split_y) print("image shape", test_img.shape) for y in split_y: cv2.line(test_img, (0, y), (test_img.shape[1], y), (0, 0, 255), 1) cv2.imshow("split line", test_img) cv2.waitKey(0) cv2.imwrite("temp.jpg", test_img) # 点、线分为多个区域 points_area_list = get_point_area(points, split_y) rows_area_list = get_line_area(rows, split_y) cols_area_list = get_line_area(cols, split_y) print("points_area_list", points_area_list) print("rows_area_list", rows_area_list) print("cols_area_list", cols_area_list) # 对每个区域进行处理 for i in range(len(points_area_list)): print("i", i) points_area = points_area_list[i] rows_area = rows_area_list[i] cols_area = cols_area_list[i] if not points_area or not rows_area or not cols_area: continue print("len(rows_area)", len(rows_area)) print("len(cols_area)", len(cols_area)) # 修复表格边框 new_rows, new_cols, longer_rows, longer_cols = \ fix_outline_area(rows_area, cols_area, points_area) if new_rows: cols_area = longer_cols rows_area += new_rows if new_cols: rows_area = longer_rows cols_area += new_cols # 单区域线 test_img = np.zeros((image_np.shape), np.uint8) test_img.fill(255) for box in rows_area+cols_area: cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1) for box in new_rows+new_cols: cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2) cv2.imshow("fix_outline", test_img) cv2.waitKey(0) # cv2.imwrite("temp.jpg", test_img) print("use time: ", time.time()-start_time) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([])} except Exception as e: print("otr_interface cannot detected table!", traceback.print_exc()) print("points", 0, "split_lines", 0, "bboxes", 0) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([])} otr_model = table_net((None, None, 3), 2) otr_model.load_weights(model_path) if __name__ == '__main__': # if len(sys.argv) == 2: # port = int(sys.argv[1]) # else: # port = 15017 # app.run(host='0.0.0.0', port=port, threaded=True, debug=True) # log("OTR running "+str(port)) table_detect("")