train_equation.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import os
  2. import random
  3. import sys
  4. from glob import glob
  5. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  6. import tensorflow as tf
  7. import keras.backend as K
  8. # tf.compat.v1.disable_eager_execution()
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  10. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  11. from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
  12. from keras.losses import BinaryCrossentropy, mse, CategoricalCrossentropy, MSE
  13. from keras.optimizer_v2.adam import Adam
  14. from click_captcha.metrics import precision, recall, f1
  15. from click_captcha.loss import focal_loss, contrastive_loss, l2_focal_loss, l1_focal_loss, ctc_accuracy
  16. from click_captcha.model import crnn_ctc_equation, crnn_ctc_equation_large, crnn_ctc_equation_less, \
  17. crnn_ctc_equation_loss
  18. from click_captcha.pre_process import gen_equation, gen_equation2
  19. PRETRAINED = True
  20. random.seed(42)
  21. image_shape = (32, 192, 1)
  22. project_root = os.path.dirname(os.path.abspath(__file__)) + "/../"
  23. class_num = 35 + 2
  24. data_path = 'equation2'
  25. if __name__ == "__main__":
  26. model = crnn_ctc_equation_loss(input_shape=image_shape, class_num=class_num)
  27. if PRETRAINED:
  28. _path = "./models/e83-loss0.06-equation.h5"
  29. model.load_weights(_path, skip_mismatch=True, by_name=True)
  30. print("read pretrained model", _path)
  31. else:
  32. print("no pretrained")
  33. # with open(project_root + "data/click/map.txt", "r") as f:
  34. # paths = f.readlines()
  35. # print("len(paths)", len(paths))
  36. paths = glob("../data/" + data_path + "/*.jpg")
  37. # data path split into train,test
  38. random.shuffle(paths)
  39. # paths = paths[:100000]
  40. trainP = paths[:int(len(paths)*0.9)]
  41. testP = paths[int(len(paths)*0.9):]
  42. print('total:', len(paths), 'train:', len(trainP), 'test:', len(testP))
  43. # batch num
  44. batch_size = 32
  45. steps_per_epoch = max(1, len(trainP) // batch_size)
  46. validation_steps = max(1, len(testP) // batch_size)
  47. # 模型权重存放位置
  48. filepath = 'models/e{epoch:02d}-loss{val_loss:.2f}-equation.h5'
  49. check_pointer = ModelCheckpoint(filepath=filepath, monitor='val_loss', verbose=0,
  50. save_weights_only=True, save_best_only=True,
  51. mode="min", save_freq='epoch')
  52. rlu = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10,
  53. verbose=1, mode='min', cooldown=0, min_lr=0)
  54. model.compile(optimizer=Adam(lr=0.0003), loss={'ctc': lambda y_true, y_pred: y_pred})
  55. # 使用ctc center loss 所需
  56. # sess = K.get_session()
  57. # sess.run(tf.compat.v1.global_variables_initializer())
  58. # data loader
  59. train_loader = gen_equation2(trainP, batch_size=batch_size, shape=image_shape, cls_num=class_num, data_path=data_path)
  60. test_loader = gen_equation2(testP, batch_size=batch_size, shape=image_shape, cls_num=class_num, data_path=data_path)
  61. # train
  62. steps_per_epoch = 500
  63. validation_steps = int(steps_per_epoch * 0.1)
  64. model.fit_generator(train_loader,
  65. steps_per_epoch=steps_per_epoch,
  66. callbacks=[check_pointer, rlu],
  67. validation_data=test_loader,
  68. validation_steps=validation_steps,
  69. epochs=1000,
  70. max_queue_size=100,
  71. use_multiprocessing=False)