train.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from tensorflow.keras.layers import *
  2. from tensorflow.keras.models import *
  3. from tensorflow.keras.optimizers import *
  4. from tensorflow.keras.losses import *
  5. from BiddingKG.dl.common.Utils import *
  6. import numpy as np
  7. from random import random
  8. import json
  9. def getData():
  10. list_data = load("./data/2021-06-25-mergeTrain.pk")
  11. train_x = []
  12. train_y = []
  13. test_x = []
  14. test_y = []
  15. test_index = []
  16. _index = -1
  17. for _data in list_data:
  18. _index += 1
  19. matrix = json.loads(_data["json_matrix"])
  20. new_matrix = []
  21. for i in range(len(matrix)):
  22. if i <56:
  23. if matrix[i] == -1:
  24. matrix[i] = 0
  25. if i%2==1:
  26. matrix[i] /= 10
  27. new_matrix.append(matrix[i])
  28. elif i<63:
  29. matrix[i] /= 10
  30. new_matrix.append(matrix[i])
  31. else:
  32. new_matrix.append(matrix[i])
  33. matrix = np.array(new_matrix)
  34. _data["json_matrix"] = matrix
  35. label = [1,0] if _data["prob"] is None else [0,1]
  36. if random()>0.2:
  37. train_x.append(matrix)
  38. train_y.append(label)
  39. else:
  40. test_index.append(_index)
  41. test_x.append(matrix)
  42. test_y.append(label)
  43. return np.array(train_x),np.array(train_y),np.array(test_x),np.array(test_y),list_data,test_index
  44. def getModel():
  45. input = Input(shape=(36,))
  46. b = Dense(2,activation="tanh")(input)
  47. out = Softmax()(b)
  48. model = Model(inputs=input,outputs=out)
  49. optimizer = Adadelta()
  50. _loss = categorical_crossentropy
  51. model.compile(optimizer,_loss,metrics=[precision,recall])
  52. model.summary()
  53. return model
  54. def train():
  55. model = getModel()
  56. train_x,train_y,test_x,test_y,list_data,test_index = getData()
  57. model.fit(x=train_x,y=train_y,batch_size=300,epochs=30,validation_data=(test_x,test_y))
  58. predict = model.predict(test_x)
  59. _count = 0
  60. for _p,_l,_index in zip(predict,test_y,test_index):
  61. if np.argmax(_p)!=np.argmax(_l):
  62. _count += 1
  63. print("===================")
  64. print(list_data[_index])
  65. print(_p)
  66. print(_l)
  67. print('diff count:%d'%_count)
  68. if __name__=="__main__":
  69. train()