predict_model.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2019/11/25 0025 9:54
  5. import os
  6. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  7. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  8. from tensorflow.keras import models
  9. import tensorflow.keras.backend as K
  10. import tensorflow as tf
  11. from PIL import Image
  12. import numpy as np
  13. import string
  14. import re
  15. global graph
  16. total_num = 0
  17. neg_num = 0
  18. graph = tf.get_default_graph()
  19. digit_characters = string.digits
  20. digit_base_model = models.load_model('gru_digit_base_model.h5')
  21. # arith_characters = '0123456789+*-%'
  22. # arith_characters = '0123456789+?-×=' #2021/11/17新增几种算术验证码
  23. arith_characters = '0123456789+?-×/=' #2022/6/21 新增除法,新增两种验证码,两个算术符验证码
  24. arith_base_model = models.load_model('gru_arith_base_model.h5') #2021/11/17新增几种算术验证码
  25. # chinese_characters = '四生乐句付仗斥令仔乎白仙甩他瓜们用丘仪失丛代印册匆禾'
  26. with open('chinese_characters.txt', encoding='utf-8') as f:
  27. chinese_characters = f.read().strip() # 20200728 更新到524个中文
  28. chinese_base_model = models.load_model('gru_chinese_base_model.h5') # 20191219 新增 20200728 更新到524个中文
  29. # english_characters = string.digits + string.ascii_uppercase + string.ascii_lowercase
  30. english_characters = string.ascii_lowercase + string.digits # 20200728 更新为全部小写多种验证码
  31. english_base_model = models.load_model('gru_english_base_model.h5') # 20200518 新增 20200728 更新为全部小写多种验证码
  32. up_low_case_characters = string.ascii_uppercase + string.ascii_lowercase + string.digits
  33. up_low_case_model = models.load_model('gru_up_low_case_base_model.h5') # 20250110 区分大小写验证码
  34. digit_input = digit_base_model.output
  35. digit_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  36. digit_decode = K.ctc_decode(y_pred=digit_input, input_length=digit_input_length * K.shape(digit_input)[1])
  37. digit_decode = K.function([digit_base_model.input, digit_input_length], [digit_decode[0][0]])
  38. arith_input = arith_base_model.output
  39. arith_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  40. arith_decode = K.ctc_decode(y_pred=arith_input, input_length=arith_input_length * K.shape(arith_input)[1])
  41. arith_decode = K.function([arith_base_model.input, arith_input_length], [arith_decode[0][0]])
  42. chinese_input = chinese_base_model.output
  43. chinese_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  44. chinese_decode = K.ctc_decode(y_pred=chinese_input, input_length=chinese_input_length * K.shape(chinese_input)[1])
  45. chinese_decode = K.function([chinese_base_model.input, chinese_input_length], [chinese_decode[0][0]])
  46. english_input = english_base_model.output
  47. english_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  48. english_decode = K.ctc_decode(y_pred=english_input, input_length=english_input_length * K.shape(english_input)[1])
  49. english_decode = K.function([english_base_model.input, english_input_length], [english_decode[0][0]])
  50. up_low_case_input = up_low_case_model.output
  51. up_low_case_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  52. up_low_case_decode = K.ctc_decode(y_pred=up_low_case_input, input_length=up_low_case_input_length * K.shape(up_low_case_input)[1])
  53. up_low_case_decode = K.function([up_low_case_model.input, up_low_case_input_length], [up_low_case_decode[0][0]])
  54. # def decode_arith(arith = '2×?=12'):
  55. # arith = arith.replace('×', '*')
  56. # items = re.split('=', arith)
  57. # if len(items)==2:
  58. # if items[-1] in ['?', '']:
  59. # return eval(items[0])
  60. # l = re.split('-|\+|\*', items[0])
  61. # signs = re.findall('-|\+|\*', items[0])
  62. # if len(l)==2 and len(signs)==1:
  63. # if l[1] == '?':
  64. # if signs[0] == '+':
  65. # return eval('%s-%s'%(items[-1], l[0]))
  66. # elif signs[0] == '-':
  67. # return eval('%s-%s'%(l[0],items[-1]))
  68. # elif signs[0] == '*':
  69. # return int(eval('%s/%s'%(items[-1], l[0])))
  70. # elif l[0] == '?':
  71. # if signs[0] == '+':
  72. # return eval('%s-%s'%(items[-1], l[1]))
  73. # elif signs[0] == '-':
  74. # return eval('%s+%s'%(l[1],items[-1]))
  75. # elif signs[0] == '*':
  76. # return int(eval('%s/%s'%(items[-1], l[1])))
  77. # return ''
  78. def decode_arith(arith='2×?=12'):
  79. try:
  80. arith = arith.replace('×', '*')
  81. if re.search('^(\d+|\?)([\+\-\*/](\d+|\?))+=(\d+|\?)?$', arith) and len(re.findall('\?', arith)) <= 1:
  82. if arith[-1] == '?':
  83. answer = str(int(eval(arith[:-2])))
  84. elif arith[-1] == '=':
  85. answer = str(int(eval(arith[:-1])))
  86. elif re.search('^(\d+|\?)[\+\-\*/](\d+|\?)=\d+$', arith):
  87. a, sign, b, _, quest = re.split('(\+|\-|\*|×|/|=)', arith)
  88. if a == '?':
  89. if sign == "+":
  90. sign = '-'
  91. elif sign == '-':
  92. sign = '+'
  93. elif sign == "*":
  94. sign = '/'
  95. elif sign == '/':
  96. sign = '*'
  97. a, quest = quest, a
  98. elif b == '?':
  99. if sign == "+":
  100. sign = '-'
  101. b, quest = quest, b
  102. a, b = b, a
  103. elif sign == '-':
  104. b, quest = quest, b
  105. elif sign == "*":
  106. sign = '/'
  107. b, quest = quest, b
  108. a, b = b, a
  109. elif sign == '/':
  110. b, quest = quest, b
  111. else:
  112. print('公式出错:', arith)
  113. answer = str(int(eval('%s%s%s' % (a, sign, b))))
  114. else:
  115. print('公式出错:', arith)
  116. elif re.search('^\d+[\+\-\*/]\d+$', arith):
  117. answer = str(int(eval(arith)))
  118. else:
  119. answer = ''
  120. return answer
  121. except:
  122. answer = ''
  123. return answer
  124. def predict_digit(img):
  125. img_arr = np.array(img.resize((100, 50), Image.BILINEAR)) / 255.0
  126. X_test = np.array([img_arr])
  127. with graph.as_default():
  128. out_pre = digit_decode([X_test, np.ones(X_test.shape[0])])[0]
  129. # y_pred = digit_base_model.predict(X_test)
  130. # out_pre = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(y_pred.shape[0]) * y_pred.shape[1])[0][0])[:, :6]
  131. out = ''.join([digit_characters[x] for x in out_pre[0]])
  132. return out
  133. def predict_arith(img):
  134. # img_arr = np.array(img.resize((100, 50), Image.BILINEAR)) / 255.0
  135. img_arr = np.array(img.resize((200, 64), Image.BILINEAR)) / 255.0 #20211117更换图片尺寸 20220621 由100,32 改为200 64
  136. X_test = np.array([img_arr])
  137. with graph.as_default():
  138. out_pre = arith_decode([X_test, np.ones(X_test.shape[0])])[0]
  139. out = ''.join([arith_characters[x] for x in out_pre[0]])
  140. try:
  141. out = decode_arith(out)
  142. except:
  143. out = ""
  144. return out
  145. def predict_chinese(img):
  146. # img_arr = np.array(img.resize((100, 50), Image.BILINEAR)) / 255.0
  147. img_arr = np.array(img.resize((120, 40), Image.BILINEAR)) / 255.0 # 更新两种中文验证码
  148. X_test = np.array([img_arr])
  149. with graph.as_default():
  150. out_pre = chinese_decode([X_test, np.ones(X_test.shape[0])])[0]
  151. out = ''.join([chinese_characters[x] for x in out_pre[0]])
  152. return out
  153. def predict_english(img):
  154. img_arr = np.array(img.resize((200, 70), Image.BILINEAR)) / 255.0 #BILINEAR NEAREST
  155. X_test = np.array([img_arr])
  156. with graph.as_default():
  157. out_pre = english_decode([X_test, np.ones(X_test.shape[0])])[0]
  158. out = ''.join([english_characters[x] for x in out_pre[0]])
  159. return out
  160. def predict_up_low_english(img):
  161. img_arr = np.array(img.resize((200, 70), Image.BILINEAR)) / 255.0 #BILINEAR NEAREST
  162. X_test = np.array([img_arr])
  163. with graph.as_default():
  164. out_pre = up_low_case_decode([X_test, np.ones(X_test.shape[0])])[0]
  165. out = ''.join([up_low_case_characters[x] for x in out_pre[0]])
  166. return out
  167. if __name__ == "__main__":
  168. import glob
  169. import time
  170. import sys
  171. import shutil
  172. neg = []
  173. # files = glob.glob(r'E:\linuxPro\captcha_pro\FileInfo0526\标注样本\shensexiansandian\*.jpg')[-3000:]
  174. files = glob.glob('E:/captcha_pic/up_low_case/*.jpg')
  175. t1 = time.time()
  176. pos = 0
  177. for i in range(len(files)):
  178. file = files[i].split('/')[-1]
  179. label = files[i].split('\\')[-1].split('.')[0]
  180. img = Image.open(files[i])
  181. if img.mode != "RGB":
  182. img = img.convert("RGB")
  183. pre = predict_up_low_english(img)
  184. if label!=pre:
  185. print(file,label, pre)
  186. neg.append(file)
  187. else:
  188. pos += 1
  189. # elif len(label) == 4:
  190. # if os.path.exists(files[i]):
  191. # try:
  192. # shutil.copy(files[i], 'english_imgs/' + file)
  193. # except IOError as e:
  194. # print('Unable to copy file %s' % e)
  195. # except:
  196. # print('Unexcepted error', sys.exc_info)
  197. print(len(neg), pos, time.time()-t1)
  198. print('准确率:%.4f'%(pos/(len(neg)+pos)))