predictor.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #!/usr/bin/env python
  2. # -*-coding:utf-8-*-
  3. import tensorflow as tf
  4. import numpy as np
  5. from BiddingKG.dl.common.Utils import viterbi_decode
  6. from BiddingKG.dl.common.Utils import *
  7. def decode(logits, trans, sequence_lengths, tag_num):
  8. viterbi_sequences = []
  9. small = -1000.0
  10. start = np.asarray([[small] * tag_num + [0]])
  11. for logit, length in zip(logits, sequence_lengths):
  12. score = logit[:length]
  13. pad = small * np.ones([length, 1])
  14. score = np.concatenate([score, pad], axis=1)
  15. score = np.concatenate([start, score], axis=0)
  16. viterbi_seq, viterbi_score = viterbi_decode(score, trans)
  17. viterbi_sequences.append(viterbi_seq[1:])
  18. return viterbi_sequences
  19. def list_to_array(data_list, dtype=np.int32):
  20. array = np.array(data_list, dtype).reshape(1, len(data_list))
  21. return array
  22. def load_graph(path):
  23. with tf.gfile.GFile(path, "rb") as f:
  24. graph_def = tf.GraphDef()
  25. graph_def.ParseFromString(f.read())
  26. with tf.Graph().as_default() as graph:
  27. tf.import_graph_def(graph_def, name="prefix")
  28. return graph
  29. class Predictor(object):
  30. def __init__(self, model_file, char_to_id, id_to_tag,url,authorization):
  31. self.char_to_id = char_to_id
  32. self.id_to_tag = {int(k):v for k,v in id_to_tag.items()}
  33. self.graph = load_graph(model_file)
  34. self.url = url
  35. self.authorization = authorization
  36. self.input_x = self.graph.get_tensor_by_name("prefix/char_inputs:0")
  37. self.char_embedding = self.graph.get_tensor_by_name("prefix/char_embeding:0")
  38. self.lengths = self.graph.get_tensor_by_name("prefix/lengths:0")
  39. self.dropout = self.graph.get_tensor_by_name("prefix/dropout:0")
  40. self.logits = self.graph.get_tensor_by_name("prefix/project/logits:0")
  41. self.trans = self.graph.get_tensor_by_name("prefix/crf_loss/transitions:0")
  42. self.sess = tf.Session(graph=self.graph)
  43. print("====",model_file,self.sess.run(self.trans))
  44. print("@@@@",self.id_to_tag)
  45. self.sess.as_default()
  46. self.num_class = len(self.id_to_tag)
  47. def predict(self, sents):
  48. inputs = []
  49. lengths = [len(text) for text in sents]
  50. max_len = max(lengths)
  51. for sent in sents:
  52. sent_ids = [self.char_to_id.get(w) if w in self.char_to_id else self.char_to_id.get("<OOV>") for w in sent]
  53. padding = [0] * (max_len - len(sent_ids))
  54. sent_ids += padding
  55. inputs.append(sent_ids)
  56. inputs = np.array(inputs, dtype=np.int32)
  57. if USE_API and self.url and self.authorization:
  58. requests_result = requests.post(API_URL + "/predict_tokens",json={"inputs": inputs.tolist(), 'lengths': lengths}, verify=True)
  59. path = json.loads(requests_result.text)['result']
  60. else:
  61. feed_dict = {
  62. self.input_x: inputs,
  63. self.lengths: lengths,
  64. self.dropout: 1.0
  65. }
  66. logits, trans = self.sess.run([self.logits, self.trans], feed_dict=feed_dict)
  67. path = decode(logits, trans, lengths, self.num_class)
  68. labels = [[self.id_to_tag.get(l) for l in p] for p in path]
  69. return labels