#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Sep 9 23:11:51 2020 @author: chineseocr """ import json import os import sys sys.path.append(os.path.dirname(__file__)) print(os.path.dirname(__file__)) from table_line import model, focal_loss, dice_coef, dice_coef_loss from tensorflow.keras.optimizers import Adam from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau from sklearn.model_selection import train_test_split from glob import glob from image import gen from config import tableModeLinePath import tensorflow.keras as keras from metrics import f1 PRETRAINED = True CHECKPOINT = False if __name__ == '__main__': if PRETRAINED: model.load_weights(tableModeLinePath) print("read pretrained model") else: print("no pretrained") if CHECKPOINT: model.load_weights("checkpoint/table-line") print("read checkpoint model") else: print("no checkpoint") # 模型权重存放位置 # filepath = 'models/table-line.h5' filepath = 'e{epoch:02d}-f1{val_f1:.2f}.h5' checkpointer = ModelCheckpoint(filepath=filepath, monitor='val_f1', verbose=0, save_weights_only=True, save_best_only=True) rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.1, patience=10, verbose=0, mode='max', cooldown=0, min_lr=0) model.compile(optimizer=Adam(lr=0.0003), loss=focal_loss(), metrics=['acc', keras.metrics.Precision(), keras.metrics.Recall(), f1]) # table line dataset label with labelme paths = glob('train/dataset-line/7/*.json') # print("paths", paths) print("len(paths)", len(paths)) # paths = paths[:3000] # 限制长度 # print('len(paths)', len(paths)) # new_paths = [] # for path in paths: # with open(path) as f: # jsonData = json.loads(f.read()) # image_height = jsonData.get('imageHeight') # if image_height <= 1200: # new_paths.append(path) # paths = new_paths # print('len(new_paths)', len(new_paths)) trainP, testP = train_test_split(paths, test_size=0.1) print('total:', len(paths), 'train:', len(trainP), 'test:', len(testP)) batchsize = 3 trainloader = gen(trainP, batchsize=batchsize, linetype=1) testloader = gen(testP, batchsize=batchsize, linetype=1) model.fit_generator(trainloader, steps_per_epoch=max(1, len(trainP) // batchsize), callbacks=[checkpointer], validation_data=testloader, validation_steps=max(1, len(testP) // batchsize), epochs=2000)