train.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import sys
  2. import os
  3. sys.path.append(os.path.abspath("../../.."))
  4. os.environ['KERAS_BACKEND'] = 'tensorflow'
  5. from keras.metrics import categorical_accuracy
  6. from BiddingKG.dl.table_head.metrics import precision, recall, f1
  7. from keras import optimizers, Model
  8. from BiddingKG.dl.table_head.models.model import get_model
  9. from BiddingKG.dl.table_head.loss import focal_loss
  10. from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  11. from BiddingKG.dl.table_head.pre_process import get_data_from_file, get_data_from_sql, my_data_loader
  12. from keras import backend as K
  13. input_shape = (6, 10, 60)
  14. output_shape = (1,)
  15. batch_size = 32
  16. epochs = 1000
  17. pretrained_path = "checkpoints/best.hdf5"
  18. checkpoint_path = "checkpoints/"
  19. PRETRAINED = True
  20. CHECKPOINT = False
  21. def train():
  22. # GPU available
  23. print("gpus", K.tensorflow_backend._get_available_gpus())
  24. # Data
  25. data_x, data_y = get_data_from_file('txt')
  26. # data_x = data_x[:60000]
  27. # data_y = data_y[:60000]
  28. print("finish read data", len(data_x))
  29. # Split -> Train, Test
  30. split_size = int(len(data_x)*0.1)
  31. test_x, test_y = data_x[:split_size], data_y[:split_size]
  32. train_x, train_y = data_x[split_size:], data_y[split_size:]
  33. # Data Loader
  34. train_data_loader = my_data_loader(train_x, train_y, batch_size=batch_size)
  35. test_data_loader = my_data_loader(test_x, test_y, batch_size=batch_size)
  36. # Model
  37. model = get_model(input_shape, output_shape)
  38. if PRETRAINED:
  39. model.load_weights(pretrained_path)
  40. print("read pretrained model", pretrained_path)
  41. else:
  42. print("no pretrained")
  43. if CHECKPOINT:
  44. model.load_weights(checkpoint_path)
  45. print("read checkpoint model", checkpoint_path)
  46. else:
  47. print("no checkpoint")
  48. filepath = 'e{epoch:02d}-f1{val_f1:.2f}'
  49. checkpoint = ModelCheckpoint(checkpoint_path+filepath+".hdf5", monitor='val_f1',
  50. verbose=1, save_best_only=True, mode='max')
  51. model.compile(optimizer=optimizers.Adam(lr=0.005), loss=focal_loss(),
  52. # model.compile(optimizer=optimizers.Adam(lr=0.005), loss='binary_crossentropy',
  53. metrics=['acc',
  54. precision, recall, f1])
  55. rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.1, patience=5,
  56. verbose=1, mode='max', cooldown=0, min_lr=0)
  57. model.fit_generator(train_data_loader,
  58. steps_per_epoch=max(1, len(train_x) // batch_size),
  59. callbacks=[checkpoint, rlu],
  60. validation_data=test_data_loader,
  61. validation_steps=max(1, len(test_x) // batch_size),
  62. epochs=epochs)
  63. # model.fit(x=[train_x[0], train_x[1], train_x[2]], y=train_y,
  64. # validation_data=([test_x[0], test_x[1], test_x[2]], test_y),
  65. # epochs=epochs, batch_size=256, shuffle=True,
  66. # callbacks=[checkpoint, rlu])
  67. return model, test_x
  68. def print_layer_output(model, data):
  69. middle_layer = Model(inputs=model.inputs,
  70. outputs=model.get_layer('input_2').output)
  71. middle_layer_output = middle_layer.predict([data[0], data[1]])
  72. print(middle_layer_output)
  73. return
  74. if __name__ == '__main__':
  75. model, data = train()
  76. # place_list = get_place_list()
  77. # _str1 = '中国电信'
  78. # _str2 = '分公司'
  79. # _list = []
  80. # for place in place_list:
  81. # _list.append(_str1 + place + _str2 + "\n")
  82. # # print(_list)
  83. # with open("电信分公司.txt", "w") as f:
  84. # f.writelines(_list)