models.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. '''
  2. Created on 2019年4月11日
  3. @author: User
  4. '''
  5. from keras import layers,models,losses,optimizers
  6. from BiddingKG.dl.common.Utils import *
  7. def getTextCNNModel(vocab,embedding_weights,input_shape=(3,50,100),classes=3):
  8. input_left = layers.Input(shape=(input_shape[1],))
  9. input_center = layers.Input(shape=(input_shape[1],))
  10. input_right = layers.Input(shape=(input_shape[1],))
  11. list_kernel = [3,6]
  12. out_left = []
  13. out_center = []
  14. out_right = []
  15. embedding = layers.Embedding(len(vocab),input_shape[2],weights=[embedding_weights] if embedding_weights is not None else None,trainable=True,name="char_embeding")
  16. for kernel in list_kernel:
  17. out_left.append(layers.Conv1D(10, kernel, activation="relu",padding='same')(embedding(input_left)))
  18. concat_left = layers.merge(out_left,mode="concat")
  19. for kernel in list_kernel:
  20. out_center.append(layers.Conv1D(4,kernel,activation="relu",padding="same")(embedding(input_center)))
  21. concat_center = layers.merge(out_center,mode="concat")
  22. for kernel in list_kernel:
  23. out_right.append(layers.Conv1D(10,kernel,activation="relu",padding="same")(embedding(input_right)))
  24. concat_right = layers.merge(out_right,mode="concat")
  25. matrix_left = layers.Dense(12,activation="relu")(concat_left)
  26. matrix_center = layers.Dense(12,activation="relu")(concat_center)
  27. matrix_right = layers.Dense(12,activation="relu")(concat_right)
  28. #layers.average(inputs)
  29. concat_matrix = layers.merge([matrix_left,matrix_center,matrix_right],mode="ave")
  30. flattern = layers.Flatten()(concat_matrix)
  31. out = layers.Dense(classes,activation="softmax")(flattern)
  32. model = models.Model([input_left,input_center,input_right],out)
  33. model.compile(optimizer=optimizers.Adam(lr=0.001), loss=losses.categorical_crossentropy, metrics=[precision,recall,f1_score])
  34. model.summary()
  35. return model
  36. import tensorflow as tf
  37. if __name__=="__main__":
  38. with tf.Graph().as_default() as g:
  39. getTextCNNModel()
  40. for _vars in tf.global_variables():
  41. print(_vars.name,_vars)