train_char.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. import random
  3. import sys
  4. from glob import glob
  5. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  6. import tensorflow as tf
  7. tf.compat.v1.disable_eager_execution()
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  10. from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
  11. from keras.losses import BinaryCrossentropy, mse, CategoricalCrossentropy, MSE
  12. from keras.optimizer_v2.adam import Adam
  13. from click_captcha.metrics import precision, recall, f1
  14. from click_captcha.loss import focal_loss, contrastive_loss, l2_focal_loss, l1_focal_loss
  15. from click_captcha.model import siamese_net, mobile_net, cnn_net, cnn_net_small, cnn_net_tiny, cnn_net_tiny_dropout
  16. from click_captcha.pre_process import gen_char
  17. PRETRAINED = False
  18. random.seed(42)
  19. image_shape = (40, 40, 1)
  20. project_root = os.path.dirname(os.path.abspath(__file__)) + "/../"
  21. class_num = 5649
  22. data_path = 'click_simple'
  23. if __name__ == "__main__":
  24. model = cnn_net_small(input_shape=image_shape, output_shape=class_num)
  25. if PRETRAINED:
  26. _path = "./models/e130-acc0.87-char.h5"
  27. model.load_weights(_path, skip_mismatch=True, by_name=True)
  28. print("read pretrained model", _path)
  29. else:
  30. print("no pretrained")
  31. # with open(project_root + "data/click/map.txt", "r") as f:
  32. # paths = f.readlines()
  33. # print("len(paths)", len(paths))
  34. paths = glob("../data/" + data_path + "/*.jpg")
  35. # data path split into train,test
  36. random.shuffle(paths)
  37. # paths = paths[:100000]
  38. trainP = paths[:int(len(paths)*0.9)]
  39. testP = paths[int(len(paths)*0.9):]
  40. print('total:', len(paths), 'train:', len(trainP), 'test:', len(testP))
  41. # batch num
  42. batch_size = 32
  43. steps_per_epoch = max(1, len(trainP) // batch_size)
  44. validation_steps = max(1, len(testP) // batch_size)
  45. # 模型权重存放位置
  46. filepath = 'models/e{epoch:02d}-acc{val_acc:.2f}-char.h5'
  47. check_pointer = ModelCheckpoint(filepath=filepath, monitor='val_acc', verbose=0,
  48. save_weights_only=True, save_best_only=True,
  49. mode="max", save_freq='epoch')
  50. rlu = ReduceLROnPlateau(monitor='val_acc', factor=0.5, patience=10,
  51. verbose=1, mode='max', cooldown=0, min_lr=0)
  52. model.compile(optimizer=Adam(lr=0.0003), loss=CategoricalCrossentropy(),
  53. metrics=['acc', f1])
  54. # data loader
  55. train_loader = gen_char(trainP, batch_size=batch_size, shape=image_shape, cls_num=class_num, data_path=data_path)
  56. test_loader = gen_char(testP, batch_size=batch_size, shape=image_shape, cls_num=class_num, data_path=data_path)
  57. # train
  58. model.fit(train_loader,
  59. steps_per_epoch=steps_per_epoch,
  60. callbacks=[check_pointer, rlu],
  61. validation_data=test_loader,
  62. validation_steps=validation_steps,
  63. epochs=1000,
  64. max_queue_size=1000,
  65. use_multiprocessing=True,
  66. workers=10)
  67. # model.fit_generator(train_loader,
  68. # steps_per_epoch=steps_per_epoch,
  69. # callbacks=[check_pointer, rlu],
  70. # validation_data=test_loader,
  71. # validation_steps=validation_steps,
  72. # epochs=1000,
  73. # max_queue_size=1000,
  74. # use_multiprocessing=True,
  75. # workers=8)