yolo_interface.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import base64
  2. import json
  3. import os
  4. import sys
  5. import traceback
  6. import torch
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
  8. from botr.yolov8.model import Predictor
  9. from botr.yolov8.predict import detect
  10. from format_convert.max_compute_config import max_compute
  11. MAX_COMPUTE = max_compute
  12. import time
  13. import cv2
  14. from flask import Flask, request
  15. from format_convert.utils import request_post, log, get_md5_from_bytes, get_platform, bytes2np
  16. from format_convert import _global
  17. ROOT = os.path.abspath(os.path.dirname(__file__)) + '/../../'
  18. model_path = ROOT + 'botr/yolov8/weights.pt'
  19. # 接口配置
  20. app = Flask(__name__)
  21. @app.route('/yolo', methods=['POST'])
  22. def _yolo():
  23. _global._init()
  24. _global.update({"port": globals().get("port")})
  25. start_time = time.time()
  26. log("into yolo_interface _yolo")
  27. try:
  28. if not request.form:
  29. log("yolo no data!")
  30. return json.dumps({"b_table_list": str([-9])})
  31. yolo_predictor = globals().get("global_yolo_predictor")
  32. if yolo_predictor is None:
  33. image_size = 640
  34. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  35. # device = 'cpu'
  36. yolo_predictor = Predictor(image_size, device, model_path)
  37. globals().update({"global_yolo_predictor": yolo_predictor})
  38. data = request.form.get("data")
  39. _md5 = request.form.get("md5")
  40. _global.update({"md5": _md5})
  41. b_table_list = yolo(data, yolo_predictor).get('b_table_list')
  42. return json.dumps({"b_table_list": b_table_list})
  43. except TimeoutError:
  44. return json.dumps({"b_table_list": str([-5])})
  45. except:
  46. traceback.print_exc()
  47. return json.dumps({"b_table_list": str([-1])})
  48. finally:
  49. log("yolo interface finish time " + str(time.time()-start_time))
  50. def yolo(data, predictor):
  51. log("into yolo_interface yolo")
  52. try:
  53. img_data = base64.b64decode(data)
  54. img = bytes2np(img_data)
  55. b_table_list = detect(img, predictor)
  56. return {"b_table_list": b_table_list}
  57. except TimeoutError:
  58. raise TimeoutError
  59. def test_yolo_model(from_remote=True):
  60. _global._init()
  61. from format_convert.convert_image import get_best_predict_size, image_process
  62. if get_platform() == "Windows":
  63. file_path = "C:/Users/Administrator/Desktop/error2.png"
  64. file_path = "C:/Users/Administrator/Downloads/1652672734044.jpg"
  65. else:
  66. file_path = "1.jpg"
  67. image_np = cv2.imread(file_path)
  68. best_h, best_w = get_best_predict_size(image_np)
  69. image_resize = cv2.resize(image_np, (best_w, best_h), interpolation=cv2.INTER_AREA)
  70. cv2.imwrite(file_path, image_resize)
  71. with open(file_path, "rb") as f:
  72. file_bytes = f.read()
  73. file_base64 = base64.b64encode(file_bytes)
  74. _md5 = get_md5_from_bytes(file_bytes)[0]
  75. _global.update({"port": 15010, "md5": _md5})
  76. if from_remote:
  77. file_json = {"data": file_base64, "is_from_pdf": False, "md5": _md5}
  78. # _url = "http://192.168.2.104:18000/otr"
  79. _url = "http://127.0.0.1:18000/otr"
  80. r = json.loads(request_post(_url, file_json))
  81. else:
  82. # otr_model = OtrModels().get_model()
  83. # r = otr(file_base64, otr_model, is_from_pdf=False)
  84. r = image_process(image_resize, file_path)
  85. print(r)
  86. if __name__ == '__main__':
  87. if len(sys.argv) == 2:
  88. port = int(sys.argv[1])
  89. elif len(sys.argv) == 3:
  90. port = int(sys.argv[1])
  91. using_gpu_index = int(sys.argv[2])
  92. else:
  93. port = 18080
  94. using_gpu_index = 0
  95. # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
  96. app.run(host='0.0.0.0', port=port)
  97. log("YOLO running "+str(port))
  98. # test_yolo_model(False)
  99. # print(json.dumps([-2]))
  100. # otr_model = OtrModels().get_model()
  101. # otr("11", otr_model)