train.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from __future__ import division
  2. from __future__ import print_function
  3. import time
  4. import tensorflow as tf
  5. from utils import *
  6. from metrics import *
  7. from models import GCN_Align
  8. import os
  9. from LoadBestModel import loadBestModel
  10. dir_best_model = os.getcwd()+"\\data1\\100000\\zh_en\\model.ckpt"
  11. # Set random seed
  12. seed = 12306
  13. np.random.seed(seed)
  14. tf.set_random_seed(seed)
  15. # Settings
  16. flags = tf.app.flags
  17. FLAGS = flags.FLAGS
  18. flags.DEFINE_string('lang', 'zh_en', 'Dataset string.') # 'zh_en', 'ja_en', 'fr_en'
  19. flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
  20. flags.DEFINE_integer('epochs', 500, 'Number of epochs to train.')
  21. flags.DEFINE_float('dropout', 0.3, 'Dropout rate (1 - keep probability).')
  22. flags.DEFINE_float('gamma', 3.0, 'Hyper-parameter for margin based loss.')
  23. flags.DEFINE_integer('k', 4, 'Number of negative samples for each positive seed.')
  24. flags.DEFINE_float('beta', 0.3, 'Weight for structure embeddings.')
  25. flags.DEFINE_integer('se_dim', 100, 'Dimension for SE.')
  26. flags.DEFINE_integer('ae_dim', 100, 'Dimension for AE.')
  27. flags.DEFINE_integer('seed', 9, 'Proportion of seeds, 3 means 30%')
  28. # Load data
  29. # adj:邻接矩阵 structure embedding 来自triples1、triples2,并计算分数
  30. # ae_input: attribute embedding 来自training_attrs_1、training_attrs_2和ent_id_1、ent_id_2
  31. # train/test: train data 来自ref_ent_ids
  32. adj, ae_input, train, test = load_data(FLAGS.lang)
  33. # Some preprocessing
  34. support = [preprocess_adj(adj)]
  35. # print("pre adj ===================")
  36. # print(support)
  37. num_supports = 1
  38. model_func = GCN_Align
  39. k = FLAGS.k
  40. e = ae_input[2][0]
  41. # Define placeholders
  42. ph_ae = {
  43. 'support': [tf.sparse_placeholder(tf.float32) for _ in range(num_supports)],
  44. 'features': tf.sparse_placeholder(tf.float32), #tf.placeholder(tf.float32),
  45. 'dropout': tf.placeholder_with_default(0., shape=()),
  46. 'num_features_nonzero': tf.placeholder_with_default(0, shape=())
  47. }
  48. ph_se = {
  49. 'support': [tf.sparse_placeholder(tf.float32) for _ in range(num_supports)],
  50. 'features': tf.placeholder(tf.float32),
  51. 'dropout': tf.placeholder_with_default(0., shape=()),
  52. 'num_features_nonzero': tf.placeholder_with_default(0, shape=())
  53. }
  54. # Create model
  55. # attribute embedding model
  56. model_ae = model_func(ph_ae, input_dim=ae_input[2][1], output_dim=FLAGS.ae_dim, ILL=train,
  57. sparse_inputs=True, featureless=False, AE=True, logging=True)
  58. # structure embedding model
  59. model_se = model_func(ph_se, input_dim=ae_input[2][0], output_dim=FLAGS.se_dim, ILL=train,
  60. sparse_inputs=False, featureless=True, AE=False, logging=True)
  61. # Initialize session
  62. sess = tf.Session()
  63. # Init variables
  64. sess.run(tf.global_variables_initializer())
  65. cost_val = []
  66. #
  67. t = len(train)
  68. L = np.ones((t, k)) * (train[:, 0].reshape((t, 1)))
  69. neg_left = L.reshape((t * k,))
  70. L = np.ones((t, k)) * (train[:, 1].reshape((t, 1)))
  71. neg2_right = L.reshape((t * k,))
  72. # print("neg_left ===================")
  73. # print(neg_left)
  74. # print("neg2_right ===================")
  75. # print(neg2_right)
  76. # Train model
  77. saver = tf.train.Saver()
  78. AE_train_loss = 10
  79. SE_train_loss = 10
  80. for epoch in range(FLAGS.epochs):
  81. if epoch % 10 == 0:
  82. neg2_left = np.random.choice(e, t * k)
  83. neg_right = np.random.choice(e, t * k)
  84. # Construct feed dictionary
  85. # 把具体值赋给事先定义好的placeholder
  86. feed_dict_ae = construct_feed_dict(ae_input, support, ph_ae)
  87. feed_dict_ae.update({ph_ae['dropout']: FLAGS.dropout})
  88. feed_dict_ae.update({'neg_left:0': neg_left, 'neg_right:0': neg_right, 'neg2_left:0': neg2_left, 'neg2_right:0': neg2_right})
  89. feed_dict_se = construct_feed_dict(1.0, support, ph_se)
  90. feed_dict_se.update({ph_se['dropout']: FLAGS.dropout})
  91. feed_dict_se.update({'neg_left:0': neg_left, 'neg_right:0': neg_right, 'neg2_left:0': neg2_left, 'neg2_right:0': neg2_right})
  92. # print(len(feed_dict_ae))
  93. # for i in feed_dict_ae.keys():
  94. # print(i)
  95. # Training step
  96. # session动态传数据
  97. outs_ae = sess.run([model_ae.opt_op, model_ae.loss], feed_dict=feed_dict_ae)
  98. outs_se = sess.run([model_se.opt_op, model_se.loss], feed_dict=feed_dict_se)
  99. cost_val.append((outs_ae[1], outs_se[1]))
  100. # Print results
  101. print("Epoch:", '%04d' % (epoch + 1), "AE_train_loss=", "{:.5f}".format(outs_ae[1]), "SE_train_loss=", "{:.5f}".format(outs_se[1]))
  102. # save best model
  103. if (outs_ae[1] <= AE_train_loss and outs_se[1] <= SE_train_loss) :
  104. # or outs_ae[1] <= 0.02:
  105. saver.save(sess, dir_best_model)
  106. AE_train_loss = outs_ae[1]
  107. SE_train_loss = outs_se[1]
  108. print("Save best Model!")
  109. print("Optimization Finished!")
  110. print("e"+str(FLAGS.epochs), "d"+str(FLAGS.dropout), "k"+str(FLAGS.k), "s"+str(FLAGS.seed),
  111. "lr"+str(FLAGS.learning_rate), "b"+str(FLAGS.beta))
  112. # loadBestModel()
  113. # Testing
  114. # feed_dict_ae = construct_feed_dict(ae_input, support, ph_ae)
  115. # feed_dict_se = construct_feed_dict(1.0, support, ph_se)
  116. # vec_ae = sess.run(model_ae.outputs, feed_dict=feed_dict_ae)
  117. # vec_se = sess.run(model_se.outputs, feed_dict=feed_dict_se)
  118. # print("AE")
  119. # get_hits(vec_ae, test)
  120. # print("SE")
  121. # get_hits(vec_se, test)
  122. # print("SE+AE")
  123. # get_combine_hits(vec_se, vec_ae, FLAGS.beta, test)
  124. #
  125. # print("Predict Similarity")
  126. # print("AE Similarity")
  127. # predict(vec_ae, test)
  128. # print("SE Similarity")
  129. # predict(vec_se, test)
  130. # print("AE+SE Similarity")
  131. # predict(np.concatenate([vec_se*FLAGS.beta, vec_ae*(1.0-FLAGS.beta)], axis=1), test)