cer_interface_torch.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. import sys
  6. import time
  7. import traceback
  8. from glob import glob
  9. import cv2
  10. # 只导入torch,protobuf会报错。需先导入TensorFlow再导入torch
  11. import tensorflow
  12. import torch
  13. from flask import Flask, request
  14. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  15. from chinese_equation_recognize.inference_equation_torch import recognize
  16. from model_torch import crnn_ctc_equation_torch6
  17. from utils import pil_resize, np2bytes, request_post, bytes2np, base64_decode, image_to_str, str_to_image
  18. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  19. package_dir = os.path.abspath(os.path.dirname(__file__))
  20. model_path = package_dir + "/models/equation6_model_acc-0.853.pth"
  21. image_shape = (32, 192, 3)
  22. # 接口配置
  23. app = Flask(__name__)
  24. @app.route('/cer', methods=['POST'])
  25. def cer():
  26. start_time = time.time()
  27. logging.info("into cer_interface cer")
  28. try:
  29. # 接收网络数据
  30. if not request.form:
  31. logging.info("cer no data!")
  32. return json.dumps({"data": "", "success": 0})
  33. data = request.form.get("data")
  34. logging.info("cer_interface get data time" + str(time.time()-start_time))
  35. # 加载模型
  36. cer_model = globals().get("global_cer_model")
  37. if cer_model is None:
  38. print("=========== init cer model ===========")
  39. cer_model = CerModels().get_model()
  40. globals().update({"global_cer_model": cer_model})
  41. # 数据转换
  42. data = base64_decode(data)
  43. image_np = bytes2np(data)
  44. # 预测
  45. result = recognize(image_np, cer_model)
  46. if result is None:
  47. return json.dumps({"data": "", "success": 0})
  48. return json.dumps({"data": result, "success": 1})
  49. except:
  50. traceback.print_exc()
  51. return json.dumps({"data": "", "success": 0})
  52. finally:
  53. logging.info("cer interface finish time " + str(time.time()-start_time))
  54. class CerModels:
  55. def __init__(self):
  56. device = torch.device("cpu")
  57. class_num = 35 + 1
  58. self.model = crnn_ctc_equation_torch6(class_num)
  59. self.model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
  60. self.model.eval()
  61. def get_model(self):
  62. return self.model
  63. def test_cer_model(from_remote=True):
  64. paths = glob("D:/Project/captcha/data/test/FileInfo1021/1d419189-5116-11ed-851c-b4b5b67760ae_7.jpg")
  65. paths = glob(r'C:\Users\Administrator\Downloads\default.jfif')
  66. for file_path in paths:
  67. img_np = cv2.imread(file_path)
  68. h, w = img_np.shape[:2]
  69. file_bytes = np2bytes(img_np)
  70. file_base64 = base64.b64encode(file_bytes)
  71. if from_remote:
  72. file_json = {"data": file_base64}
  73. # _url = "http://192.168.2.102:17061/cer"
  74. _url = "http://127.0.0.1:17061/cer"
  75. result = json.loads(request_post(_url, file_json))
  76. if result.get("success"):
  77. result = int(result.get("data"))
  78. cv2.imshow("img_np", img_np)
  79. print("equation result", result)
  80. cv2.waitKey(0)
  81. else:
  82. print("failed!")
  83. if __name__ == "__main__":
  84. # app.run(host='127.0.0.1', port=17061, debug=False)
  85. test_cer_model()