train.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Sep 9 23:11:51 2020
  5. @author: chineseocr
  6. """
  7. import json
  8. import os
  9. import random
  10. import sys
  11. sys.path.append(os.path.dirname(__file__))
  12. print(os.path.dirname(__file__))
  13. from table_line import model, focal_loss, dice_coef, dice_coef_loss, zero_loss
  14. from tensorflow.keras.optimizers import Adam
  15. from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  16. from sklearn.model_selection import train_test_split
  17. from glob import glob
  18. from image import gen
  19. from config import tableModeLinePath
  20. import tensorflow.keras as keras
  21. from metrics import f1
  22. PRETRAINED = False
  23. CHECKPOINT = False
  24. random.seed(42)
  25. if __name__ == '__main__':
  26. if PRETRAINED:
  27. # model.load_weights(tableModeLinePath)
  28. _path = "./models/e12-f10.76.h5"
  29. model.load_weights(_path)
  30. print("read pretrained model", _path)
  31. else:
  32. print("no pretrained")
  33. if CHECKPOINT:
  34. model.load_weights("checkpoint/table-line")
  35. print("read checkpoint model")
  36. else:
  37. print("no checkpoint")
  38. # 模型权重存放位置
  39. # filepath = 'models/table-line.h5'
  40. filepath = 'e{epoch:02d}-f1{val_f1:.2f}.h5'
  41. checkpointer = ModelCheckpoint(filepath=filepath, monitor='val_f1', verbose=0,
  42. save_weights_only=True, save_best_only=True, mode="max")
  43. rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.1, patience=5,
  44. verbose=0, mode='max', cooldown=0, min_lr=0)
  45. model.compile(optimizer=Adam(lr=0.003), loss=focal_loss(2., 0.25),
  46. metrics=['acc', keras.metrics.Precision(), keras.metrics.Recall(), f1])
  47. # for layer in model.layers:
  48. # layer.trainable = False
  49. # print(layer.trainable)
  50. # table line dataset label with labelme
  51. # real_data_path = 'train/dataset-line/8/*.json'
  52. real_data_path = 'C:/Table_Label/clean_data/*.json'
  53. paths = glob('train/dataset-line/7/*.json') + glob(real_data_path)
  54. # print("paths", paths)
  55. print("len(paths)", len(paths))
  56. # paths = paths[:3000]
  57. # 限制长度
  58. # print('len(paths)', len(paths))
  59. # new_paths = []
  60. # for path in paths:
  61. # with open(path) as f:
  62. # jsonData = json.loads(f.read())
  63. # image_height = jsonData.get('imageHeight')
  64. # if image_height <= 1200:
  65. # new_paths.append(path)
  66. # paths = new_paths
  67. # print('len(new_paths)', len(new_paths))
  68. # trainP, testP = train_test_split(paths, test_size=0.1)
  69. testP = random.sample(paths, int(len(paths)*0.1))
  70. trainP = []
  71. for p in paths:
  72. if p not in testP:
  73. trainP.append(p)
  74. random.shuffle(trainP)
  75. print('total:', len(paths), 'train:', len(trainP), 'test:', len(testP))
  76. batchsize = 3
  77. trainloader = gen(trainP, batchsize=batchsize, linetype=2)
  78. testloader = gen(testP, batchsize=batchsize, linetype=2)
  79. model.fit_generator(trainloader,
  80. steps_per_epoch=max(1, len(trainP) // batchsize),
  81. callbacks=[checkpointer, rlu],
  82. validation_data=testloader,
  83. validation_steps=max(1, len(testP) // batchsize),
  84. epochs=2000)