predict_model.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2019/11/25 0025 9:54
  5. import os
  6. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  7. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  8. from tensorflow.keras import models
  9. import tensorflow.keras.backend as K
  10. import tensorflow as tf
  11. from PIL import Image
  12. import numpy as np
  13. import string
  14. import re
  15. global graph
  16. total_num = 0
  17. neg_num = 0
  18. graph = tf.get_default_graph()
  19. digit_characters = string.digits
  20. digit_base_model = models.load_model('gru_digit_base_model.h5')
  21. # arith_characters = '0123456789+*-%'
  22. # arith_characters = '0123456789+?-×=' #2021/11/17新增几种算术验证码
  23. arith_characters = '0123456789+?-×/=' #2022/6/21 新增除法,新增两种验证码,两个算术符验证码
  24. arith_base_model = models.load_model('gru_arith_base_model.h5') #2021/11/17新增几种算术验证码
  25. # chinese_characters = '四生乐句付仗斥令仔乎白仙甩他瓜们用丘仪失丛代印册匆禾'
  26. with open('chinese_characters.txt', encoding='utf-8') as f:
  27. chinese_characters = f.read().strip() # 20200728 更新到524个中文
  28. chinese_base_model = models.load_model('gru_chinese_base_model.h5') # 20191219 新增 20200728 更新到524个中文
  29. # english_characters = string.digits + string.ascii_uppercase + string.ascii_lowercase
  30. english_characters = string.ascii_lowercase + string.digits # 20200728 更新为全部小写多种验证码
  31. english_base_model = models.load_model('gru_english_base_model.h5') # 20200518 新增 20200728 更新为全部小写多种验证码
  32. digit_input = digit_base_model.output
  33. digit_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  34. digit_decode = K.ctc_decode(y_pred=digit_input, input_length=digit_input_length * K.shape(digit_input)[1])
  35. digit_decode = K.function([digit_base_model.input, digit_input_length], [digit_decode[0][0]])
  36. arith_input = arith_base_model.output
  37. arith_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  38. arith_decode = K.ctc_decode(y_pred=arith_input, input_length=arith_input_length * K.shape(arith_input)[1])
  39. arith_decode = K.function([arith_base_model.input, arith_input_length], [arith_decode[0][0]])
  40. chinese_input = chinese_base_model.output
  41. chinese_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  42. chinese_decode = K.ctc_decode(y_pred=chinese_input, input_length=chinese_input_length * K.shape(chinese_input)[1])
  43. chinese_decode = K.function([chinese_base_model.input, chinese_input_length], [chinese_decode[0][0]])
  44. english_input = english_base_model.output
  45. english_input_length = tf.keras.Input(batch_shape=[None], dtype='int32')
  46. english_decode = K.ctc_decode(y_pred=english_input, input_length=english_input_length * K.shape(english_input)[1])
  47. english_decode = K.function([english_base_model.input, english_input_length], [english_decode[0][0]])
  48. # def decode_arith(arith = '2×?=12'):
  49. # arith = arith.replace('×', '*')
  50. # items = re.split('=', arith)
  51. # if len(items)==2:
  52. # if items[-1] in ['?', '']:
  53. # return eval(items[0])
  54. # l = re.split('-|\+|\*', items[0])
  55. # signs = re.findall('-|\+|\*', items[0])
  56. # if len(l)==2 and len(signs)==1:
  57. # if l[1] == '?':
  58. # if signs[0] == '+':
  59. # return eval('%s-%s'%(items[-1], l[0]))
  60. # elif signs[0] == '-':
  61. # return eval('%s-%s'%(l[0],items[-1]))
  62. # elif signs[0] == '*':
  63. # return int(eval('%s/%s'%(items[-1], l[0])))
  64. # elif l[0] == '?':
  65. # if signs[0] == '+':
  66. # return eval('%s-%s'%(items[-1], l[1]))
  67. # elif signs[0] == '-':
  68. # return eval('%s+%s'%(l[1],items[-1]))
  69. # elif signs[0] == '*':
  70. # return int(eval('%s/%s'%(items[-1], l[1])))
  71. # return ''
  72. def decode_arith(arith='2×?=12'):
  73. try:
  74. arith = arith.replace('×', '*')
  75. if re.search('^(\d+|\?)([\+\-\*/](\d+|\?))+=(\d+|\?)?$', arith) and len(re.findall('\?', arith)) <= 1:
  76. if arith[-1] == '?':
  77. answer = str(int(eval(arith[:-2])))
  78. elif arith[-1] == '=':
  79. answer = str(int(eval(arith[:-1])))
  80. elif re.search('^(\d+|\?)[\+\-\*/](\d+|\?)=\d+$', arith):
  81. a, sign, b, _, quest = re.split('(\+|\-|\*|×|/|=)', arith)
  82. if a == '?':
  83. if sign == "+":
  84. sign = '-'
  85. elif sign == '-':
  86. sign = '+'
  87. elif sign == "*":
  88. sign = '/'
  89. elif sign == '/':
  90. sign = '*'
  91. a, quest = quest, a
  92. elif b == '?':
  93. if sign == "+":
  94. sign = '-'
  95. b, quest = quest, b
  96. a, b = b, a
  97. elif sign == '-':
  98. b, quest = quest, b
  99. elif sign == "*":
  100. sign = '/'
  101. b, quest = quest, b
  102. a, b = b, a
  103. elif sign == '/':
  104. b, quest = quest, b
  105. else:
  106. print('公式出错:', arith)
  107. answer = str(int(eval('%s%s%s' % (a, sign, b))))
  108. else:
  109. print('公式出错:', arith)
  110. else:
  111. answer = ''
  112. return answer
  113. except:
  114. answer = ''
  115. return answer
  116. def predict_digit(img):
  117. img_arr = np.array(img.resize((100, 50), Image.BILINEAR)) / 255.0
  118. X_test = np.array([img_arr])
  119. with graph.as_default():
  120. out_pre = digit_decode([X_test, np.ones(X_test.shape[0])])[0]
  121. # y_pred = digit_base_model.predict(X_test)
  122. # out_pre = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(y_pred.shape[0]) * y_pred.shape[1])[0][0])[:, :6]
  123. out = ''.join([digit_characters[x] for x in out_pre[0]])
  124. return out
  125. def predict_arith(img):
  126. # img_arr = np.array(img.resize((100, 50), Image.BILINEAR)) / 255.0
  127. img_arr = np.array(img.resize((200, 64), Image.BILINEAR)) / 255.0 #20211117更换图片尺寸 20220621 由100,32 改为200 64
  128. X_test = np.array([img_arr])
  129. with graph.as_default():
  130. out_pre = arith_decode([X_test, np.ones(X_test.shape[0])])[0]
  131. out = ''.join([arith_characters[x] for x in out_pre[0]])
  132. try:
  133. out = decode_arith(out)
  134. except:
  135. out = ""
  136. return out
  137. def predict_chinese(img):
  138. # img_arr = np.array(img.resize((100, 50), Image.BILINEAR)) / 255.0
  139. img_arr = np.array(img.resize((120, 40), Image.BILINEAR)) / 255.0 # 更新两种中文验证码
  140. X_test = np.array([img_arr])
  141. with graph.as_default():
  142. out_pre = chinese_decode([X_test, np.ones(X_test.shape[0])])[0]
  143. out = ''.join([chinese_characters[x] for x in out_pre[0]])
  144. return out
  145. def predict_english(img):
  146. img_arr = np.array(img.resize((200, 70), Image.BILINEAR)) / 255.0 #BILINEAR NEAREST
  147. X_test = np.array([img_arr])
  148. with graph.as_default():
  149. out_pre = english_decode([X_test, np.ones(X_test.shape[0])])[0]
  150. out = ''.join([english_characters[x] for x in out_pre[0]])
  151. return out
  152. if __name__ == "__main__":
  153. import glob
  154. import time
  155. import sys
  156. import shutil
  157. neg = []
  158. files = glob.glob(r'E:\linuxPro\captcha_pro\FileInfo0526\标注样本\shensexiansandian\*.jpg')[-3000:]
  159. t1 = time.time()
  160. for i in range(len(files)):
  161. file = files[i].split('\\')[-1]
  162. label = files[i].split('\\')[-1].split('_')[0]
  163. img = Image.open(files[i])
  164. if img.mode != "RGB":
  165. img = img.convert("RGB")
  166. pre = predict_english(img)
  167. if label!=pre:
  168. print(file,label, pre)
  169. neg.append(file)
  170. elif len(label) == 4:
  171. if os.path.exists(files[i]):
  172. try:
  173. shutil.copy(files[i], 'english_imgs/' + file)
  174. except IOError as e:
  175. print('Unable to copy file %s' % e)
  176. except:
  177. print('Unexcepted error', sys.exc_info)
  178. print(len(neg), time.time()-t1)