123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- from tensorflow.keras.layers import *
- from tensorflow.keras.models import *
- from tensorflow.keras.optimizers import *
- from tensorflow.keras.losses import *
- from BiddingKG.dl.common.Utils import *
- import numpy as np
- from random import random
- import json
- def getData():
- list_data = load("./data/2021-06-25-mergeTrain.pk")
- train_x = []
- train_y = []
- test_x = []
- test_y = []
- test_index = []
- _index = -1
- for _data in list_data:
- _index += 1
- matrix = json.loads(_data["json_matrix"])
- new_matrix = []
- for i in range(len(matrix)):
- if i <56:
- if matrix[i] == -1:
- matrix[i] = 0
- if i%2==1:
- matrix[i] /= 10
- new_matrix.append(matrix[i])
- elif i<63:
- matrix[i] /= 10
- new_matrix.append(matrix[i])
- else:
- new_matrix.append(matrix[i])
- matrix = np.array(new_matrix)
- _data["json_matrix"] = matrix
- label = [1,0] if _data["prob"] is None else [0,1]
- if random()>0.2:
- train_x.append(matrix)
- train_y.append(label)
- else:
- test_index.append(_index)
- test_x.append(matrix)
- test_y.append(label)
- return np.array(train_x),np.array(train_y),np.array(test_x),np.array(test_y),list_data,test_index
- def getModel():
- input = Input(shape=(36,))
- b = Dense(2,activation="tanh")(input)
- out = Softmax()(b)
- model = Model(inputs=input,outputs=out)
- optimizer = Adadelta()
- _loss = categorical_crossentropy
- model.compile(optimizer,_loss,metrics=[precision,recall])
- model.summary()
- return model
- def train():
- model = getModel()
- train_x,train_y,test_x,test_y,list_data,test_index = getData()
- model.fit(x=train_x,y=train_y,batch_size=300,epochs=30,validation_data=(test_x,test_y))
- predict = model.predict(test_x)
- _count = 0
- for _p,_l,_index in zip(predict,test_y,test_index):
- if np.argmax(_p)!=np.argmax(_l):
- _count += 1
- print("===================")
- print(list_data[_index])
- print(_p)
- print(_l)
- print('diff count:%d'%_count)
- if __name__=="__main__":
- train()
|