''' Created on 2019年11月29日 @author: User ''' import sys import os import numpy as np import tensorflow as tf from BiddingKG.dl.common.Utils import viterbi_decode from BiddingKG.dl.common.Utils import * from zipfile import ZipFile import json # ner_model_path = os.path.abspath("../savedModel") ner_model_path = os.path.abspath("../selffool_savedmodel_before") def _load_map_file(path, char_map_name, id_map_name): with ZipFile(path) as myzip: with myzip.open('all_map.json') as myfile: content = myfile.readline() content = content.decode() data = json.loads(content) return data.get(char_map_name), data.get(id_map_name) def decode(logits, trans, sequence_lengths, tag_num): viterbi_sequences = [] small = -1000.0 start = np.asarray([[small] * tag_num + [0]]) for logit, length in zip(logits, sequence_lengths): score = logit[:length] pad = small * np.ones([length, 1]) score = np.concatenate([score, pad], axis=1) score = np.concatenate([start, score], axis=0) viterbi_seq, viterbi_score = viterbi_decode(score, trans) viterbi_sequences.append(viterbi_seq[1:]) return viterbi_sequences class SelfNer(): def __init__(self): self.load_ner_model() def load_ner_model(self): config = {'lstm_dim':100, 'num_chars':6591, 'num_tags':25, 'char_dim':100, 'lr':0.00002, 'input_dropout_keep':1.0, 'optimizer':'adam', 'clip':5} self.config = config self.lstm_dim = config["lstm_dim"] self.num_chars = config["num_chars"] self.num_tags = config["num_tags"] self.char_dim = config["char_dim"] self.lr = config["lr"] self.char_to_id, self.id_to_seg = _load_map_file(os.path.dirname(__file__)+"/../data/map.zip", "char_map", "ner_map") self.id_to_tag = {int(k):v for k,v in self.id_to_seg.items()} self.tag_to_id = {v:int(k) for k,v in self.id_to_seg.items()} self.graph = tf.Graph() self.sess = tf.Session(graph=self.graph) with self.sess.as_default() as sess: meta_graph_def = tf.saved_model.loader.load(sess, ["serve"], os.path.dirname(__file__)+"/../savedModel") signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY signature = meta_graph_def.signature_def self.char_inputs = sess.graph.get_tensor_by_name(signature[signature_key].inputs["char_inputs"].name) self.lengths = sess.graph.get_tensor_by_name(signature[signature_key].inputs["lengths"].name) self.logits = sess.graph.get_tensor_by_name(signature[signature_key].outputs["logits"].name) self.trans = sess.graph.get_tensor_by_name(signature[signature_key].outputs["trans"].name) self.dropout = sess.graph.get_tensor_by_name(signature[signature_key].inputs["dropout"].name) def predict(self,sess,sents): inputs = [] lengths = [len(text) for text in sents] max_len = max(lengths) for sent in sents: sent_ids = [self.char_to_id.get(w) if w in self.char_to_id else self.char_to_id.get("") for w in sent] padding = [0] * (max_len - len(sent_ids)) sent_ids += padding inputs.append(sent_ids) inputs = np.array(inputs, dtype=np.int32) if USE_API: requests_result = requests.post(API_URL + "/predict_selfNer", json={"inputs": inputs.tolist(),'lengths':lengths}, verify=True) path = json.loads(requests_result.text)['result'] else: feed_dict = { self.char_inputs: inputs, self.lengths: lengths, self.dropout: 1.0 } logits, trans = sess.run([self.logits, self.trans], feed_dict=feed_dict) path = decode(logits, trans, lengths, self.num_tags) labels = [[self.id_to_tag.get(l) for l in p] for p in path] return labels def ner(self, text_list,ignore=False): text_list = _check_input(text_list,ignore) ner_labels = self.predict(self.sess,text_list) #print(ner_labels) all_entitys = [] for ti, text in enumerate(text_list): ens = [] entity = "" i = 0 ner_label = ner_labels[ti] chars = list(text) for label, word in zip(ner_label, chars): # i += 1 if label == "O": if entity: # print('不完整实体:', lb, entity) ens.append((i - len(entity), i, lt, entity)) entity = "" i += 1 continue lt = label.split("_")[1] lb = label.split("_")[0] if lb == "S": ens.append((i, i + 1, lt, word)) elif lb == "B": entity = "" entity += word elif lb == "M": entity += word elif lb == "E": entity += word ens.append((i - len(entity)+1, i + 1, lt, entity)) entity = "" i += 1 all_entitys.append(ens) return all_entitys def _check_input(text, ignore=False): if not text: return [] if not isinstance(text, list): text = [text] null_index = [i for i, t in enumerate(text) if not t] if null_index and not ignore: raise Exception("null text in input ") return text