{ "cells": [ { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "524\n", "特啦同爱手使清弟时还睡近线却文政完么展便工在不办入强路安起理斥队何听农验它哦呢座口刚叔沙信內笑装往用更站社亲房跑啊章屋马直众立前都其造该瓜片论水夜习白感离下神送常当只认干发到思席研画领极指吗讲九越仗过受这已乐穿原才包丛满岸第千限今色种头被要活或转紧没后向气咯者跟想你吧学如问忙界解坚科一确非经张快顶由争目雪仔拿分小照般表牛命上北各他部之新孩令黑爬无级友睛脸关忽给动嘴南事写应光将情八晚成紫步所呀些飞息似车开边敢匆阵伯志禾四着够底处道衣说难七音伟仪儿通术面形停胜尤二谁题您赶热万深历导帮反收行传少生度侯员位斗会流五岁村因决但轿歌菜地找古再作物力总她样字围性准苦和怎人咱六候冲叶空加付提望外而熟共坐战连读吃点把是让心打于丘那间河整劳老建倒数治女本十接每高竺世身单山土城个敌明类士己失乎觉合门区轻究子定中我雨记体从至必甩半任告就得册民草姑产什报现青算钱太比大压见师国运石取们出怕像句唱知话家旧主眼自野去变化哪重花急火然哥并星别乡多书法月阶回相早意系以天很破的件业带跳仙机了林印声先代旁风渐长进许名晴块阳船也放几实年团此果最军史际脚革又公树呼婆切饭群两做全里平有庄等改未拉来根亮叫次百刻响海结三落走场识教且为义内条慢方寻真观看东日利答兴服枪掉量住好对可背细\n", "524\n" ] } ], "source": [ "from captcha.image import ImageCaptcha\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import random\n", "import glob\n", "import re\n", "from pylab import mpl\n", "mpl.rcParams['font.sans-serif'] = ['SimHei'] #中文显示问题\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'retina'\n", "\n", "import string\n", "# characters = string.digits + string.ascii_uppercase # 验证码字符集合数字+英文\n", "with open('/data/captcha/chinese_characters.txt', encoding='utf-8') as f:\n", " characters = f.read().strip()\n", "print(len(characters))\n", "# characters = '四生乐句付仗斥令仔乎白仙甩他瓜们用丘仪失丛代印册匆禾' # 中文字符集合\n", "print(characters)\n", "print(len(set(characters)))\n", "width, height, n_len, n_class = 120, 40, 4, len(characters) + 1 #图片宽、高,验证码最大长度,分类类别:字符集+1个空值" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ " # 防止 tensorflow 占用所有显存\n", "import tensorflow as tf\n", "import tensorflow.keras.backend as K\n", "\n", "config = tf.ConfigProto()\n", "config.gpu_options.allow_growth=True #True \n", "sess = tf.Session(config=config)\n", "K.set_session(sess)\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# 定义 CTC Loss\n", "import tensorflow.keras.backend as K\n", "\n", "def ctc_lambda_func(args):\n", " '''\n", " 定义ctc损失函数\n", " 参数:y_pred:预测值,labels:标签,input_length:lstm tiemstep,label_length:标签长度\n", " ''' \n", " y_pred, labels, input_length, label_length = args\n", " return K.ctc_batch_cost(labels, y_pred, input_length, label_length)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# 定义网络\n", "from tensorflow.keras.models import *\n", "from tensorflow.keras.layers import *\n", "\n", "input_tensor = Input((height, width, 3))\n", "x = input_tensor\n", "\n", "for i, n_cnn in enumerate([2, 2, 2, 2]): \n", " for j in range(n_cnn):\n", " x = Conv2D(32*2**min(i, 3), kernel_size=3, padding='same', kernel_initializer='he_uniform')(x) # 32*2**min(i, 3)\n", " x = BatchNormalization()(x)\n", " x = Activation('relu')(x)\n", " x = MaxPooling2D(2 if i < 3 else (2, 1))(x)\n", "\n", "x = Permute((2, 1, 3))(x)\n", "x = TimeDistributed(Flatten())(x)\n", "\n", "rnn_size = 64 # 128\n", "# x = Bidirectional(CuDNNGRU(rnn_size, return_sequences=True))(x)\n", "# x = Bidirectional(CuDNNGRU(rnn_size, return_sequences=True))(x)\n", "x = Bidirectional(GRU(rnn_size, return_sequences=True))(x)\n", "x = Bidirectional(GRU(rnn_size, return_sequences=True))(x)\n", "x = Dense(n_class, activation='softmax')(x)\n", "\n", "base_model = Model(inputs=input_tensor, outputs=x)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "labels = Input(name='the_labels', shape=[n_len], dtype='float32')\n", "input_length = Input(name='input_length', shape=[1], dtype='int64')\n", "label_length = Input(name='label_length', shape=[1], dtype='int64')\n", "loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([x, labels, input_length, label_length])\n", "\n", "model = Model(inputs=[input_tensor, labels, input_length, label_length], outputs=loss_out)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# # 网络结构可视化\n", "# from tensorflow.keras.utils import plot_model\n", "# from IPython.display import Image\n", "\n", "# plot_model(model, to_file='ctc.png', show_shapes=True)\n", "# Image('ctc.png')\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "input_1 (InputLayer) (None, 40, 120, 3) 0 \n", "_________________________________________________________________\n", "conv2d (Conv2D) (None, 40, 120, 32) 896 \n", "_________________________________________________________________\n", "batch_normalization (BatchNo (None, 40, 120, 32) 128 \n", "_________________________________________________________________\n", "activation (Activation) (None, 40, 120, 32) 0 \n", "_________________________________________________________________\n", "conv2d_1 (Conv2D) (None, 40, 120, 32) 9248 \n", "_________________________________________________________________\n", "batch_normalization_1 (Batch (None, 40, 120, 32) 128 \n", "_________________________________________________________________\n", "activation_1 (Activation) (None, 40, 120, 32) 0 \n", "_________________________________________________________________\n", "max_pooling2d (MaxPooling2D) (None, 20, 60, 32) 0 \n", "_________________________________________________________________\n", "conv2d_2 (Conv2D) (None, 20, 60, 64) 18496 \n", "_________________________________________________________________\n", "batch_normalization_2 (Batch (None, 20, 60, 64) 256 \n", "_________________________________________________________________\n", "activation_2 (Activation) (None, 20, 60, 64) 0 \n", "_________________________________________________________________\n", "conv2d_3 (Conv2D) (None, 20, 60, 64) 36928 \n", "_________________________________________________________________\n", "batch_normalization_3 (Batch (None, 20, 60, 64) 256 \n", "_________________________________________________________________\n", "activation_3 (Activation) (None, 20, 60, 64) 0 \n", "_________________________________________________________________\n", "max_pooling2d_1 (MaxPooling2 (None, 10, 30, 64) 0 \n", "_________________________________________________________________\n", "conv2d_4 (Conv2D) (None, 10, 30, 128) 73856 \n", "_________________________________________________________________\n", "batch_normalization_4 (Batch (None, 10, 30, 128) 512 \n", "_________________________________________________________________\n", "activation_4 (Activation) (None, 10, 30, 128) 0 \n", "_________________________________________________________________\n", "conv2d_5 (Conv2D) (None, 10, 30, 128) 147584 \n", "_________________________________________________________________\n", "batch_normalization_5 (Batch (None, 10, 30, 128) 512 \n", "_________________________________________________________________\n", "activation_5 (Activation) (None, 10, 30, 128) 0 \n", "_________________________________________________________________\n", "max_pooling2d_2 (MaxPooling2 (None, 5, 15, 128) 0 \n", "_________________________________________________________________\n", "conv2d_6 (Conv2D) (None, 5, 15, 256) 295168 \n", "_________________________________________________________________\n", "batch_normalization_6 (Batch (None, 5, 15, 256) 1024 \n", "_________________________________________________________________\n", "activation_6 (Activation) (None, 5, 15, 256) 0 \n", "_________________________________________________________________\n", "conv2d_7 (Conv2D) (None, 5, 15, 256) 590080 \n", "_________________________________________________________________\n", "batch_normalization_7 (Batch (None, 5, 15, 256) 1024 \n", "_________________________________________________________________\n", "activation_7 (Activation) (None, 5, 15, 256) 0 \n", "_________________________________________________________________\n", "max_pooling2d_3 (MaxPooling2 (None, 2, 15, 256) 0 \n", "_________________________________________________________________\n", "permute (Permute) (None, 15, 2, 256) 0 \n", "_________________________________________________________________\n", "time_distributed (TimeDistri (None, 15, 512) 0 \n", "_________________________________________________________________\n", "bidirectional (Bidirectional (None, 15, 128) 221568 \n", "_________________________________________________________________\n", "bidirectional_1 (Bidirection (None, 15, 128) 74112 \n", "_________________________________________________________________\n", "dense (Dense) (None, 15, 525) 67725 \n", "=================================================================\n", "Total params: 1,539,501\n", "Trainable params: 1,537,581\n", "Non-trainable params: 1,920\n", "_________________________________________________________________\n" ] } ], "source": [ "base_model.summary()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "等伯今立 [478, 207, 114, 62] 等伯今立\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/python/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/matplotlib/font_manager.py:1241: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.\n", " (prop.get_family(), self.defaultFamily[fontext]))\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 201, "width": 1151 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from PIL import Image, ImageFont, ImageDraw\n", "\n", "def random_color(start, end, opacity=None):\n", " '''\n", " 随机颜色函数,返回指定范围随机颜色值\n", " 参数:start:颜色最低值,end:颜色最高值\n", " '''\n", " red = random.randint(start, end)\n", " green = random.randint(start, end)\n", " blue = random.randint(start, end)\n", " if opacity is None:\n", " return (red, green, blue)\n", " return (red, green, blue, opacity)\n", "def random_xy(width,height): \n", " '''\n", " 随机位置函数,返回指定范围随机位置坐标\n", " 参数:width:图片宽,height:图片高\n", " '''\n", " x = random.randint(0, width)\n", " y = random.randint(0, height)\n", " return x, y\n", "\n", "table = []\n", "for i in range( 256 ):\n", " table.append( i * 1.97 )\n", " \n", "def create_captcha_image(chars, background, width=120, height=40):\n", " '''\n", " 生成验证码图片\n", " chars:要生成的字符串\n", " background:背景颜色\n", " '''\n", " image = Image.new('RGB', (width, height), color=background)\n", " draw = ImageDraw.Draw(image)\n", " def get_char_img(char,font,color,angle):\n", " '''\n", " 生成单个字符图片,随机颜色加随机旋转\n", " \n", " '''\n", " w, h = draw.textsize(char, font=font)\n", " im = Image.new('RGBA',(w,h), color=background)\n", " ImageDraw.Draw(im).text((0,0), char, font=font, fill=color)\n", " im = im.crop(im.getbbox())\n", " rot = im.rotate(angle,Image.BILINEAR,expand=1)\n", " bg = Image.new('RGBA',rot.size,background)\n", " im = Image.composite(rot, bg, rot)\n", " return im\n", " w_all = 0\n", " im_list = []\n", " w_list = []\n", " for c in chars:\n", " fonts = ['/usr/share/fonts/WindowsFonts/fonts/STXINGKA.TTF','/usr/share/fonts/WindowsFonts/fonts/simhei.ttf']\n", " font = ImageFont.truetype(font=random.choice(fonts), size=random.randint(18,20))\n", " char_img = get_char_img(char=c, font=font, color=random_color(0,90), angle=random.randint(0,0))\n", " w, h = char_img.size\n", " w_all += random.randint(0,5) \n", " w_list.append(w_all)\n", " im_list.append(char_img)\n", "# image.paste(char_img, (w_all,random.randint(0,image.size[1]-h)))\n", " w_all += w\n", " if w_all > width:\n", " image = Image.new('RGB', (w_all, height), color=background)\n", "\n", " for i in range(len(w_list)):\n", " image.paste(im_list[i], (w_list[i],random.randint(0,height-im_list[i].size[1])))\n", " return image.resize((width, height))\n", "\n", "def generate_image(random_str, width=120, height=40):\n", " '''\n", " 随机生成验证码,从四种字体随机抽取一种生成文字,加随机线干扰和点干扰\n", " 参数:random_str:要生成验证码的文字\n", " 返回:验证码图片\n", " ''' \n", "# image = Image.new(mode='RGB', size=(width, height), color=(255,255,255))\n", "# fonts = ['/usr/share/fonts/WindowsFonts/fonts/simsunb.ttf']\n", " fonts = ['/usr/share/fonts/WindowsFonts/fonts/ariali.ttf', '/usr/share/fonts/WindowsFonts/fonts/simhei.ttf',\n", " '/usr/share/fonts/WindowsFonts/fonts/simsunb.ttf', '/usr/share/fonts/WindowsFonts/fonts/calibri.ttf']\n", " font = ImageFont.truetype(font=random.choice(fonts), size=20)\n", " chars = random_str\n", " color = random_color(120,250)\n", " background = random_color(110,255)\n", " image = create_captcha_image(chars, background, width=width, height=height)\n", " draw = ImageDraw.Draw(image) \n", "# for _ in range(random.randint(5, 10)):\n", "# draw.line(xy=(random_xy(width,height),random_xy(width,height)),fill=random_color(80, 255))\n", " for _ in range(random.randint(100,150)):\n", " draw.point(xy=(random_xy(width,height)),fill=random_color(20, 250))\n", "# for _ in range(random.randint(50,80)):\n", "# x,y = random_xy(width,height)\n", "# draw.line(xy=(x,y,x+random.randint(-10,15),y+random.randint(-10,15)),fill=random_color(120, 255))\n", " \n", "# draw.text(xy=(random.randint(0,10),random.randint(0,2)),text= random_str, fill=text_fill, font=font)\n", "# len_t = len(random_str)\n", "# for i in range(len_t):\n", "# draw.text(xy=(random.randint(0,2)+i*int(width/len_t),random.randint(1,6)),text= random_str[i], \n", "# fill=random_color(0,180,opacity=80), font=font)\n", " \n", " return image.resize((120,40), Image.BILINEAR)\n", "\n", "random_str = ''.join([random.choice(characters) for j in range(random.randint(4,4))])\n", "label = [characters.find(x) for x in random_str]\n", "label_str = ''.join([characters[x] for x in label if x < len(characters)])\n", "print(random_str, label, label_str)\n", "img = generate_image(random_str, 100, 25) \n", "# img = generate_image('瓜用匆生')\n", "img2 = Image.open('Chinese/1da7be39-2189-11ea-b304-408d5cd36814_瓜用匆生.jpg')\n", "img2 = img2.resize((120,40), Image.BILINEAR)\n", "im = [img, img2]\n", "plt.figure(figsize=(20,10))\n", "for i in range(1,3): \n", " plt.subplot(2,2,i)\n", " plt.imshow(im[i-1])\n", "plt.show()\n", "\n", "# plt.imshow(img)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "用四匆匆\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 144, "width": 370 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 重组验证码\n", "def rebuild_img(path):\n", " '''\n", " 读取本地4-5位验证码图片进行裁剪分割为单个数字,从分割的字符随机抽取n个重组为新图片\n", " 参数:path:图片路径\n", " 返回:重组后图片\n", " '''\n", " label = path.split('_')[-1][:-4]\n", " if label.isalpha() and len(label) > 3:\n", " crop_n = len(label) \n", " img = Image.open(path)\n", " w, h = img.size\n", " fig_size = int(w/crop_n)\n", " fig_list = []\n", " new_label = []\n", " if crop_n == 4:\n", " for i in range(crop_n):\n", " fig_list.append(img.crop((i*fig_size+2, 2, (i+1)*fig_size-2, h-2 )))\n", " for i in range(crop_n):\n", " idx = random.randint(0,crop_n-1) # 修改为打乱顺序\n", " img.paste(fig_list[idx], (i*fig_size+2, 2, (i+1)*fig_size-2, h-2 ))\n", " new_label.append(label[idx])\n", " elif crop_n == 5: \n", " for i in range(crop_n):\n", " fig_list.append(img.crop((i*fig_size, 0, (i+1)*fig_size, h )))\n", " for i in range(crop_n):\n", " idx = random.randint(0,crop_n-1)\n", " img.paste(fig_list[idx], (i*fig_size, 0, (i+1)*fig_size, h ))\n", " new_label.append(label[idx])\n", " \n", " draw = ImageDraw.Draw(img) \n", " for _ in range(random.randint(0,50)): # 在重组的验证码图片上加噪声\n", " draw.point(xy=(random_xy(width,height)),fill=random_color(20, 250))\n", " \n", " return img.resize((120,40), Image.BILINEAR), ''.join(new_label)\n", "image, random_str = rebuild_img(random.choice(glob.glob('Chinese/*.jpg'))) \n", "print(random_str)\n", "plt.imshow(image)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "文件名:标签 字典大小 4092\n", "标签: 究旧当物\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 144, "width": 370 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def get_label_dic(label_file):\n", " '''通过文件名=标签 文件重构 名字:标签 字典'''\n", " from collections import defaultdict\n", " label_dic = defaultdict(str)\n", " with open(label_file, encoding='utf-8') as f:\n", " lines = f.readlines()\n", " for line in lines:\n", " result = line.split('=')[-1].strip()\n", " name = line.split('=')[0].strip()\n", " if len(result) != 4:\n", " continue\n", " else: \n", " label_dic[name] = result\n", " return label_dic\n", "\n", "label_dic = get_label_dic('/data/captcha/total_chinese/total_chinese.txt')\n", "print('文件名:标签 字典大小',len(label_dic))\n", "\n", "def smartSliceImg(img_path, label_dic, count=4, p_w=3):\n", " '''\n", " :param img:\n", " :param outDir:\n", " :param count: 图片中有多少个图片\n", " :param p_w: 对切割地方多少像素内进行判断\n", " :return:\n", " '''\n", " img = Image.open(img_path)\n", " w, h = img.size\n", " pixdata = img.load()\n", " eachWidth = int(w / count)\n", " beforeX = 1\n", " name = img_path.split('/')[-1][:-4]\n", " label = label_dic.get(name, None)\n", " if not label or re.search('[\\u4e00-\\u9fa5]', label)==None:\n", " print(label)\n", " return\n", "\n", " imgs = []\n", " for i in range(count):\n", " allBCount = []\n", " nextXOri = (i + 1) * eachWidth\n", "\n", " for x in range(nextXOri - p_w, nextXOri + p_w):\n", " if x >= w:\n", " x = w - 1\n", " if x < 0:\n", " x = 0\n", " b_count = 1\n", " for y in range(h):\n", "# if pixdata[x,y]==pixdata[x,2]:\n", " if pixdata[x,y][0]==pixdata[x,y][1]==pixdata[x,y][2]:\n", " b_count += 1\n", " allBCount.append({'x_pos': x, 'count': b_count})\n", " sort = sorted(allBCount, key=lambda e: e.get('count'), reverse=True)\n", "# print(allBCount)\n", "# print(sort)\n", " nextX = sort[0]['x_pos'] if sort[0]['x_pos']" ] }, "metadata": { "image/png": { "height": 144, "width": 370 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 测试生成器\n", "# import matplotlib\n", "# # matplotlib.use('qt4agg')\n", "# #指定默认字体\n", "# matplotlib.rcParams['font.sans-serif'] = ['SimHei']\n", "# matplotlib.rcParams['font.family']='sans-serif'\n", "# #解决负号'-'显示为方块的问题\n", "# matplotlib.rcParams['axes.unicode_minus'] = False\n", "\n", "data = CaptchaSequence(characters, batch_size=10, steps=1)\n", "[X_test, y_test, input_length, label_length], _ = data[0]\n", "idx =3\n", "plt.imshow(X_test[idx])\n", "print(''.join([characters[x] for x in y_test[idx]]))\n", "print(y_test[idx])\n", "# print(''.join([characters[x] for x in y_test[idx] if x < len(characters)]))\n", "# plt.title(''.join([characters[x] for x in y_test[idx] if x < len(characters)]))\n", "# print(input_length, label_length)\n", "# print(y_test)\n", "# print(X_test.shape)\n", "# print(n_class)" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "于话重胜\n" ] } ], "source": [ "# # 从现有图片生成测试数据\n", "# def get_data(img_path):\n", "# img = Image.open(img_path)\n", "# # img = img.crop((0, height-25, width, height))\n", "# w, h = img.size\n", "# data = np.zeros((1,h, w, 3))\n", "# data[0] = np.array(img)/255.0\n", "# return data\n", "# img_path = '../FileInfo/ffc510f4-f977-11e9-b970-408d5cd36814_5802.jpg'\n", "\n", "# data = get_data(img_path)\n", "# print(data.shape)\n", "# plt.imshow(data[0])\n", "print(''.join([characters[x] for x in [307, 390, 401, 229]]))" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "# 准确率回调函数\n", "from tqdm import tqdm\n", "\n", "def evaluate(model, batch_size=128, steps=1):\n", " '''\n", " 准确率验证函数,每批次的验证码长度必须一致\n", " ''' \n", " batch_acc = 0\n", " valid_data = CaptchaSequence(characters, batch_size, steps)\n", " for i in range(len(valid_data)):\n", " [X_test, y_test, _, _], _ = valid_data[i]\n", " y_pred = base_model.predict(X_test)\n", " shape = y_pred.shape\n", " # out = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(shape[0])*shape[1],)[0][0])[:, :4]\n", " out = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(shape[0])*shape[1],)[0][0])[:, :]\n", " # print(y_test)\n", " # print(type(y_test))\n", " # print(y_test[y_test<10, axis=1])\n", " # print(out)\n", " if out.shape[1] >= 4:\n", " batch_acc += (y_test[:,:out.shape[1]] == out).all(axis=1).mean()\n", " return batch_acc / steps" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.27000000000000002" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# model.load_weights('digit4to6_ctc_best2.h5')\n", "evaluate(base_model,batch_size=1, steps=10)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.callbacks import Callback\n", "\n", "class Evaluate(Callback):\n", " '''\n", " 准确率验证的类,每批次的验证码长度必须一致\n", " ''' \n", " def __init__(self):\n", " self.accs = []\n", " \n", " def on_epoch_end(self, epoch, logs=None):\n", " logs = logs or {}\n", " acc = evaluate(base_model, batch_size=128) # evaluate(base_model)\n", " logs['val_acc'] = acc\n", " self.accs.append(acc)\n", " print('\\nacc%.4f'%acc)\n", "# print(f'\\nacc: {acc*100:.4f}')" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/200\n", "1000/1000 [==============================] - 262s 262ms/step - loss: 23.9300 - val_loss: 22.1291\n", "Epoch 2/200\n", "1000/1000 [==============================] - 266s 266ms/step - loss: 20.2433 - val_loss: 19.9841\n", "Epoch 3/200\n", "1000/1000 [==============================] - 264s 264ms/step - loss: 19.5101 - val_loss: 18.1271\n", "Epoch 4/200\n", "1000/1000 [==============================] - 270s 270ms/step - loss: 15.8576 - val_loss: 13.6020\n", "Epoch 5/200\n", "1000/1000 [==============================] - 251s 251ms/step - loss: 11.2304 - val_loss: 10.5907\n", "Epoch 6/200\n", "1000/1000 [==============================] - 253s 253ms/step - loss: 7.7945 - val_loss: 6.8276\n", "Epoch 7/200\n", "1000/1000 [==============================] - 249s 249ms/step - loss: 5.0638 - val_loss: 4.8995\n", "Epoch 8/200\n", "1000/1000 [==============================] - 269s 269ms/step - loss: 3.2676 - val_loss: 3.1209\n", "Epoch 9/200\n", "1000/1000 [==============================] - 259s 259ms/step - loss: 2.2798 - val_loss: 2.4022\n", "Epoch 10/200\n", "1000/1000 [==============================] - 268s 268ms/step - loss: 1.6577 - val_loss: 1.5855\n", "Epoch 11/200\n", "1000/1000 [==============================] - 258s 258ms/step - loss: 1.3260 - val_loss: 1.3435\n", "Epoch 12/200\n", "1000/1000 [==============================] - 263s 263ms/step - loss: 1.0780 - val_loss: 1.1741\n", "Epoch 13/200\n", "1000/1000 [==============================] - 262s 262ms/step - loss: 0.9467 - val_loss: 0.9290\n", "Epoch 14/200\n", "1000/1000 [==============================] - 262s 262ms/step - loss: 0.8385 - val_loss: 0.7924\n", "Epoch 15/200\n", "1000/1000 [==============================] - 248s 248ms/step - loss: 0.7515 - val_loss: 0.7596\n", "Epoch 16/200\n", "1000/1000 [==============================] - 222s 222ms/step - loss: 0.6868 - val_loss: 0.6856\n", "Epoch 17/200\n", "1000/1000 [==============================] - 219s 219ms/step - loss: 0.6275 - val_loss: 0.6722\n", "Epoch 18/200\n", "1000/1000 [==============================] - 250s 250ms/step - loss: 0.6068 - val_loss: 0.6041\n", "Epoch 19/200\n", "1000/1000 [==============================] - 256s 256ms/step - loss: 0.5443 - val_loss: 0.7427\n", "Epoch 20/200\n", "1000/1000 [==============================] - 247s 247ms/step - loss: 0.5192 - val_loss: 0.6104\n", "Epoch 21/200\n", "1000/1000 [==============================] - 192s 192ms/step - loss: 0.5181 - val_loss: 0.5642\n", "Epoch 22/200\n", "1000/1000 [==============================] - 175s 175ms/step - loss: 0.4950 - val_loss: 0.4901\n", "Epoch 23/200\n", "1000/1000 [==============================] - 179s 179ms/step - loss: 0.4738 - val_loss: 0.8496\n", "Epoch 24/200\n", "1000/1000 [==============================] - 179s 179ms/step - loss: 0.4547 - val_loss: 0.4348\n", "Epoch 25/200\n", "1000/1000 [==============================] - 178s 178ms/step - loss: 0.4420 - val_loss: 0.5450\n", "Epoch 26/200\n", "1000/1000 [==============================] - 197s 197ms/step - loss: 0.4077 - val_loss: 0.5357\n", "Epoch 27/200\n", "1000/1000 [==============================] - 179s 179ms/step - loss: 0.4126 - val_loss: 0.3803\n", "Epoch 28/200\n", "1000/1000 [==============================] - 194s 194ms/step - loss: 0.3915 - val_loss: 0.4115\n", "Epoch 29/200\n", "1000/1000 [==============================] - 187s 187ms/step - loss: 0.3893 - val_loss: 0.4076\n", "Epoch 30/200\n", "1000/1000 [==============================] - 171s 171ms/step - loss: 0.3695 - val_loss: 0.4259\n", "Epoch 31/200\n", "1000/1000 [==============================] - 183s 183ms/step - loss: 0.3677 - val_loss: 0.4213\n", "Epoch 32/200\n", "1000/1000 [==============================] - 176s 176ms/step - loss: 0.3579 - val_loss: 1.5800\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluate()\n", "# 模型训练\n", "from tensorflow.keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint\n", "from tensorflow.keras.optimizers import *\n", "# model.load_weights('digit4to6_ctc_best.h5')\n", "\n", "train_data = CaptchaSequence(characters, batch_size=128, steps=1000) # (characters, batch_size=128, steps=1000)\n", "valid_data = CaptchaSequence(characters, batch_size=128, steps=100) # (characters, batch_size=128, steps=100)\n", "# callbacks = [EarlyStopping(patience=5), Evaluate(), \n", "# CSVLogger('ctc.csv'), ModelCheckpoint('ctc_best.h5', save_best_only=True)]\n", "callbacks = [EarlyStopping(patience=5),ModelCheckpoint('gru_chinese524char_ctc_best.h5', save_best_only=True)]\n", "model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=Adam(1e-3, amsgrad=True))\n", "# model.fit_generator(train_data, epochs=100, validation_data=valid_data,\n", "# callbacks=callbacks)\n", "model.fit_generator(train_data, epochs=200, validation_data=valid_data, workers=4, use_multiprocessing=True,\n", " callbacks=callbacks)\n", "\n", "# 20200721 gru_chinese524char_ctc_best.h5 训练32个epoch从 loss: 23.9300 - val_loss: 22.1291 下降到 loss: 0.3677 - val_loss: 0.4213" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/200\n", "1000/1000 [==============================] - 367s 367ms/step - loss: 0.0241 - val_loss: 0.0232\n", "Epoch 2/200\n", "1000/1000 [==============================] - 349s 349ms/step - loss: 0.0222 - val_loss: 0.0208\n", "Epoch 3/200\n", "1000/1000 [==============================] - 348s 348ms/step - loss: 0.0197 - val_loss: 0.0258\n", "Epoch 4/200\n", "1000/1000 [==============================] - 350s 350ms/step - loss: 0.0215 - val_loss: 0.0204\n", "Epoch 5/200\n", "1000/1000 [==============================] - 350s 350ms/step - loss: 0.0184 - val_loss: 0.0273\n", "Epoch 6/200\n", "1000/1000 [==============================] - 342s 342ms/step - loss: 0.0192 - val_loss: 0.0197\n", "Epoch 7/200\n", "1000/1000 [==============================] - 355s 355ms/step - loss: 0.0204 - val_loss: 0.0209\n", "Epoch 8/200\n", "1000/1000 [==============================] - 350s 350ms/step - loss: 0.0208 - val_loss: 0.0158\n", "Epoch 9/200\n", "1000/1000 [==============================] - 351s 351ms/step - loss: 0.0190 - val_loss: 0.0159\n", "Epoch 10/200\n", "1000/1000 [==============================] - 352s 352ms/step - loss: 0.0178 - val_loss: 0.0205\n", "Epoch 11/200\n", "1000/1000 [==============================] - 354s 354ms/step - loss: 0.0187 - val_loss: 0.0192\n", "Epoch 12/200\n", "1000/1000 [==============================] - 351s 351ms/step - loss: 0.0188 - val_loss: 0.0145\n", "Epoch 13/200\n", "1000/1000 [==============================] - 347s 347ms/step - loss: 0.0188 - val_loss: 0.0184\n", "Epoch 14/200\n", "1000/1000 [==============================] - 350s 350ms/step - loss: 0.0197 - val_loss: 0.0283\n", "Epoch 15/200\n", "1000/1000 [==============================] - 345s 345ms/step - loss: 0.0198 - val_loss: 0.0207\n", "Epoch 16/200\n", "1000/1000 [==============================] - 346s 346ms/step - loss: 0.0157 - val_loss: 0.0175\n", "Epoch 17/200\n", "1000/1000 [==============================] - 361s 361ms/step - loss: 0.0160 - val_loss: 0.0199\n", "Epoch 18/200\n", "1000/1000 [==============================] - 342s 342ms/step - loss: 0.0165 - val_loss: 0.0142\n", "Epoch 19/200\n", "1000/1000 [==============================] - 338s 338ms/step - loss: 0.0156 - val_loss: 0.0164\n", "Epoch 20/200\n", "1000/1000 [==============================] - 344s 344ms/step - loss: 0.0164 - val_loss: 0.0162\n", "Epoch 21/200\n", "1000/1000 [==============================] - 347s 347ms/step - loss: 0.0164 - val_loss: 0.0176\n", "Epoch 22/200\n", "1000/1000 [==============================] - 339s 339ms/step - loss: 0.0162 - val_loss: 0.0165\n", "Epoch 23/200\n", "1000/1000 [==============================] - 348s 348ms/step - loss: 0.0167 - val_loss: 0.0145\n", "Epoch 24/200\n", "1000/1000 [==============================] - 338s 338ms/step - loss: 0.0163 - val_loss: 0.0156\n", "Epoch 25/200\n", "1000/1000 [==============================] - 335s 335ms/step - loss: 0.0148 - val_loss: 0.0147\n", "Epoch 26/200\n", "1000/1000 [==============================] - 344s 344ms/step - loss: 0.0149 - val_loss: 0.0138\n", "Epoch 27/200\n", "1000/1000 [==============================] - 437s 437ms/step - loss: 0.0149 - val_loss: 0.0140\n", "Epoch 28/200\n", "1000/1000 [==============================] - 440s 440ms/step - loss: 0.0147 - val_loss: 0.0223\n", "Epoch 29/200\n", "1000/1000 [==============================] - 368s 368ms/step - loss: 0.0155 - val_loss: 0.0136\n", "Epoch 30/200\n", "1000/1000 [==============================] - 375s 375ms/step - loss: 0.0161 - val_loss: 0.0120\n", "Epoch 31/200\n", "1000/1000 [==============================] - 351s 351ms/step - loss: 0.0133 - val_loss: 0.0126\n", "Epoch 32/200\n", "1000/1000 [==============================] - 346s 346ms/step - loss: 0.0146 - val_loss: 0.0219\n", "Epoch 33/200\n", "1000/1000 [==============================] - 359s 359ms/step - loss: 0.0155 - val_loss: 0.0174\n", "Epoch 34/200\n", "1000/1000 [==============================] - 359s 359ms/step - loss: 0.0143 - val_loss: 0.0145\n", "Epoch 35/200\n", "1000/1000 [==============================] - 372s 372ms/step - loss: 0.0154 - val_loss: 0.0161\n", "Epoch 36/200\n", "1000/1000 [==============================] - 351s 351ms/step - loss: 0.0152 - val_loss: 0.0128\n", "Epoch 37/200\n", "1000/1000 [==============================] - 359s 359ms/step - loss: 0.0155 - val_loss: 0.0141\n", "Epoch 38/200\n", "1000/1000 [==============================] - 335s 335ms/step - loss: 0.0143 - val_loss: 0.0132\n", "Epoch 39/200\n", "1000/1000 [==============================] - 403s 403ms/step - loss: 0.0125 - val_loss: 0.0118\n", "Epoch 40/200\n", "1000/1000 [==============================] - 342s 342ms/step - loss: 0.0135 - val_loss: 0.0131\n", "Epoch 41/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0125 - val_loss: 0.0145\n", "Epoch 42/200\n", "1000/1000 [==============================] - 330s 330ms/step - loss: 0.0134 - val_loss: 0.0134\n", "Epoch 43/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0126 - val_loss: 0.0117\n", "Epoch 44/200\n", "1000/1000 [==============================] - 330s 330ms/step - loss: 0.0122 - val_loss: 0.0117\n", "Epoch 45/200\n", "1000/1000 [==============================] - 328s 328ms/step - loss: 0.0121 - val_loss: 0.0093\n", "Epoch 46/200\n", "1000/1000 [==============================] - 332s 332ms/step - loss: 0.0120 - val_loss: 0.0099\n", "Epoch 47/200\n", "1000/1000 [==============================] - 342s 342ms/step - loss: 0.0118 - val_loss: 0.0118\n", "Epoch 48/200\n", "1000/1000 [==============================] - 336s 336ms/step - loss: 0.0112 - val_loss: 0.0145\n", "Epoch 49/200\n", "1000/1000 [==============================] - 333s 333ms/step - loss: 0.0134 - val_loss: 0.0146\n", "Epoch 50/200\n", "1000/1000 [==============================] - 342s 342ms/step - loss: 0.0129 - val_loss: 0.0110\n", "Epoch 51/200\n", "1000/1000 [==============================] - 329s 329ms/step - loss: 0.0119 - val_loss: 0.0147\n", "Epoch 52/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0108 - val_loss: 0.0152\n", "Epoch 53/200\n", "1000/1000 [==============================] - 357s 357ms/step - loss: 0.0137 - val_loss: 0.0138\n", "Epoch 54/200\n", "1000/1000 [==============================] - 345s 345ms/step - loss: 0.0111 - val_loss: 0.0095\n", "Epoch 55/200\n", "1000/1000 [==============================] - 332s 332ms/step - loss: 0.0112 - val_loss: 0.0168\n", "Epoch 56/200\n", "1000/1000 [==============================] - 328s 328ms/step - loss: 0.0119 - val_loss: 0.0178\n", "Epoch 57/200\n", "1000/1000 [==============================] - 328s 328ms/step - loss: 0.0129 - val_loss: 0.0117\n", "Epoch 58/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0140 - val_loss: 0.0115\n", "Epoch 59/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0125 - val_loss: 0.0108\n", "Epoch 60/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0118 - val_loss: 0.0134\n", "Epoch 61/200\n", "1000/1000 [==============================] - 325s 325ms/step - loss: 0.0117 - val_loss: 0.0089\n", "Epoch 62/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0108 - val_loss: 0.0093\n", "Epoch 63/200\n", "1000/1000 [==============================] - 328s 328ms/step - loss: 0.0141 - val_loss: 0.0121\n", "Epoch 64/200\n", "1000/1000 [==============================] - 329s 329ms/step - loss: 0.0122 - val_loss: 0.0137\n", "Epoch 65/200\n", "1000/1000 [==============================] - 331s 331ms/step - loss: 0.0139 - val_loss: 0.0141\n", "Epoch 66/200\n", "1000/1000 [==============================] - 329s 329ms/step - loss: 0.0118 - val_loss: 0.0100\n", "Epoch 67/200\n", "1000/1000 [==============================] - 328s 328ms/step - loss: 0.0120 - val_loss: 0.0118\n", "Epoch 68/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0107 - val_loss: 0.0092\n", "Epoch 69/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0094 - val_loss: 0.0097\n", "Epoch 70/200\n", "1000/1000 [==============================] - 325s 325ms/step - loss: 0.0110 - val_loss: 0.0095\n", "Epoch 71/200\n", "1000/1000 [==============================] - 331s 331ms/step - loss: 0.0105 - val_loss: 0.0088\n", "Epoch 72/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0102 - val_loss: 0.0100\n", "Epoch 73/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0103 - val_loss: 0.0082\n", "Epoch 74/200\n", "1000/1000 [==============================] - 323s 323ms/step - loss: 0.0103 - val_loss: 0.0082\n", "Epoch 75/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0104 - val_loss: 0.0080\n", "Epoch 76/200\n", "1000/1000 [==============================] - 324s 324ms/step - loss: 0.0095 - val_loss: 0.0068\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 77/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0121 - val_loss: 0.0098\n", "Epoch 78/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0100 - val_loss: 0.0085\n", "Epoch 79/200\n", "1000/1000 [==============================] - 325s 325ms/step - loss: 0.0100 - val_loss: 0.0104\n", "Epoch 80/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0123 - val_loss: 0.0102\n", "Epoch 81/200\n", "1000/1000 [==============================] - 325s 325ms/step - loss: 0.0108 - val_loss: 0.0097\n", "Epoch 82/200\n", "1000/1000 [==============================] - 324s 324ms/step - loss: 0.0093 - val_loss: 0.0071\n", "Epoch 83/200\n", "1000/1000 [==============================] - 328s 328ms/step - loss: 0.0096 - val_loss: 0.0076\n", "Epoch 84/200\n", "1000/1000 [==============================] - 325s 325ms/step - loss: 0.0086 - val_loss: 0.0095\n", "Epoch 85/200\n", "1000/1000 [==============================] - 328s 328ms/step - loss: 0.0101 - val_loss: 0.0119\n", "Epoch 86/200\n", "1000/1000 [==============================] - 325s 325ms/step - loss: 0.0107 - val_loss: 0.0099\n", "Epoch 87/200\n", "1000/1000 [==============================] - 325s 325ms/step - loss: 0.0101 - val_loss: 0.0078\n", "Epoch 88/200\n", "1000/1000 [==============================] - 349s 349ms/step - loss: 0.0106 - val_loss: 0.0085\n", "Epoch 89/200\n", "1000/1000 [==============================] - 339s 339ms/step - loss: 0.0103 - val_loss: 0.0109\n", "Epoch 90/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0094 - val_loss: 0.0071\n", "Epoch 91/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0111 - val_loss: 0.0082\n", "Epoch 92/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0085 - val_loss: 0.0118\n", "Epoch 93/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0092 - val_loss: 0.0092\n", "Epoch 94/200\n", "1000/1000 [==============================] - 325s 325ms/step - loss: 0.0114 - val_loss: 0.0092\n", "Epoch 95/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0089 - val_loss: 0.0089\n", "Epoch 96/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0096 - val_loss: 0.0139\n", "Epoch 97/200\n", "1000/1000 [==============================] - 328s 328ms/step - loss: 0.0095 - val_loss: 0.0089\n", "Epoch 98/200\n", "1000/1000 [==============================] - 324s 324ms/step - loss: 0.0110 - val_loss: 0.0092\n", "Epoch 99/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0104 - val_loss: 0.0091\n", "Epoch 100/200\n", "1000/1000 [==============================] - 325s 325ms/step - loss: 0.0089 - val_loss: 0.0097\n", "Epoch 101/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0090 - val_loss: 0.0089\n", "Epoch 102/200\n", "1000/1000 [==============================] - 324s 324ms/step - loss: 0.0097 - val_loss: 0.0104\n", "Epoch 103/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0091 - val_loss: 0.0066\n", "Epoch 104/200\n", "1000/1000 [==============================] - 336s 336ms/step - loss: 0.0088 - val_loss: 0.0087\n", "Epoch 105/200\n", "1000/1000 [==============================] - 335s 335ms/step - loss: 0.0088 - val_loss: 0.0094\n", "Epoch 106/200\n", "1000/1000 [==============================] - 333s 333ms/step - loss: 0.0085 - val_loss: 0.0110\n", "Epoch 107/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0096 - val_loss: 0.0090\n", "Epoch 108/200\n", "1000/1000 [==============================] - 325s 325ms/step - loss: 0.0088 - val_loss: 0.0088\n", "Epoch 109/200\n", "1000/1000 [==============================] - 324s 324ms/step - loss: 0.0104 - val_loss: 0.0098\n", "Epoch 110/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0094 - val_loss: 0.0092\n", "Epoch 111/200\n", "1000/1000 [==============================] - 324s 324ms/step - loss: 0.0081 - val_loss: 0.0071\n", "Epoch 112/200\n", "1000/1000 [==============================] - 323s 323ms/step - loss: 0.0094 - val_loss: 0.0087\n", "Epoch 113/200\n", "1000/1000 [==============================] - 326s 326ms/step - loss: 0.0089 - val_loss: 0.0106\n", "Epoch 114/200\n", "1000/1000 [==============================] - 325s 325ms/step - loss: 0.0104 - val_loss: 0.0125\n", "Epoch 115/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0113 - val_loss: 0.0077\n", "Epoch 116/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0099 - val_loss: 0.0071\n", "Epoch 117/200\n", "1000/1000 [==============================] - 325s 325ms/step - loss: 0.0111 - val_loss: 0.0104\n", "Epoch 118/200\n", "1000/1000 [==============================] - 328s 328ms/step - loss: 0.0113 - val_loss: 0.0069\n", "Epoch 119/200\n", "1000/1000 [==============================] - 327s 327ms/step - loss: 0.0085 - val_loss: 0.0073\n", "Epoch 120/200\n", "1000/1000 [==============================] - 331s 331ms/step - loss: 0.0092 - val_loss: 0.0101\n", "Epoch 121/200\n", "1000/1000 [==============================] - 330s 330ms/step - loss: 0.0096 - val_loss: 0.0083\n", "Epoch 122/200\n", "1000/1000 [==============================] - 354s 354ms/step - loss: 0.0090 - val_loss: 0.0084\n", "Epoch 123/200\n", "1000/1000 [==============================] - 340s 340ms/step - loss: 0.0089 - val_loss: 0.0090\n", "Epoch 124/200\n", "1000/1000 [==============================] - 349s 349ms/step - loss: 0.0085 - val_loss: 0.0083\n", "Epoch 125/200\n", "1000/1000 [==============================] - 386s 386ms/step - loss: 0.0098 - val_loss: 0.0081\n", "Epoch 126/200\n", "1000/1000 [==============================] - 367s 367ms/step - loss: 0.0090 - val_loss: 0.0113\n", "Epoch 127/200\n", "1000/1000 [==============================] - 376s 376ms/step - loss: 0.0095 - val_loss: 0.0107\n", "Epoch 128/200\n", "1000/1000 [==============================] - 374s 374ms/step - loss: 0.0097 - val_loss: 0.0110\n", "Epoch 129/200\n", "1000/1000 [==============================] - 396s 396ms/step - loss: 0.0080 - val_loss: 0.0081\n", "Epoch 130/200\n", "1000/1000 [==============================] - 389s 389ms/step - loss: 0.0076 - val_loss: 0.0093\n", "Epoch 131/200\n", "1000/1000 [==============================] - 417s 417ms/step - loss: 0.0082 - val_loss: 0.0121\n", "Epoch 132/200\n", "1000/1000 [==============================] - 436s 436ms/step - loss: 0.0081 - val_loss: 0.0079\n", "Epoch 133/200\n", "1000/1000 [==============================] - 448s 448ms/step - loss: 0.0116 - val_loss: 0.0092\n", "Epoch 134/200\n", "1000/1000 [==============================] - 412s 412ms/step - loss: 0.0083 - val_loss: 0.0074\n", "Epoch 135/200\n", "1000/1000 [==============================] - 439s 439ms/step - loss: 0.0089 - val_loss: 0.0102\n", "Epoch 136/200\n", "1000/1000 [==============================] - 265s 265ms/step - loss: 0.0082 - val_loss: 0.0087\n", "Epoch 137/200\n", "1000/1000 [==============================] - 222s 222ms/step - loss: 0.0076 - val_loss: 0.0088\n", "Epoch 138/200\n", "1000/1000 [==============================] - 236s 236ms/step - loss: 0.0088 - val_loss: 0.0069\n", "Epoch 139/200\n", "1000/1000 [==============================] - 219s 219ms/step - loss: 0.0078 - val_loss: 0.0076\n", "Epoch 140/200\n", "1000/1000 [==============================] - 212s 212ms/step - loss: 0.0104 - val_loss: 0.0080\n", "Epoch 141/200\n", "1000/1000 [==============================] - 203s 203ms/step - loss: 0.0083 - val_loss: 0.0087\n", "Epoch 142/200\n", "1000/1000 [==============================] - 189s 189ms/step - loss: 0.0105 - val_loss: 0.0109\n", "Epoch 143/200\n", "1000/1000 [==============================] - 217s 217ms/step - loss: 0.0093 - val_loss: 0.0066\n", "Epoch 144/200\n", "1000/1000 [==============================] - 274s 274ms/step - loss: 0.0093 - val_loss: 0.0085\n", "Epoch 145/200\n", "1000/1000 [==============================] - 280s 280ms/step - loss: 0.0085 - val_loss: 0.0069\n", "Epoch 146/200\n", "1000/1000 [==============================] - 254s 254ms/step - loss: 0.0074 - val_loss: 0.0075\n", "Epoch 147/200\n", "1000/1000 [==============================] - 268s 268ms/step - loss: 0.0086 - val_loss: 0.0139\n", "Epoch 148/200\n", "1000/1000 [==============================] - 273s 273ms/step - loss: 0.0078 - val_loss: 0.0109\n", "Epoch 149/200\n", "1000/1000 [==============================] - 274s 274ms/step - loss: 0.0084 - val_loss: 0.0079\n", "Epoch 150/200\n", "1000/1000 [==============================] - 284s 284ms/step - loss: 0.0086 - val_loss: 0.0075\n", "Epoch 151/200\n", "1000/1000 [==============================] - 280s 280ms/step - loss: 0.0089 - val_loss: 0.0062\n", "Epoch 152/200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "1000/1000 [==============================] - 284s 284ms/step - loss: 0.0085 - val_loss: 0.0123\n", "Epoch 153/200\n", "1000/1000 [==============================] - 249s 249ms/step - loss: 0.0082 - val_loss: 0.0068\n", "Epoch 154/200\n", "1000/1000 [==============================] - 263s 263ms/step - loss: 0.0085 - val_loss: 0.0081\n", "Epoch 155/200\n", "1000/1000 [==============================] - 270s 270ms/step - loss: 0.0082 - val_loss: 0.0083\n", "Epoch 156/200\n", "1000/1000 [==============================] - 255s 255ms/step - loss: 0.0066 - val_loss: 0.0068\n", "Epoch 157/200\n", "1000/1000 [==============================] - 269s 269ms/step - loss: 0.0070 - val_loss: 0.0060\n", "Epoch 158/200\n", "1000/1000 [==============================] - 273s 273ms/step - loss: 0.0065 - val_loss: 0.0132\n", "Epoch 159/200\n", "1000/1000 [==============================] - 255s 255ms/step - loss: 0.0069 - val_loss: 0.0054\n", "Epoch 160/200\n", "1000/1000 [==============================] - 270s 270ms/step - loss: 0.0074 - val_loss: 0.0068\n", "Epoch 161/200\n", "1000/1000 [==============================] - 257s 257ms/step - loss: 0.0069 - val_loss: 0.0086\n", "Epoch 162/200\n", "1000/1000 [==============================] - 264s 264ms/step - loss: 0.0076 - val_loss: 0.0062\n", "Epoch 163/200\n", "1000/1000 [==============================] - 260s 260ms/step - loss: 0.0076 - val_loss: 0.0072\n", "Epoch 164/200\n", "1000/1000 [==============================] - 259s 259ms/step - loss: 0.0078 - val_loss: 0.0089\n", "Epoch 165/200\n", "1000/1000 [==============================] - 247s 247ms/step - loss: 0.0080 - val_loss: 0.0064\n", "Epoch 166/200\n", "1000/1000 [==============================] - 272s 272ms/step - loss: 0.0069 - val_loss: 0.0081\n", "Epoch 167/200\n", "1000/1000 [==============================] - 271s 271ms/step - loss: 0.0069 - val_loss: 0.0067\n", "Epoch 168/200\n", "1000/1000 [==============================] - 269s 269ms/step - loss: 0.0071 - val_loss: 0.0056\n", "Epoch 169/200\n", "1000/1000 [==============================] - 271s 271ms/step - loss: 0.0073 - val_loss: 0.0048\n", "Epoch 170/200\n", "1000/1000 [==============================] - 259s 259ms/step - loss: 0.0071 - val_loss: 0.0080\n", "Epoch 171/200\n", "1000/1000 [==============================] - 274s 274ms/step - loss: 0.0083 - val_loss: 0.0073\n", "Epoch 172/200\n", "1000/1000 [==============================] - 276s 276ms/step - loss: 0.0068 - val_loss: 0.0064\n", "Epoch 173/200\n", "1000/1000 [==============================] - 265s 265ms/step - loss: 0.0069 - val_loss: 0.0068\n", "Epoch 174/200\n", "1000/1000 [==============================] - 262s 262ms/step - loss: 0.0064 - val_loss: 0.0076\n", "Epoch 175/200\n", "1000/1000 [==============================] - 271s 271ms/step - loss: 0.0083 - val_loss: 0.0070\n", "Epoch 176/200\n", "1000/1000 [==============================] - 278s 278ms/step - loss: 0.0080 - val_loss: 0.0113\n", "Epoch 177/200\n", "1000/1000 [==============================] - 275s 275ms/step - loss: 0.0067 - val_loss: 0.0098\n", "Epoch 178/200\n", "1000/1000 [==============================] - 230s 230ms/step - loss: 0.0064 - val_loss: 0.0053\n", "Epoch 179/200\n", "1000/1000 [==============================] - 182s 182ms/step - loss: 0.0063 - val_loss: 0.0063\n", "Epoch 180/200\n", "1000/1000 [==============================] - 175s 175ms/step - loss: 0.0066 - val_loss: 0.0067\n", "Epoch 181/200\n", "1000/1000 [==============================] - 182s 182ms/step - loss: 0.0068 - val_loss: 0.0082\n", "Epoch 182/200\n", "1000/1000 [==============================] - 172s 172ms/step - loss: 0.0078 - val_loss: 0.0068\n", "Epoch 183/200\n", "1000/1000 [==============================] - 188s 188ms/step - loss: 0.0074 - val_loss: 0.0063\n", "Epoch 184/200\n", "1000/1000 [==============================] - 182s 182ms/step - loss: 0.0066 - val_loss: 0.0084\n", "Epoch 185/200\n", "1000/1000 [==============================] - 172s 172ms/step - loss: 0.0066 - val_loss: 0.0067\n", "Epoch 186/200\n", "1000/1000 [==============================] - 193s 193ms/step - loss: 0.0066 - val_loss: 0.0054\n", "Epoch 187/200\n", "1000/1000 [==============================] - 179s 179ms/step - loss: 0.0083 - val_loss: 0.0054\n", "Epoch 188/200\n", "1000/1000 [==============================] - 194s 194ms/step - loss: 0.0071 - val_loss: 0.0089\n", "Epoch 189/200\n", "1000/1000 [==============================] - 194s 194ms/step - loss: 0.0064 - val_loss: 0.0085\n", "Epoch 190/200\n", "1000/1000 [==============================] - 180s 180ms/step - loss: 0.0072 - val_loss: 0.0077\n", "Epoch 191/200\n", "1000/1000 [==============================] - 203s 203ms/step - loss: 0.0077 - val_loss: 0.0059\n", "Epoch 192/200\n", "1000/1000 [==============================] - 185s 185ms/step - loss: 0.0066 - val_loss: 0.0064\n", "Epoch 193/200\n", "1000/1000 [==============================] - 163s 163ms/step - loss: 0.0063 - val_loss: 0.0056\n", "Epoch 194/200\n", "1000/1000 [==============================] - 200s 200ms/step - loss: 0.0072 - val_loss: 0.0068\n", "Epoch 195/200\n", "1000/1000 [==============================] - 227s 227ms/step - loss: 0.0056 - val_loss: 0.0079\n", "Epoch 196/200\n", "1000/1000 [==============================] - 183s 183ms/step - loss: 0.0074 - val_loss: 0.0054\n", "Epoch 197/200\n", "1000/1000 [==============================] - 166s 166ms/step - loss: 0.0062 - val_loss: 0.0061\n", "Epoch 198/200\n", "1000/1000 [==============================] - 171s 171ms/step - loss: 0.0068 - val_loss: 0.0060\n", "Epoch 199/200\n", "1000/1000 [==============================] - 190s 190ms/step - loss: 0.0081 - val_loss: 0.0073\n", "Epoch 200/200\n", "1000/1000 [==============================] - 165s 165ms/step - loss: 0.0080 - val_loss: 0.0068\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 载入最好的模型继续训练一会\n", "from tensorflow.keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint\n", "from tensorflow.keras.optimizers import *\n", "model.load_weights('gru_chinese524char_ctc_best.h5')\n", "train_data = CaptchaSequence(characters, batch_size=128, steps=1000) # (characters, batch_size=128, steps=1000)\n", "valid_data = CaptchaSequence(characters, batch_size=128, steps=100) # (characters, batch_size=128, steps=100)\n", "# callbacks = [EarlyStopping(patience=5),\n", "# CSVLogger('ctc.csv', append=True), ModelCheckpoint('ctc_best.h5', save_best_only=True)]\n", "callbacks = [CSVLogger('ctc.csv', append=True), ModelCheckpoint('gru_chinese524char_ctc_best.h5', save_best_only=True)]\n", "\n", "model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=Adam(1e-4, amsgrad=True))\n", "model.fit_generator(train_data, epochs=200, validation_data=valid_data, workers=4, use_multiprocessing=True,\n", " callbacks=callbacks)\n", "\n", "# 继续训练近200个epoch,模型gru_chinese524char_ctc_best.h5最终损失loss: 0.0361 - val_loss: 0.0320 四个字的验证码 调整生成验证码的颜色值\n", "#让背景色与与字体颜色有区分度,再跑200epoch从loss: 0.0361 - val_loss: 0.0320 降到loss: 0.0080 - val_loss: 0.0068" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# model.load_weights('ctc_best.h5')\n", "# base_model.save('gru_chinese_base_model.h5')\n", "base_model.save('gru_chinese524char_base_model.h5')" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pred:得得第穿\n", "true:得第穿军\n", "pred:枪枪世代\n", "true:枪世代儿\n", "pred:前会准怕\n", "true:前会传怕\n", "总耗时: 25.031147241592407\n", "2557 3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/python/anaconda3/envs/dl_nlp/lib/python3.5/site-packages/matplotlib/font_manager.py:1241: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.\n", " (prop.get_family(), self.defaultFamily[fontext]))\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 171, "width": 370 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 测试模型\n", "# model.load_weights('gru_chinese524char_ctc_best.h5')\n", "characters2 = characters + ' '\n", "import time\n", "import re\n", "def get_test_data():\n", " '''\n", " 从本地获取验证码图片并生成测试数据\n", " ''' \n", " X = []\n", " Y = []\n", " for path in glob.glob('/data/captcha/total_chinese/*.jpg'): # Digit5/*.jpg #Chinese/*.jpg /data/captcha/total_chinese/*.jpg\n", "# random_str = path.split('_')[-1][:-4] # #Chinese/*.jpg\n", " name = path.split('/')[-1].strip()[:-4]\n", " random_str = label_dic.get(name, None)\n", " if not random_str:\n", " continue\n", " if random_str.isalpha() and len(random_str) ==4:\n", " random_str = random_str.upper()\n", "# print(random_str)\n", "# if random_str.isdigit() and len(random_str) ==4:\n", " img = Image.open(path)\n", " X.append(np.array(img.resize((120,40), Image.BILINEAR))/255.0)\n", " label_idx = [characters.find(x) for x in random_str]\n", " if len(random_str) < n_len:\n", " label_idx += [n_class-1]*(n_len-len(random_str)) \n", " Y.append(label_idx)\n", " return [np.array(X), np.array(Y), np.ones(len(X)), np.ones(len(X))],np.ones(len(X))\n", "\n", "# data = [get_test_data()]\n", "\n", "data = CaptchaSequence(characters, batch_size=128, steps=20, chars_len=(4,4))\n", "\n", "pos = neg = 0\n", "t1 = time.time()\n", "\n", "neg_img = []\n", "neg_str = []\n", "for i in range(len(data)): \n", " flag = False\n", " [X_test, y_test, _, _], _ = data[i]\n", " y_pred = base_model.predict(X_test)\n", "# print(y_pred.shape)\n", " out_pre = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(y_pred.shape[0])*y_pred.shape[1])[0][0])[:, :4]\n", "# print(out_pre.shape)\n", "# print(out)\n", " for j in range(out_pre.shape[0]):\n", " out = ''.join([characters[x] for x in out_pre[j] if x < len(characters)]) \n", " y_true = ''.join([characters[x] for x in y_test[j] if x < len(characters)])\n", "# out = ''.join([characters2[x] for x in out_pre[j] if x < len(characters)]) \n", "# if re.sub(' ','',out) != y_true:\n", " if out != y_true:\n", " neg_img.append(X_test[j])\n", " neg_str.append('pred:' + str(out) + '\\ttrue:' + str(y_true))\n", " plt.imshow(X_test[j])\n", " plt.title('pred:' + str(out) + '\\ntrue: ' + str(y_true))\n", " print('pred:' + str(out) + '\\ntrue:' + str(y_true))\n", " neg += 1\n", " flag = True\n", "# break\n", "# time.sleep(1)\n", "# argmax = np.argmax(y_pred, axis=2)[j]\n", "# print(list(zip(argmax, ''.join([characters2[x] for x in argmax]))))\n", " else:\n", "# print('pred:' + str(out) + '\\ntrue: ' + str(y_true))\n", " pos += 1 \n", "\n", "# if flag:\n", "# break\n", "t2 = time.time()\n", "print('总耗时:',t2-t1)\n", "print(pos,neg)\n", "# # plt.imshow(X_test[0])\n", "# # plt.title('pred:' + str(out) + '\\ntrue: ' + str(y_true))\n", "\n", "# argmax = np.argmax(y_pred, axis=2)[0]\n", "# list(zip(argmax, ''.join([characters2[x] for x in argmax])))\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pred:得得第穿\ttrue:得第穿军\n", "pred:枪枪世代\ttrue:枪世代儿\n", "pred:前会准怕\ttrue:前会传怕\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "image/png": { "height": 114, "width": 860 }, "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import math\n", "f = plt.figure(figsize=(20,10))\n", "n = math.ceil(len(neg_img)/4)\n", "for i in range(len(neg_img)):\n", " plt.subplot(n,4,i+1)\n", " plt.imshow(neg_img[i])\n", " print(neg_str[i])\n", "plt.show()\n", "\n", "# with open('/data/captcha/total_chinese/total_chinese.txt', encoding='utf-8') as f:\n", "# line = f.readlines()\n", "# print(len(line))\n", "# print(label_dic['f15d93fc-aa3d-11ea-8951-5254009c362b'])\n", "# print(len(label_dic))\n", "# 'f15d93fc-aa3d-11ea-8951-5254009c362b' in label_dic\n", "# line[:3]" ] }, { "cell_type": "code", "execution_count": 199, "metadata": {}, "outputs": [], "source": [ "evaluate(base_model,batch_size=128, steps=10)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 展示损失下降图\n", "import pandas as pd\n", "\n", "df = pd.read_csv('ctc.csv')\n", "df[['loss', 'val_loss']].plot()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.0" } }, "nbformat": 4, "nbformat_minor": 2 }