train.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import sys
  2. import os
  3. sys.path.append(os.path.abspath("../../.."))
  4. os.environ['KERAS_BACKEND'] = 'tensorflow'
  5. from BiddingKG.dl.table_head.models.layer_utils import MyModelCheckpoint
  6. from BiddingKG.dl.table_head.metrics import precision, recall, f1
  7. from keras import optimizers, Model
  8. from BiddingKG.dl.table_head.models.model import get_model
  9. from BiddingKG.dl.table_head.loss import focal_loss
  10. from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  11. from BiddingKG.dl.table_head.pre_process import get_data_from_file, get_data_from_sql, my_data_loader, my_data_loader_2, \
  12. get_random
  13. from keras import backend as K
  14. model_id = 1
  15. if model_id == 1:
  16. input_shape = (6, 20, 60)
  17. output_shape = (1,)
  18. batch_size = 128
  19. epochs = 1000
  20. PRETRAINED = False
  21. CHECKPOINT = False
  22. # 用GPU
  23. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  24. else:
  25. input_shape = (None, None, 20, 60)
  26. output_shape = (None, None)
  27. batch_size = 1
  28. epochs = 1000
  29. PRETRAINED = False
  30. CHECKPOINT = False
  31. # 用CPU
  32. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  33. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  34. pretrained_path = "checkpoints/" + str(model_id) + "/best.hdf5"
  35. checkpoint_path = "checkpoints/" + str(model_id) + "/"
  36. def train():
  37. # GPU available
  38. print("gpus", K.tensorflow_backend._get_available_gpus())
  39. # Data
  40. data_x, data_y = get_data_from_file('txt', model_id=model_id)
  41. print("finish read data", len(data_x))
  42. # Split -> Train, Test
  43. if model_id == 1:
  44. split_size = int(len(data_x)*0.1)
  45. test_x, test_y = data_x[:split_size], data_y[:split_size]
  46. train_x, train_y = data_x[split_size:], data_y[split_size:]
  47. else:
  48. data_x, data_y = get_random(data_x, data_y)
  49. split_size = int(len(data_x)*0.1)
  50. test_x, test_y = data_x[:split_size], data_y[:split_size]
  51. train_x, train_y = data_x[split_size:], data_y[split_size:]
  52. print("len(train_x), len(test_x)", len(train_x), len(test_x))
  53. # Data Loader
  54. if model_id == 1:
  55. train_data_loader = my_data_loader(train_x, train_y, batch_size=batch_size)
  56. test_data_loader = my_data_loader(test_x, test_y, batch_size=batch_size)
  57. else:
  58. train_data_loader = my_data_loader_2(train_x, train_y, batch_size=batch_size)
  59. test_data_loader = my_data_loader_2(test_x, test_y, batch_size=1)
  60. # Model
  61. model = get_model(input_shape, output_shape, model_id=model_id)
  62. if PRETRAINED:
  63. model.load_weights(pretrained_path)
  64. print("read pretrained model", pretrained_path)
  65. else:
  66. print("no pretrained")
  67. if CHECKPOINT:
  68. model.load_weights(checkpoint_path)
  69. print("read checkpoint model", checkpoint_path)
  70. else:
  71. print("no checkpoint")
  72. filepath = 'e-{epoch:02d}_f1-{val_f1:.2f}'
  73. # filepath = 'e-{epoch:02d}_acc-{val_loss:.2f}'
  74. checkpoint = ModelCheckpoint(checkpoint_path+filepath+".hdf5",
  75. monitor='val_f1',
  76. verbose=1,
  77. save_best_only=True,
  78. mode='max')
  79. model.compile(optimizer=optimizers.Adam(lr=0.0005),
  80. loss={"output": focal_loss(3., 0.5)},
  81. # loss_weights={"output": 0.5},
  82. metrics=['acc', precision, recall, f1])
  83. rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.5, patience=10,
  84. verbose=1, mode='max', cooldown=0, min_lr=0)
  85. model.fit_generator(train_data_loader,
  86. steps_per_epoch=max(1, len(train_x) // batch_size),
  87. callbacks=[checkpoint, rlu],
  88. validation_data=test_data_loader,
  89. validation_steps=max(1, len(test_x) // batch_size),
  90. epochs=epochs)
  91. return model, test_x
  92. def print_layer_output(model, data):
  93. middle_layer = Model(inputs=model.inputs,
  94. outputs=model.get_layer('input_2').output)
  95. middle_layer_output = middle_layer.predict([data[0], data[1]])
  96. print(middle_layer_output)
  97. return
  98. if __name__ == '__main__':
  99. model, data = train()
  100. # place_list = get_place_list()
  101. # _str1 = '中国电信'
  102. # _str2 = '分公司'
  103. # _list = []
  104. # for place in place_list:
  105. # _list.append(_str1 + place + _str2 + "\n")
  106. # # print(_list)
  107. # with open("电信分公司.txt", "w") as f:
  108. # f.writelines(_list)