import sys import os sys.path.append(os.path.abspath("../..")) from keras import optimizers from tensorflow.contrib.metrics import f1_score from tensorflow.python.ops.metrics_impl import precision, recall from BiddingKG.dl.table_head.models.model import get_model from BiddingKG.dl.table_head.loss import focal_loss from keras.callbacks import ModelCheckpoint from BiddingKG.dl.table_head.pre_process import get_data_from_file import numpy as np input_shape = (2, 10) output_shape = (2,) pretrained_path = "" checkpoint_path = "checkpoints/" PRETRAINED = False CHECKPOINT = False def train(): # Data data_x, data_y = get_data_from_file() data_x = np.array(data_x) data_y = np.array(data_y) # Split -> Train, Test split_size = int(len(data_x)*0.1) test_x, test_y = data_x[:split_size], data_y[:split_size] train_x, train_y = data_x[split_size:], data_y[split_size:] # (table_num, 2 sentences, dim characters) -> (2, table_num, dim) train_x = np.transpose(train_x, (1, 0, 2)) test_x = np.transpose(test_x, (1, 0, 2)) # Model model = get_model(input_shape, output_shape) if PRETRAINED: model.load_weights(pretrained_path) print("read pretrained model", pretrained_path) else: print("no pretrained") if CHECKPOINT: model.load_weights(checkpoint_path) print("read checkpoint model", checkpoint_path) else: print("no checkpoint") filepath = '{epoch:02d}-{val_loss:.2f}.h5' checkpoint = ModelCheckpoint(checkpoint_path+filepath+".hdf5", monitor=focal_loss(), verbose=1, save_best_only=True, mode='min') model.compile(optimizer=optimizers.Adam(lr=0.0005), loss=focal_loss(), metrics=[focal_loss()]) print(train_x.shape, train_y.shape) model.fit(x=[train_x[0], train_x[1]], y=train_y, validation_data=([test_x[0], test_x[1]], test_y), epochs=100, batch_size=128, shuffle=True, callbacks=[checkpoint]) if __name__ == '__main__': train()