import base64 import copy import json import os import time import sys import traceback sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") from format_convert import _global import cv2 import numpy as np from PIL import Image from idc.model import direction_model from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, pil_resize import tensorflow as tf sess = tf.compat.v1.Session(graph=tf.Graph()) image_shape = (640, 640) def adjust_direction(image_np, model): # 4个方向 cls_num = 4 # 构建数据 origin_image = copy.deepcopy(image_np) image_np = pil_resize(image_np, image_shape[0], image_shape[1]) X = np.expand_dims(np.array(image_np), 0) # 预测 with sess.as_default(): with sess.graph.as_default(): pred = model.predict(X) pred = pred.astype(np.float64) pred = np.argmax(pred[0]) # 根据分类计算角度 angle = 360 - pred*int((360/cls_num)) # 根据角度旋转 image_pil = Image.fromarray(origin_image) image_rotate = np.array(image_pil.rotate(angle, expand=1)) return image_rotate def idc(data, model): log("into idc_interface isr") try: # start_time = time.time() img_data = base64.b64decode(data) img_np = bytes2np(img_data) image_rotate = adjust_direction(img_np, model) # print(time.time()-start_time) return {"image": image_rotate} except TimeoutError: return {"image": [-5]} except: traceback.print_exc() return {"image": [-1]} class IdcModels: def __init__(self): # python文件所在目录 _dir = os.path.abspath(os.path.dirname(__file__)) # detect model_path = _dir + "/models/model.h5" with sess.as_default(): with sess.graph.as_default(): self.model = direction_model(input_shape=(image_shape[0], image_shape[1], 3), output_shape=4) self.model.load_weights(model_path) def get_model(self): return self.model def test_idc_model(from_remote=False): file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg" with open(file_path, "rb") as f: file_bytes = f.read() img_np = cv2.imread(file_path) print(img_np.shape) file_base64 = base64.b64encode(file_bytes) _md5 = get_md5_from_bytes(file_bytes)[0] _global._init() _global.update({"port": 15010, "md5": _md5}) if from_remote: file_json = {"data": file_base64, "md5": _md5} # _url = "http://192.168.2.102:17000/ocr" _url = "http://127.0.0.1:17000/ocr" print(json.loads(request_post(_url, file_json))) else: idc_model = IdcModels().get_model() result = idc(file_base64, idc_model) # print(result) if type(result.get("image")) == list: print(result) else: img = result.get("image") print(img.shape) cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) cv2.imshow("img", img) cv2.waitKey(0) # print(result) if __name__ == "__main__": test_idc_model()