train.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import sys
  2. import os
  3. sys.path.append(os.path.abspath("../.."))
  4. from keras import optimizers
  5. from tensorflow.contrib.metrics import f1_score
  6. from tensorflow.python.ops.metrics_impl import precision, recall
  7. from BiddingKG.dl.table_head.models.model import get_model
  8. from BiddingKG.dl.table_head.loss import focal_loss
  9. from keras.callbacks import ModelCheckpoint
  10. from BiddingKG.dl.table_head.pre_process import get_data_from_file
  11. import numpy as np
  12. input_shape = (2, 10)
  13. output_shape = (2,)
  14. pretrained_path = ""
  15. checkpoint_path = "checkpoints/"
  16. PRETRAINED = False
  17. CHECKPOINT = False
  18. def train():
  19. # Data
  20. data_x, data_y = get_data_from_file()
  21. data_x = np.array(data_x)
  22. data_y = np.array(data_y)
  23. # Split -> Train, Test
  24. split_size = int(len(data_x)*0.1)
  25. test_x, test_y = data_x[:split_size], data_y[:split_size]
  26. train_x, train_y = data_x[split_size:], data_y[split_size:]
  27. # (table_num, 2 sentences, dim characters) -> (2, table_num, dim)
  28. train_x = np.transpose(train_x, (1, 0, 2))
  29. test_x = np.transpose(test_x, (1, 0, 2))
  30. # Model
  31. model = get_model(input_shape, output_shape)
  32. if PRETRAINED:
  33. model.load_weights(pretrained_path)
  34. print("read pretrained model", pretrained_path)
  35. else:
  36. print("no pretrained")
  37. if CHECKPOINT:
  38. model.load_weights(checkpoint_path)
  39. print("read checkpoint model", checkpoint_path)
  40. else:
  41. print("no checkpoint")
  42. filepath = '{epoch:02d}-{val_loss:.2f}.h5'
  43. checkpoint = ModelCheckpoint(checkpoint_path+filepath+".hdf5", monitor=focal_loss(),
  44. verbose=1, save_best_only=True, mode='min')
  45. model.compile(optimizer=optimizers.Adam(lr=0.0005), loss=focal_loss(),
  46. metrics=[focal_loss()])
  47. print(train_x.shape, train_y.shape)
  48. model.fit(x=[train_x[0], train_x[1]], y=train_y,
  49. validation_data=([test_x[0], test_x[1]], test_y),
  50. epochs=100, batch_size=128, shuffle=True,
  51. callbacks=[checkpoint])
  52. if __name__ == '__main__':
  53. train()