chd_interface.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. import time
  6. import traceback
  7. from glob import glob
  8. import cv2
  9. import numpy as np
  10. from PIL import Image
  11. from chinese_detect.inference_yolo_char import get_tiny_inference_model, detect
  12. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  13. import tensorflow as tf
  14. from flask import Flask, request
  15. from border_recognize.model import u_net_drag
  16. from border_recognize.inference_drag import recognize
  17. from utils import pil_resize, np2bytes, request_post, bytes2np, get_anchors, get_classes, get_colors
  18. tf.compat.v1.disable_eager_execution()
  19. sess = tf.compat.v1.Session(graph=tf.Graph())
  20. _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
  21. model_path = _dir + "/models/char_yolo_loss_39.90.h5"
  22. anchors = get_anchors(_dir + "/yolo_data/my_anchors.txt")
  23. classes = get_classes(_dir + "/yolo_data/my_classes.txt")
  24. colors = get_colors(len(classes))
  25. # 接口配置
  26. app = Flask(__name__)
  27. @app.route('/chd', methods=['POST'])
  28. def chd():
  29. start_time = time.time()
  30. logging.info("into chd_interface chd")
  31. try:
  32. # 接收网络数据
  33. if not request.form:
  34. logging.info("chd no data!")
  35. return json.dumps({"data": "", "success": 0})
  36. data = request.form.get("data")
  37. logging.info("chd_interface get data time" + str(time.time()-start_time))
  38. # 加载模型
  39. chd_model = globals().get("global_chd_model")
  40. if chd_model is None:
  41. print("=========== init chd model ===========")
  42. chd_model = ChdModels().get_model()
  43. globals().update({"global_chd_model": chd_model})
  44. # 数据转换
  45. data = base64.b64decode(data)
  46. image_np = bytes2np(data)
  47. # 预测
  48. _, out_boxes, out_classes = detect(image_np, chd_model, sess)
  49. return json.dumps({"data": out_boxes, "success": 1})
  50. except:
  51. traceback.print_exc()
  52. return json.dumps({"data": "", "success": 0})
  53. finally:
  54. logging.info("chd interface finish time " + str(time.time()-start_time))
  55. class ChdModels:
  56. def __init__(self):
  57. # detect
  58. with sess.as_default():
  59. with sess.graph.as_default():
  60. self.model = get_tiny_inference_model(anchors, len(classes), weights_path=model_path)
  61. def get_model(self):
  62. return self.model
  63. def test_chd_model(from_remote=True):
  64. paths = glob("D:/Project/captcha/data/test/phrase_5.jpg")
  65. for file_path in paths:
  66. img_np = cv2.imread(file_path)
  67. file_bytes = np2bytes(img_np)
  68. file_base64 = base64.b64encode(file_bytes)
  69. if from_remote:
  70. file_json = {"data": file_base64}
  71. # _url = "http://192.168.2.102:17000/ocr"
  72. _url = "http://127.0.0.1:17000/chd"
  73. result = json.loads(request_post(_url, file_json))
  74. if result.get("success"):
  75. out_boxes = result.get("data")
  76. print("out_boxes", out_boxes)
  77. for box in out_boxes:
  78. cv2.rectangle(img_np, (box[0], box[1]), (box[2], box[3]), (0, 0, 255))
  79. cv2.imshow("img_np", img_np)
  80. cv2.waitKey(0)
  81. else:
  82. print("failed!")
  83. if __name__ == "__main__":
  84. # app.run(host='127.0.0.1', port=17000, debug=False)
  85. test_chd_model()