123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- import base64
- import multiprocessing as mp
- import os
- # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
- import sys
- import time
- import traceback
- from multiprocessing.context import Process
- import multiprocessing
- from flask import Flask, jsonify
- from flask import request
- import logging
- # from table_line import *
- import cv2
- import numpy as np
- import tensorflow as tf
- from table_line import get_best_predict_size, table_line, table_net, line_fix
- from table_predict import predict
- from table_postprocess import *
- app = Flask(__name__)
- app.config['JSON_AS_ASCII'] = False
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- logger = logging.getLogger(__name__)
- def log(msg):
- """
- @summary:打印信息
- """
- logger.info(msg)
- @app.route('/otr', methods=['POST'])
- def otr():
- start_time = time.time()
- if request.method == "POST":
- # 检测是否有数据
- if not request.data:
- return 'no data'
- img_data = base64.b64decode(request.data)
- # points_and_lines = pool.apply(table_detect, (img_data,))
- points_and_lines = table_detect(img_data)
- return points_and_lines
- flag = 0
- model_path = "models/table-line.h5"
- def table_detect(img_data):
- print("child process ", os.getpid())
- start_time = time.time()
- try:
- # p = "test_files/train_463.jpg"
- p = "test_files/8.png"
- # p = "test_files/无边框1.jpg"
- # p = "test_files/part1.png"
- # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00e959a0bc9011ebaf5a00163e0ae709" + \
- # "\\00e95f7cbc9011ebaf5a00163e0ae709_pdf_page0.png"
- # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00fb3e52bc7e11eb836000163e0ae709" + \
- # "\\00fb43acbc7e11eb836000163e0ae709.png"
- # p = "test_files/table.jpg"
- # p = "C:\\Users\\Administrator\\Desktop\\table2-1.jpg"
- # p = r"D:\Project\format_conversion_maxcompute\format_convert\temp" \
- # r"\3cf0bfae0a1d11ec9622f02f748bb81b\3cf0e6a40a1d11ecb866f02f748bb81b_pdf_page1.png"
- # p = "data_process/create_data/0.jpg"
- # p = "../format_conversion/temp/f1fe9c4ac8e511eb81d700163e0857b6/f1fea1e0c8e511eb81d700163e0857b6.png"
- # p = "../format_conversion/1.png"
- image_np = cv2.imread(p)
- # 选择与图片最接近分辨率,以防失真
- best_h, best_w = get_best_predict_size(image_np)
- # 调用模型
- rows, cols, image_np = predict(image_np, otr_model, size=(best_w, best_h))
- print("len(cols), len(rows)", len(cols), len(rows))
- cols = merge_line(cols, 1)
- rows = merge_line(rows, 0)
- rows, cols = fix_gap(rows, cols)
- print("len(cols), len(rows)", len(cols), len(rows))
- print("rows", rows)
- print("cols", cols)
- # 创建空图
- test_img = np.zeros((image_np.shape), np.uint8)
- test_img.fill(255)
- for box in rows+cols:
- cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1)
- cv2.imwrite("predict.jpg", test_img)
- cv2.imshow("my post process", test_img)
- cv2.waitKey(0)
- # test_img = line_fix(test_img)
- # cv2.imshow("line fix", test_img)
- # cv2.waitKey(0)
- # 调用模型
- # rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
- #
- # if not rows or not cols:
- # print("points", 0, "split_lines", 0, "bboxes", 0)
- # return {"points": str([]), "split_lines": str([]),
- # "bboxes": str([]), "outline_points": str([])}
- #
- # print("len(rows)", len(rows))
- # print("len(cols)", len(cols))
- # 创建空图
- # test_img = np.zeros((image_np.shape), np.uint8)
- # test_img.fill(255)
- # for box in rows+cols:
- # cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1)
- # cv2.imshow("his post process", test_img)
- # cv2.waitKey(0)
- # cv2.imwrite("temp.jpg", test_img)
- # 处理结果
- # 合并直线
- # 计算交点
- points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
- # 计算分割线
- split_y = get_split_line(cols, (image_np.shape[0], image_np.shape[1]))
- print("split_y", split_y)
- print("image shape", test_img.shape)
- for y in split_y:
- cv2.line(test_img, (0, y), (test_img.shape[1], y), (0, 0, 255), 1)
- cv2.imshow("split line", test_img)
- cv2.waitKey(0)
- cv2.imwrite("temp.jpg", test_img)
- # 点、线分为多个区域
- points_area_list = get_point_area(points, split_y)
- rows_area_list = get_line_area(rows, split_y)
- cols_area_list = get_line_area(cols, split_y)
- print("points_area_list", points_area_list)
- print("rows_area_list", rows_area_list)
- print("cols_area_list", cols_area_list)
- # 对每个区域进行处理
- for i in range(len(points_area_list)):
- print("i", i)
- points_area = points_area_list[i]
- rows_area = rows_area_list[i]
- cols_area = cols_area_list[i]
- if not points_area or not rows_area or not cols_area:
- continue
- print("len(rows_area)", len(rows_area))
- print("len(cols_area)", len(cols_area))
- # 修复表格边框
- new_rows, new_cols, longer_rows, longer_cols = \
- fix_outline_area(rows_area, cols_area, points_area)
- if new_rows:
- cols_area = longer_cols
- rows_area += new_rows
- if new_cols:
- rows_area = longer_rows
- cols_area += new_cols
- # 单区域线
- test_img = np.zeros((image_np.shape), np.uint8)
- test_img.fill(255)
- for box in rows_area+cols_area:
- cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1)
- for box in new_rows+new_cols:
- cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2)
- cv2.imshow("fix_outline", test_img)
- cv2.waitKey(0)
- # cv2.imwrite("temp.jpg", test_img)
- print("use time: ", time.time()-start_time)
- return {"points": str([]), "split_lines": str([]),
- "bboxes": str([]), "outline_points": str([])}
- except Exception as e:
- print("otr_interface cannot detected table!", traceback.print_exc())
- print("points", 0, "split_lines", 0, "bboxes", 0)
- return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
- "outline_points": str([])}
- otr_model = table_net((None, None, 3), 2)
- otr_model.load_weights(model_path)
- if __name__ == '__main__':
- # if len(sys.argv) == 2:
- # port = int(sys.argv[1])
- # else:
- # port = 15017
- # app.run(host='0.0.0.0', port=port, threaded=True, debug=True)
- # log("OTR running "+str(port))
- table_detect("")
|