predict_model.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2019/11/25 0025 9:54
  5. import os
  6. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  7. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  8. from tensorflow.keras import models
  9. import tensorflow.keras.backend as K
  10. import tensorflow as tf
  11. from PIL import Image
  12. import numpy as np
  13. import string
  14. global graph
  15. total_num = 0
  16. neg_num = 0
  17. graph = tf.get_default_graph()
  18. digit_characters = string.digits
  19. digit_base_model = models.load_model('gru_digit_base_model.h5')
  20. arith_characters = '0123456789+*-%'
  21. arith_base_model = models.load_model('gru_arith_base_model.h5')
  22. def predict_digit(img):
  23. img_arr = np.array(img.resize((100, 50), Image.BILINEAR)) / 255.0
  24. X_test = np.array([img_arr])
  25. with graph.as_default():
  26. y_pred = digit_base_model.predict(X_test)
  27. out_pre = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(y_pred.shape[0]) * y_pred.shape[1])[0][0])[:, :6]
  28. out = ''.join([digit_characters[x] for x in out_pre[0]])
  29. return out
  30. def predict_arith(img):
  31. img_arr = np.array(img.resize((100, 50), Image.BILINEAR)) / 255.0
  32. X_test = np.array([img_arr])
  33. with graph.as_default():
  34. y_pred = arith_base_model.predict(X_test)
  35. out_pre = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(y_pred.shape[0]) * y_pred.shape[1])[0][0])[:, :6]
  36. out = ''.join([arith_characters[x] for x in out_pre[0]])
  37. return out