bdr_interface.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. import sys
  6. import time
  7. import traceback
  8. from glob import glob
  9. import cv2
  10. import numpy as np
  11. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  13. import tensorflow as tf
  14. from flask import Flask, request
  15. from border_recognize.model import u_net_drag
  16. from border_recognize.inference_drag import recognize
  17. from utils import pil_resize, np2bytes, request_post, bytes2np
  18. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  19. tf.compat.v1.disable_eager_execution()
  20. sess = tf.compat.v1.Session(graph=tf.Graph())
  21. package_dir = os.path.abspath(os.path.dirname(__file__))
  22. model_path = package_dir + "/models/drag_f1_0.42.h5"
  23. image_shape = (128, 256, 3)
  24. # 接口配置
  25. app = Flask(__name__)
  26. @app.route('/bdr', methods=['POST'])
  27. def bdr():
  28. start_time = time.time()
  29. logging.info("into bdr_interface bdr")
  30. try:
  31. # 接收网络数据
  32. if not request.form:
  33. logging.info("bdr no data!")
  34. return json.dumps({"data": "", "success": 0})
  35. data = request.form.get("data")
  36. logging.info("bdr_interface get data time" + str(time.time()-start_time))
  37. # 加载模型
  38. bdr_model = globals().get("global_bdr_model")
  39. if bdr_model is None:
  40. print("=========== init bdr model ===========")
  41. bdr_model = BdrModels().get_model()
  42. globals().update({"global_bdr_model": bdr_model})
  43. # 数据转换
  44. data = base64.b64decode(data)
  45. image_np = bytes2np(data)
  46. # 预测
  47. w = recognize(image_np, bdr_model, sess)
  48. return json.dumps({"data": w, "success": 1})
  49. except:
  50. traceback.print_exc()
  51. return json.dumps({"data": "", "success": 0})
  52. finally:
  53. logging.info("bdr interface finish time " + str(time.time()-start_time))
  54. class BdrModels:
  55. def __init__(self):
  56. with sess.as_default():
  57. with sess.graph.as_default():
  58. self.model = u_net_drag(input_shape=image_shape)
  59. self.model.load_weights(model_path)
  60. def get_model(self):
  61. return self.model
  62. def test_bdr_model(from_remote=True):
  63. paths = glob("D:/Project/captcha/data/test/yolo_18.jpg")
  64. for file_path in paths:
  65. img_np = cv2.imread(file_path)
  66. h, w = img_np.shape[:2]
  67. file_bytes = np2bytes(img_np)
  68. file_base64 = base64.b64encode(file_bytes)
  69. if from_remote:
  70. file_json = {"data": file_base64}
  71. # _url = "http://192.168.2.102:17000/ocr"
  72. _url = "http://127.0.0.1:17000/bdr"
  73. result = json.loads(request_post(_url, file_json))
  74. if result.get("success"):
  75. w = int(result.get("data"))
  76. print("w", w)
  77. img_new = np.concatenate([img_np[:, w:, :], img_np[:, :w, :]], axis=1)
  78. cv2.imshow("img_np", img_np)
  79. cv2.imshow("img_new", img_new)
  80. cv2.waitKey(0)
  81. else:
  82. print("failed!")
  83. if __name__ == "__main__":
  84. app.run(host='127.0.0.1', port=17055, debug=False)
  85. # test_bdr_model()