123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import base64
- import json
- import logging
- import os
- import sys
- import time
- import traceback
- from glob import glob
- import cv2
- import numpy as np
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
- import tensorflow as tf
- from flask import Flask, request
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- from chinese_equation_denoise.inference_equation_denoise import denoise
- from chinese_equation_denoise.model import u_net_denoise
- from utils import pil_resize, np2bytes, request_post, bytes2np, base64_decode, image_to_str, str_to_image
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- tf.compat.v1.disable_eager_execution()
- sess = tf.compat.v1.Session(graph=tf.Graph())
- package_dir = os.path.abspath(os.path.dirname(__file__))
- model_path = package_dir + "/models/denoise_loss_53.97.h5"
- image_shape = (32, 192, 1)
- # 接口配置
- app = Flask(__name__)
- @app.route('/ced', methods=['POST'])
- def ced():
- start_time = time.time()
- logging.info("into ced_interface ced")
- try:
- # 接收网络数据
- if not request.form:
- logging.info("ced no data!")
- return json.dumps({"data": "", "success": 0})
- data = request.form.get("data")
- logging.info("ced_interface get data time" + str(time.time()-start_time))
- # 加载模型
- ced_model = globals().get("global_ced_model")
- if ced_model is None:
- print("=========== init ced model ===========")
- ced_model = CedModels().get_model()
- globals().update({"global_ced_model": ced_model})
- # 数据转换
- data = base64_decode(data)
- image_np = bytes2np(data)
- # 预测
- image_result = denoise(image_np, ced_model, sess)
- image_str = image_to_str(image_result)
- return json.dumps({"data": image_str, "success": 1})
- except:
- traceback.print_exc()
- return json.dumps({"data": "", "success": 0})
- finally:
- logging.info("ced interface finish time " + str(time.time()-start_time))
- class CedModels:
- def __init__(self):
- with sess.as_default():
- with sess.graph.as_default():
- self.model = u_net_denoise(input_shape=image_shape, class_num=image_shape[2])
- self.model.load_weights(model_path)
- def get_model(self):
- return self.model
- def test_ced_model(from_remote=True):
- paths = glob("D:/Project/captcha/data/test/FileInfo1021/19584571-511d-11ed-93ac-b4b5b67760ae_3.jpg")
- for file_path in paths:
- img_np = cv2.imread(file_path)
- h, w = img_np.shape[:2]
- file_bytes = np2bytes(img_np)
- file_base64 = base64.b64encode(file_bytes)
- if from_remote:
- file_json = {"data": file_base64}
- # _url = "http://192.168.2.102:17060/ced"
- _url = "http://127.0.0.1:17060/ced"
- result = json.loads(request_post(_url, file_json))
- if result.get("success"):
- img_new = str_to_image(result.get("data"))
- cv2.imshow("img_np", img_np)
- cv2.imshow("img_new", img_new)
- cv2.waitKey(0)
- else:
- print("failed!")
- if __name__ == "__main__":
- # app.run(host='127.0.0.1', port=17060, debug=False)
- test_ced_model()
- # with open(r'C:\Users\Administrator\Downloads\新建文本文档+(3).txt', 'r') as f:
- # _b = f.read()
- # # b_str = str(_b)
- # # print(len(b_str))
- # data = base64_decode(_b)
- # with open(r'C:\Users\Administrator\Downloads\11.jpg', 'wb') as f:
- # f.write(data)
- # image_np = bytes2np(data)
- # print(image_np.shape)
- # cv2.imwrite(r'C:\Users\Administrator\Downloads\11.jpg', image_np)
- # cv2.imshow('img', image_np)
- # cv2.waitKey(0)
|