train_phrase.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import os
  2. import random
  3. import sys
  4. from glob import glob
  5. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  6. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  8. from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
  9. from keras.losses import BinaryCrossentropy, mse, CategoricalCrossentropy
  10. from keras.metrics import Precision, Recall
  11. from keras.optimizer_v2.adam import Adam
  12. import keras.backend as K
  13. from click_captcha.metrics import precision, recall, f1
  14. from click_captcha.loss import focal_loss, contrastive_loss, l2_focal_loss, l1_focal_loss, l2_loss
  15. from click_captcha.model import siamese_net, mobile_net, cnn_net, cnn_net_drag, u_net_drag, lstm_phrase, text_cnn_phrase
  16. from click_captcha.pre_process import gen_mobile, gen_drag, gen_phrase
  17. PRETRAINED = False
  18. random.seed(42)
  19. project_root = os.path.dirname(os.path.abspath(__file__)) + "/../"
  20. vocabulary_len = 5792
  21. sequence_len = 6
  22. if __name__ == "__main__":
  23. model = text_cnn_phrase((sequence_len, vocabulary_len))
  24. if PRETRAINED:
  25. _path = "./models/e08-f10.86-phrase.h5"
  26. model.load_weights(_path, skip_mismatch=True, by_name=True)
  27. print("read pretrained model", _path)
  28. else:
  29. print("no pretrained")
  30. with open("../data/phrase/map2.txt", "r") as f:
  31. map_list = f.readlines()
  32. # data path split into train,test
  33. random.shuffle(map_list)
  34. testP = map_list[:int(len(map_list)*0.1)]
  35. trainP = map_list[int(len(map_list)*0.1):]
  36. print('total:', len(map_list), 'train:', len(trainP), 'test:', len(testP))
  37. # batch num
  38. batch_size = 64
  39. steps_per_epoch = max(1, len(trainP) // batch_size)
  40. validation_steps = max(1, len(testP) // batch_size)
  41. # 模型权重存放位置
  42. filepath = 'models/e{epoch:02d}-f1{val_f1:.2f}-phrase.h5'
  43. check_pointer = ModelCheckpoint(filepath=filepath, monitor='val_f1', verbose=0,
  44. save_weights_only=True, save_best_only=True,
  45. mode="max", save_freq='epoch')
  46. rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.5, patience=10,
  47. verbose=1, mode='max', cooldown=0, min_lr=0)
  48. model.compile(optimizer=Adam(lr=0.003), loss=BinaryCrossentropy(),
  49. metrics=['acc', f1])
  50. # data loader
  51. train_loader = gen_phrase(trainP, batch_size=batch_size, shape=(vocabulary_len, sequence_len))
  52. test_loader = gen_phrase(testP, batch_size=batch_size, shape=(vocabulary_len, sequence_len))
  53. # train
  54. model.fit_generator(train_loader,
  55. steps_per_epoch=steps_per_epoch,
  56. callbacks=[check_pointer],
  57. validation_data=test_loader,
  58. validation_steps=validation_steps,
  59. epochs=1000)