yolo_interface.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import base64
  2. import json
  3. import os
  4. import sys
  5. import traceback
  6. import torch
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
  8. from botr.yolov8.model import Predictor
  9. from botr.yolov8.predict import detect
  10. from format_convert.max_compute_config import max_compute
  11. MAX_COMPUTE = max_compute
  12. import time
  13. import cv2
  14. from flask import Flask, request
  15. from format_convert.utils import request_post, log, get_md5_from_bytes, get_platform, bytes2np
  16. from format_convert import _global
  17. ROOT = os.path.abspath(os.path.dirname(__file__)) + '/../../'
  18. model_path = ROOT + 'botr/yolov8/weights.pt'
  19. # 接口配置
  20. app = Flask(__name__)
  21. @app.route('/yolo', methods=['POST'])
  22. def _yolo():
  23. _global._init()
  24. _global.update({"port": globals().get("port")})
  25. start_time = time.time()
  26. log("into yolo_interface _yolo")
  27. try:
  28. if not request.form:
  29. log("yolo no data!")
  30. return json.dumps({"b_table_list": str([-9])})
  31. yolo_predictor = globals().get("global_yolo_predictor")
  32. if yolo_predictor is None:
  33. image_size = 640
  34. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  35. # device = 'cpu'
  36. yolo_predictor = Predictor(image_size, device, model_path)
  37. globals().update({"global_yolo_predictor": yolo_predictor})
  38. data = request.form.get("data")
  39. _md5 = request.form.get("md5")
  40. _global.update({"md5": _md5})
  41. b_table_list = yolo(data, yolo_predictor).get('b_table_list')
  42. return json.dumps({"b_table_list": b_table_list})
  43. except TimeoutError:
  44. return json.dumps({"b_table_list": str([-5])})
  45. except:
  46. traceback.print_exc()
  47. return json.dumps({"b_table_list": str([-1])})
  48. finally:
  49. log("yolo interface finish time " + str(time.time()-start_time))
  50. def yolo(data, predictor):
  51. log("into yolo_interface yolo")
  52. try:
  53. img_data = base64.b64decode(data)
  54. img = bytes2np(img_data)
  55. b_table_list = detect(img, predictor)
  56. return {"b_table_list": b_table_list}
  57. except TimeoutError:
  58. raise TimeoutError
  59. def test_yolo_model(from_remote=True):
  60. _global._init()
  61. file_path = "C:/Users/Administrator/Desktop/test_b_table/yolo_error/error6.png"
  62. # file_path = "C:/Users/Administrator/Desktop/test_b_table/error10.png"
  63. # file_path = "C:/Users/Administrator/Downloads/1652672734044.jpg"
  64. from format_convert.convert_image import get_best_predict_size
  65. from format_convert.utils import np2bytes, pil_resize
  66. image_np = cv2.imread(file_path)
  67. # best_h, best_w = get_best_predict_size(image_np)
  68. # image_np = pil_resize(image_np, best_h, best_w)
  69. file_bytes = np2bytes(image_np)
  70. file_base64 = base64.b64encode(file_bytes)
  71. _md5 = get_md5_from_bytes(file_bytes)[0]
  72. _global.update({"port": 15010, "md5": _md5})
  73. file_json = {"data": file_base64, "md5": _md5}
  74. # _url = "http://192.168.2.104:18080/yolo"
  75. _url = "http://127.0.0.1:18080/yolo"
  76. r = json.loads(request_post(_url, file_json))
  77. print(r)
  78. r = r.get('b_table_list')[0]
  79. for bbox in r:
  80. bbox = [int(x) for x in bbox]
  81. cv2.rectangle(image_np, bbox[0:2], bbox[2:4], (0, 0, 255), 2)
  82. cv2.namedWindow('img', cv2.WINDOW_NORMAL)
  83. cv2.imshow('img', image_np)
  84. cv2.waitKey(0)
  85. if __name__ == '__main__':
  86. # port = 18080
  87. # using_gpu_index = 0
  88. # app.run(host='0.0.0.0', port=port)
  89. test_yolo_model(True)