123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- #!/usr/bin/env python
- # -*-coding:utf-8-*-
- import tensorflow as tf
- import numpy as np
- from BiddingKG.dl.common.Utils import viterbi_decode
- from BiddingKG.dl.common.Utils import *
- def decode(logits, trans, sequence_lengths, tag_num):
- viterbi_sequences = []
- small = -1000.0
- start = np.asarray([[small] * tag_num + [0]])
- for logit, length in zip(logits, sequence_lengths):
- score = logit[:length]
- pad = small * np.ones([length, 1])
- score = np.concatenate([score, pad], axis=1)
- score = np.concatenate([start, score], axis=0)
- viterbi_seq, viterbi_score = viterbi_decode(score, trans)
- viterbi_sequences.append(viterbi_seq[1:])
- return viterbi_sequences
- def list_to_array(data_list, dtype=np.int32):
- array = np.array(data_list, dtype).reshape(1, len(data_list))
- return array
- def load_graph(path):
- with tf.gfile.GFile(path, "rb") as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- with tf.Graph().as_default() as graph:
- tf.import_graph_def(graph_def, name="prefix")
- return graph
- class Predictor(object):
- def __init__(self, model_file, char_to_id, id_to_tag,url,authorization):
- self.char_to_id = char_to_id
- self.id_to_tag = {int(k):v for k,v in id_to_tag.items()}
- self.graph = load_graph(model_file)
-
- self.url = url
- self.authorization = authorization
- self.input_x = self.graph.get_tensor_by_name("prefix/char_inputs:0")
- self.char_embedding = self.graph.get_tensor_by_name("prefix/char_embeding:0")
- self.lengths = self.graph.get_tensor_by_name("prefix/lengths:0")
- self.dropout = self.graph.get_tensor_by_name("prefix/dropout:0")
- self.logits = self.graph.get_tensor_by_name("prefix/project/logits:0")
- self.trans = self.graph.get_tensor_by_name("prefix/crf_loss/transitions:0")
- self.sess = tf.Session(graph=self.graph)
- print("====",model_file,self.sess.run(self.trans))
- print("@@@@",self.id_to_tag)
- self.sess.as_default()
- self.num_class = len(self.id_to_tag)
- def predict(self, sents):
- inputs = []
- lengths = [len(text) for text in sents]
- max_len = max(lengths)
- for sent in sents:
- sent_ids = [self.char_to_id.get(w) if w in self.char_to_id else self.char_to_id.get("<OOV>") for w in sent]
- padding = [0] * (max_len - len(sent_ids))
- sent_ids += padding
- inputs.append(sent_ids)
- inputs = np.array(inputs, dtype=np.int32)
-
- if USE_PAI_EAS and self.url and self.authorization:
- request = tf_predict_pb2.PredictRequest()
- request.inputs["char_inputs"].dtype = tf_predict_pb2.DT_INT32
- request.inputs["char_inputs"].array_shape.dim.extend(np.shape(inputs))
- request.inputs["char_inputs"].int_val.extend(np.array(inputs,dtype=np.int32).reshape(-1))
- request.inputs["lengths"].dtype = tf_predict_pb2.DT_INT32
- request.inputs["lengths"].array_shape.dim.extend(np.shape(lengths))
- request.inputs["lengths"].int_val.extend(np.array(lengths,dtype=np.int32).reshape(-1))
- request.inputs["dropout"].dtype = tf_predict_pb2.DT_FLOAT
- request.inputs["dropout"].float_val.extend([1.0])
- request_data = request.SerializeToString()
- list_outputs = ["logits","trans"]
- result = vpc_requests(self.url, self.authorization, request_data, list_outputs)
- if result is not None:
- logits = result["logits"]
- trans = result["trans"]
- else:
- feed_dict = {
- self.input_x: inputs,
- self.lengths: lengths,
- self.dropout: 1.0
- }
-
- logits, trans = self.sess.run([self.logits, self.trans], feed_dict=feed_dict)
- else:
-
- feed_dict = {
- self.input_x: inputs,
- self.lengths: lengths,
- self.dropout: 1.0
- }
- logits, trans = self.sess.run([self.logits, self.trans], feed_dict=feed_dict)
- path = decode(logits, trans, lengths, self.num_class)
- labels = [[self.id_to_tag.get(l) for l in p] for p in path]
- return labels
|