train_origin.py 1.4 KB

123456789101112131415161718192021222324252627
  1. import sys
  2. from table_line import model, dice_coef_loss, dice_coef
  3. from tensorflow.keras.optimizers import Adam
  4. from tensorflow.keras.callbacks import ModelCheckpoint,ReduceLROnPlateau
  5. from sklearn.model_selection import train_test_split
  6. from glob import glob
  7. from image import gen, gen_origin
  8. if __name__=='__main__':
  9. filepath = 'models/table-line.h5'##模型权重存放位置
  10. checkpointer = ModelCheckpoint(filepath=filepath,monitor='loss',verbose=0,save_weights_only=True, save_best_only=True)
  11. rlu = ReduceLROnPlateau(monitor='loss', factor=0.1, patience=5, verbose=0, mode='auto', cooldown=0, min_lr=0)
  12. model.compile(optimizer=Adam(lr=0.0003), loss=dice_coef_loss(), metrics=['acc', dice_coef])
  13. paths = glob('./train/dataset-line/6/*.json') ##table line dataset label with labelme
  14. trainP,testP = train_test_split(paths,test_size=0.1)
  15. print('total:',len(paths),'train:',len(trainP),'test:',len(testP))
  16. batchsize=4
  17. trainloader = gen_origin(trainP,batchsize=batchsize,linetype=1)
  18. testloader = gen_origin(testP,batchsize=batchsize,linetype=1)
  19. model.fit_generator(trainloader,
  20. steps_per_epoch=max(1,len(trainP)//batchsize),
  21. callbacks=[checkpointer],
  22. validation_data=testloader,
  23. validation_steps=max(1,len(testP)//batchsize),
  24. epochs=30)