cac_interface.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. import sys
  6. import time
  7. import traceback
  8. from glob import glob
  9. import cv2
  10. import numpy as np
  11. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  12. import tensorflow as tf
  13. from flask import Flask, request
  14. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  15. from captcha_classify.inference_classify import classify
  16. from captcha_classify.model import cnn_net_tiny
  17. from utils import pil_resize, np2bytes, request_post, bytes2np, base64_decode, image_to_str, str_to_image
  18. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  19. tf.compat.v1.disable_eager_execution()
  20. sess = tf.compat.v1.Session(graph=tf.Graph())
  21. package_dir = os.path.abspath(os.path.dirname(__file__))
  22. model_path = package_dir + "/models/e262-acc0.81-classify.h5"
  23. image_shape = (128, 128, 1)
  24. class_num = 3
  25. # 接口配置
  26. app = Flask(__name__)
  27. @app.route('/cac', methods=['POST'])
  28. def cac():
  29. start_time = time.time()
  30. logging.info("into cac_interface cac")
  31. try:
  32. # 接收网络数据
  33. if not request.form:
  34. logging.info("cac no data!")
  35. return json.dumps({"data": "", "success": 0})
  36. data = request.form.get("data")
  37. logging.info("cac_interface get data time" + str(time.time()-start_time))
  38. # 加载模型
  39. cac_model = globals().get("global_cac_model")
  40. if cac_model is None:
  41. print("=========== init cac model ===========")
  42. cac_model = CacModels().get_model()
  43. globals().update({"global_cac_model": cac_model})
  44. # 数据转换
  45. data = base64_decode(data)
  46. image_np = bytes2np(data)
  47. # 预测
  48. result = classify(image_np, cac_model, sess)
  49. logging.info('cac result ' + str(result))
  50. if result is None:
  51. return json.dumps({"data": "", "success": 0})
  52. return json.dumps({"data": result, "success": 1})
  53. except:
  54. traceback.print_exc()
  55. return json.dumps({"data": "", "success": 0})
  56. finally:
  57. logging.info("cac interface finish time " + str(time.time()-start_time))
  58. class CacModels:
  59. def __init__(self):
  60. with sess.as_default():
  61. with sess.graph.as_default():
  62. self.model = cnn_net_tiny(input_shape=image_shape, output_shape=class_num)
  63. self.model.load_weights(model_path)
  64. def get_model(self):
  65. return self.model
  66. def test_cac_model(from_remote=True):
  67. paths = glob("D:/Project/captcha/data/chinese/1.jpg")
  68. for file_path in paths:
  69. img_np = cv2.imread(file_path)
  70. h, w = img_np.shape[:2]
  71. file_bytes = np2bytes(img_np)
  72. file_base64 = base64.b64encode(file_bytes)
  73. if from_remote:
  74. file_json = {"data": file_base64}
  75. # _url = "http://192.168.2.102:17061/cac"
  76. _url = "http://127.0.0.1:17062/cac"
  77. result = json.loads(request_post(_url, file_json))
  78. if result.get("success"):
  79. result = int(result.get("data"))
  80. cv2.imshow("img_np", img_np)
  81. print("classify result", result)
  82. cv2.waitKey(0)
  83. else:
  84. print("failed!")
  85. if __name__ == "__main__":
  86. # app.run(host='127.0.0.1', port=17062, debug=False)
  87. test_cac_model()