idc_interface.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import base64
  2. import copy
  3. import json
  4. import os
  5. import time
  6. import sys
  7. import traceback
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  9. from format_convert import _global
  10. import cv2
  11. import numpy as np
  12. from PIL import Image
  13. from idc.model import direction_model
  14. from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, pil_resize
  15. import tensorflow as tf
  16. sess = tf.compat.v1.Session(graph=tf.Graph())
  17. image_shape = (640, 640)
  18. def adjust_direction(image_np, model):
  19. # 4个方向
  20. cls_num = 4
  21. # 构建数据
  22. origin_image = copy.deepcopy(image_np)
  23. image_np = pil_resize(image_np, image_shape[0], image_shape[1])
  24. X = np.expand_dims(np.array(image_np), 0)
  25. # 预测
  26. with sess.as_default():
  27. with sess.graph.as_default():
  28. pred = model.predict(X)
  29. pred = pred.astype(np.float64)
  30. pred = np.argmax(pred[0])
  31. # 根据分类计算角度
  32. angle = 360 - pred*int((360/cls_num))
  33. # 根据角度旋转
  34. image_pil = Image.fromarray(origin_image)
  35. image_rotate = np.array(image_pil.rotate(angle, expand=1))
  36. return image_rotate
  37. def idc(data, model):
  38. log("into idc_interface isr")
  39. try:
  40. # start_time = time.time()
  41. img_data = base64.b64decode(data)
  42. img_np = bytes2np(img_data)
  43. image_rotate = adjust_direction(img_np, model)
  44. # print(time.time()-start_time)
  45. return {"image": image_rotate}
  46. except TimeoutError:
  47. return {"image": [-5]}
  48. except:
  49. traceback.print_exc()
  50. return {"image": [-1]}
  51. class IdcModels:
  52. def __init__(self):
  53. # python文件所在目录
  54. _dir = os.path.abspath(os.path.dirname(__file__))
  55. # detect
  56. model_path = _dir + "/models/model.h5"
  57. with sess.as_default():
  58. with sess.graph.as_default():
  59. self.model = direction_model(input_shape=(image_shape[0], image_shape[1], 3),
  60. output_shape=4)
  61. self.model.load_weights(model_path)
  62. def get_model(self):
  63. return self.model
  64. def test_idc_model(from_remote=False):
  65. file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg"
  66. with open(file_path, "rb") as f:
  67. file_bytes = f.read()
  68. img_np = cv2.imread(file_path)
  69. print(img_np.shape)
  70. file_base64 = base64.b64encode(file_bytes)
  71. _md5 = get_md5_from_bytes(file_bytes)[0]
  72. _global._init()
  73. _global.update({"port": 15010, "md5": _md5})
  74. if from_remote:
  75. file_json = {"data": file_base64, "md5": _md5}
  76. # _url = "http://192.168.2.102:17000/ocr"
  77. _url = "http://127.0.0.1:17000/ocr"
  78. print(json.loads(request_post(_url, file_json)))
  79. else:
  80. idc_model = IdcModels().get_model()
  81. result = idc(file_base64, idc_model)
  82. # print(result)
  83. if type(result.get("image")) == list:
  84. print(result)
  85. else:
  86. img = result.get("image")
  87. print(img.shape)
  88. cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
  89. cv2.imshow("img", img)
  90. cv2.waitKey(0)
  91. # print(result)
  92. if __name__ == "__main__":
  93. test_idc_model()