train.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Sep 9 23:11:51 2020
  5. @author: chineseocr
  6. """
  7. import json
  8. import os
  9. import sys
  10. sys.path.append(os.path.dirname(__file__))
  11. print(os.path.dirname(__file__))
  12. from table_line import model, focal_loss, dice_coef, dice_coef_loss
  13. from tensorflow.keras.optimizers import Adam
  14. from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  15. from sklearn.model_selection import train_test_split
  16. from glob import glob
  17. from image import gen
  18. from config import tableModeLinePath
  19. import tensorflow.keras as keras
  20. from metrics import f1
  21. PRETRAINED = True
  22. CHECKPOINT = False
  23. if __name__ == '__main__':
  24. if PRETRAINED:
  25. model.load_weights(tableModeLinePath)
  26. print("read pretrained model")
  27. else:
  28. print("no pretrained")
  29. if CHECKPOINT:
  30. model.load_weights("checkpoint/table-line")
  31. print("read checkpoint model")
  32. else:
  33. print("no checkpoint")
  34. # 模型权重存放位置
  35. # filepath = 'models/table-line.h5'
  36. filepath = 'e{epoch:02d}-f1{val_f1:.2f}.h5'
  37. checkpointer = ModelCheckpoint(filepath=filepath, monitor='val_f1', verbose=0,
  38. save_weights_only=True, save_best_only=True)
  39. rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.1, patience=10,
  40. verbose=0, mode='max', cooldown=0, min_lr=0)
  41. model.compile(optimizer=Adam(lr=0.0003), loss=focal_loss(),
  42. metrics=['acc', keras.metrics.Precision(), keras.metrics.Recall(), f1])
  43. # table line dataset label with labelme
  44. paths = glob('train/dataset-line/7/*.json')
  45. # print("paths", paths)
  46. print("len(paths)", len(paths))
  47. # paths = paths[:3000]
  48. # 限制长度
  49. # print('len(paths)', len(paths))
  50. # new_paths = []
  51. # for path in paths:
  52. # with open(path) as f:
  53. # jsonData = json.loads(f.read())
  54. # image_height = jsonData.get('imageHeight')
  55. # if image_height <= 1200:
  56. # new_paths.append(path)
  57. # paths = new_paths
  58. # print('len(new_paths)', len(new_paths))
  59. trainP, testP = train_test_split(paths, test_size=0.1)
  60. print('total:', len(paths), 'train:', len(trainP), 'test:', len(testP))
  61. batchsize = 3
  62. trainloader = gen(trainP, batchsize=batchsize, linetype=1)
  63. testloader = gen(testP, batchsize=batchsize, linetype=1)
  64. model.fit_generator(trainloader,
  65. steps_per_epoch=max(1, len(trainP) // batchsize),
  66. callbacks=[checkpointer],
  67. validation_data=testloader,
  68. validation_steps=max(1, len(testP) // batchsize),
  69. epochs=2000)