model.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #!/usr/bin/env python
  2. #-*-coding:utf-8-*-
  3. import tensorflow as tf
  4. import pickle
  5. import numpy as np
  6. from BiddingKG.dl.common.Utils import viterbi_decode
  7. def decode(logits, trans, sequence_lengths, tag_num):
  8. viterbi_sequences = []
  9. small = -1000.0
  10. start = np.asarray([[small] * tag_num + [0]])
  11. for logit, length in zip(logits, sequence_lengths):
  12. score = logit[:length]
  13. pad = small * np.ones([length, 1])
  14. logits = np.concatenate([score, pad], axis=1)
  15. logits = np.concatenate([start, logits], axis=0)
  16. viterbi_seq, viterbi_score = viterbi_decode(logits, trans)
  17. viterbi_sequences .append(viterbi_seq[1:])
  18. return viterbi_sequences
  19. def load_map(path):
  20. with open(path, 'rb') as f:
  21. char_to_id, tag_to_id, id_to_tag = pickle.load(f)
  22. return char_to_id, id_to_tag
  23. def load_graph(path):
  24. with tf.gfile.GFile(path, mode='rb') as f:
  25. graph_def = tf.GraphDef()
  26. graph_def.ParseFromString(f.read())
  27. with tf.Graph().as_default() as graph:
  28. tf.import_graph_def(graph_def, name="prefix")
  29. return graph
  30. class Model(object):
  31. def __init__(self, map_file, model_file):
  32. self.char_to_id, self.id_to_tag = load_map(map_file)
  33. self.graph = load_graph(model_file)
  34. self.input_x = self.graph.get_tensor_by_name("prefix/char_inputs:0")
  35. self.lengths = self.graph.get_tensor_by_name("prefix/lengths:0")
  36. self.dropout = self.graph.get_tensor_by_name("prefix/dropout:0")
  37. self.logits = self.graph.get_tensor_by_name("prefix/project/logits:0")
  38. self.trans = self.graph.get_tensor_by_name("prefix/crf_loss/transitions:0")
  39. self.sess = tf.Session(graph=self.graph)
  40. self.sess.as_default()
  41. self.num_class = len(self.id_to_tag)
  42. def predict(self, sents):
  43. inputs = []
  44. lengths = [len(text) for text in sents]
  45. max_len = max(lengths)
  46. for sent in sents:
  47. sent_ids = [self.char_to_id.get(w) if w in self.char_to_id else self.char_to_id.get("<OOV>") for w in sent]
  48. padding = [0] * (max_len - len(sent_ids))
  49. sent_ids += padding
  50. inputs.append(sent_ids)
  51. inputs = np.array(inputs, dtype=np.int32)
  52. feed_dict = {
  53. self.input_x: inputs,
  54. self.lengths: lengths,
  55. self.dropout: 1.0
  56. }
  57. logits, trans = self.sess.run([self.logits, self.trans], feed_dict=feed_dict)
  58. path = decode(logits, trans, lengths, self.num_class)
  59. labels = [[self.id_to_tag.get(l) for l in p] for p in path]
  60. return labels