model.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. '''
  2. Created on 2019年3月26日
  3. @author: User
  4. '''
  5. import os
  6. import sys
  7. sys.path.append(os.path.abspath("../.."))
  8. from BiddingKG.dl.common.Utils import *
  9. from keras import layers
  10. from keras import models
  11. from keras import losses,optimizers
  12. import tensorflow as tf
  13. def getBiLSTMModel():
  14. input = layers.Input(shape=(100,))
  15. lstm = layers.Bidirectional(layers.LSTM(128,return_sequences=False))(input)
  16. matrix = layers.Dense(48,activation="relu")(lstm)
  17. out = layers.Dense(2,activation="softmax")(matrix)
  18. model = models.Model(inputs=input,outputs=out)
  19. model.compile(optimizer=optimizers.Adam(lr=0.01), loss=losses.categorical_crossentropy,metrics=[precision,recall,f1_score])
  20. model.summary()
  21. return model
  22. def getTextCNNModel():
  23. input = layers.Input(shape=(100,60))
  24. list_pool = []
  25. list_size = [10,15,20]
  26. for size in list_size:
  27. c = layers.Conv1D(filters=4,kernel_size=size,strides=1,activation="relu")(input)
  28. p = layers.AvgPool1D(pool_size=int(c.shape[1]))(c)
  29. list_pool.append(p)
  30. concat = layers.merge(list_pool,mode="concat")
  31. flatten = layers.Flatten()(concat)
  32. matrix = layers.Dense(12,activation="relu")(flatten)
  33. out = layers.Dense(2,activation="softmax")(matrix)
  34. model = models.Model(inputs=input,outputs=out)
  35. model.compile(optimizer=optimizers.Adadelta(),loss=losses.categorical_crossentropy,metrics=[precision,recall,f1_score])
  36. model.summary()
  37. return model
  38. def get_context_form_model(vocab_len,char_dim,lstm_dim,context_dim):
  39. input = tf.placeholder(shape=[None,9,30])
  40. with tf.variable_scope("embedding"):
  41. embedding = tf.get_variable("char_embedding",shape=[vocab_len,char_dim])
  42. embedding_input = tf.nn.embedding_lookup(embedding,input)
  43. with tf.variable_scope("bi_rnn"):
  44. resize_input = tf.reshape(embedding_input,[-1,30,char_dim])
  45. forward_cell = tf.contrib.rnn.BasicLSTMCell(lstm_dim, state_is_tuple=True)
  46. backward_cell = tf.contrib.rnn.BasicLSTMCell(lstm_dim,state_is_tuple=True)
  47. outputs,forward_output,backward_output = tf.nn.static_bidirectional_rnn(forward_cell,backward_cell,resize_input)
  48. bi_output = tf.concat([forward_output,backward_output],-1) #[-1,lstm_dim*2]
  49. with tf.variable_scope("context"):
  50. context_out = tf.reshape(bi_output,[-1,9,lstm_dim*2])
  51. context_v = tf.get_variable("context_v",shape=[lstm_dim*2,context_dim])
  52. context_emb = tf.nn.relu(tf.matmul(context_out,context_v))
  53. center_out = tf.slice(context_out,[0,4,0],[-1,1,-1])
  54. if __name__=="__main__":
  55. #getBiLSTMModel()
  56. getTextCNNModel()