predict_model.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2019/11/25 0025 9:54
  5. import os
  6. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  7. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  8. from tensorflow.keras import models
  9. import tensorflow.keras.backend as K
  10. import tensorflow as tf
  11. from PIL import Image
  12. import numpy as np
  13. import string
  14. import re
  15. global graph
  16. total_num = 0
  17. neg_num = 0
  18. graph = tf.get_default_graph()
  19. digit_characters = string.digits
  20. digit_base_model = models.load_model('gru_digit_base_model.h5')
  21. # arith_characters = '0123456789+*-%'
  22. # arith_characters = '0123456789+?-×=' #2021/11/17新增几种算术验证码
  23. arith_characters = '0123456789+?-×/=' #2022/6/21 新增除法,新增两种验证码,两个算术符验证码
  24. arith_base_model = models.load_model('gru_arith_base_model.h5') #2021/11/17新增几种算术验证码
  25. # chinese_characters = '四生乐句付仗斥令仔乎白仙甩他瓜们用丘仪失丛代印册匆禾'
  26. with open('chinese_characters.txt', encoding='utf-8') as f:
  27. chinese_characters = f.read().strip() # 20200728 更新到524个中文
  28. chinese_base_model = models.load_model('gru_chinese_base_model.h5') # 20191219 新增 20200728 更新到524个中文
  29. # english_characters = string.digits + string.ascii_uppercase + string.ascii_lowercase
  30. english_characters = string.ascii_lowercase + string.digits # 20200728 更新为全部小写多种验证码
  31. english_base_model = models.load_model('gru_english_base_model.h5') # 20200518 新增 20200728 更新为全部小写多种验证码
  32. up_low_case_characters = string.ascii_uppercase + string.ascii_lowercase + string.digits
  33. up_low_case_model = models.load_model('gru_up_low_case_base_model.h5') # 20250110 区分大小写验证码
  34. digit_input = digit_base_model.output
  35. digit_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  36. digit_decode = K.ctc_decode(y_pred=digit_input, input_length=digit_input_length * K.shape(digit_input)[1])
  37. digit_decode = K.function([digit_base_model.input, digit_input_length], [digit_decode[0][0]])
  38. arith_input = arith_base_model.output
  39. arith_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  40. arith_decode = K.ctc_decode(y_pred=arith_input, input_length=arith_input_length * K.shape(arith_input)[1])
  41. arith_decode = K.function([arith_base_model.input, arith_input_length], [arith_decode[0][0]])
  42. chinese_input = chinese_base_model.output
  43. chinese_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  44. chinese_decode = K.ctc_decode(y_pred=chinese_input, input_length=chinese_input_length * K.shape(chinese_input)[1])
  45. chinese_decode = K.function([chinese_base_model.input, chinese_input_length], [chinese_decode[0][0]])
  46. english_input = english_base_model.output
  47. english_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  48. english_decode = K.ctc_decode(y_pred=english_input, input_length=english_input_length * K.shape(english_input)[1])
  49. english_decode = K.function([english_base_model.input, english_input_length], [english_decode[0][0]])
  50. up_low_case_input = up_low_case_model.output
  51. up_low_case_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  52. up_low_case_decode = K.ctc_decode(y_pred=up_low_case_input, input_length=up_low_case_input_length * K.shape(up_low_case_input)[1])
  53. up_low_case_decode = K.function([up_low_case_model.input, up_low_case_input_length], [up_low_case_decode[0][0]])
  54. # def decode_arith(arith = '2×?=12'):
  55. # arith = arith.replace('×', '*')
  56. # items = re.split('=', arith)
  57. # if len(items)==2:
  58. # if items[-1] in ['?', '']:
  59. # return eval(items[0])
  60. # l = re.split('-|\+|\*', items[0])
  61. # signs = re.findall('-|\+|\*', items[0])
  62. # if len(l)==2 and len(signs)==1:
  63. # if l[1] == '?':
  64. # if signs[0] == '+':
  65. # return eval('%s-%s'%(items[-1], l[0]))
  66. # elif signs[0] == '-':
  67. # return eval('%s-%s'%(l[0],items[-1]))
  68. # elif signs[0] == '*':
  69. # return int(eval('%s/%s'%(items[-1], l[0])))
  70. # elif l[0] == '?':
  71. # if signs[0] == '+':
  72. # return eval('%s-%s'%(items[-1], l[1]))
  73. # elif signs[0] == '-':
  74. # return eval('%s+%s'%(l[1],items[-1]))
  75. # elif signs[0] == '*':
  76. # return int(eval('%s/%s'%(items[-1], l[1])))
  77. # return ''
  78. def decode_arith(arith='2×?=12'):
  79. try:
  80. arith = arith.replace('×', '*')
  81. if re.search('^(\d+|\?)([\+\-\*/](\d+|\?))+=(\d+|\?)?$', arith) and len(re.findall('\?', arith)) <= 1:
  82. if arith[-1] == '?':
  83. answer = str(int(eval(arith[:-2])))
  84. elif arith[-1] == '=':
  85. answer = str(int(eval(arith[:-1])))
  86. elif re.search('^(\d+|\?)[\+\-\*/](\d+|\?)=\d+$', arith):
  87. a, sign, b, _, quest = re.split('(\+|\-|\*|×|/|=)', arith)
  88. if a == '?':
  89. if sign == "+":
  90. sign = '-'
  91. elif sign == '-':
  92. sign = '+'
  93. elif sign == "*":
  94. sign = '/'
  95. elif sign == '/':
  96. sign = '*'
  97. a, quest = quest, a
  98. elif b == '?':
  99. if sign == "+":
  100. sign = '-'
  101. b, quest = quest, b
  102. a, b = b, a
  103. elif sign == '-':
  104. b, quest = quest, b
  105. elif sign == "*":
  106. sign = '/'
  107. b, quest = quest, b
  108. a, b = b, a
  109. elif sign == '/':
  110. b, quest = quest, b
  111. else:
  112. print('公式出错:', arith)
  113. answer = str(int(eval('%s%s%s' % (a, sign, b))))
  114. else:
  115. print('公式出错:', arith)
  116. else:
  117. answer = ''
  118. return answer
  119. except:
  120. answer = ''
  121. return answer
  122. def predict_digit(img):
  123. img_arr = np.array(img.resize((100, 50), Image.BILINEAR)) / 255.0
  124. X_test = np.array([img_arr])
  125. with graph.as_default():
  126. out_pre = digit_decode([X_test, np.ones(X_test.shape[0])])[0]
  127. # y_pred = digit_base_model.predict(X_test)
  128. # out_pre = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(y_pred.shape[0]) * y_pred.shape[1])[0][0])[:, :6]
  129. out = ''.join([digit_characters[x] for x in out_pre[0]])
  130. return out
  131. def predict_arith(img):
  132. # img_arr = np.array(img.resize((100, 50), Image.BILINEAR)) / 255.0
  133. img_arr = np.array(img.resize((200, 64), Image.BILINEAR)) / 255.0 #20211117更换图片尺寸 20220621 由100,32 改为200 64
  134. X_test = np.array([img_arr])
  135. with graph.as_default():
  136. out_pre = arith_decode([X_test, np.ones(X_test.shape[0])])[0]
  137. out = ''.join([arith_characters[x] for x in out_pre[0]])
  138. try:
  139. out = decode_arith(out)
  140. except:
  141. out = ""
  142. return out
  143. def predict_chinese(img):
  144. # img_arr = np.array(img.resize((100, 50), Image.BILINEAR)) / 255.0
  145. img_arr = np.array(img.resize((120, 40), Image.BILINEAR)) / 255.0 # 更新两种中文验证码
  146. X_test = np.array([img_arr])
  147. with graph.as_default():
  148. out_pre = chinese_decode([X_test, np.ones(X_test.shape[0])])[0]
  149. out = ''.join([chinese_characters[x] for x in out_pre[0]])
  150. return out
  151. def predict_english(img):
  152. img_arr = np.array(img.resize((200, 70), Image.BILINEAR)) / 255.0 #BILINEAR NEAREST
  153. X_test = np.array([img_arr])
  154. with graph.as_default():
  155. out_pre = english_decode([X_test, np.ones(X_test.shape[0])])[0]
  156. out = ''.join([english_characters[x] for x in out_pre[0]])
  157. return out
  158. def predict_up_low_english(img):
  159. img_arr = np.array(img.resize((200, 70), Image.BILINEAR)) / 255.0 #BILINEAR NEAREST
  160. X_test = np.array([img_arr])
  161. with graph.as_default():
  162. out_pre = up_low_case_decode([X_test, np.ones(X_test.shape[0])])[0]
  163. out = ''.join([up_low_case_characters[x] for x in out_pre[0]])
  164. return out
  165. if __name__ == "__main__":
  166. import glob
  167. import time
  168. import sys
  169. import shutil
  170. neg = []
  171. # files = glob.glob(r'E:\linuxPro\captcha_pro\FileInfo0526\标注样本\shensexiansandian\*.jpg')[-3000:]
  172. files = glob.glob('E:/captcha_pic/up_low_case/*.jpg')
  173. t1 = time.time()
  174. pos = 0
  175. for i in range(len(files)):
  176. file = files[i].split('/')[-1]
  177. label = files[i].split('\\')[-1].split('.')[0]
  178. img = Image.open(files[i])
  179. if img.mode != "RGB":
  180. img = img.convert("RGB")
  181. pre = predict_up_low_english(img)
  182. if label!=pre:
  183. print(file,label, pre)
  184. neg.append(file)
  185. else:
  186. pos += 1
  187. # elif len(label) == 4:
  188. # if os.path.exists(files[i]):
  189. # try:
  190. # shutil.copy(files[i], 'english_imgs/' + file)
  191. # except IOError as e:
  192. # print('Unable to copy file %s' % e)
  193. # except:
  194. # print('Unexcepted error', sys.exc_info)
  195. print(len(neg), pos, time.time()-t1)
  196. print('准确率:%.4f'%(pos/(len(neg)+pos)))