import base64 import json import logging import os import sys import time import traceback from glob import glob import numpy as np import cv2 os.environ["CUDA_VISIBLE_DEVICES"] = "-1" sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") import tensorflow as tf from flask import Flask, request from utils import np2bytes, request_post, bytes2np, get_anchors, get_classes, get_colors, base64_decode, rgba_to_rgb from puzzle_detect.inference_yolo_puzzle import get_tiny_inference_model, detect logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') tf.compat.v1.disable_eager_execution() sess = tf.compat.v1.Session(graph=tf.Graph()) package_dir = os.path.abspath(os.path.dirname(__file__)) model_path = package_dir + "/models/puzzle_yolo_loss_4.15.h5" anchors = get_anchors(package_dir + "/yolo_data/my_anchors_puzzle.txt") classes = get_classes(package_dir + "/yolo_data/my_classes_puzzle.txt") colors = get_colors(len(classes)) # 接口配置 app = Flask(__name__) @app.route('/pzd', methods=['POST']) def pzd(): start_time = time.time() logging.info("into pzd_interface pzd") try: # 接收网络数据 if not request.form: logging.info("pzd no data!") return json.dumps({"data": "", "success": 0}) data = request.form.get("data") logging.info("pzd_interface get data time" + str(time.time()-start_time)) # 加载模型 pzd_model = globals().get("global_pzd_model") if pzd_model is None: print("=========== init pzd model ===========") pzd_model = PzdModels().get_model() globals().update({"global_pzd_model": pzd_model}) # 数据转换 # print("base64_data", data) data = base64_decode(data) image_np = bytes2np(data, unchanged=True) # 处理透明通道 rgba_flag = 0 position_list = [] if image_np.shape[2] == 4: # 将透明部分置为黑色,转为3通道输出 rgba_flag = 1 image_np, position_list = rgba_to_rgb(image_np, return_position=True) # 预测 _, out_boxes, out_classes = detect(image_np, pzd_model, sess) if len(out_boxes) == 0 and rgba_flag and position_list: logging.info('pzd 直接输出透明位置1') out_boxes = [position_list[0]] if len(out_boxes) >= 1 and rgba_flag and len(position_list) == 1: logging.info('pzd 直接输出透明位置2') out_boxes = [position_list[0]] return json.dumps({"data": out_boxes, "success": 1}) except: traceback.print_exc() return json.dumps({"data": "", "success": 0}) finally: logging.info("pzd interface finish time " + str(time.time()-start_time)) class PzdModels: def __init__(self): # detect with sess.as_default(): with sess.graph.as_default(): self.model = get_tiny_inference_model(anchors, len(classes), weights_path=model_path) def get_model(self): return self.model def get_puzzle_tip_location(image_base64): """ 获取提示拼图的位置高度 :param image_base64: :return: """ image_bytes = base64_decode(image_base64) image_np = bytes2np(image_bytes) # image_np = cv2.imread(r'C:\Users\Administrator\Desktop\test_capture\puzzle2.jpg') # print(image_np) # image_np = np.mean(image_np, axis=0) hs = np.where(image_np > 20)[0] h1 = hs[0] h2 = hs[-1] # cv2.line(image_np, (0, h1), (image_np.shape[1], h1), (0, 0, 255)) # cv2.line(image_np, (0, h2), (image_np.shape[1], h2), (0, 0, 255)) # print('hs', hs) # cv2.imshow('img', image_np) # cv2.waitKey(0) return [h1, h2] def test_pzd_model(from_remote=1, read_base64_file=0): paths = glob('puzzle4.png') for file_path in paths: if not read_base64_file: img_np = cv2.imread(file_path) file_bytes = np2bytes(img_np) file_base64 = base64.b64encode(file_bytes) else: with open('base64.txt', 'r') as f: file_base64 = f.read() file_bytes = base64_decode(file_base64) img_np = bytes2np(file_bytes, unchanged=True) img_np = rgba_to_rgb(img_np) # cv2.imshow('img', img_np) # cv2.waitKey(0) if from_remote: _url = "http://192.168.2.103:17059/pzd" else: _url = "http://127.0.0.1:17059/pzd" file_json = {"data": file_base64} result = json.loads(request_post(_url, file_json)) if result.get("success"): out_boxes = result.get("data") print("out_boxes", out_boxes) for box in out_boxes: cv2.rectangle(img_np, (box[0], box[1]), (box[2], box[3]), (0, 0, 255)) cv2.imshow("img_np", img_np) cv2.waitKey(0) else: print("failed!") if __name__ == "__main__": # app.run(host='127.0.0.1', port=17059, debug=False) test_pzd_model() # get_puzzle_tip_location(None)