#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Sep 9 23:11:51 2020 @author: chineseocr """ import json import os import random import sys sys.path.append(os.path.dirname(__file__)) print(os.path.dirname(__file__)) from table_line import model, focal_loss, dice_coef, dice_coef_loss, zero_loss from tensorflow.keras.optimizers import Adam from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau from sklearn.model_selection import train_test_split from glob import glob from image import gen from config import tableModeLinePath import tensorflow.keras as keras from metrics import f1 PRETRAINED = False CHECKPOINT = False random.seed(42) if __name__ == '__main__': if PRETRAINED: # model.load_weights(tableModeLinePath) _path = "./models/e12-f10.76.h5" model.load_weights(_path) print("read pretrained model", _path) else: print("no pretrained") if CHECKPOINT: model.load_weights("checkpoint/table-line") print("read checkpoint model") else: print("no checkpoint") # 模型权重存放位置 # filepath = 'models/table-line.h5' filepath = 'e{epoch:02d}-f1{val_f1:.2f}.h5' checkpointer = ModelCheckpoint(filepath=filepath, monitor='val_f1', verbose=0, save_weights_only=True, save_best_only=True, mode="max") rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.1, patience=5, verbose=0, mode='max', cooldown=0, min_lr=0) model.compile(optimizer=Adam(lr=0.003), loss=focal_loss(2., 0.25), metrics=['acc', keras.metrics.Precision(), keras.metrics.Recall(), f1]) # for layer in model.layers: # layer.trainable = False # print(layer.trainable) # table line dataset label with labelme # real_data_path = 'train/dataset-line/8/*.json' real_data_path = 'C:/Table_Label/clean_data/*.json' paths = glob('train/dataset-line/7/*.json') + glob(real_data_path) # print("paths", paths) print("len(paths)", len(paths)) # paths = paths[:3000] # 限制长度 # print('len(paths)', len(paths)) # new_paths = [] # for path in paths: # with open(path) as f: # jsonData = json.loads(f.read()) # image_height = jsonData.get('imageHeight') # if image_height <= 1200: # new_paths.append(path) # paths = new_paths # print('len(new_paths)', len(new_paths)) # trainP, testP = train_test_split(paths, test_size=0.1) testP = random.sample(paths, int(len(paths)*0.1)) trainP = [] for p in paths: if p not in testP: trainP.append(p) random.shuffle(trainP) print('total:', len(paths), 'train:', len(trainP), 'test:', len(testP)) batchsize = 3 trainloader = gen(trainP, batchsize=batchsize, linetype=2) testloader = gen(testP, batchsize=batchsize, linetype=2) model.fit_generator(trainloader, steps_per_epoch=max(1, len(trainP) // batchsize), callbacks=[checkpointer, rlu], validation_data=testloader, validation_steps=max(1, len(testP) // batchsize), epochs=2000)