train_yolo_char.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Tue Jun 24 10:53:51 2022
  4. train
  5. @author: fangjiasheng
  6. """
  7. import copy
  8. import json
  9. import os
  10. import random
  11. import sys
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  13. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  14. from click_captcha.utils import get_classes, get_anchors
  15. from click_captcha.model import yolo_net
  16. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  17. # train need keras low version
  18. # from tensorflow.keras.optimizers import Adam
  19. # from keras.optimizer_v2.adam import Adam
  20. from keras.optimizers import Adam
  21. # from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  22. from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  23. from glob import glob
  24. from click_captcha.pre_process import gen_yolo_char
  25. from keras.metrics import mse
  26. # from keras.losses import categorical_crossentropy as CategoricalCrossentropy
  27. PRETRAINED = False
  28. CHECKPOINT = False
  29. random.seed(42)
  30. if __name__ == '__main__':
  31. annotation_path = '../data/detect/map.txt'
  32. weight_path = 'models/e20-loss116.38.h5'
  33. log_dir = 'yolo_data/logs/000/'
  34. classes_path = 'yolo_data/my_classes.txt'
  35. anchors_path = 'yolo_data/my_anchors.txt'
  36. class_names = get_classes(classes_path)
  37. num_classes = len(class_names)
  38. anchors = get_anchors(anchors_path)
  39. # multiple of 32, hw
  40. input_shape = (160, 256)
  41. # default setting
  42. is_tiny_version = len(anchors) == 6
  43. model = yolo_net(input_shape, anchors, num_classes,
  44. load_pretrained=PRETRAINED,
  45. weights_path=weight_path)
  46. val_split = 0.1
  47. with open(annotation_path) as f:
  48. lines = f.readlines()
  49. random.shuffle(lines)
  50. # lines = lines[:100]
  51. num_val = int(len(lines)*val_split)
  52. num_train = len(lines) - num_val
  53. print('total:', len(lines), 'train:', num_train, 'test:', num_val)
  54. file_path = 'models/e{epoch:02d}-loss{val_loss:.2f}.h5'
  55. checkpoint = ModelCheckpoint(file_path, monitor='val_loss',
  56. save_weights_only=True, save_best_only=True, period=2)
  57. reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10, verbose=1)
  58. # model.compile(optimizer=Adam(lr=0.003), loss={'yolo_loss': lambda y_true, y_pred: y_pred},
  59. # metrics=['acc', mse])
  60. # batch num
  61. batch_size = 32
  62. steps_per_epoch = max(1, num_train//batch_size)
  63. # data loader
  64. train_loader = gen_yolo_char(lines[:num_train], batch_size, input_shape, anchors, num_classes)
  65. test_loader = gen_yolo_char(lines[num_train:], batch_size, input_shape, anchors, num_classes)
  66. # Train with frozen layers first, to get a stable loss.
  67. # model.fit_generator(train_loader,
  68. # steps_per_epoch=steps_per_epoch,
  69. # callbacks=[reduce_lr],
  70. # validation_data=test_loader,
  71. # validation_steps=max(1, num_val // batch_size),
  72. # epochs=20,
  73. # max_queue_size=50)
  74. # Unfreeze and continue training, to fine-tune.
  75. for i in range(len(model.layers)):
  76. model.layers[i].trainable = True
  77. print('Unfreeze all of the layers.')
  78. model.compile(optimizer=Adam(lr=0.003), loss={'yolo_loss': lambda y_true, y_pred: y_pred},
  79. metrics=[mse])
  80. model.fit_generator(train_loader,
  81. steps_per_epoch=steps_per_epoch,
  82. validation_data=test_loader,
  83. validation_steps=max(1, num_val//batch_size),
  84. epochs=500,
  85. callbacks=[checkpoint, reduce_lr])