pzd_interface.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. import sys
  6. import time
  7. import traceback
  8. from glob import glob
  9. import cv2
  10. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  11. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  12. import tensorflow as tf
  13. from flask import Flask, request
  14. from utils import np2bytes, request_post, bytes2np, get_anchors, get_classes, get_colors
  15. from puzzle_detect.inference_yolo_puzzle import get_tiny_inference_model, detect
  16. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  17. tf.compat.v1.disable_eager_execution()
  18. sess = tf.compat.v1.Session(graph=tf.Graph())
  19. package_dir = os.path.abspath(os.path.dirname(__file__))
  20. model_path = package_dir + "/models/puzzle_yolo_loss_4.15.h5"
  21. anchors = get_anchors(package_dir + "/yolo_data/my_anchors_puzzle.txt")
  22. classes = get_classes(package_dir + "/yolo_data/my_classes_puzzle.txt")
  23. colors = get_colors(len(classes))
  24. # 接口配置
  25. app = Flask(__name__)
  26. @app.route('/pzd', methods=['POST'])
  27. def pzd():
  28. start_time = time.time()
  29. logging.info("into pzd_interface pzd")
  30. try:
  31. # 接收网络数据
  32. if not request.form:
  33. logging.info("pzd no data!")
  34. return json.dumps({"data": "", "success": 0})
  35. data = request.form.get("data")
  36. logging.info("pzd_interface get data time" + str(time.time()-start_time))
  37. # 加载模型
  38. pzd_model = globals().get("global_pzd_model")
  39. if pzd_model is None:
  40. print("=========== init pzd model ===========")
  41. pzd_model = PzdModels().get_model()
  42. globals().update({"global_pzd_model": pzd_model})
  43. # 数据转换
  44. data = base64.b64decode(data)
  45. image_np = bytes2np(data)
  46. # 预测
  47. _, out_boxes, out_classes = detect(image_np, pzd_model, sess)
  48. return json.dumps({"data": out_boxes, "success": 1})
  49. except:
  50. traceback.print_exc()
  51. return json.dumps({"data": "", "success": 0})
  52. finally:
  53. logging.info("pzd interface finish time " + str(time.time()-start_time))
  54. class PzdModels:
  55. def __init__(self):
  56. # detect
  57. with sess.as_default():
  58. with sess.graph.as_default():
  59. self.model = get_tiny_inference_model(anchors, len(classes), weights_path=model_path)
  60. def get_model(self):
  61. return self.model
  62. def test_pzd_model(from_remote=True):
  63. paths = glob("D:/Project/captcha/data/test/yolo_12.jpg")
  64. for file_path in paths:
  65. img_np = cv2.imread(file_path)
  66. file_bytes = np2bytes(img_np)
  67. file_base64 = base64.b64encode(file_bytes)
  68. if from_remote:
  69. file_json = {"data": file_base64}
  70. # _url = "http://192.168.2.102:17000/ocr"
  71. _url = "http://127.0.0.1:17000/pzd"
  72. result = json.loads(request_post(_url, file_json))
  73. if result.get("success"):
  74. out_boxes = result.get("data")
  75. print("out_boxes", out_boxes)
  76. for box in out_boxes:
  77. cv2.rectangle(img_np, (box[0], box[1]), (box[2], box[3]), (0, 0, 255))
  78. cv2.imshow("img_np", img_np)
  79. cv2.waitKey(0)
  80. else:
  81. print("failed!")
  82. if __name__ == "__main__":
  83. # app.run(host='127.0.0.1', port=17000, debug=False)
  84. test_pzd_model()