train.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import sys
  2. import os
  3. sys.path.append(os.path.abspath("../../.."))
  4. os.environ['KERAS_BACKEND'] = 'tensorflow'
  5. import keras
  6. from BiddingKG.dl.table_head.metrics import precision, recall, f1
  7. from keras import optimizers, Model
  8. from BiddingKG.dl.table_head.models.model import get_model
  9. from BiddingKG.dl.table_head.loss import focal_loss
  10. from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  11. from BiddingKG.dl.table_head.pre_process import get_data_from_file, get_data_from_sql, my_data_loader
  12. import numpy as np
  13. from keras import backend as K
  14. input_shape = (3, 10, 60)
  15. output_shape = (3,)
  16. batch_size = 1024
  17. epochs = 1000
  18. pretrained_path = "checkpoints/best.hdf5"
  19. checkpoint_path = "checkpoints/"
  20. PRETRAINED = False
  21. CHECKPOINT = False
  22. def train():
  23. # GPU available
  24. print("gpus", K.tensorflow_backend._get_available_gpus())
  25. # Data
  26. data_x, data_y = get_data_from_file('txt')
  27. print("finish read data", len(data_x))
  28. # Split -> Train, Test
  29. split_size = int(len(data_x)*0.1)
  30. test_x, test_y = data_x[:split_size], data_y[:split_size]
  31. train_x, train_y = data_x[split_size:], data_y[split_size:]
  32. # Data Loader
  33. train_data_loader = my_data_loader(train_x, train_y, batch_size=batch_size)
  34. test_data_loader = my_data_loader(test_x, test_y, batch_size=batch_size)
  35. # Model
  36. model = get_model(input_shape, output_shape)
  37. if PRETRAINED:
  38. model.load_weights(pretrained_path)
  39. print("read pretrained model", pretrained_path)
  40. else:
  41. print("no pretrained")
  42. if CHECKPOINT:
  43. model.load_weights(checkpoint_path)
  44. print("read checkpoint model", checkpoint_path)
  45. else:
  46. print("no checkpoint")
  47. filepath = 'e-{epoch:02d}-loss-{val_loss:.2f}'
  48. checkpoint = ModelCheckpoint(checkpoint_path+filepath+".hdf5", monitor='val_f1',
  49. verbose=1, save_best_only=True, mode='max')
  50. model.compile(optimizer=optimizers.Adam(lr=0.0005), loss='binary_crossentropy',
  51. metrics=['binary_crossentropy', 'acc',
  52. precision, recall, f1])
  53. rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.1, patience=5,
  54. verbose=1, mode='max', cooldown=0, min_lr=0)
  55. model.fit_generator(train_data_loader,
  56. steps_per_epoch=max(1, len(train_x) // batch_size),
  57. callbacks=[checkpoint, rlu],
  58. validation_data=test_data_loader,
  59. validation_steps=max(1, len(test_x) // batch_size),
  60. epochs=epochs)
  61. # model.fit(x=[train_x[0], train_x[1], train_x[2]], y=train_y,
  62. # validation_data=([test_x[0], test_x[1], test_x[2]], test_y),
  63. # epochs=epochs, batch_size=256, shuffle=True,
  64. # callbacks=[checkpoint, rlu])
  65. return model, test_x
  66. def print_layer_output(model, data):
  67. middle_layer = Model(inputs=model.inputs,
  68. outputs=model.get_layer('input_2').output)
  69. middle_layer_output = middle_layer.predict([data[0], data[1]])
  70. print(middle_layer_output)
  71. return
  72. if __name__ == '__main__':
  73. model, data = train()