1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- Created on Thu Sep 9 23:11:51 2020
- @author: chineseocr
- """
- import json
- import os
- import sys
- sys.path.append(os.path.dirname(__file__))
- print(os.path.dirname(__file__))
- from table_line import model, focal_loss, dice_coef, dice_coef_loss
- from tensorflow.keras.optimizers import Adam
- from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
- from sklearn.model_selection import train_test_split
- from glob import glob
- from image import gen
- from config import tableModeLinePath
- import tensorflow.keras as keras
- from metrics import f1
- PRETRAINED = True
- CHECKPOINT = False
- if __name__ == '__main__':
- if PRETRAINED:
- model.load_weights(tableModeLinePath)
- print("read pretrained model")
- else:
- print("no pretrained")
- if CHECKPOINT:
- model.load_weights("checkpoint/table-line")
- print("read checkpoint model")
- else:
- print("no checkpoint")
- # 模型权重存放位置
- # filepath = 'models/table-line.h5'
- filepath = 'e{epoch:02d}-f1{val_f1:.2f}.h5'
- checkpointer = ModelCheckpoint(filepath=filepath, monitor='val_f1', verbose=0,
- save_weights_only=True, save_best_only=True)
- rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.1, patience=10,
- verbose=0, mode='max', cooldown=0, min_lr=0)
- model.compile(optimizer=Adam(lr=0.0003), loss=focal_loss(),
- metrics=['acc', keras.metrics.Precision(), keras.metrics.Recall(), f1])
- # table line dataset label with labelme
- paths = glob('train/dataset-line/7/*.json')
- # print("paths", paths)
- print("len(paths)", len(paths))
- # paths = paths[:3000]
- # 限制长度
- # print('len(paths)', len(paths))
- # new_paths = []
- # for path in paths:
- # with open(path) as f:
- # jsonData = json.loads(f.read())
- # image_height = jsonData.get('imageHeight')
- # if image_height <= 1200:
- # new_paths.append(path)
- # paths = new_paths
- # print('len(new_paths)', len(new_paths))
- trainP, testP = train_test_split(paths, test_size=0.1)
- print('total:', len(paths), 'train:', len(trainP), 'test:', len(testP))
- batchsize = 3
- trainloader = gen(trainP, batchsize=batchsize, linetype=1)
- testloader = gen(testP, batchsize=batchsize, linetype=1)
- model.fit_generator(trainloader,
- steps_per_epoch=max(1, len(trainP) // batchsize),
- callbacks=[checkpointer],
- validation_data=testloader,
- validation_steps=max(1, len(testP) // batchsize),
- epochs=2000)
|