cer_interface.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. import sys
  6. import time
  7. import traceback
  8. from glob import glob
  9. import cv2
  10. import numpy as np
  11. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  12. import tensorflow as tf
  13. from flask import Flask, request
  14. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  15. from chinese_equation_recognize.inference_equation import recognize
  16. from chinese_equation_recognize.model import crnn_ctc_equation_large, crnn_ctc_equation_loss
  17. from utils import pil_resize, np2bytes, request_post, bytes2np, base64_decode, image_to_str, str_to_image
  18. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  19. tf.compat.v1.disable_eager_execution()
  20. sess = tf.compat.v1.Session(graph=tf.Graph())
  21. package_dir = os.path.abspath(os.path.dirname(__file__))
  22. model_path = package_dir + "/models/e55-loss0.14-equation.h5"
  23. image_shape = (32, 192, 1)
  24. # 接口配置
  25. app = Flask(__name__)
  26. @app.route('/cer', methods=['POST'])
  27. def cer():
  28. start_time = time.time()
  29. logging.info("into cer_interface cer")
  30. try:
  31. # 接收网络数据
  32. if not request.form:
  33. logging.info("cer no data!")
  34. return json.dumps({"data": "", "success": 0})
  35. data = request.form.get("data")
  36. logging.info("cer_interface get data time" + str(time.time()-start_time))
  37. # 加载模型
  38. cer_model = globals().get("global_cer_model")
  39. if cer_model is None:
  40. print("=========== init cer model ===========")
  41. cer_model = CerModels().get_model()
  42. globals().update({"global_cer_model": cer_model})
  43. # 数据转换
  44. data = base64_decode(data)
  45. image_np = bytes2np(data)
  46. # 预测
  47. result = recognize(image_np, cer_model, sess)
  48. if result is None:
  49. return json.dumps({"data": "", "success": 0})
  50. return json.dumps({"data": result, "success": 1})
  51. except:
  52. traceback.print_exc()
  53. return json.dumps({"data": "", "success": 0})
  54. finally:
  55. logging.info("cer interface finish time " + str(time.time()-start_time))
  56. class CerModels:
  57. def __init__(self):
  58. with sess.as_default():
  59. with sess.graph.as_default():
  60. self.model = crnn_ctc_equation_loss(input_shape=image_shape, class_num=35+2, is_train=False)
  61. self.model.load_weights(model_path)
  62. def get_model(self):
  63. return self.model
  64. def test_cer_model(from_remote=True):
  65. paths = glob("D:/Project/captcha/data/test/FileInfo1021/1d419189-5116-11ed-851c-b4b5b67760ae_7.jpg")
  66. paths = glob(r'C:\Users\Administrator\Downloads\default.jfif')
  67. for file_path in paths:
  68. img_np = cv2.imread(file_path)
  69. h, w = img_np.shape[:2]
  70. file_bytes = np2bytes(img_np)
  71. file_base64 = base64.b64encode(file_bytes)
  72. if from_remote:
  73. file_json = {"data": file_base64}
  74. # _url = "http://192.168.2.102:17061/cer"
  75. _url = "http://127.0.0.1:17061/cer"
  76. result = json.loads(request_post(_url, file_json))
  77. if result.get("success"):
  78. result = int(result.get("data"))
  79. cv2.imshow("img_np", img_np)
  80. print("equation result", result)
  81. cv2.waitKey(0)
  82. else:
  83. print("failed!")
  84. if __name__ == "__main__":
  85. # app.run(host='127.0.0.1', port=17061, debug=False)
  86. test_cer_model()