123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- import sys
- import os
- sys.path.append(os.path.abspath("../../.."))
- os.environ['KERAS_BACKEND'] = 'tensorflow'
- from BiddingKG.dl.table_head.models.layer_utils import MyModelCheckpoint
- from BiddingKG.dl.table_head.metrics import precision, recall, f1
- from keras import optimizers, Model
- from BiddingKG.dl.table_head.models.model import get_model
- from BiddingKG.dl.table_head.loss import focal_loss
- from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
- from BiddingKG.dl.table_head.pre_process import get_data_from_file, get_data_from_sql, my_data_loader, my_data_loader_2, \
- get_random
- from keras import backend as K
- model_id = 1
- if model_id == 1:
- input_shape = (6, 20, 60)
- output_shape = (1,)
- batch_size = 128
- epochs = 1000
- PRETRAINED = False
- CHECKPOINT = False
- # 用GPU
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
- else:
- input_shape = (None, None, 20, 60)
- output_shape = (None, None)
- batch_size = 1
- epochs = 1000
- PRETRAINED = False
- CHECKPOINT = False
- # 用CPU
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
- pretrained_path = "checkpoints/" + str(model_id) + "/best.hdf5"
- checkpoint_path = "checkpoints/" + str(model_id) + "/"
- def train():
- # GPU available
- print("gpus", K.tensorflow_backend._get_available_gpus())
- # Data
- data_x, data_y = get_data_from_file('txt', model_id=model_id)
- print("finish read data", len(data_x))
- # Split -> Train, Test
- if model_id == 1:
- split_size = int(len(data_x)*0.1)
- test_x, test_y = data_x[:split_size], data_y[:split_size]
- train_x, train_y = data_x[split_size:], data_y[split_size:]
- else:
- data_x, data_y = get_random(data_x, data_y)
- split_size = int(len(data_x)*0.1)
- test_x, test_y = data_x[:split_size], data_y[:split_size]
- train_x, train_y = data_x[split_size:], data_y[split_size:]
- print("len(train_x), len(test_x)", len(train_x), len(test_x))
- # Data Loader
- if model_id == 1:
- train_data_loader = my_data_loader(train_x, train_y, batch_size=batch_size)
- test_data_loader = my_data_loader(test_x, test_y, batch_size=batch_size)
- else:
- train_data_loader = my_data_loader_2(train_x, train_y, batch_size=batch_size)
- test_data_loader = my_data_loader_2(test_x, test_y, batch_size=1)
- # Model
- model = get_model(input_shape, output_shape, model_id=model_id)
- if PRETRAINED:
- model.load_weights(pretrained_path)
- print("read pretrained model", pretrained_path)
- else:
- print("no pretrained")
- if CHECKPOINT:
- model.load_weights(checkpoint_path)
- print("read checkpoint model", checkpoint_path)
- else:
- print("no checkpoint")
- filepath = 'e-{epoch:02d}_f1-{val_f1:.2f}'
- # filepath = 'e-{epoch:02d}_acc-{val_loss:.2f}'
- checkpoint = ModelCheckpoint(checkpoint_path+filepath+".hdf5",
- monitor='val_f1',
- verbose=1,
- save_best_only=True,
- mode='max')
- model.compile(optimizer=optimizers.Adam(lr=0.0005),
- loss={"output": focal_loss(3., 0.5)},
- # loss_weights={"output": 0.5},
- metrics=['acc', precision, recall, f1])
- rlu = ReduceLROnPlateau(monitor='val_f1', factor=0.5, patience=10,
- verbose=1, mode='max', cooldown=0, min_lr=0)
- model.fit_generator(train_data_loader,
- steps_per_epoch=max(1, len(train_x) // batch_size),
- callbacks=[checkpoint, rlu],
- validation_data=test_data_loader,
- validation_steps=max(1, len(test_x) // batch_size),
- epochs=epochs)
- return model, test_x
- def print_layer_output(model, data):
- middle_layer = Model(inputs=model.inputs,
- outputs=model.get_layer('input_2').output)
- middle_layer_output = middle_layer.predict([data[0], data[1]])
- print(middle_layer_output)
- return
- if __name__ == '__main__':
- model, data = train()
- # place_list = get_place_list()
- # _str1 = '中国电信'
- # _str2 = '分公司'
- # _list = []
- # for place in place_list:
- # _list.append(_str1 + place + _str2 + "\n")
- # # print(_list)
- # with open("电信分公司.txt", "w") as f:
- # f.writelines(_list)
|