{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0123456789+?-×/= 17\n" ] } ], "source": [ "import os\n", "# os.environ['CUDA_VISIBLE_DEVICES'] = ''\n", "\n", "from captcha.image import ImageCaptcha\n", "from PIL import Image, ImageFont, ImageDraw\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import random\n", "import uuid\n", "import math\n", "import glob\n", "import string\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'retina'\n", "\n", "characters = '0123456789+?-×/=' # 验证码字符集合\n", "\n", "width, height, n_len, n_class = 200, 64, 12, len(characters) + 1 #图片宽、高,验证码最大长度,分类类别:字符集+1个空值\n", "\n", "font_paths = glob.glob('latin/*')\n", "# '/usr/share/fonts/opentype/noto/NotoSerifCJK-Bold.ttc', , '/usr/share/fonts/truetype/arphic/ukai.ttc' '/usr/share/fonts/truetype/arphic/uming.ttc', 'latin/arialbi.ttf',\n", "fonts = [ '/usr/share/fonts/opentype/malayalam/Manjari-Regular.otf', '/usr/share/fonts/opentype/malayalam/Manjari-Thin.otf', '/usr/share/fonts/opentype/noto/NotoSerifCJK-Regular.ttc', '/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc', '/usr/share/fonts/opentype/noto/NotoSansCJK-Bold.ttc']\n", "fonts2 = ['latin/segoeuil.ttf', 'latin/verdana.ttf', 'latin/calibri.ttf', 'latin/SIMLI.TTF', 'latin/verdanai.ttf', 'latin/framd.ttf', 'latin/ariali.ttf', 'latin/LSANS.TTF']\n", "fonts = fonts+fonts2\n", "print(characters, n_class)" ] }, { "cell_type": "code", "execution_count": 155, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['70', '/', '10', '+', '0', '×', '20', '=', '?']" ] }, "execution_count": 155, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# paths = glob.glob('/data/captcha/arithmetic/160_60/*.jpg') # 100_26 70_25 100_40 330_69 160_60 146_46\n", "# i = 12\n", "# img = Image.open(paths[i])\n", "# img2 = img.resize((width, height), Image.BILINEAR)\n", "\n", "# plt.imshow(img)\n", "# plt.show()\n", "# plt.imshow(img2)\n", "# plt.show()\n", "text = '70/10+0×20=?'\n", "re.split('(\\+|\\-|\\*|×|/|=)', text)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0+1-0=?\n", "image size (200, 64) (200, 64)\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 279, "width": 2329 }, "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "random_color(rs, re, gs, ge, bs, be) (10, 18, 14)\n" ] } ], "source": [ "'''生成彩色图像'''\n", "import re\n", "\n", "def get_arith(top=9, i=None, que_mark=True, numOfas=1):\n", " '''生成带等号问号+x-三种算法,返回公式及求解答案\n", " i: 第几个位置为问号\n", " numOfas:算术符号数量\n", " que_mark:True公式包含问号,False不包含'''\n", " a = random.randint(0,top)\n", " sign = random.choice(['+','*','-','/']) if top<10 else random.choice(['+','-'])\n", " b = random.randint(0,top) if sign!='/' else random.randint(1, top) \n", " if sign=='/':\n", " answer = int(eval(('%d*%d'%(a,b)))) if sign!='/' else int(eval(('%d*%d'%(a,b))))\n", " a, answer = answer, a\n", " elif sign=='-':\n", " b = random.randint(0,a)\n", " answer = int(eval('%d%s%d'%(a,sign,b)))\n", " else:\n", " answer = int(eval('%d%s%d'%(a,sign,b))) \n", " \n", " if numOfas==2: \n", " sign2 = random.choice(['+','-'])\n", " c = random.randint(0,top) if sign2=='+' else random.randint(0,answer)\n", " answer = int(eval('%d%s%d%s%d'%(a,sign,b,sign2,c)))\n", " if sign=='*' and random.random()>0.5:\n", " sign = '×' \n", "# arith = '%d%s%d%s%d=%d'%(a,sign,b,sign2,c,answer)\n", " a = str(a)\n", " b = str(b)\n", " c = str(c)\n", " answer = str(answer)\n", " l = [a,b,c,answer]\n", " if que_mark:\n", " i = random.choice([0,1,2,3]) if i==None else i\n", " question = l[i]\n", " l[i] = '?'\n", " arith = '%s%s%s%s%s=%s'%(l[0],sign,l[1],sign2,l[2],l[3])\n", " return arith, question\n", " arith = '%s%s%s%s%s='%(l[0],sign,l[1],sign2,l[2])\n", " return arith, answer\n", " else:\n", " if sign=='*' and random.random()>0.5:\n", " sign = '×' \n", "# arith = '%d%s%d=%d'%(a,sign,b,answer)\n", " a = str(a)\n", " b = str(b)\n", " answer = str(answer)\n", " l = [a,b,answer] \n", " if que_mark:\n", " i = random.choice([0,1,2]) if i==None else i\n", " question = l[i]\n", " l[i] = '?'\n", " arith = '%s%s%s=%s'%(l[0],sign,l[1], l[2])\n", " return arith, question\n", " arith = '%s%s%s='%(l[0],sign,l[1])\n", " return arith, answer\n", "\n", "\n", "def get_wavy_line(w = (0, 100),h = (30, 50)):\n", " '''产生波浪线坐标'''\n", " import random\n", " n = 50\n", " x = 0\n", " y = random.randint(h[0],h[1])\n", " flag = random.randint(0,2)\n", " xy = [(x, y)]\n", " while x < w[1]:\n", " temp_y = random.randint(1, 3)\n", " temp_x = random.randint(5, 10)\n", " if flag == 0:\n", " if y + temp_y > h[1]:\n", " y -= temp_y\n", " flag = 1\n", " else:\n", " y += temp_y\n", " else:\n", " if y - temp_y < h[0]:\n", " y += temp_y\n", " flag = 0\n", " else:\n", " y -= temp_y\n", " x = x+temp_x if x+temp_x < w[1] else w[1]\n", " xy.append((x, y))\n", " return xy\n", "def Asin(x, A=8,w=0.05, b=6, k=40):\n", " '''\n", " y=Asin(ωx+φ)+k在直角坐标系上的图象\n", " A——振幅,当物体作轨迹符合正弦曲线的直线往复运动时,其值为行程的1/2。\n", " (ωx+φ)——相位,反映变量y所处的状态。\n", " φ——初相,x=0时的相位;反映在坐标系上则为图像的左右移动。\n", " k——偏距,反映在坐标系上则为图像的上移或下移。\n", " ω——角速度, 控制正弦周期(单位弧度内震动的次数)。\n", " '''\n", " return A*math.sin(w*x+b)+k\n", "\n", "def random_xy(width,height): \n", " '''\n", " 随机位置函数,返回指定范围随机位置坐标\n", " 参数:width:图片宽,height:图片高\n", " '''\n", " x = random.randint(0, width)\n", " y = random.randint(0, height)\n", " return x, y\n", "def random_color(color_tuple):\n", " '''\n", " 随机颜色函数,返回指定范围随机颜色值\n", " 参数:start:颜色最低值,end:颜色最高值\n", " '''\n", " if len(color_tuple)==2:\n", " rs, re = color_tuple\n", " gs = bs = rs\n", " ge = be = re\n", " else:\n", " rs, re, gs, ge, bs, be = color_tuple\n", " red = random.randint(rs, re)\n", " green = random.randint(gs, ge)\n", " blue = random.randint(bs, be)\n", " return (red, green, blue)\n", "\n", "def gen_captcha(text, fig_size=(200,70), fonts=['fonts/ANTQUAB.TTF'],font_color=(10,100),same_color=1, font_size=(25, 35), rotate=(0,0),\n", " font_noise=0, offset_w=(0,0), offset_h=0, line=(0,0), shortline=(0,0), line_width=(0,1), line_color=(200,250), point=(0,500), \n", " point_color=(150,250), frame_color=None, wavy=(0,0), bg=(200,255)):\n", " '''\n", " text:验证码文本\n", " size:验证码图片宽高\n", " fonts:字体列表,随机选择一个\n", " font_noise: 字体散点干扰,0不加干扰,1加干扰\n", " offset_hor: 左右偏移值\n", " offset_var: 上下偏移值\n", " fill:字体颜色范围\n", " rotate:字体旋转角度\n", " line:干扰线条数范围\n", " point:干扰点数范围\n", " wavy:波浪线数范围\n", " color:干扰线、点 颜色\n", " bg:背景色范围\n", " '''\n", " bg = random_color(bg)\n", " img = Image.new(mode='RGB', size=fig_size, color=bg) #\n", " draw = ImageDraw.Draw(im=img, mode='RGB') # im, mode=None\n", " \n", " font_path = random.choice(fonts)\n", "# font_path = 'latin/verdana.ttf'\n", "# print('font_path:',font_path)\n", "# font_name = font_path.split('/')[-1][:-4]\n", "# print('font_name:', font_name)\n", " rotate = (rotate, rotate) if isinstance(rotate, int) else rotate\n", " font = ImageFont.truetype(font_path, size=random.randint(font_size[0], font_size[1])) # font=None, size=10, index=0, encoding=\"\"\n", " \n", " def get_char_img(char,font,font_color,rotate,bg, font_noise=0):\n", " '''\n", " 生成单个字符图片,随机颜色加随机旋转\n", " \n", " '''\n", "# print('get_char_img', char)\n", " w, h = draw.textsize(char, font=font)\n", " im = Image.new('RGBA',(w,h), color=bg)\n", " ImageDraw.Draw(im).text((0,0), char, font=font, fill=font_color) \n", " if rotate!=(0, 0) and char not in ['+','-','×']:\n", " im = im.rotate(random.randint(rotate[0], rotate[1]),Image.BILINEAR,expand=1)\n", " im = im.crop(im.getbbox())\n", " if font_noise: \n", " im_draw = ImageDraw.Draw(im)\n", "# for i in range(random.randint(1,20)):\n", " for i in range(random.randint(int(w*h*0.01),min(int(w*h*0.05), 5))):\n", " im_draw.point(xy=(random.randint(0, w), random.randint(0, h)),fill=bg)\n", "\n", " table = []\n", " for i in range(256):\n", " table.append(i * 97) # 5.97\n", " mask = im.convert('L').point(table) \n", " return (im, mask)\n", " \n", "# char_color = random.randint(font_color[0],font_color[1])\n", " char_color = random_color(font_color)\n", " \n", " # 解决两位数问题\n", " chars = re.split('(\\+|\\-|\\*|×|/|=)', text)\n", " char_imgs = []\n", " char_list = []\n", " if same_color: \n", " for char in chars:\n", " char_list.append(char)\n", " char_imgs.append(get_char_img(char, font, font_color=char_color, rotate=rotate, bg=bg, font_noise=font_noise))\n", " else:\n", " for char in chars:\n", " char_list.append(char)\n", " char_imgs.append(get_char_img(char, font, font_color=random_color(font_color), rotate=rotate, bg=bg, font_noise=font_noise))\n", "\n", "\n", "# re_s = re.search('(\\d+|\\?)(\\+|-|\\*|×)(\\d+|\\?)(=)(-?\\d+|\\?)?', text)\n", "# if re_s:\n", "# # print(re_s.group(0))\n", "# char_imgs = []\n", "# char_list = []\n", "# if same_color: \n", "# for i in range(1,6):\n", "# if re_s.group(i)!=None:\n", "# char_list.append(re_s.group(i))\n", "# char_imgs.append(get_char_img(re_s.group(i), font, font_color=char_color, rotate=rotate, bg=bg, font_noise=font_noise))\n", "# else:\n", "# for i in range(1,6):\n", "# if re_s.group(i)!=None:\n", "# char_list.append(re_s.group(i))\n", "# char_imgs.append(get_char_img(re_s.group(i), font, font_color=random_color(font_color), rotate=rotate, bg=bg, font_noise=font_noise))\n", "# else:\n", "# if same_color: \n", "# char_imgs = [get_char_img(char, font, font_color=char_color, rotate=rotate, bg=bg, font_noise=font_noise) for char in text]\n", "# else:\n", "# # char_imgs = [get_char_img(char, font, font_color=random.randint(font_color[0],font_color[1]), rotate=rotate, bg=bg, font_noise=font_noise) for char in text]\n", "# char_imgs = [get_char_img(char, font, font_color=random_color(font_color), rotate=rotate, bg=bg, font_noise=font_noise) for char in text] \n", " ws = [img[0].size[0] for img in char_imgs]\n", " hs = [img[0].size[1] for img in char_imgs]\n", " w = max(sum(ws), fig_size[0])\n", " h = max(max(hs), fig_size[1])\n", " if w>fig_size[0] or h>fig_size[1]:\n", " img = Image.new('RGB',(w+6,h+6), color=bg)\n", " draw = ImageDraw.Draw(im=img, mode='RGB') # im, mode=None\n", " w, h = img.size\n", " fig_size = img.size\n", " \n", "\n", " # 短线\n", " for i in range(random.randint(shortline[0], shortline[1])):\n", " x0, y0 = random_xy(w, h)\n", " x1 = x0 + random.randint(2, 5)\n", " y1 = y0 + random.randint(2, 5)\n", " draw.line(xy=((x0,y0),(x1,y1)),\n", " fill=random_color(line_color),\n", " width=random.randint(line_width[0], line_width[1])) # xy, fill=None, width=0\n", " \n", " if rotate!=(0, 0):\n", " temp_x = random.randint(0, min(50,int((fig_size[0]-sum(ws))/2+1))) #int((fig_size[0]-sum(ws))/5)\n", " temp_y = random.randint(int((fig_size[1]-hs[0])/8), int((fig_size[1]-hs[0])/2+1))\n", "# print('len(char_imgs):',len(char_imgs))\n", " for i in range(len(char_imgs)):\n", " tmp_offset = random.randint(offset_w[0], offset_w[1]) if sum(ws)+(len(ws)-1)*offset_w[1] 0:\n", " temp_x = new_x if new_x+ws[i]=0.5:\n", " A_ = random.uniform(hs[1]*0.1,hs[1]*0.2)\n", " w_ = math.pi*4/w#random.uniform(0.04, 0.06)\n", " b_ = random.random()*math.pi\n", " k_ = random.uniform(h*0.5, h*0.7)\n", " # 波浪线\n", " for _ in range(random.randint(wavy[0],wavy[1])): \n", " draw.line(xy=[(x, Asin(x, A_, w_, b_, k_)) for x in range(int(w))], \n", " fill=char_color, width=random.randint(line_width[0], line_width[1])) \n", " else:\n", " # 波浪线\n", " for _ in range(random.randint(wavy[0],wavy[1])): \n", " draw.line(xy=get_wavy_line(w = (0, w),h = (min(hs)-5, max(hs)+5)), \n", " fill=char_color, width=random.randint(line_width[0], line_width[1])) \n", " \n", " # 边框\n", " if frame_color!=None:\n", " draw.line(xy=[(0,0),(0, h), (0, 0), (w, 0),(w-1,0),(w-1, h), (0,h-1),(w-1, h-1)], fill=random_color(frame_color))\n", " \n", " if rotate==(0, 0): \n", " temp_x = random.randint(0, min(50, int((fig_size[0]-sum(ws))/2+1))) #int((fig_size[0]-sum(ws))/5)\n", " temp_y = random.randint(int((fig_size[1]-hs[0])/8), int((fig_size[1]-hs[0])/2+1))\n", " for i in range(len(char_imgs)):\n", " tmp_offset = random.randint(offset_w[0], offset_w[1]) if sum(ws)+(len(ws)-1)*offset_w[1] 0:\n", " temp_x = new_x if new_x+ws[i]0.3\n", "re_s = re.search('(\\d+|\\?)(\\+|-|\\*|×)(\\d+|\\?)(=)(\\d+|\\?)?', random_str)\n", "print(random_str)\n", "for i in range(1,6):\n", "# print(i)\n", " if re_s.group(i)!=None:\n", " print(re_s.group(i))" ] }, { "cell_type": "code", "execution_count": 106, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6×0=?\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 140, "width": 370 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 重组验证码\n", "\n", "crops = glob.glob('/data/captcha/arithmetic/crop_70_25/crop_num/*.jpg')\n", "bgs_7025 = glob.glob('/data/captcha/arithmetic/crop_70_25/crop_bg/*.jpg')\n", "def merge_img_7025():\n", "# if random.random()>0.4:\n", "# img = Image.new(mode='RGB', size=((70,25)), color=(255,255,255))\n", "# else:\n", " img = Image.open(random.choice(bgs_7025))\n", " w, h = img.size\n", " draw = ImageDraw.Draw(img) \n", " \n", " w0 = random.randint(0,4)\n", " h0 = random.randint(1,5)\n", " label = []\n", " range_num = random.randint(1,2)\n", " for i in range(range_num):\n", " im_p = random.choice(crops)\n", " lb = im_p.split('/')[-1].split('_')[0]\n", " if lb=='0' and range_num == 2 and i==0:\n", " continue\n", " label.append(lb)\n", " im = Image.open(im_p)\n", " w, h = im.size \n", " img.paste(im, (w0,h0)) # ,w//4*(i+1), h\n", " w0 += w\n", " \n", " fh = glob.glob('/data/captcha/arithmetic/crop_70_25/crop_sign/jiajiancheng/*.jpg')\n", " im_p = random.choice(fh)\n", " lb = im_p.split('/')[-1].split('_')[0]\n", " if lb == 'jia':\n", " lb = '+'\n", " elif lb == 'jian':\n", " lb = '-'\n", " elif lb == 'cheng':\n", " lb = '×'\n", " label.append(lb)\n", " im = Image.open(im_p)\n", " w, h = im.size \n", " img.paste(im, (w0,h0)) # ,w//4*(i+1), h\n", " w0 += w \n", " \n", " range_num = random.randint(1,2)\n", " for i in range(range_num):\n", " im_p = random.choice(crops)\n", " lb = im_p.split('/')[-1].split('_')[0]\n", " if lb=='0' and range_num == 2 and i==0:\n", " continue\n", " label.append(lb)\n", " im = Image.open(im_p)\n", " w, h = im.size \n", " img.paste(im, (w0,h0)) # ,w//4*(i+1), h\n", " w0 += w \n", " \n", " fh = glob.glob('/data/captcha/arithmetic/crop_70_25/crop_sign/denghao/*.jpg')\n", " im_p = random.choice(fh)\n", " lb = im_p.split('/')[-1].split('_')[0]\n", " if lb == 'deng':\n", " lb = '='\n", " label.append(lb)\n", " im = Image.open(im_p)\n", " w, h = im.size \n", " img.paste(im, (w0,h0)) # ,w//4*(i+1), h\n", " w0 += w \n", " \n", " fh = glob.glob('/data/captcha/arithmetic/crop_70_25/crop_sign/wenhao/*.jpg')\n", " im_p = random.choice(fh)\n", " lb = im_p.split('/')[-1].split('_')[0]\n", " if lb == 'wen':\n", " lb = '?'\n", " label.append(lb)\n", " im = Image.open(im_p)\n", " w, h = im.size \n", " img.paste(im, (w0,h0)) # ,w//4*(i+1), h\n", " w0 += w \n", " w, h = img.size \n", " for i in range(0,2):\n", " x0, y0 = random_xy(w, h)\n", " x1 = x0 + random.randint(2, 5)\n", " y1 = y0 + random.randint(2, 5)\n", " draw.line(xy=((x0,y0),(x1,y1)),\n", " fill=random_color((200,250)),\n", " width=1) # xy, fill=None, width=0 \n", " for _ in range(random.randint(0,10)):\n", " draw.point(xy=(random_xy(w, h)),fill=random_color((180,250))) \n", "\n", " return img.resize((width, height), Image.BILINEAR), ''.join(label)\n", "\n", "img, label = merge_img_7025()\n", "print(label)\n", "plt.imshow(img)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 113, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "7-2=?\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 140, "width": 370 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "'''添加真实验证码'''\n", "imgs_60 = glob.glob('/data/captcha/arithmetic/160_60/*.jpg') #变形字体\n", "imgs_46 = glob.glob('/data/captcha/arithmetic/146_46/*.jpg') #两个算术符号\n", "name_label_dic = dict() \n", "with open('/data/captcha/arithmetic/146_46/answer.txt', 'r', encoding='utf-8') as f: #/data/captcha/arithmetic/160_60/answer.txt\n", " lines = f.readlines()\n", " for line in lines:\n", " f_n, q, a = line.strip().split('\\t') \n", " name_label_dic[f_n] = q\n", " \n", "with open('/data/captcha/arithmetic/160_60/answer.txt', 'r', encoding='utf-8') as f: #/data/captcha/arithmetic/160_60/answer.txt\n", " lines = f.readlines()\n", " for line in lines:\n", " f_n, q, a = line.strip().split('\\t') \n", " if f_n in name_label_dic:\n", " print('file 已存在:', f_n)\n", " continue\n", " name_label_dic[f_n] = q \n", "\n", "def get_real_img(imgs):\n", " num = len(imgs)\n", " im_p = random.choice(imgs[:int(0.9*num)])\n", " file_name = im_p.split('/')[-1]\n", " label = name_label_dic[file_name]\n", " img = Image.open(im_p) \n", " w, h = img.size \n", " draw = ImageDraw.Draw(img) \n", " for i in range(2,20):\n", " x0, y0 = random_xy(w, h)\n", " x1 = x0 + random.randint(2, 5)\n", " y1 = y0 + random.randint(2, 5)\n", " draw.line(xy=((x0,y0),(x1,y1)),\n", " fill=random_color((200,250)),\n", " width=1) # xy, fill=None, width=0 \n", " for _ in range(random.randint(20,500)):\n", " draw.point(xy=(random_xy(w, h)),fill=random_color((180,250))) \n", "\n", " return img.resize((width, height), Image.BILINEAR), label\n", "img, label = get_real_img(imgs_60) #imgs_46 imgs_60\n", "print(label)\n", "plt.imshow(img)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.utils import Sequence\n", "# from collections import Counter" ] }, { "cell_type": "code", "execution_count": 114, "metadata": {}, "outputs": [], "source": [ "''' 彩色图像生成 '''\n", "from tensorflow.keras.utils import Sequence\n", "from collections import Counter\n", "\n", "class CaptchaSequence(Sequence):\n", " '''\n", " 继承Sequence的数据生成类,方便调用多CPU,加快生成训练及测试数据\n", " 参数:self.characters:验证码字符集合,self.batch_size:每批次样本数,self.steps:生成多少批数据,self.n_len:验证码长度,\n", " self.width:图片宽度,self.height:图片高度,self.input_length:lstm time step长度,self.label_length:标签长度\n", " 返回:array类型训练或测试数据 \n", " \n", " '''\n", " def __init__(self, characters, batch_size, steps, n_len=n_len, width=width, height=height, \n", " input_length=12, label_length=6, chars_len=(4, 6)): # width=128, height=64, input_length=16, label_length=4\n", " self.characters = characters\n", " self.batch_size = batch_size\n", " self.steps = steps\n", " self.n_len = n_len\n", " self.width = width\n", " self.height = height\n", " self.input_length = input_length\n", " self.label_length = label_length\n", " self.chars_len = chars_len\n", "# self.label_length = self.n_len\n", " self.n_class = len(characters)+1\n", "# self.n_class = -2\n", "# self.generator = ImageCaptcha(width=width, height=height, font_sizes=(12,20,18,25))\n", "# self.fonts_list = glob.glob('/usr/share/fonts/WindowsFonts/fonts/*.ttf')\n", " \n", " def __len__(self):\n", " return self.steps\n", "\n", " def __getitem__(self, idx):\n", " batch_label_length = random.choice([4,5,4,4])\n", "# imgs = []\n", "# print('batch_label_length',batch_label_length)\n", " X = np.zeros((self.batch_size, self.height, self.width, 3), dtype=np.float32)\n", " y = np.zeros((self.batch_size, self.n_len), dtype=np.uint8)\n", "# print(y)\n", "# y = np.zeros((self.batch_size, batch_label_length), dtype=np.uint8)\n", " input_length = np.ones(self.batch_size)*self.input_length\n", " label_length = np.ones(self.batch_size)*self.n_len \n", "# label_length = np.ones(self.batch_size)*batch_label_length\n", " max_num = 65\n", " for i in range(self.batch_size):\n", "\n", " if i%max_num <= 3: # line=(0,0), line_width=(0,1), point=(0,100),wavy=(0,0) \n", " random_str, question = get_arith(top=9, i=1)\n", " image = gen_captcha(random_str, fig_size=(100,26), fonts=fonts,font_color=(20,230,20,230,20,230),same_color=1, font_size=(15, 20), rotate=0,\n", " font_noise=0,offset_w=(-1,3),offset_h=0, line=(0,0), shortline=(10,20), line_width=(0,1), line_color=(100,150), point=(0,0),\n", " point_color=(0,0),frame_color=(120,150),wavy=(0,0), bg=(255,255))\n", "\n", "\n", " elif i%max_num <= 6: # line=(0,5), line_width=(0,1), point=(20,300),wavy=(0,0)\n", " random_str, question = get_arith(top=99, i=2)\n", " image = gen_captcha(random_str, fig_size=(70,25), fonts=fonts,font_color=(70,100),same_color=1, font_size=(12, 15), rotate=0,\n", " font_noise=0,offset_w=(-1,0),offset_h=0, line=(0,0), shortline=(150,200), line_width=(0,1), line_color=(180,230), point=(200,300),\n", " point_color=(200,250),frame_color=None,wavy=(0,0), bg=(210,255))\n", "\n", " elif i%max_num <= 9: # line=(0,0), line_width=(0,2), point=(0,0),wavy=(1,1)\n", " random_str, question = get_arith(top=9)\n", " image = gen_captcha(random_str, fig_size=(100,26), fonts=fonts,font_color=(20,230,20,230,20,230),same_color=1, font_size=(15, 20), rotate=0,\n", " font_noise=0,offset_w=(-1,3),offset_h=0, line=(0,0), shortline=(0,0), line_width=(0,1), line_color=(100,150), point=(0,0),\n", " point_color=(0,0),frame_color=(120,150),wavy=(0,0), bg=(255,255))\n", "\n", " elif i%max_num <= 12: # line=(0,0), line_width=(0,1), point=(0,80),wavy=(0,0)\n", " random_str, question = get_arith(top=99, i=2)\n", " image = gen_captcha(random_str, fig_size=(70,25), fonts=fonts,font_color=(10,230,10,230,10,230),same_color=0, font_size=(12, 15), rotate=0,\n", " font_noise=0,offset_w=(-1,1),offset_h=0, line=(0,0), shortline=(0,0), line_width=(0,1), line_color=(150,200), point=(0,0),\n", " point_color=(0,0),frame_color=None,wavy=(0,0), bg=(150,255))\n", " \n", " elif i%max_num<=15:\n", " random_str, question = get_arith(top=9, que_mark=False)\n", " image = gen_captcha(random_str, fig_size=(100,40), fonts=fonts,font_color=(0,0),same_color=1, font_size=(20, 25), rotate=0,\n", " font_noise=0,offset_w=(-1,1),offset_h=0, line=(3,3), shortline=(0,0), line_width=(0,1), line_color=(0,0), point=(0,0),\n", " point_color=(200,250),frame_color=None,wavy=(0,0), bg=(250,255))\n", "\n", " elif i%max_num<=20:\n", " random_str, question = get_arith(top=99, i=2)\n", " image = gen_captcha(random_str, fig_size=(330, 69), fonts=fonts,font_color=(10,250,10,250,10,250),same_color=0, font_size=(35, 40), rotate=30,\n", " font_noise=0,offset_w=(5,5),offset_h=0, line=(3,6), shortline=(0,5), line_width=(1,2), line_color=(150,230), point=(30,130),\n", " point_color=(50,230),frame_color=None,wavy=(0,0), bg=(255,255))\n", " \n", " elif i%max_num<=30:\n", " random_str, question = get_arith(top=9)\n", " tmp_w = random.randint(70,100)\n", " tmp_h = random.randint(25, 35)\n", " font_s = (int(tmp_h*0.8), int(tmp_h*0.9))\n", " image = gen_captcha(random_str, fig_size=(tmp_w,tmp_h), fonts=fonts,font_color=(200,250),same_color=0, font_size=font_s, rotate=20,\n", " font_noise=0,offset_w=(-2,1),offset_h=2, line=(0,5), shortline=(0,100), line_width=(0,1), line_color=(10,150), point=(0,200),\n", " point_color=(50,255),frame_color=None,wavy=(0,0), bg=(10,150)) \n", " elif i%max_num<=35:\n", " random_str, question = get_arith(top=9, i=2) \n", " image = gen_captcha(random_str, fig_size=(160, 60), fonts=fonts,font_color=(0,250,0,250,0,250),same_color=1, font_size=(35, 40), rotate=(20,30),\n", " font_noise=0,offset_w=(-6,-2),offset_h=0, line=(0,0), shortline=(0,0), line_width=(1,2), line_color=(150,230), point=(0,0),\n", " point_color=(50,230),frame_color=(120,150),wavy=(0,0), bg=(190,250))\n", " elif i%max_num<=40:\n", " random_str, question = get_arith(top=9, i=3, numOfas=2)\n", " image = gen_captcha(random_str, fig_size=(146, 46), fonts=fonts,font_color=(0,250,0,250,0,250),same_color=1, font_size=(25, 30), rotate=(0,0),\n", " font_noise=0,offset_w=(-2,3),offset_h=0, line=(0,0), shortline=(0,0), line_width=(1,2), line_color=(150,230), point=(10,50),\n", " point_color=(150,230),frame_color=(150,200),wavy=(0,0), bg=(220,250))\n", "\n", " elif i%max_num<=45:\n", " image, random_str = merge_img_7025()\n", " elif i%max_num<=50: \n", " image, random_str = get_real_img(imgs_46) #imgs_46 imgs_60\n", " elif i%max_num<=55: \n", " image, random_str = get_real_img(imgs_60) #imgs_46 imgs_60\n", " \n", " else: \n", " random_str, question = get_arith(top=99, i=2)\n", " tmp_w = random.randint(70,100)\n", " tmp_h = random.randint(25, 35)\n", " font_s = (int(tmp_h*0.8), int(tmp_h*0.9))\n", " image = gen_captcha(random_str, fig_size=(tmp_w,tmp_h), fonts=fonts,font_color=(0,180),same_color=0, font_size=font_s, rotate=20,\n", " font_noise=0,offset_w=(2,5),offset_h=2, line=(0,5), shortline=(0,100), line_width=(0,1), line_color=(10,200), point=(0,200),\n", " point_color=(50,255),frame_color=None,wavy=(0,0), bg=(150,255)) \n", " \n", "\n", " X[i] = np.array(image)/255.0\n", " random_str = random_str.replace('*', '×') \n", " label = [self.characters.find(x) for x in random_str] # 全部标签转换为小写\n", " if len(random_str) < self.n_len:\n", " label += [self.n_class]*(self.n_len-len(random_str)) \n", " y[i] = label\n", " \n", "# return imgs# \n", " return [X, y, input_length, label_length], np.ones(self.batch_size)" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "import re\n", "a = re.search('(\\d+|\\?)(\\+|-|\\*|×)(\\d+|\\?)(=)(-?\\d+|\\?)?', '2-?=-7')" ] }, { "cell_type": "code", "execution_count": 131, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('6*23=?', 138)" ] }, "execution_count": 131, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_arith(top=99, i=2)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "data = CaptchaSequence(characters, batch_size=64, steps=2,input_length=12, label_length=10,chars_len=(5, 5)) # (characters, batch_size=128, steps=100)" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(64, 32, 100, 3)\n" ] }, { "ename": "ValueError", "evalue": "cannot reshape array of size 9600 into shape (32,100)", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0;36m18\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mheight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwidth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;31m# x = data[1]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36mreshape\u001b[0;34m(a, newshape, order)\u001b[0m\n\u001b[1;32m 255\u001b[0m [5, 6]])\n\u001b[1;32m 256\u001b[0m \"\"\"\n\u001b[0;32m--> 257\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_wrapfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'reshape'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnewshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 258\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36m_wrapfunc\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_wrapfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;31m# An AttributeError occurs if the object does not have\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: cannot reshape array of size 9600 into shape (32,100)" ] } ], "source": [ "l, _ = data[1]\n", "x = l[0]\n", "print(x.shape)\n", "idx =18\n", "plt.imshow(np.reshape(x[idx], (height, width)))\n", "\n", "# x = data[1]\n", "# idx = 8\n", "# plt.imshow(x[idx])\n", "# len4_imgs[:5]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "input_1 (InputLayer) (None, 64, 200, 3) 0 \n", "_________________________________________________________________\n", "conv2d (Conv2D) (None, 64, 200, 32) 896 \n", "_________________________________________________________________\n", "batch_normalization (BatchNo (None, 64, 200, 32) 128 \n", "_________________________________________________________________\n", "leaky_re_lu (LeakyReLU) (None, 64, 200, 32) 0 \n", "_________________________________________________________________\n", "conv2d_1 (Conv2D) (None, 64, 200, 32) 1056 \n", "_________________________________________________________________\n", "batch_normalization_1 (Batch (None, 64, 200, 32) 128 \n", "_________________________________________________________________\n", "leaky_re_lu_1 (LeakyReLU) (None, 64, 200, 32) 0 \n", "_________________________________________________________________\n", "max_pooling2d (MaxPooling2D) (None, 32, 100, 32) 0 \n", "_________________________________________________________________\n", "conv2d_2 (Conv2D) (None, 32, 100, 64) 18496 \n", "_________________________________________________________________\n", "batch_normalization_2 (Batch (None, 32, 100, 64) 256 \n", "_________________________________________________________________\n", "leaky_re_lu_2 (LeakyReLU) (None, 32, 100, 64) 0 \n", "_________________________________________________________________\n", "conv2d_3 (Conv2D) (None, 32, 100, 64) 4160 \n", "_________________________________________________________________\n", "batch_normalization_3 (Batch (None, 32, 100, 64) 256 \n", "_________________________________________________________________\n", "leaky_re_lu_3 (LeakyReLU) (None, 32, 100, 64) 0 \n", "_________________________________________________________________\n", "max_pooling2d_1 (MaxPooling2 (None, 16, 50, 64) 0 \n", "_________________________________________________________________\n", "conv2d_4 (Conv2D) (None, 16, 50, 128) 73856 \n", "_________________________________________________________________\n", "batch_normalization_4 (Batch (None, 16, 50, 128) 512 \n", "_________________________________________________________________\n", "leaky_re_lu_4 (LeakyReLU) (None, 16, 50, 128) 0 \n", "_________________________________________________________________\n", "conv2d_5 (Conv2D) (None, 16, 50, 128) 16512 \n", "_________________________________________________________________\n", "batch_normalization_5 (Batch (None, 16, 50, 128) 512 \n", "_________________________________________________________________\n", "leaky_re_lu_5 (LeakyReLU) (None, 16, 50, 128) 0 \n", "_________________________________________________________________\n", "max_pooling2d_2 (MaxPooling2 (None, 8, 25, 128) 0 \n", "_________________________________________________________________\n", "conv2d_6 (Conv2D) (None, 8, 25, 256) 295168 \n", "_________________________________________________________________\n", "batch_normalization_6 (Batch (None, 8, 25, 256) 1024 \n", "_________________________________________________________________\n", "leaky_re_lu_6 (LeakyReLU) (None, 8, 25, 256) 0 \n", "_________________________________________________________________\n", "conv2d_7 (Conv2D) (None, 8, 25, 256) 65792 \n", "_________________________________________________________________\n", "batch_normalization_7 (Batch (None, 8, 25, 256) 1024 \n", "_________________________________________________________________\n", "leaky_re_lu_7 (LeakyReLU) (None, 8, 25, 256) 0 \n", "_________________________________________________________________\n", "max_pooling2d_3 (MaxPooling2 (None, 4, 25, 256) 0 \n", "_________________________________________________________________\n", "conv2d_8 (Conv2D) (None, 4, 25, 256) 590080 \n", "_________________________________________________________________\n", "batch_normalization_8 (Batch (None, 4, 25, 256) 1024 \n", "_________________________________________________________________\n", "leaky_re_lu_8 (LeakyReLU) (None, 4, 25, 256) 0 \n", "_________________________________________________________________\n", "conv2d_9 (Conv2D) (None, 4, 25, 256) 65792 \n", "_________________________________________________________________\n", "batch_normalization_9 (Batch (None, 4, 25, 256) 1024 \n", "_________________________________________________________________\n", "leaky_re_lu_9 (LeakyReLU) (None, 4, 25, 256) 0 \n", "_________________________________________________________________\n", "max_pooling2d_4 (MaxPooling2 (None, 2, 25, 256) 0 \n", "_________________________________________________________________\n", "permute (Permute) (None, 25, 2, 256) 0 \n", "_________________________________________________________________\n", "time_distributed (TimeDistri (None, 25, 512) 0 \n", "_________________________________________________________________\n", "bidirectional (Bidirectional (None, 25, 256) 492288 \n", "_________________________________________________________________\n", "bidirectional_1 (Bidirection (None, 25, 256) 295680 \n", "_________________________________________________________________\n", "dense (Dense) (None, 25, 17) 4369 \n", "=================================================================\n", "Total params: 1,930,033\n", "Trainable params: 1,927,089\n", "Non-trainable params: 2,944\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "# 定义网络\n", "from tensorflow.keras.models import *\n", "from tensorflow.keras.layers import *\n", "\n", "# 定义 CTC Loss\n", "import tensorflow.keras.backend as K\n", "\n", "def ctc_lambda_func(args):\n", " '''\n", " 定义ctc损失函数\n", " 参数:y_pred:预测值,labels:标签,input_length:lstm tiemstep,label_length:标签长度\n", " ''' \n", " y_pred, labels, input_length, label_length = args\n", " return K.ctc_batch_cost(labels, y_pred, input_length, label_length)\n", "\n", "input_tensor = Input((height, width, 3))\n", "x = input_tensor\n", "\n", "for i, n_cnn in enumerate([2, 2, 2, 2, 2]): \n", " for j in range(n_cnn):\n", " kernel_size = 3 if j==0 else 1\n", " x = Conv2D(32*2**min(i, 3), kernel_size=kernel_size, padding='same', kernel_initializer='he_uniform')(x) # 32*2**min(i, 3)\n", " x = BatchNormalization()(x)\n", "# x = Activation('relu')(x) # 20200729 relu 改LeakyReLU\n", " x = LeakyReLU(0.01)(x)\n", " x = MaxPooling2D(2 if i < 3 else (2, 1))(x)\n", "\n", "x = Permute((2, 1, 3))(x)\n", "x = TimeDistributed(Flatten())(x)\n", "rnn_size = 128 # 128 32\n", "\n", "x = Bidirectional(GRU(rnn_size, return_sequences=True))(x)\n", "x = Bidirectional(GRU(rnn_size, return_sequences=True))(x) # 200epoch 0.0153 - val_loss: 0.0136\n", "\n", "x = Dense(n_class, activation='softmax')(x)\n", "base_model = Model(inputs=input_tensor, outputs=x)\n", "print(base_model.summary())\n", "\n", "labels = Input(name='the_labels', shape=[None], dtype='float32')\n", "input_length = Input(name='input_length', shape=[1], dtype='int64')\n", "label_length = Input(name='label_length', shape=[1], dtype='int64')\n", "loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([x, labels, input_length, label_length])\n", "model = Model(inputs=[input_tensor, labels, input_length, label_length], outputs=loss_out)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "1000/1000 [==============================] - 273s 273ms/step - loss: 0.0988 - val_loss: 0.6654\n", "Epoch 2/100\n", "1000/1000 [==============================] - 264s 264ms/step - loss: 0.0516 - val_loss: 0.1077\n", "Epoch 3/100\n", "1000/1000 [==============================] - 263s 263ms/step - loss: 0.0344 - val_loss: 0.0714\n", "Epoch 4/100\n", "1000/1000 [==============================] - 265s 265ms/step - loss: 0.0243 - val_loss: 0.0662\n", "Epoch 5/100\n", "1000/1000 [==============================] - 264s 264ms/step - loss: 0.0202 - val_loss: 0.2005\n", "Epoch 6/100\n", "1000/1000 [==============================] - 264s 264ms/step - loss: 0.0163 - val_loss: 0.0482\n", "Epoch 7/100\n", "1000/1000 [==============================] - 263s 263ms/step - loss: 0.0162 - val_loss: 0.0576\n", "Epoch 8/100\n", "1000/1000 [==============================] - 263s 263ms/step - loss: 0.0124 - val_loss: 0.2479\n", "Epoch 9/100\n", "1000/1000 [==============================] - 263s 263ms/step - loss: 0.0130 - val_loss: 0.0381\n", "Epoch 10/100\n", "1000/1000 [==============================] - 263s 263ms/step - loss: 0.0092 - val_loss: 0.0623\n", "Epoch 11/100\n", "1000/1000 [==============================] - 263s 263ms/step - loss: 0.0130 - val_loss: 0.0643\n", "Epoch 12/100\n", "1000/1000 [==============================] - 262s 262ms/step - loss: 0.0105 - val_loss: 0.3409\n", "Epoch 13/100\n", "1000/1000 [==============================] - 264s 264ms/step - loss: 0.0102 - val_loss: 0.6846\n", "Epoch 14/100\n", "1000/1000 [==============================] - 264s 264ms/step - loss: 0.0099 - val_loss: 0.0280\n", "Epoch 15/100\n", "1000/1000 [==============================] - 263s 263ms/step - loss: 0.0114 - val_loss: 0.0403\n", "Epoch 16/100\n", "1000/1000 [==============================] - 265s 265ms/step - loss: 0.0080 - val_loss: 0.0108\n", "Epoch 17/100\n", "1000/1000 [==============================] - 263s 263ms/step - loss: 0.0078 - val_loss: 0.2057\n", "Epoch 18/100\n", "1000/1000 [==============================] - 264s 264ms/step - loss: 0.0083 - val_loss: 0.0482\n", "Epoch 19/100\n", "1000/1000 [==============================] - 265s 265ms/step - loss: 0.0070 - val_loss: 0.0104\n", "Epoch 20/100\n", "1000/1000 [==============================] - 262s 262ms/step - loss: 0.0063 - val_loss: 0.0053\n", "Epoch 21/100\n", "1000/1000 [==============================] - 264s 264ms/step - loss: 0.0058 - val_loss: 0.0156\n", "Epoch 22/100\n", "1000/1000 [==============================] - 264s 264ms/step - loss: 0.0048 - val_loss: 0.0793\n", "Epoch 23/100\n", "1000/1000 [==============================] - 262s 262ms/step - loss: 0.0082 - val_loss: 1.3160\n", "Epoch 24/100\n", "1000/1000 [==============================] - 265s 265ms/step - loss: 0.0076 - val_loss: 0.0200\n", "Epoch 25/100\n", "1000/1000 [==============================] - 264s 264ms/step - loss: 0.0049 - val_loss: 0.0352\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tensorflow.keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint\n", "from tensorflow.keras.optimizers import *\n", "import gc \n", "\n", "# model.load_weights('gru_DigitAndEnglist_ctc_best.h5') # gru_DigitAndEnglist_ctc_best_0924\n", "# model.load_weights('gru_DigitAndEnglist_ctc_best_0927.h5') #DigitAndEnglist_cnn5gru_ctc_best2.h5 DigitAndEnglist_cnn5gru_ctc_best\n", "# 'mobilenet_DigitAndEnglist_ctc_best_32.h5' 损失下降到0.2左右 准确率97 \n", "# model.load_weights('gru_english4to6_ctc_best_1012.h5')\n", "\n", "train_data = CaptchaSequence(characters, batch_size=128, steps=1000,input_length=25, label_length=12,chars_len=(4, 6)) # (characters, batch_size=128, steps=1000)\n", "valid_data = CaptchaSequence(characters, batch_size=128, steps=100,input_length=25, label_length=12,chars_len=(4, 6)) # (characters, batch_size=128, steps=100)\n", "\n", "callbacks = [EarlyStopping(patience=5),ModelCheckpoint('gru_arithmetic_ctc_best_20220617.h5', save_best_only=True)]\n", "model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=Adam(1e-3, amsgrad=True))\n", "model.fit_generator(train_data, epochs=100, validation_data=valid_data, workers=4, use_multiprocessing=True,\n", " callbacks=callbacks)" ] }, { "cell_type": "code", "execution_count": 115, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "1000/1000 [==============================] - 267s 267ms/step - loss: 0.0983 - val_loss: 0.0179\n", "Epoch 2/100\n", "1000/1000 [==============================] - 254s 254ms/step - loss: 0.0109 - val_loss: 0.0143\n", "Epoch 3/100\n", "1000/1000 [==============================] - 254s 254ms/step - loss: 0.0075 - val_loss: 0.0107\n", "Epoch 4/100\n", "1000/1000 [==============================] - 257s 257ms/step - loss: 0.0063 - val_loss: 0.0016\n", "Epoch 5/100\n", "1000/1000 [==============================] - 255s 255ms/step - loss: 0.0051 - val_loss: 0.0084\n", "Epoch 6/100\n", "1000/1000 [==============================] - 255s 255ms/step - loss: 0.0042 - val_loss: 0.0024\n", "Epoch 7/100\n", "1000/1000 [==============================] - 254s 254ms/step - loss: 0.0037 - val_loss: 0.0037\n", "Epoch 8/100\n", "1000/1000 [==============================] - 253s 253ms/step - loss: 0.0035 - val_loss: 0.0032\n", "Epoch 9/100\n", "1000/1000 [==============================] - 253s 253ms/step - loss: 0.0031 - val_loss: 3.8587e-04\n", "Epoch 10/100\n", "1000/1000 [==============================] - 253s 253ms/step - loss: 0.0034 - val_loss: 0.0012\n", "Epoch 11/100\n", "1000/1000 [==============================] - 253s 253ms/step - loss: 0.0035 - val_loss: 5.9768e-04\n", "Epoch 12/100\n", "1000/1000 [==============================] - 254s 254ms/step - loss: 0.0028 - val_loss: 0.0012\n", "Epoch 13/100\n", "1000/1000 [==============================] - 253s 253ms/step - loss: 0.0029 - val_loss: 6.9224e-04\n", "Epoch 14/100\n", "1000/1000 [==============================] - 255s 255ms/step - loss: 0.0026 - val_loss: 3.2590e-04\n", "Epoch 15/100\n", "1000/1000 [==============================] - 255s 255ms/step - loss: 0.0027 - val_loss: 2.5925e-04\n", "Epoch 16/100\n", "1000/1000 [==============================] - 254s 254ms/step - loss: 0.0026 - val_loss: 0.0045\n", "Epoch 17/100\n", "1000/1000 [==============================] - 255s 255ms/step - loss: 0.0026 - val_loss: 4.6918e-04\n", "Epoch 18/100\n", "1000/1000 [==============================] - 255s 255ms/step - loss: 0.0028 - val_loss: 6.1181e-04\n", "Epoch 19/100\n", " 9/1000 [..............................] - ETA: 4:23 - loss: 2.4819e-04" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Process ForkPoolWorker-1440:\n", "Process ForkPoolWorker-1439:\n", "Process ForkPoolWorker-1438:\n", "Process ForkPoolWorker-1437:\n", "Traceback (most recent call last):\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/process.py\", line 254, in _bootstrap\n", " self.run()\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/process.py\", line 93, in run\n", " self._target(*self._args, **self._kwargs)\n", "Traceback (most recent call last):\n", "Traceback (most recent call last):\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/pool.py\", line 119, in worker\n", " result = (True, func(*args, **kwds))\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/process.py\", line 254, in _bootstrap\n", " self.run()\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/process.py\", line 254, in _bootstrap\n", " self.run()\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/process.py\", line 93, in run\n", " self._target(*self._args, **self._kwargs)\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/process.py\", line 93, in run\n", " self._target(*self._args, **self._kwargs)\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/pool.py\", line 119, in worker\n", " result = (True, func(*args, **kwds))\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/tensorflow/python/keras/utils/data_utils.py\", line 432, in get_index\n", " return _SHARED_SEQUENCES[uid][i]\n", "Traceback (most recent call last):\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/pool.py\", line 108, in worker\n", " task = get()\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/tensorflow/python/keras/utils/data_utils.py\", line 432, in get_index\n", " return _SHARED_SEQUENCES[uid][i]\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/queues.py\", line 342, in get\n", " with self._rlock:\n", " File \"\", line 82, in __getitem__\n", " point_color=(50,230),frame_color=None,wavy=(0,0), bg=(255,255))\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/process.py\", line 254, in _bootstrap\n", " self.run()\n", " File \"\", line 108, in __getitem__\n", " image, random_str = get_real_img(imgs_60) #imgs_46 imgs_60\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/synchronize.py\", line 96, in __enter__\n", " return self._semlock.__enter__()\n", " File \"\", line 30, in get_real_img\n", " x1 = x0 + random.randint(2, 5)\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/process.py\", line 93, in run\n", " self._target(*self._args, **self._kwargs)\n", "KeyboardInterrupt\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/random.py\", line 218, in randint\n", " return self.randrange(a, b+1)\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/pool.py\", line 108, in worker\n", " task = get()\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/queues.py\", line 343, in get\n", " res = self._reader.recv_bytes()\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/connection.py\", line 216, in recv_bytes\n", " buf = self._recv_bytes(maxlength)\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/connection.py\", line 407, in _recv_bytes\n", " buf = self._recv(4)\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/multiprocessing/connection.py\", line 379, in _recv\n", " chunk = read(handle, remaining)\n", "KeyboardInterrupt\n", " File \"\", line 188, in gen_captcha\n", " char_imgs.append(get_char_img(char, font, font_color=random_color(font_color), rotate=rotate, bg=bg, font_noise=font_noise))\n", " File \"\", line 160, in get_char_img\n", " im = im.rotate(random.randint(rotate[0], rotate[1]),Image.BILINEAR,expand=1)\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/random.py\", line 189, in randrange\n", " istop = _int(stop)\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/PIL/Image.py\", line 1915, in rotate\n", " fillcolor=fillcolor)\n", "KeyboardInterrupt\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/PIL/Image.py\", line 2192, in transform\n", " return self.convert('RGBa').transform(\n", " File \"/home/python/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/PIL/Image.py\", line 1026, in convert\n", " if dither is None:\n", "KeyboardInterrupt\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m'ctc'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mAdam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1e-4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mamsgrad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m model.fit_generator(train_data, epochs=100, validation_data=valid_data, workers=4, use_multiprocessing=True,\n\u001b[0;32m---> 17\u001b[0;31m callbacks=callbacks)\n\u001b[0m", "\u001b[0;32m~/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/tensorflow/python/keras/engine/training.py\u001b[0m in \u001b[0;36mfit_generator\u001b[0;34m(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)\u001b[0m\n\u001b[1;32m 1777\u001b[0m \u001b[0muse_multiprocessing\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0muse_multiprocessing\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1778\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1779\u001b[0;31m initial_epoch=initial_epoch)\n\u001b[0m\u001b[1;32m 1780\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1781\u001b[0m def evaluate_generator(self,\n", "\u001b[0;32m~/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/tensorflow/python/keras/engine/training_generator.py\u001b[0m in \u001b[0;36mfit_generator\u001b[0;34m(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m outs = model.train_on_batch(\n\u001b[0;32m--> 204\u001b[0;31m x, y, sample_weight=sample_weight, class_weight=class_weight)\n\u001b[0m\u001b[1;32m 205\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/tensorflow/python/keras/engine/training.py\u001b[0m in \u001b[0;36mtrain_on_batch\u001b[0;34m(self, x, y, sample_weight, class_weight)\u001b[0m\n\u001b[1;32m 1550\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_train_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1552\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1553\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1554\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/tensorflow/python/keras/backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2912\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_callable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeed_arrays\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_symbols\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msymbol_vals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msession\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2913\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2914\u001b[0;31m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2915\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_fetch_callbacks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfetched\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2916\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfetched\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1380\u001b[0m ret = tf_session.TF_SessionRunCallable(\n\u001b[1;32m 1381\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatus\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1382\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 1383\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1384\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "from tensorflow.keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint\n", "from tensorflow.keras.optimizers import *\n", "import gc \n", "train_data = CaptchaSequence(characters, batch_size=128, steps=1000,input_length=25, label_length=12,chars_len=(4, 6)) # (characters, batch_size=128, steps=1000)\n", "valid_data = CaptchaSequence(characters, batch_size=128, steps=10,input_length=25, label_length=12,chars_len=(4, 6)) # (characters, batch_size=128, steps=100)\n", "\n", "callbacks = [CSVLogger('ctc.csv', append=True), ModelCheckpoint('gru_arithmetic_ctc_best_20220617.h5', save_best_only=True)]\n", "# model.load_weights('gru_english4to6_ctc_best_5.h5') # 以前英文数字模型预测\n", "# model.load_weights('gru_english4to6_ctc_best_1014.h5') # lose:0.0203 val_loss:0.012\n", "# model.load_weights('gru_english4to6_ctc_best_1104.h5') # 1104 卷积核 3 5 ,1105卷积核3 1\n", "model.load_weights('gru_arithmetic_ctc_best_20220617.h5')\n", "# gru_DigitAndEnglist_ctc_best.h5 mobilenet_DigitAndEnglist_ctc_best0930\n", "# callbacks = [CSVLogger('ctc.csv', append=True), ModelCheckpoint('DigitAndEnglist_cnn5gru_ctc_best2.h5', save_best_only=True)]\n", "# model.load_weights('DigitAndEnglist_cnn5gru_ctc_best2.h5')\n", "model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=Adam(1e-4, amsgrad=True))\n", "model.fit_generator(train_data, epochs=100, validation_data=valid_data, workers=4, use_multiprocessing=True,\n", " callbacks=callbacks)" ] }, { "cell_type": "code", "execution_count": 121, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.0" ] }, "execution_count": 121, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 准确率回调函数\n", "from tqdm import tqdm\n", "\n", "def evaluate(model, batch_size=128, steps=1):\n", " '''\n", " 准确率验证函数,每批次的验证码长度必须一致\n", " ''' \n", " batch_acc = 0\n", " valid_data = CaptchaSequence(characters, batch_size, steps)\n", " for i in range(len(valid_data)):\n", " [X_test, y_test, _, _], _ = valid_data[i]\n", " y_pred = base_model.predict(X_test)\n", " shape = y_pred.shape\n", " # out = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(shape[0])*shape[1],)[0][0])[:, :4]\n", " out = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(shape[0])*shape[1],)[0][0])[:, :]\n", " # print(y_test)\n", " # print(type(y_test))\n", " # print(y_test[y_test<10, axis=1])\n", " # print(out)\n", " if out.shape[1] >= 4:\n", " batch_acc += (y_test[:,:out.shape[1]] == out).all(axis=1).mean()\n", " return batch_acc / steps\n", "evaluate(base_model,batch_size=256, steps=10)" ] }, { "cell_type": "code", "execution_count": 133, "metadata": {}, "outputs": [], "source": [ "base_model.save('gru_arithmetic_base_model_20220620.h5') # 保存基础模型,预测用\n", "\n", "x= base_model.output # [batch_sizes, series_length, classes]\n", "input_length = Input(batch_shape=[None], dtype='int32')\n", "ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])\n", "decode = K.function([base_model.input, input_length], [ctc_decode[0][0]])\n", "\n", "def decode_arith(arith = '2×?=12'):\n", " arith = arith.replace('×', '*')\n", " items = re.split('=', arith)\n", " if len(items)==2:\n", " if items[-1] in ['?', '']:\n", " return eval(items[0])\n", " l = re.split('-|\\+|\\*', items[0])\n", " signs = re.findall('-|\\+|\\*', items[0])\n", " if len(l)==2 and len(signs)==1:\n", " if l[1] == '?':\n", " if signs[0] == '+':\n", " return eval('%s-%s'%(items[-1], l[0]))\n", " elif signs[0] == '-':\n", " return eval('%s-%s'%(l[0],items[-1]))\n", " elif signs[0] == '*':\n", " return eval('%s/%s'%(items[-1], l[0])) \n", " elif l[0] == '?':\n", " if signs[0] == '+':\n", " return eval('%s-%s'%(items[-1], l[1]))\n", " elif signs[0] == '-':\n", " return eval('%s+%s'%(l[1],items[-1]))\n", " elif signs[0] == '*':\n", " return eval('%s/%s'%(items[-1], l[1])) \n", " return ''\n", "def decode_arith(arith = '2×?=12'):\n", " arith = arith.replace('×', '*')\n", " if re.search('^(\\d+|\\?)([\\+\\-\\*/](\\d+|\\?))+=(\\d+|\\?)?$', arith) and len(re.findall('\\?', arith))<=1:\n", " if arith[-1] == '?':\n", " answer = str(int(eval(arith[:-2])))\n", " elif arith[-1] == '=':\n", " answer = str(int(eval(arith[:-1]))) \n", " elif re.search('^(\\d+|\\?)[\\+\\-\\*/](\\d+|\\?)=\\d+$', arith):\n", " a,sign,b,_,quest = re.split('(\\+|\\-|\\*|×|/|=)', arith)\n", " if a=='?':\n", " if sign==\"+\":\n", " sign = '-'\n", " elif sign=='-':\n", " sign = '+'\n", " elif sign==\"*\":\n", " sign = '/'\n", " elif sign=='/':\n", " sign = '*'\n", " a, quest = quest, a\n", " elif b == '?':\n", " if sign==\"+\":\n", " sign = '-' \n", " b, quest = quest, b\n", " a, b = b, a \n", " elif sign=='-':\n", " b, quest = quest, b \n", " elif sign==\"*\":\n", " sign = '/'\n", " b, quest = quest, b\n", " a, b = b, a \n", " elif sign=='/':\n", " b, quest = quest, b \n", " else:\n", " print('公式出错:', arith)\n", " answer = str(int(eval('%s%s%s'%(a,sign,b)))) \n", " else:\n", " print('公式出错:', arith)\n", " \n", " else:\n", " answer = ''\n", " return answer" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "out ccntq\n" ] }, { "data": { "text/plain": [ "Text(0.5, 1.0, 'ccntq')" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 163, "width": 370 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "img = Image.open('FileInfo0508/31c1f481-912a-11ea-b24d-408d5cd36814_cmftq.jpg') # 波浪线验证码\n", "# img = Image.open('/data/captcha/shensebeijingsandian/pgv4_d58a8328-c425-11ea-be07-ecf4bbc56acd.jpg') # 深色背景验证码\n", "# img = Image.open('/data/captcha/0ad9.jpg').resize((200,70), Image.BILINEAR) #小图噪点 \n", "img = img.resize((width, height), Image.BILINEAR)\n", "def img2array(image, width=width,height=height):\n", " X = np.zeros((1, height, width, 3))\n", " image = image.convert('L')\n", " px = [image.getpixel((x,2)) for x in range(image.size[0])]\n", " c = Counter(px)\n", " m = c.most_common()\n", " bg = m[0][0]\n", " bg_img = Image.new(mode='L', size=(width,height), color=bg)\n", " bg_img.paste(image, box=(0, 0)) # \n", " X[0] = np.expand_dims(np.array(bg_img)/255.0, axis=-1)\n", " return X\n", "img_arr = img2array(img)\n", "\n", "out_pre = decode([img_arr, np.ones(img_arr.shape[0])])\n", "out = ''.join([characters[x] for x in out_pre[0][0]])\n", "plt.imshow(img)\n", "print('out', out)\n", "plt.title(out)" ] }, { "cell_type": "code", "execution_count": 117, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pred:17+12=?\ttrue:17-12=?\n", "pred:5+5=?\ttrue:5×5=?\n", "2\n", "总耗时: 23.909424543380737\n", "正确数:998, 错误数:2, 总样本:1000, 准确率:0.9980\n" ] } ], "source": [ "import time\n", "data = CaptchaSequence(characters, batch_size=200, steps=5, input_length=25, chars_len=(6,6))\n", "# model.load_weights('gru_DigitAndEnglist_ctc_best_0927.h5') \n", "# model.load_weights('mobilenet_DigitAndEnglist_ctc_best0930.h5')\n", "# model.load_weights('mobilenet_DigitAndEnglist_ctc_best_32.h5')\n", "# model.load_weights('gru_english4to6_ctc_best_5.h5') \n", "pos = neg = 0\n", "t1 = time.time()\n", "err_img = []\n", "err_label = []\n", "for i in range(len(data)): \n", " flag = False\n", " [X_test, y_test, input_len, label_len], _ = data[i]\n", " for idx in range(len(X_test)):\n", " in_data = X_test[idx:idx+1]\n", " out_pre = decode([in_data, np.ones(in_data.shape[0])])\n", "# print(out_pre)\n", " out = ''.join([characters[x] for x in out_pre[0][0]]) \n", " \n", " y_true = ''.join([characters[x] for x in y_test[idx] if x < len(characters)])\n", "# print('out', out, y_true)\n", " if out != y_true:\n", " err_img.append(X_test[idx])\n", " err_label.append('pre: %s, lab: %s'%(out, y_true))\n", " print('pred:' + str(out) + '\\ttrue:' + str(y_true))\n", " neg += 1\n", " flag = True\n", " else:\n", " pos += 1 \n", "print(len(err_img))\n", "\n", "t2 = time.time()\n", "print('总耗时:',t2-t1)\n", "print('正确数:%d, 错误数:%d, 总样本:%d, 准确率:%.4f'%(pos,neg,pos+neg, pos/(pos+neg)))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'pre: 51-1+5=?, lab: 5-1+5=?')" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 153, "width": 370 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "i = 3\n", "# plt.imshow(err_img[i].reshape((height, width)))\n", "plt.imshow(err_img[i])\n", "plt.title(err_label[i])\n", "# # idx = 8\n", "# # img_arr = X_test[idx:idx+1]\n", "# # out_pre = decode([img_arr, np.ones(img_arr.shape[0])])\n", "# # out = ''.join([characters[x] for x in out_pre[0][0]])\n", "# # y_true = ''.join([characters[x] for x in y_test[idx] if x < len(characters)])\n", "# # plt.imshow(img_arr.reshape((height, width)))\n", "# # print('out', out)\n", "# # plt.title(out)\n", "# # i = 9\n", "# # print(model.layers[i].name)\n", "# # model.layers[i].get_weights() # 打印某层权重\n", "# # height\n", "# # model.load_weights('gru_english4to6_ctc_best_1105.h5')\n", "# # model.load_weights('gru_arithmetic_ctc_best_1108.h5')\n", "# paths = glob.glob('/data/captcha/arithmetic/330_69/*.jpg') # 100_26 70_25 100_40 330_69\n", "\n", "# i = 12\n", "# img = Image.open(paths[i])\n", "# img2 = img.resize((width, height), Image.BILINEAR)\n", "# # img_arr = [np.array(img2)/255.0]\n", "# # out_pre = decode([img_arr, np.ones((1,))])\n", "# # out = ''.join([characters[x] for x in out_pre[0][0]])\n", "# # print('out', out)\n", "# plt.imshow(img)\n", "# plt.show()\n", "# plt.imshow(img2)\n", "# plt.show()" ] }, { "cell_type": "code", "execution_count": 132, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "预测错误 8 9 False\n", "正确数:838, 总数:839, 准确率:0.9988\n" ] } ], "source": [ "'''预测真实验证码,统计准确率'''\n", "import re\n", "pos = neg = 0\n", "n = 0\n", "model.load_weights('gru_arithmetic_ctc_best_20220617.h5')\n", "\n", "# path3 = '/data/captcha/arithmetic/100_40/*.jpg' #正确数:588, 总数:1505, 准确率:0.3907 正确数:1500, 总数:1505, 准确率:0.9967\n", "# path3 = '/data/captcha/arithmetic/100_26/*.jpg' # 正确数:1122, 总数:2822, 准确率:0.3976 正确数:2761, 总数:2822, 准确率:0.9784\n", "\n", "path3 = '/data/captcha/arithmetic/160_60/*.jpg'\n", "# path3 = '/data/captcha/arithmetic/146_46/*.jpg'\n", "with open('/data/captcha/arithmetic/160_60/answer.txt', 'r', encoding='utf-8') as f:\n", " lines = f.readlines()\n", "d = dict()\n", "for line in lines:\n", " f_n, q, a = line.strip().split('\\t') \n", " d[f_n] = a\n", "err_imgs = []\n", "err_labels = []\n", "files = glob.glob(path3)\n", "sp = int(len(files)*0.8)\n", "# sp = min(int(len(files)*0.8), 3000)\n", "for file in files[:]:\n", " name = file.split('/')[-1]\n", " if name not in d:\n", " print(name)\n", " continue\n", " try:\n", " img = Image.open(file)\n", " except:\n", " print('打开错误:',file)\n", " continue\n", "\n", "# label = file.split('_')[-1][:-4].lower() # 答案在文件名_分割的情况\n", " label = d[name] #答案在answer.txt 文件中的情况\n", "\n", " img = img.resize((width, height), Image.BILINEAR)\n", " \n", " X = np.zeros((1, height, width, 3))\n", " img = img.convert('RGB')\n", " X[0] = np.array(img)/255.0\n", " \n", " out_pre = decode([X, np.ones(X.shape[0])])\n", " out = ''.join([characters[x] for x in out_pre[0][0]])\n", " \n", " try:\n", " gs = out\n", " out = decode_arith(arith = out)\n", " out = str(int(out))\n", " except:\n", " print('计算错误:输出公式:',gs)\n", " \n", " if label == out:\n", " pos += 1\n", " else:\n", " neg += 1\n", " print('预测错误',label, out, label==out)\n", " err_imgs.append(img)\n", " err_labels.append('label:%s pred:%s pred_gs:%s'%(label,out,gs))\n", " n += 1\n", "# if n > 100:\n", "# break\n", "print('正确数:%d, 总数:%d, 准确率:%.4f'%(pos, pos+neg, pos/(pos+neg)))" ] }, { "cell_type": "code", "execution_count": 120, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 153, "width": 370 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import time\n", "for i in range(min(10, len(err_imgs))):\n", " plt.imshow(err_imgs[i])\n", " plt.title(err_labels[i])\n", " plt.show()\n", " time.sleep(0.5)\n", " if i > 10:\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "decode_arith('0×?=60')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }