train.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import sys
  2. import os
  3. sys.path.append(os.path.abspath("../../.."))
  4. os.environ['KERAS_BACKEND'] = 'tensorflow'
  5. import keras
  6. from keras.metrics import categorical_accuracy
  7. from BiddingKG.dl.table_head.metrics import precision, recall, f1
  8. from keras import optimizers, Model
  9. from BiddingKG.dl.table_head.models.model import get_model
  10. from BiddingKG.dl.table_head.loss import focal_loss
  11. from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  12. from BiddingKG.dl.table_head.pre_process import get_data_from_file, get_data_from_sql, my_data_loader
  13. import numpy as np
  14. from keras import backend as K
  15. input_shape = (3, 10, 60)
  16. output_shape = (3,)
  17. batch_size = 1024
  18. epochs = 1000
  19. pretrained_path = "checkpoints/best.hdf5"
  20. checkpoint_path = "checkpoints/"
  21. PRETRAINED = False
  22. CHECKPOINT = False
  23. def train():
  24. # GPU available
  25. print("gpus", K.tensorflow_backend._get_available_gpus())
  26. # Data
  27. data_x, data_y = get_data_from_file('txt')
  28. print("finish read data", len(data_x))
  29. # Split -> Train, Test
  30. split_size = int(len(data_x)*0.1)
  31. test_x, test_y = data_x[:split_size], data_y[:split_size]
  32. train_x, train_y = data_x[split_size:], data_y[split_size:]
  33. # Data Loader
  34. train_data_loader = my_data_loader(train_x, train_y, batch_size=batch_size)
  35. test_data_loader = my_data_loader(test_x, test_y, batch_size=batch_size)
  36. # Model
  37. model = get_model(input_shape, output_shape)
  38. if PRETRAINED:
  39. model.load_weights(pretrained_path)
  40. print("read pretrained model", pretrained_path)
  41. else:
  42. print("no pretrained")
  43. if CHECKPOINT:
  44. model.load_weights(checkpoint_path)
  45. print("read checkpoint model", checkpoint_path)
  46. else:
  47. print("no checkpoint")
  48. filepath = 'e{epoch:02d}-loss{val_loss:.2f}'
  49. checkpoint = ModelCheckpoint(checkpoint_path+filepath+".hdf5", monitor='val_f1',
  50. verbose=1, save_best_only=True, mode='max')
  51. model.compile(optimizer=optimizers.Adam(lr=0.005), loss=focal_loss(),
  52. metrics=[categorical_accuracy,
  53. precision, recall, f1])
  54. rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.1, patience=5,
  55. verbose=1, mode='max', cooldown=0, min_lr=0)
  56. model.fit_generator(train_data_loader,
  57. steps_per_epoch=max(1, len(train_x) // batch_size),
  58. callbacks=[checkpoint, rlu],
  59. validation_data=test_data_loader,
  60. validation_steps=max(1, len(test_x) // batch_size),
  61. epochs=epochs)
  62. # model.fit(x=[train_x[0], train_x[1], train_x[2]], y=train_y,
  63. # validation_data=([test_x[0], test_x[1], test_x[2]], test_y),
  64. # epochs=epochs, batch_size=256, shuffle=True,
  65. # callbacks=[checkpoint, rlu])
  66. return model, test_x
  67. def print_layer_output(model, data):
  68. middle_layer = Model(inputs=model.inputs,
  69. outputs=model.get_layer('input_2').output)
  70. middle_layer_output = middle_layer.predict([data[0], data[1]])
  71. print(middle_layer_output)
  72. return
  73. if __name__ == '__main__':
  74. model, data = train()