train.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Sep 9 23:11:51 2020
  5. @author: chineseocr
  6. """
  7. import json
  8. import os
  9. import sys
  10. sys.path.append('train')
  11. from table_line import model, focal_loss, dice_coef, dice_coef_loss
  12. from tensorflow.keras.optimizers import Adam
  13. from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  14. from sklearn.model_selection import train_test_split
  15. from glob import glob
  16. from image import gen
  17. from config import tableModeLinePath
  18. import keras
  19. PRETRAINED = True
  20. CHECKPOINT = False
  21. if __name__ == '__main__':
  22. if PRETRAINED:
  23. model.load_weights(tableModeLinePath)
  24. print("read pretrained model")
  25. else:
  26. print("no pretrained")
  27. if CHECKPOINT:
  28. model.load_weights("checkpoint/table-line")
  29. print("read checkpoint model")
  30. else:
  31. print("no checkpoint")
  32. # 模型权重存放位置
  33. filepath = 'models/table-line.h5'
  34. # filepath = 'table-line-{epoch:02d}-{val_loss:.2f}.h5'
  35. checkpointer = ModelCheckpoint(filepath=filepath, monitor='loss', verbose=0,
  36. save_weights_only=True, save_best_only=True)
  37. rlu = ReduceLROnPlateau(monitor='loss', factor=0.1, patience=5,
  38. verbose=0, mode='auto', cooldown=0, min_lr=0)
  39. model.compile(optimizer=Adam(lr=0.0003), loss=focal_loss(),
  40. metrics=['acc', keras.metrics.Precision(), keras.metrics.Recall()])
  41. # table line dataset label with labelme
  42. paths = glob('train/dataset-line/6/*.json')
  43. # print("paths", paths)
  44. print("len(paths)", len(paths))
  45. paths = paths[:3000]
  46. # 限制长度
  47. # print('len(paths)', len(paths))
  48. # new_paths = []
  49. # for path in paths:
  50. # with open(path) as f:
  51. # jsonData = json.loads(f.read())
  52. # image_height = jsonData.get('imageHeight')
  53. # if image_height <= 1200:
  54. # new_paths.append(path)
  55. # paths = new_paths
  56. # print('len(new_paths)', len(new_paths))
  57. trainP, testP = train_test_split(paths, test_size=0.1)
  58. print('total:', len(paths), 'train:', len(trainP), 'test:', len(testP))
  59. batchsize = 3
  60. trainloader = gen(trainP, batchsize=batchsize, linetype=1)
  61. testloader = gen(testP, batchsize=batchsize, linetype=1)
  62. model.fit_generator(trainloader,
  63. steps_per_epoch=max(1, len(trainP) // batchsize),
  64. callbacks=[checkpointer],
  65. validation_data=testloader,
  66. validation_steps=max(1, len(testP) // batchsize),
  67. epochs=2000)