yolo_interface.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import base64
  2. import json
  3. import os
  4. import sys
  5. import traceback
  6. import torch
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
  8. from botr.yolov8.model import Predictor
  9. from botr.yolov8.predict import detect
  10. from config.max_compute_config import MAX_COMPUTE
  11. import time
  12. import cv2
  13. from flask import Flask, request
  14. from format_convert.utils import request_post, log, get_md5_from_bytes, get_platform, bytes2np
  15. from format_convert import _global
  16. ROOT = os.path.abspath(os.path.dirname(__file__)) + '/../../'
  17. model_path = ROOT + 'botr/yolov8/weights.pt'
  18. # 接口配置
  19. app = Flask(__name__)
  20. @app.route('/yolo', methods=['POST'])
  21. def _yolo():
  22. _global._init()
  23. _global.update({"port": globals().get("port")})
  24. start_time = time.time()
  25. log("into yolo_interface _yolo")
  26. try:
  27. if not request.form:
  28. log("yolo no data!")
  29. return json.dumps({"b_table_list": str([-9])})
  30. yolo_predictor = globals().get("global_yolo_predictor")
  31. if yolo_predictor is None:
  32. image_size = 640
  33. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  34. # device = 'cpu'
  35. yolo_predictor = Predictor(image_size, device, model_path)
  36. globals().update({"global_yolo_predictor": yolo_predictor})
  37. data = request.form.get("data")
  38. _md5 = request.form.get("md5")
  39. _global.update({"md5": _md5})
  40. b_table_list = yolo(data, yolo_predictor).get('b_table_list')
  41. return json.dumps({"b_table_list": b_table_list})
  42. except TimeoutError:
  43. return json.dumps({"b_table_list": str([-5])})
  44. except:
  45. traceback.print_exc()
  46. return json.dumps({"b_table_list": str([-1])})
  47. finally:
  48. log("yolo interface finish time " + str(time.time()-start_time))
  49. def yolo(data, predictor):
  50. log("into yolo_interface yolo")
  51. try:
  52. img_data = base64.b64decode(data)
  53. img = bytes2np(img_data)
  54. b_table_list = detect(img, predictor)
  55. return {"b_table_list": b_table_list}
  56. except TimeoutError:
  57. raise TimeoutError
  58. def test_yolo_model(from_remote=True):
  59. _global._init()
  60. file_path = "C:/Users/Administrator/Desktop/test_b_table/yolo_error/error6.png"
  61. # file_path = "C:/Users/Administrator/Desktop/test_b_table/error10.png"
  62. # file_path = "C:/Users/Administrator/Downloads/1652672734044.jpg"
  63. from format_convert.convert_image import get_best_predict_size
  64. from format_convert.utils import np2bytes, pil_resize
  65. image_np = cv2.imread(file_path)
  66. # best_h, best_w = get_best_predict_size(image_np)
  67. # image_np = pil_resize(image_np, best_h, best_w)
  68. file_bytes = np2bytes(image_np)
  69. file_base64 = base64.b64encode(file_bytes)
  70. _md5 = get_md5_from_bytes(file_bytes)[0]
  71. _global.update({"port": 15010, "md5": _md5})
  72. file_json = {"data": file_base64, "md5": _md5}
  73. # _url = "http://192.168.2.104:18080/yolo"
  74. _url = "http://127.0.0.1:18080/yolo"
  75. r = json.loads(request_post(_url, file_json))
  76. print(r)
  77. r = r.get('b_table_list')[0]
  78. for bbox in r:
  79. bbox = [int(x) for x in bbox]
  80. cv2.rectangle(image_np, bbox[0:2], bbox[2:4], (0, 0, 255), 2)
  81. cv2.namedWindow('img', cv2.WINDOW_NORMAL)
  82. cv2.imshow('img', image_np)
  83. cv2.waitKey(0)
  84. if __name__ == '__main__':
  85. # port = 18080
  86. # using_gpu_index = 0
  87. # app.run(host='0.0.0.0', port=port)
  88. test_yolo_model(True)