123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- import base64
- import json
- import logging
- import os
- import sys
- import time
- import traceback
- from glob import glob
- import numpy as np
- import cv2
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- import tensorflow as tf
- from flask import Flask, request
- from utils import np2bytes, request_post, bytes2np, get_anchors, get_classes, get_colors, base64_decode, rgba_to_rgb
- from puzzle_detect.inference_yolo_puzzle import get_tiny_inference_model, detect
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- tf.compat.v1.disable_eager_execution()
- sess = tf.compat.v1.Session(graph=tf.Graph())
- package_dir = os.path.abspath(os.path.dirname(__file__))
- model_path = package_dir + "/models/puzzle_yolo_loss_4.15.h5"
- anchors = get_anchors(package_dir + "/yolo_data/my_anchors_puzzle.txt")
- classes = get_classes(package_dir + "/yolo_data/my_classes_puzzle.txt")
- colors = get_colors(len(classes))
- # 接口配置
- app = Flask(__name__)
- @app.route('/pzd', methods=['POST'])
- def pzd():
- start_time = time.time()
- logging.info("into pzd_interface pzd")
- try:
- # 接收网络数据
- if not request.form:
- logging.info("pzd no data!")
- return json.dumps({"data": "", "success": 0})
- data = request.form.get("data")
- logging.info("pzd_interface get data time" + str(time.time()-start_time))
- # 加载模型
- pzd_model = globals().get("global_pzd_model")
- if pzd_model is None:
- print("=========== init pzd model ===========")
- pzd_model = PzdModels().get_model()
- globals().update({"global_pzd_model": pzd_model})
- # 数据转换
- # print("base64_data", data)
- data = base64_decode(data)
- image_np = bytes2np(data, unchanged=True)
- # 处理透明通道
- rgba_flag = 0
- position_list = []
- if image_np.shape[2] == 4:
- # 将透明部分置为黑色,转为3通道输出
- rgba_flag = 1
- image_np, position_list = rgba_to_rgb(image_np, return_position=True)
- # 预测
- _, out_boxes, out_classes = detect(image_np, pzd_model, sess)
- if len(out_boxes) == 0 and rgba_flag and position_list:
- logging.info('pzd 直接输出透明位置1')
- out_boxes = [position_list[0]]
- if len(out_boxes) >= 1 and rgba_flag and len(position_list) == 1:
- logging.info('pzd 直接输出透明位置2')
- out_boxes = [position_list[0]]
- return json.dumps({"data": out_boxes, "success": 1})
- except:
- traceback.print_exc()
- return json.dumps({"data": "", "success": 0})
- finally:
- logging.info("pzd interface finish time " + str(time.time()-start_time))
- class PzdModels:
- def __init__(self):
- # detect
- with sess.as_default():
- with sess.graph.as_default():
- self.model = get_tiny_inference_model(anchors, len(classes), weights_path=model_path)
- def get_model(self):
- return self.model
- def get_puzzle_tip_location(image_base64):
- """
- 获取提示拼图的位置高度
- :param image_base64:
- :return:
- """
- image_bytes = base64_decode(image_base64)
- image_np = bytes2np(image_bytes)
- # image_np = cv2.imread(r'C:\Users\Administrator\Desktop\test_capture\puzzle2.jpg')
- # print(image_np)
- # image_np = np.mean(image_np, axis=0)
- hs = np.where(image_np > 20)[0]
- h1 = hs[0]
- h2 = hs[-1]
- # cv2.line(image_np, (0, h1), (image_np.shape[1], h1), (0, 0, 255))
- # cv2.line(image_np, (0, h2), (image_np.shape[1], h2), (0, 0, 255))
- # print('hs', hs)
- # cv2.imshow('img', image_np)
- # cv2.waitKey(0)
- return [h1, h2]
- def test_pzd_model(from_remote=1, read_base64_file=0):
- paths = glob('puzzle4.png')
- for file_path in paths:
- if not read_base64_file:
- img_np = cv2.imread(file_path)
- file_bytes = np2bytes(img_np)
- file_base64 = base64.b64encode(file_bytes)
- else:
- with open('base64.txt', 'r') as f:
- file_base64 = f.read()
- file_bytes = base64_decode(file_base64)
- img_np = bytes2np(file_bytes, unchanged=True)
- img_np = rgba_to_rgb(img_np)
- # cv2.imshow('img', img_np)
- # cv2.waitKey(0)
- if from_remote:
- _url = "http://192.168.2.103:17059/pzd"
- else:
- _url = "http://127.0.0.1:17059/pzd"
- file_json = {"data": file_base64}
- result = json.loads(request_post(_url, file_json))
- if result.get("success"):
- out_boxes = result.get("data")
- print("out_boxes", out_boxes)
- for box in out_boxes:
- cv2.rectangle(img_np, (box[0], box[1]), (box[2], box[3]), (0, 0, 255))
- cv2.imshow("img_np", img_np)
- cv2.waitKey(0)
- else:
- print("failed!")
- if __name__ == "__main__":
- # app.run(host='127.0.0.1', port=17059, debug=False)
- test_pzd_model()
- # get_puzzle_tip_location(None)
|