train_yolo_char_260.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Tue Jun 24 10:53:51 2022
  4. train
  5. @author: fangjiasheng
  6. """
  7. import copy
  8. import json
  9. import os
  10. import random
  11. import sys
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  13. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  14. from click_captcha.utils import get_classes, get_anchors
  15. from click_captcha.model_260 import yolo_net, yolo_net_char
  16. os.environ["CUDA_VISIBLE_DEVICES"] = "1"
  17. # train need keras low version
  18. from tensorflow.keras.optimizers import Adam
  19. # from keras.optimizer_v2.adam import Adam
  20. # from keras.optimizers import Adam
  21. # from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  22. from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  23. from glob import glob
  24. from click_captcha.pre_process import gen_yolo_char
  25. from keras.metrics import mse
  26. PRETRAINED = False
  27. CHECKPOINT = False
  28. random.seed(42)
  29. if __name__ == '__main__':
  30. annotation_path = '../data/detect/map.txt'
  31. weight_path = 'models/e16-loss16.27-char.h5'
  32. log_dir = 'yolo_data/logs/000/'
  33. classes_path = 'yolo_data/my_classes.txt'
  34. anchors_path = 'yolo_data/my_anchors.txt'
  35. class_names = get_classes(classes_path)
  36. num_classes = len(class_names)
  37. anchors = get_anchors(anchors_path)
  38. # multiple of 32, hw
  39. input_shape = (160, 256)
  40. # default setting
  41. is_tiny_version = len(anchors) == 6
  42. model = yolo_net_char(input_shape, anchors, num_classes,
  43. load_pretrained=PRETRAINED,
  44. weights_path=weight_path)
  45. val_split = 0.1
  46. with open(annotation_path) as f:
  47. lines = f.readlines()
  48. random.shuffle(lines)
  49. # lines = lines[:100]
  50. num_val = int(len(lines)*val_split)
  51. num_train = len(lines) - num_val
  52. print('total:', len(lines), 'train:', num_train, 'test:', num_val)
  53. file_path = 'models/e{epoch:02d}-loss{val_loss:.2f}-char.h5'
  54. checkpoint = ModelCheckpoint(file_path, monitor='val_loss',
  55. save_weights_only=True, save_best_only=True, period=1)
  56. reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1)
  57. # model.compile(optimizer=Adam(lr=0.003), loss={'yolo_loss': lambda y_true, y_pred: y_pred},
  58. # metrics=['acc', mse])
  59. # batch num
  60. batch_size = 32
  61. steps_per_epoch = max(1, num_train//batch_size)
  62. # data loader
  63. train_loader = gen_yolo_char(lines[:num_train], batch_size, input_shape, anchors, num_classes)
  64. test_loader = gen_yolo_char(lines[num_train:], batch_size, input_shape, anchors, num_classes)
  65. # Train with frozen layers first, to get a stable loss.
  66. # model.fit_generator(train_loader,
  67. # steps_per_epoch=steps_per_epoch,
  68. # callbacks=[reduce_lr],
  69. # validation_data=test_loader,
  70. # validation_steps=max(1, num_val // batch_size),
  71. # epochs=20,
  72. # max_queue_size=50)
  73. # Unfreeze and continue training, to fine-tune.
  74. for i in range(len(model.layers)):
  75. model.layers[i].trainable = True
  76. print('Unfreeze all of the layers.')
  77. model.compile(optimizer=Adam(lr=0.0003), loss={'yolo_loss': lambda y_true, y_pred: y_pred},
  78. metrics=[mse])
  79. model.fit_generator(train_loader,
  80. steps_per_epoch=steps_per_epoch,
  81. validation_data=test_loader,
  82. validation_steps=max(1, num_val//batch_size),
  83. epochs=500,
  84. callbacks=[checkpoint, reduce_lr],
  85. max_queue_size=50)