ced_interface.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. import sys
  6. import time
  7. import traceback
  8. from glob import glob
  9. import cv2
  10. import numpy as np
  11. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  12. import tensorflow as tf
  13. from flask import Flask, request
  14. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  15. from chinese_equation_denoise.inference_equation_denoise import denoise
  16. from chinese_equation_denoise.model import u_net_denoise
  17. from utils import pil_resize, np2bytes, request_post, bytes2np, base64_decode, image_to_str, str_to_image
  18. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  19. tf.compat.v1.disable_eager_execution()
  20. sess = tf.compat.v1.Session(graph=tf.Graph())
  21. package_dir = os.path.abspath(os.path.dirname(__file__))
  22. model_path = package_dir + "/models/denoise_loss_53.97.h5"
  23. image_shape = (32, 192, 1)
  24. # 接口配置
  25. app = Flask(__name__)
  26. @app.route('/ced', methods=['POST'])
  27. def ced():
  28. start_time = time.time()
  29. logging.info("into ced_interface ced")
  30. try:
  31. # 接收网络数据
  32. if not request.form:
  33. logging.info("ced no data!")
  34. return json.dumps({"data": "", "success": 0})
  35. data = request.form.get("data")
  36. logging.info("ced_interface get data time" + str(time.time()-start_time))
  37. # 加载模型
  38. ced_model = globals().get("global_ced_model")
  39. if ced_model is None:
  40. print("=========== init ced model ===========")
  41. ced_model = CedModels().get_model()
  42. globals().update({"global_ced_model": ced_model})
  43. # 数据转换
  44. data = base64_decode(data)
  45. image_np = bytes2np(data)
  46. # 预测
  47. image_result = denoise(image_np, ced_model, sess)
  48. image_str = image_to_str(image_result)
  49. return json.dumps({"data": image_str, "success": 1})
  50. except:
  51. traceback.print_exc()
  52. return json.dumps({"data": "", "success": 0})
  53. finally:
  54. logging.info("ced interface finish time " + str(time.time()-start_time))
  55. class CedModels:
  56. def __init__(self):
  57. with sess.as_default():
  58. with sess.graph.as_default():
  59. self.model = u_net_denoise(input_shape=image_shape, class_num=image_shape[2])
  60. self.model.load_weights(model_path)
  61. def get_model(self):
  62. return self.model
  63. def test_ced_model(from_remote=True):
  64. paths = glob("D:/Project/captcha/data/test/FileInfo1021/19584571-511d-11ed-93ac-b4b5b67760ae_3.jpg")
  65. for file_path in paths:
  66. img_np = cv2.imread(file_path)
  67. h, w = img_np.shape[:2]
  68. file_bytes = np2bytes(img_np)
  69. file_base64 = base64.b64encode(file_bytes)
  70. if from_remote:
  71. file_json = {"data": file_base64}
  72. # _url = "http://192.168.2.102:17060/ced"
  73. _url = "http://127.0.0.1:17060/ced"
  74. result = json.loads(request_post(_url, file_json))
  75. if result.get("success"):
  76. img_new = str_to_image(result.get("data"))
  77. cv2.imshow("img_np", img_np)
  78. cv2.imshow("img_new", img_new)
  79. cv2.waitKey(0)
  80. else:
  81. print("failed!")
  82. if __name__ == "__main__":
  83. # app.run(host='127.0.0.1', port=17060, debug=False)
  84. test_ced_model()
  85. # with open(r'C:\Users\Administrator\Downloads\新建文本文档+(3).txt', 'r') as f:
  86. # _b = f.read()
  87. # # b_str = str(_b)
  88. # # print(len(b_str))
  89. # data = base64_decode(_b)
  90. # with open(r'C:\Users\Administrator\Downloads\11.jpg', 'wb') as f:
  91. # f.write(data)
  92. # image_np = bytes2np(data)
  93. # print(image_np.shape)
  94. # cv2.imwrite(r'C:\Users\Administrator\Downloads\11.jpg', image_np)
  95. # cv2.imshow('img', image_np)
  96. # cv2.waitKey(0)