chd_interface.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. import sys
  6. import time
  7. import traceback
  8. from glob import glob
  9. import numpy as np
  10. import cv2
  11. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  13. import tensorflow as tf
  14. from flask import Flask, request
  15. from chinese_detect.inference_yolo_char import get_tiny_inference_model, detect
  16. from utils import pil_resize, np2bytes, request_post, bytes2np, get_anchors, get_classes, get_colors
  17. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  18. tf.compat.v1.disable_eager_execution()
  19. sess = tf.compat.v1.Session(graph=tf.Graph())
  20. package_dir = os.path.abspath(os.path.dirname(__file__))
  21. model_path = package_dir + "/models/char_yolo_loss_39.90.h5"
  22. anchors = get_anchors(package_dir + "/yolo_data/my_anchors.txt")
  23. classes = get_classes(package_dir + "/yolo_data/my_classes.txt")
  24. colors = get_colors(len(classes))
  25. # 接口配置
  26. app = Flask(__name__)
  27. @app.route('/chd', methods=['POST'])
  28. def chd():
  29. start_time = time.time()
  30. logging.info("into chd_interface chd")
  31. try:
  32. # 接收网络数据
  33. if not request.form:
  34. logging.info("chd no data!")
  35. return json.dumps({"data": "", "success": 0})
  36. data = request.form.get("data")
  37. is_tips = request.form.get("tips")
  38. logging.info("chd_interface get data time" + str(time.time()-start_time))
  39. # 加载模型
  40. chd_model = globals().get("global_chd_model")
  41. if chd_model is None:
  42. print("=========== init chd model ===========")
  43. chd_model = ChdModels().get_model()
  44. globals().update({"global_chd_model": chd_model})
  45. # 数据转换
  46. data = base64.b64decode(data)
  47. image_np = bytes2np(data)
  48. # 预测
  49. _, out_boxes, out_classes = detect(image_np, chd_model, sess, is_tips=is_tips)
  50. return json.dumps({"data": out_boxes, "success": 1})
  51. except:
  52. traceback.print_exc()
  53. return json.dumps({"data": "", "success": 0})
  54. finally:
  55. logging.info("chd interface finish time " + str(time.time()-start_time))
  56. class ChdModels:
  57. def __init__(self):
  58. # detect
  59. with sess.as_default():
  60. with sess.graph.as_default():
  61. self.model = get_tiny_inference_model(anchors, len(classes), weights_path=model_path)
  62. def get_model(self):
  63. return self.model
  64. def test_chd_model(from_remote=True):
  65. paths = glob("D:/Project/captcha/data/test/phrase_5.jpg")
  66. for file_path in paths:
  67. img_np = cv2.imread(file_path)
  68. file_bytes = np2bytes(img_np)
  69. file_base64 = base64.b64encode(file_bytes)
  70. if from_remote:
  71. file_json = {"data": file_base64}
  72. # _url = "http://192.168.2.102:17000/ocr"
  73. _url = "http://127.0.0.1:17000/chd"
  74. result = json.loads(request_post(_url, file_json))
  75. if result.get("success"):
  76. out_boxes = result.get("data")
  77. print("out_boxes", out_boxes)
  78. for box in out_boxes:
  79. cv2.rectangle(img_np, (box[0], box[1]), (box[2], box[3]), (0, 0, 255))
  80. cv2.imshow("img_np", img_np)
  81. cv2.waitKey(0)
  82. else:
  83. print("failed!")
  84. if __name__ == "__main__":
  85. app.run(host='127.0.0.1', port=17000, debug=False)
  86. # test_chd_model()