bdr_interface.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. import time
  6. import traceback
  7. from glob import glob
  8. import cv2
  9. import numpy as np
  10. from PIL import Image
  11. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  12. import tensorflow as tf
  13. from flask import Flask, request
  14. from border_recognize.model import u_net_drag
  15. from border_recognize.inference_drag import recognize
  16. from utils import pil_resize, np2bytes, request_post, bytes2np
  17. tf.compat.v1.disable_eager_execution()
  18. sess = tf.compat.v1.Session(graph=tf.Graph())
  19. # 接口配置
  20. app = Flask(__name__)
  21. @app.route('/bdr', methods=['POST'])
  22. def bdr():
  23. start_time = time.time()
  24. logging.info("into bdr_interface bdr")
  25. try:
  26. # 接收网络数据
  27. if not request.form:
  28. logging.info("bdr no data!")
  29. return json.dumps({"data": "", "success": 0})
  30. data = request.form.get("data")
  31. logging.info("bdr_interface get data time" + str(time.time()-start_time))
  32. # 加载模型
  33. bdr_model = globals().get("global_bdr_model")
  34. if bdr_model is None:
  35. print("=========== init bdr model ===========")
  36. bdr_model = BdrModels().get_model()
  37. globals().update({"global_bdr_model": bdr_model})
  38. # 数据转换
  39. data = base64.b64decode(data)
  40. image_np = bytes2np(data)
  41. # 预测
  42. w = recognize(image_np, bdr_model, sess)
  43. return json.dumps({"data": w, "success": 1})
  44. except:
  45. traceback.print_exc()
  46. return json.dumps({"data": "", "success": 0})
  47. finally:
  48. logging.info("bdr interface finish time " + str(time.time()-start_time))
  49. class BdrModels:
  50. def __init__(self):
  51. # python文件所在目录
  52. _dir = os.path.abspath(os.path.dirname(__file__))
  53. # detect
  54. model_path = _dir + "/models/drag_f1_0.42.h5"
  55. image_shape = (128, 256, 3)
  56. with sess.as_default():
  57. with sess.graph.as_default():
  58. self.model = u_net_drag(input_shape=image_shape)
  59. self.model.load_weights(model_path)
  60. def get_model(self):
  61. return self.model
  62. def test_bdr_model(from_remote=True):
  63. paths = glob("D:/Project/captcha/data/test/yolo_18.jpg")
  64. for file_path in paths:
  65. img_np = cv2.imread(file_path)
  66. h, w = img_np.shape[:2]
  67. file_bytes = np2bytes(img_np)
  68. file_base64 = base64.b64encode(file_bytes)
  69. if from_remote:
  70. file_json = {"data": file_base64}
  71. # _url = "http://192.168.2.102:17000/ocr"
  72. _url = "http://127.0.0.1:17000/bdr"
  73. result = json.loads(request_post(_url, file_json))
  74. if result.get("success"):
  75. w = int(result.get("data"))
  76. print("w", w)
  77. img_new = np.concatenate([img_np[:, w:, :], img_np[:, :w, :]], axis=1)
  78. cv2.imshow("img_np", img_np)
  79. cv2.imshow("img_new", img_new)
  80. cv2.waitKey(0)
  81. else:
  82. print("failed!")
  83. if __name__ == "__main__":
  84. # app.run(host='127.0.0.1', port=17000, debug=False)
  85. test_bdr_model()