12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import sys
- import os
- sys.path.append(os.path.abspath("../.."))
- from keras import optimizers
- from tensorflow.contrib.metrics import f1_score
- from tensorflow.python.ops.metrics_impl import precision, recall
- from BiddingKG.dl.table_head.models.model import get_model
- from BiddingKG.dl.table_head.loss import focal_loss
- from keras.callbacks import ModelCheckpoint
- from BiddingKG.dl.table_head.pre_process import get_data_from_file
- import numpy as np
- input_shape = (2, 10)
- output_shape = (2,)
- pretrained_path = ""
- checkpoint_path = "checkpoints/"
- PRETRAINED = False
- CHECKPOINT = False
- def train():
- # Data
- data_x, data_y = get_data_from_file()
- data_x = np.array(data_x)
- data_y = np.array(data_y)
- # Split -> Train, Test
- split_size = int(len(data_x)*0.1)
- test_x, test_y = data_x[:split_size], data_y[:split_size]
- train_x, train_y = data_x[split_size:], data_y[split_size:]
- # (table_num, 2 sentences, dim characters) -> (2, table_num, dim)
- train_x = np.transpose(train_x, (1, 0, 2))
- test_x = np.transpose(test_x, (1, 0, 2))
- # Model
- model = get_model(input_shape, output_shape)
- if PRETRAINED:
- model.load_weights(pretrained_path)
- print("read pretrained model", pretrained_path)
- else:
- print("no pretrained")
- if CHECKPOINT:
- model.load_weights(checkpoint_path)
- print("read checkpoint model", checkpoint_path)
- else:
- print("no checkpoint")
- filepath = '{epoch:02d}-{val_loss:.2f}.h5'
- checkpoint = ModelCheckpoint(checkpoint_path+filepath+".hdf5", monitor=focal_loss(),
- verbose=1, save_best_only=True, mode='min')
- model.compile(optimizer=optimizers.Adam(lr=0.0005), loss=focal_loss(),
- metrics=[focal_loss()])
- print(train_x.shape, train_y.shape)
- model.fit(x=[train_x[0], train_x[1]], y=train_y,
- validation_data=([test_x[0], test_x[1]], test_y),
- epochs=100, batch_size=128, shuffle=True,
- callbacks=[checkpoint])
- if __name__ == '__main__':
- train()
|