pzd_interface.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. import sys
  6. import time
  7. import traceback
  8. from glob import glob
  9. import numpy as np
  10. import cv2
  11. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  13. import tensorflow as tf
  14. from flask import Flask, request
  15. from utils import np2bytes, request_post, bytes2np, get_anchors, get_classes, get_colors, base64_decode, rgba_to_rgb
  16. from puzzle_detect.inference_yolo_puzzle import get_tiny_inference_model, detect
  17. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  18. tf.compat.v1.disable_eager_execution()
  19. sess = tf.compat.v1.Session(graph=tf.Graph())
  20. package_dir = os.path.abspath(os.path.dirname(__file__))
  21. model_path = package_dir + "/models/puzzle_yolo_loss_4.15.h5"
  22. anchors = get_anchors(package_dir + "/yolo_data/my_anchors_puzzle.txt")
  23. classes = get_classes(package_dir + "/yolo_data/my_classes_puzzle.txt")
  24. colors = get_colors(len(classes))
  25. # 接口配置
  26. app = Flask(__name__)
  27. @app.route('/pzd', methods=['POST'])
  28. def pzd():
  29. start_time = time.time()
  30. logging.info("into pzd_interface pzd")
  31. try:
  32. # 接收网络数据
  33. if not request.form:
  34. logging.info("pzd no data!")
  35. return json.dumps({"data": "", "success": 0})
  36. data = request.form.get("data")
  37. logging.info("pzd_interface get data time" + str(time.time()-start_time))
  38. # 加载模型
  39. pzd_model = globals().get("global_pzd_model")
  40. if pzd_model is None:
  41. print("=========== init pzd model ===========")
  42. pzd_model = PzdModels().get_model()
  43. globals().update({"global_pzd_model": pzd_model})
  44. # 数据转换
  45. # print("base64_data", data)
  46. data = base64_decode(data)
  47. image_np = bytes2np(data, unchanged=True)
  48. # 处理透明通道
  49. rgba_flag = 0
  50. position_list = []
  51. if image_np.shape[2] == 4:
  52. # 将透明部分置为黑色,转为3通道输出
  53. rgba_flag = 1
  54. image_np, position_list = rgba_to_rgb(image_np, return_position=True)
  55. # 预测
  56. _, out_boxes, out_classes = detect(image_np, pzd_model, sess)
  57. if len(out_boxes) == 0 and rgba_flag and position_list:
  58. logging.info('pzd 直接输出透明位置1')
  59. out_boxes = [position_list[0]]
  60. if len(out_boxes) >= 1 and rgba_flag and len(position_list) == 1:
  61. logging.info('pzd 直接输出透明位置2')
  62. out_boxes = [position_list[0]]
  63. return json.dumps({"data": out_boxes, "success": 1})
  64. except:
  65. traceback.print_exc()
  66. return json.dumps({"data": "", "success": 0})
  67. finally:
  68. logging.info("pzd interface finish time " + str(time.time()-start_time))
  69. class PzdModels:
  70. def __init__(self):
  71. # detect
  72. with sess.as_default():
  73. with sess.graph.as_default():
  74. self.model = get_tiny_inference_model(anchors, len(classes), weights_path=model_path)
  75. def get_model(self):
  76. return self.model
  77. def get_puzzle_tip_location(image_base64):
  78. """
  79. 获取提示拼图的位置高度
  80. :param image_base64:
  81. :return:
  82. """
  83. image_bytes = base64_decode(image_base64)
  84. image_np = bytes2np(image_bytes)
  85. # image_np = cv2.imread(r'C:\Users\Administrator\Desktop\test_capture\puzzle2.jpg')
  86. # print(image_np)
  87. # image_np = np.mean(image_np, axis=0)
  88. hs = np.where(image_np > 20)[0]
  89. h1 = hs[0]
  90. h2 = hs[-1]
  91. # cv2.line(image_np, (0, h1), (image_np.shape[1], h1), (0, 0, 255))
  92. # cv2.line(image_np, (0, h2), (image_np.shape[1], h2), (0, 0, 255))
  93. # print('hs', hs)
  94. # cv2.imshow('img', image_np)
  95. # cv2.waitKey(0)
  96. return [h1, h2]
  97. def test_pzd_model(from_remote=1, read_base64_file=0):
  98. paths = glob('puzzle4.png')
  99. for file_path in paths:
  100. if not read_base64_file:
  101. img_np = cv2.imread(file_path)
  102. file_bytes = np2bytes(img_np)
  103. file_base64 = base64.b64encode(file_bytes)
  104. else:
  105. with open('base64.txt', 'r') as f:
  106. file_base64 = f.read()
  107. file_bytes = base64_decode(file_base64)
  108. img_np = bytes2np(file_bytes, unchanged=True)
  109. img_np = rgba_to_rgb(img_np)
  110. # cv2.imshow('img', img_np)
  111. # cv2.waitKey(0)
  112. if from_remote:
  113. _url = "http://192.168.2.103:17059/pzd"
  114. else:
  115. _url = "http://127.0.0.1:17059/pzd"
  116. file_json = {"data": file_base64}
  117. result = json.loads(request_post(_url, file_json))
  118. if result.get("success"):
  119. out_boxes = result.get("data")
  120. print("out_boxes", out_boxes)
  121. for box in out_boxes:
  122. cv2.rectangle(img_np, (box[0], box[1]), (box[2], box[3]), (0, 0, 255))
  123. cv2.imshow("img_np", img_np)
  124. cv2.waitKey(0)
  125. else:
  126. print("failed!")
  127. if __name__ == "__main__":
  128. # app.run(host='127.0.0.1', port=17059, debug=False)
  129. test_pzd_model()
  130. # get_puzzle_tip_location(None)