train_equation_denoise.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os
  2. import random
  3. import sys
  4. from glob import glob
  5. os.environ["CUDA_VISIBLE_DEVICES"] = "1"
  6. import tensorflow as tf
  7. # tf.compat.v1.disable_eager_execution()
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  10. from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
  11. from keras.losses import BinaryCrossentropy, mse, CategoricalCrossentropy, MSE
  12. from keras.optimizer_v2.adam import Adam
  13. from click_captcha.metrics import precision, recall, f1
  14. from click_captcha.loss import focal_loss, contrastive_loss, l2_focal_loss, l1_focal_loss, perceptual_loss
  15. from click_captcha.model import u_net_denoise
  16. from click_captcha.pre_process import gen_equation, gen_equation2, gen_equation_denoise
  17. PRETRAINED = False
  18. random.seed(42)
  19. image_shape = (32, 192, 1)
  20. project_root = os.path.dirname(os.path.abspath(__file__)) + "/../"
  21. if __name__ == "__main__":
  22. model = u_net_denoise(input_shape=image_shape, class_num=image_shape[2])
  23. if PRETRAINED:
  24. _path = "./models/e130-acc0.87-char.h5"
  25. model.load_weights(_path, skip_mismatch=True, by_name=True)
  26. print("read pretrained model", _path)
  27. else:
  28. print("no pretrained")
  29. # batch num
  30. batch_size = 32
  31. # 模型权重存放位置
  32. filepath = 'models/e{epoch:02d}-loss{val_loss:.2f}-denoise.h5'
  33. check_pointer = ModelCheckpoint(filepath=filepath, monitor='val_loss', verbose=0,
  34. save_weights_only=True, save_best_only=True,
  35. mode="min", save_freq='epoch')
  36. rlu = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10,
  37. verbose=1, mode='min', cooldown=0, min_lr=0)
  38. model.compile(optimizer=Adam(lr=0.0003), loss=perceptual_loss(),
  39. metrics=['acc', precision, recall, f1])
  40. # data loader
  41. train_loader = gen_equation_denoise(None, batch_size=batch_size, shape=image_shape)
  42. test_loader = gen_equation_denoise(None, batch_size=batch_size, shape=image_shape)
  43. # train
  44. steps_per_epoch = 1000
  45. validation_steps = int(steps_per_epoch * 0.1)
  46. model.fit_generator(train_loader,
  47. steps_per_epoch=steps_per_epoch,
  48. callbacks=[check_pointer, rlu],
  49. validation_data=test_loader,
  50. validation_steps=validation_steps,
  51. epochs=1000,
  52. max_queue_size=100,
  53. use_multiprocessing=False)