idc_interface.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import base64
  2. import copy
  3. import json
  4. import os
  5. import time
  6. import sys
  7. import traceback
  8. from glob import glob
  9. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  10. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  11. from format_convert.max_compute_config import max_compute
  12. import tensorflow as tf
  13. MAX_COMPUTE = max_compute
  14. if not MAX_COMPUTE:
  15. # tensorflow 内存设置
  16. try:
  17. gpus = tf.config.list_physical_devices('GPU')
  18. if len(gpus) > 0:
  19. tf.config.experimental.set_virtual_device_configuration(
  20. gpus[0],
  21. [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)])
  22. except:
  23. traceback.print_exc()
  24. # pass
  25. # gpus = tf.config.list_physical_devices('GPU')
  26. # for gpu in gpus: # 如果使用多块GPU时
  27. # tf.config.experimental.set_memory_growth(gpu, True)
  28. os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
  29. os.environ['CUDA_CACHE_DISABLE'] = str(0)
  30. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05)
  31. sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
  32. from format_convert import _global
  33. import cv2
  34. import numpy as np
  35. from PIL import Image
  36. from idc.model import direction_model
  37. from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, pil_resize, np2bytes
  38. import tensorflow as tf
  39. from flask import Flask, request
  40. from idc.pre_process import get_text_region, get_best_predict_size2
  41. tf.compat.v1.disable_eager_execution()
  42. sess = tf.compat.v1.Session(graph=tf.Graph())
  43. # image_shape = (192, 192)
  44. image_shape = (640, 640)
  45. def adjust_direction(image_np, model, if_return_angle=False):
  46. # 4个方向
  47. cls_num = 4
  48. # 构建数据
  49. origin_image = copy.deepcopy(image_np)
  50. # image_np = pil_resize(image_np, image_shape[0], image_shape[1])
  51. # 获取合适的文字区域
  52. image_np = get_text_region(image_np, image_shape)
  53. # cv2.imshow("get_text_region", image_np)
  54. # cv2.waitKey(0)
  55. # print(type(image_np))
  56. if type(image_np) != np.ndarray:
  57. return None
  58. if len(image_np.shape) < 3:
  59. image_np = np.expand_dims(image_np, axis=-1)
  60. if image_np.shape[0] != image_shape[0] or image_np.shape[1] != image_shape[1]:
  61. image_np = pil_resize(image_np, image_shape[0], image_shape[1])
  62. image_np = np.expand_dims(image_np[:, :, 0], axis=-1)
  63. X = np.expand_dims(np.array(image_np), 0)
  64. # 预测
  65. with sess.as_default():
  66. with sess.graph.as_default():
  67. pred = model.predict(X, batch_size=1)
  68. pred = pred.astype(np.float64)
  69. pred = np.argmax(pred[0])
  70. # 根据分类计算角度
  71. angle = int(360 - pred*int((360/cls_num)))
  72. if if_return_angle:
  73. return angle
  74. else:
  75. if angle not in [0, 360]:
  76. # 根据角度旋转
  77. image_pil = Image.fromarray(origin_image)
  78. image_rotate = np.array(image_pil.rotate(angle, expand=1))
  79. else:
  80. image_rotate = origin_image
  81. return image_rotate
  82. def idc(data, model):
  83. log("into idc_interface idc")
  84. try:
  85. # start_time = time.time()
  86. img_data = base64.b64decode(data)
  87. img_np = bytes2np(img_data)
  88. angle = adjust_direction(img_np, model, if_return_angle=True)
  89. if angle is None:
  90. angle = 0
  91. # print(time.time()-start_time)
  92. log("idc angle " + str(angle))
  93. return {"angle": angle}
  94. except TimeoutError:
  95. return {"angle": [-5]}
  96. except:
  97. traceback.print_exc()
  98. return {"angle": [-1]}
  99. # 接口配置
  100. app = Flask(__name__)
  101. @app.route('/idc', methods=['POST'])
  102. def _idc():
  103. _global._init()
  104. _global.update({"port": globals().get("port")})
  105. start_time = time.time()
  106. log("into idc_interface _idc")
  107. try:
  108. if not request.form:
  109. log("idc no data!")
  110. return json.dumps({"angle": str([-9])})
  111. data = request.form.get("data")
  112. log("idc_interface get data time" + str(time.time()-start_time))
  113. _md5 = request.form.get("md5")
  114. _global.update({"md5": _md5})
  115. idc_model = globals().get("global_idc_model")
  116. if idc_model is None:
  117. print("=========== init idc model ===========")
  118. idc_model = IdcModels().get_model()
  119. globals().update({"global_idc_model": idc_model})
  120. angle = idc(data, idc_model).get("angle")
  121. return json.dumps({"angle": angle})
  122. except TimeoutError:
  123. return json.dumps({"angle": str([-5])})
  124. except:
  125. traceback.print_exc()
  126. return json.dumps({"angle": str([-1])})
  127. finally:
  128. log("idc interface finish time " + str(time.time()-start_time))
  129. class IdcModels:
  130. def __init__(self):
  131. # python文件所在目录
  132. _dir = os.path.abspath(os.path.dirname(__file__))
  133. # detect
  134. model_path = _dir + "/models/e484-f10.96.h5"
  135. with sess.as_default():
  136. with sess.graph.as_default():
  137. self.model = direction_model(input_shape=(image_shape[0], image_shape[1], 1),
  138. output_shape=4)
  139. self.model.load_weights(model_path)
  140. def get_model(self):
  141. return self.model
  142. def test_idc_model(from_remote=False):
  143. idc_model = IdcModels().get_model()
  144. paths = glob("C:/Users/Administrator/Desktop/test_image/error43.png")
  145. # file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg"
  146. for file_path in paths:
  147. img_np = cv2.imread(file_path)
  148. # img_np = pil_resize(img_np, 640, 640)
  149. h, w = get_best_predict_size2(img_np, threshold=1080)
  150. img_np = pil_resize(img_np, h, w)
  151. # print(img_np.shape)
  152. file_bytes = np2bytes(img_np)
  153. file_base64 = base64.b64encode(file_bytes)
  154. _md5 = get_md5_from_bytes(file_bytes)[0]
  155. _global._init()
  156. _global.update({"port": 15010, "md5": _md5})
  157. if from_remote:
  158. file_json = {"data": file_base64, "md5": _md5}
  159. # _url = "http://192.168.2.102:17000/ocr"
  160. _url = "http://127.0.0.1:17000/ocr"
  161. print(json.loads(request_post(_url, file_json)))
  162. else:
  163. result = idc(file_base64, idc_model)
  164. # print(result)
  165. if type(result.get("angle")) == list:
  166. print(result)
  167. else:
  168. angle = result.get("angle")
  169. img = Image.fromarray(img_np)
  170. img = np.array(img.rotate(angle, expand=1))
  171. print("angle", angle)
  172. print(img.shape)
  173. cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
  174. cv2.imshow("img", img)
  175. cv2.waitKey(0)
  176. # print(result)
  177. if __name__ == "__main__":
  178. test_idc_model()