5.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. '''
  2. Created on 2019年1月15日
  3. @author: User
  4. '''
  5. import tensorflow as tf
  6. from tensorflow.contrib.crf import crf_log_likelihood
  7. path = "D://Anaconda3.4//envs//dl_nlp//fool//pos.pb"
  8. def loss_layer(project_logits,y_target,trans,max_steps):
  9. with tf.variable_scope("crf_loss1"):
  10. log_likelihood, trans = crf_log_likelihood(inputs=project_logits, tag_indices=y_target,
  11. transition_params=trans, sequence_lengths=max_steps)
  12. return tf.reduce_mean(-log_likelihood)
  13. def load_graph(path):
  14. with tf.gfile.GFile(path, mode='rb') as f:
  15. graph_def = tf.GraphDef()
  16. graph_def.ParseFromString(f.read())
  17. for i,n in enumerate(graph_def.node):
  18. print("Name of the node - %s" % n.name)
  19. with tf.Graph().as_default() as graph:
  20. tf.import_graph_def(graph_def, name="prefix")
  21. return graph
  22. '''
  23. with tf.gfile.GFile(path, mode='rb') as f:
  24. graph_def = tf.GraphDef()
  25. graph_def.ParseFromString(f.read())
  26. for i,n in enumerate(graph_def.node):
  27. print("Name of the node - %s" % n.name)
  28. with tf.Graph().as_default() as graph:
  29. tf.import_graph_def(graph_def)
  30. trans = graph.get_tensor_by_name("prefix/crf_loss/transitions:0")
  31. logits = graph.get_tensor_by_name("prefix/project/logits:0")
  32. y_target = tf.placeholder()
  33. loss = loss_layer(logits, y_target, trans, 100)
  34. summaryWriter = tf.summary.FileWriter('log/', graph)
  35. #tf.Graph().get_operations()
  36. '''
  37. def buildModel():
  38. graph = load_graph(path)
  39. with graph.as_default():
  40. trans = graph.get_tensor_by_name("prefix/crf_loss/transitions:0")
  41. lengths = graph.get_tensor_by_name("prefix/lengths:0")
  42. logits = graph.get_tensor_by_name("prefix/project/logits:0")
  43. print(logits)
  44. print(trans)
  45. y_target = tf.placeholder(dtype=tf.int32, shape=[None, None], name='y_target')
  46. #loss = loss_layer(logits, y_target, trans, lengths)
  47. summaryWriter = tf.summary.FileWriter('log/', graph)
  48. if __name__=="__main__":
  49. # import fool
  50. # a = fool.LEXICAL_ANALYSER
  51. # a._load_ner_model()
  52. # _dict = a.ner_model.id_to_tag
  53. # for _key in _dict.keys():
  54. # print(_key,_dict[_key])
  55. # load_graph(path)
  56. a = [1,2,3,44]
  57. print(a[-100:])