chr_interface.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. import sys
  6. import time
  7. import traceback
  8. from glob import glob
  9. import cv2
  10. import numpy as np
  11. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  13. import tensorflow as tf
  14. from flask import Flask, request
  15. from chinese_recognize.model import cnn_net, cnn_net_tiny, cnn_net_small
  16. from chinese_recognize.inference_char import recognize
  17. from utils import pil_resize, np2bytes, request_post, bytes2np, base64_decode, str_to_image
  18. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  19. tf.compat.v1.disable_eager_execution()
  20. sess = tf.compat.v1.Session(graph=tf.Graph())
  21. package_dir = os.path.abspath(os.path.dirname(__file__))
  22. image_shape = (40, 40, 1)
  23. model_path = package_dir + "/models/char_acc_0.89.h5"
  24. # 接口配置
  25. app = Flask(__name__)
  26. @app.route('/chr', methods=['POST'])
  27. def _chr():
  28. start_time = time.time()
  29. logging.info("into chr_interface chr")
  30. try:
  31. # 接收网络数据
  32. if not request.form:
  33. logging.info("chr no data!")
  34. return json.dumps({"data": "", "success": 0})
  35. data = request.form.get("data")
  36. logging.info("chr_interface get data time" + str(time.time()-start_time))
  37. # 加载模型
  38. chr_model = globals().get("global_chr_model")
  39. if chr_model is None:
  40. print("=========== init chr model ===========")
  41. chr_model = ChrModels().get_model()
  42. globals().update({"global_chr_model": chr_model})
  43. # 数据转换
  44. str_list = json.loads(data)
  45. image_np_list = []
  46. for _str in str_list:
  47. image_np = str_to_image(_str)
  48. # b64 = _str.encode("utf-8")
  49. # image_np = bytes2np(base64_decode(b64))
  50. image_np_list.append(image_np)
  51. # 预测
  52. char_list = recognize(image_np_list, chr_model, sess)
  53. return json.dumps({"data": char_list, "success": 1})
  54. except:
  55. traceback.print_exc()
  56. return json.dumps({"data": "", "success": 0})
  57. finally:
  58. logging.info("chr interface finish time " + str(time.time()-start_time))
  59. class ChrModels:
  60. def __init__(self):
  61. with sess.as_default():
  62. with sess.graph.as_default():
  63. self.model = cnn_net_small(input_shape=image_shape)
  64. self.model.load_weights(model_path)
  65. def get_model(self):
  66. return self.model
  67. def test_chr_model(from_remote=True):
  68. paths = glob("D:/Project/captcha/data/test/char_9.jpg")
  69. str_list = []
  70. for file_path in paths:
  71. img_np = cv2.imread(file_path)
  72. cv2.imshow("img_np", img_np)
  73. cv2.waitKey(0)
  74. h, w = img_np.shape[:2]
  75. file_bytes = np2bytes(img_np)
  76. file_base64 = base64.b64encode(file_bytes)
  77. file_str = file_base64.decode("utf-8")
  78. str_list.append(file_str)
  79. if from_remote:
  80. file_json = {"data": json.dumps(str_list)}
  81. _url = "http://192.168.2.103:17057/chr"
  82. result = json.loads(request_post(_url, file_json))
  83. if result.get("success"):
  84. char_list = result.get("data")
  85. for i in range(len(paths)):
  86. print("image_path, char", paths[i], char_list[i])
  87. else:
  88. print("failed!")
  89. if __name__ == "__main__":
  90. app.run(host='127.0.0.1', port=17057, debug=False)
  91. # test_chr_model()