LoadBestModel.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import tensorflow as tf
  2. from utils import *
  3. from metrics import *
  4. from models import GCN_Align
  5. import gc
  6. import time
  7. import os
  8. def loadBestModel():
  9. dir_best_model = os.getcwd()+"\\data1\\100000\\zh_en\\model.ckpt"
  10. sess = tf.Session()
  11. # Define placeholders
  12. num_supports = 1
  13. ph_ae = {
  14. 'support': [tf.sparse_placeholder(tf.float32) for _ in range(num_supports)],
  15. 'features': tf.sparse_placeholder(tf.float32), #tf.placeholder(tf.float32),
  16. 'dropout': tf.placeholder_with_default(0., shape=()),
  17. 'num_features_nonzero': tf.placeholder_with_default(0, shape=())
  18. }
  19. ph_se = {
  20. 'support': [tf.sparse_placeholder(tf.float32) for _ in range(num_supports)],
  21. 'features': tf.placeholder(tf.float32),
  22. 'dropout': tf.placeholder_with_default(0., shape=()),
  23. 'num_features_nonzero': tf.placeholder_with_default(0, shape=())
  24. }
  25. # some flags
  26. flags = tf.app.flags
  27. FLAGS = flags.FLAGS
  28. flags.DEFINE_string('lang', 'zh_en', 'Dataset string.') # 'zh_en', 'ja_en', 'fr_en'
  29. flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
  30. flags.DEFINE_integer('epochs', 500, 'Number of epochs to train.')
  31. flags.DEFINE_float('dropout', 0.3, 'Dropout rate (1 - keep probability).')
  32. flags.DEFINE_float('gamma', 3.0, 'Hyper-parameter for margin based loss.')
  33. flags.DEFINE_integer('k', 4, 'Number of negative samples for each positive seed.')
  34. flags.DEFINE_float('beta', 0.3, 'Weight for structure embeddings.')
  35. flags.DEFINE_integer('se_dim', 100, 'Dimension for SE.')
  36. flags.DEFINE_integer('ae_dim', 100, 'Dimension for AE.')
  37. flags.DEFINE_integer('seed', 9, 'Proportion of seeds, 3 means 30%')
  38. # data process
  39. adj, ae_input, train, test = load_data(FLAGS.lang)
  40. support = [preprocess_adj(adj)]
  41. # 把具体值赋给事先定义好的placeholder
  42. feed_dict_ae = construct_feed_dict(ae_input, support, ph_ae)
  43. feed_dict_ae.update({ph_ae['dropout']: FLAGS.dropout})
  44. feed_dict_se = construct_feed_dict(1.0, support, ph_se)
  45. feed_dict_se.update({ph_se['dropout']: FLAGS.dropout})
  46. # 负样本填充placeholder
  47. t = 0
  48. k = 0
  49. e = ae_input[2][0]
  50. L = np.ones((t, k))
  51. neg_left = L.reshape((t * k,))
  52. L = np.ones((t, k))
  53. neg2_right = L.reshape((t * k,))
  54. neg2_left = np.random.choice(e, t * k)
  55. neg_right = np.random.choice(e, t * k)
  56. feed_dict_ae.update({'neg_left:0': neg_left, 'neg_right:0': neg_right, 'neg2_left:0': neg2_left, 'neg2_right:0': neg2_right})
  57. feed_dict_se.update({'neg_left:0': neg_left, 'neg_right:0': neg_right, 'neg2_left:0': neg2_left, 'neg2_right:0': neg2_right})
  58. # Create model
  59. model_func = GCN_Align
  60. # attribute embedding model
  61. model_ae = model_func(ph_ae, input_dim=ae_input[2][1], output_dim=FLAGS.ae_dim, ILL=train, sparse_inputs=True, featureless=False, logging=True)
  62. # structure embedding model
  63. model_se = model_func(ph_se, input_dim=ae_input[2][0], output_dim=FLAGS.se_dim, ILL=train, sparse_inputs=False, featureless=True, logging=True)
  64. # load model
  65. saver = tf.train.Saver()
  66. saver.restore(sess, dir_best_model)
  67. # run the last layer, get vector
  68. # print(len(feed_dict_ae))
  69. # for i in feed_dict_ae.keys():
  70. # print(i)
  71. vec_ae = sess.run(model_ae.outputs, feed_dict=feed_dict_ae)
  72. vec_se = sess.run(model_se.outputs, feed_dict=feed_dict_se)
  73. # 清内存
  74. print("清内存")
  75. del saver
  76. del model_ae
  77. del model_se
  78. del model_func
  79. del feed_dict_ae
  80. del feed_dict_se
  81. del adj
  82. del ae_input
  83. del train
  84. # del test
  85. del support
  86. del sess
  87. gc.collect()
  88. # print("AE")
  89. # get_hits(vec_ae, test)
  90. # print("SE")
  91. # get_hits(vec_se, test)
  92. # print("SE+AE")
  93. # get_combine_hits(vec_se, vec_ae, FLAGS.beta, test)
  94. #
  95. # calculate similarity
  96. # print("AE Similarity")
  97. # print(len(vec_ae), len(test))
  98. # predict(vec_ae, test)
  99. # print("SE Similarity")
  100. # predict(vec_se, test)
  101. # print("AE+SE Similarity")
  102. # predict(np.concatenate([vec_se*FLAGS.beta, vec_ae*(1.0-FLAGS.beta)], axis=1), test)
  103. print("Predict New Align Orgs")
  104. start_time = time.time()
  105. predict_new(np.concatenate([vec_se*FLAGS.beta, vec_ae*(1.0-FLAGS.beta)], axis=1))
  106. print("use time", time.time()-start_time)
  107. print("e"+str(FLAGS.epochs), "d"+str(FLAGS.dropout), "k"+str(FLAGS.k), "s"+str(FLAGS.seed),
  108. "lr"+str(FLAGS.learning_rate), "b"+str(FLAGS.beta))
  109. if __name__ == '__main__':
  110. loadBestModel()