selffool_ner.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. '''
  2. Created on 2019年11月29日
  3. @author: User
  4. '''
  5. import sys
  6. import os
  7. import numpy as np
  8. import tensorflow as tf
  9. from BiddingKG.dl.common.Utils import viterbi_decode
  10. from BiddingKG.dl.common.Utils import *
  11. from zipfile import ZipFile
  12. import json
  13. # ner_model_path = os.path.abspath("../savedModel")
  14. ner_model_path = os.path.abspath("../selffool_savedmodel_before")
  15. def _load_map_file(path, char_map_name, id_map_name):
  16. with ZipFile(path) as myzip:
  17. with myzip.open('all_map.json') as myfile:
  18. content = myfile.readline()
  19. content = content.decode()
  20. data = json.loads(content)
  21. return data.get(char_map_name), data.get(id_map_name)
  22. def decode(logits, trans, sequence_lengths, tag_num):
  23. viterbi_sequences = []
  24. small = -1000.0
  25. start = np.asarray([[small] * tag_num + [0]])
  26. for logit, length in zip(logits, sequence_lengths):
  27. score = logit[:length]
  28. pad = small * np.ones([length, 1])
  29. score = np.concatenate([score, pad], axis=1)
  30. score = np.concatenate([start, score], axis=0)
  31. viterbi_seq, viterbi_score = viterbi_decode(score, trans)
  32. viterbi_sequences.append(viterbi_seq[1:])
  33. return viterbi_sequences
  34. class SelfNer():
  35. def __init__(self):
  36. self.load_ner_model()
  37. def load_ner_model(self):
  38. config = {'lstm_dim':100,
  39. 'num_chars':6591,
  40. 'num_tags':25,
  41. 'char_dim':100,
  42. 'lr':0.00002,
  43. 'input_dropout_keep':1.0,
  44. 'optimizer':'adam',
  45. 'clip':5}
  46. self.config = config
  47. self.lstm_dim = config["lstm_dim"]
  48. self.num_chars = config["num_chars"]
  49. self.num_tags = config["num_tags"]
  50. self.char_dim = config["char_dim"]
  51. self.lr = config["lr"]
  52. self.char_to_id, self.id_to_seg = _load_map_file(os.path.dirname(__file__)+"/../data/map.zip", "char_map", "ner_map")
  53. self.id_to_tag = {int(k):v for k,v in self.id_to_seg.items()}
  54. self.tag_to_id = {v:int(k) for k,v in self.id_to_seg.items()}
  55. self.graph = tf.Graph()
  56. self.sess = tf.Session(graph=self.graph)
  57. with self.sess.as_default() as sess:
  58. meta_graph_def = tf.saved_model.loader.load(sess, ["serve"], os.path.dirname(__file__)+"/../savedModel")
  59. signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
  60. signature = meta_graph_def.signature_def
  61. self.char_inputs = sess.graph.get_tensor_by_name(signature[signature_key].inputs["char_inputs"].name)
  62. self.lengths = sess.graph.get_tensor_by_name(signature[signature_key].inputs["lengths"].name)
  63. self.logits = sess.graph.get_tensor_by_name(signature[signature_key].outputs["logits"].name)
  64. self.trans = sess.graph.get_tensor_by_name(signature[signature_key].outputs["trans"].name)
  65. self.dropout = sess.graph.get_tensor_by_name(signature[signature_key].inputs["dropout"].name)
  66. def predict(self,sess,sents):
  67. inputs = []
  68. lengths = [len(text) for text in sents]
  69. max_len = max(lengths)
  70. for sent in sents:
  71. sent_ids = [self.char_to_id.get(w) if w in self.char_to_id else self.char_to_id.get("<OOV>") for w in sent]
  72. padding = [0] * (max_len - len(sent_ids))
  73. sent_ids += padding
  74. inputs.append(sent_ids)
  75. inputs = np.array(inputs, dtype=np.int32)
  76. if USE_PAI_EAS:
  77. request = tf_predict_pb2.PredictRequest()
  78. request.inputs["char_inputs"].dtype = tf_predict_pb2.DT_INT32
  79. request.inputs["char_inputs"].array_shape.dim.extend(np.shape(inputs))
  80. request.inputs["char_inputs"].int_val.extend(np.array(inputs,dtype=np.int32).reshape(-1))
  81. request.inputs["lengths"].dtype = tf_predict_pb2.DT_INT32
  82. request.inputs["lengths"].array_shape.dim.extend(np.shape(lengths))
  83. request.inputs["lengths"].int_val.extend(np.array(lengths,dtype=np.int32).reshape(-1))
  84. request.inputs["dropout"].dtype = tf_predict_pb2.DT_FLOAT
  85. request.inputs["dropout"].float_val.extend([1.0])
  86. request_data = request.SerializeToString()
  87. list_outputs = ["logits","trans"]
  88. result = vpc_requests(selffool_url, selffool_authorization, request_data, list_outputs)
  89. if result is not None:
  90. logits = result["logits"]
  91. trans = result["trans"]
  92. else:
  93. feed_dict = {
  94. self.char_inputs: inputs,
  95. self.lengths: lengths,
  96. self.dropout: 1.0
  97. }
  98. logits, trans = sess.run([self.logits, self.trans], feed_dict=feed_dict)
  99. else:
  100. feed_dict = {
  101. self.char_inputs: inputs,
  102. self.lengths: lengths,
  103. self.dropout: 1.0
  104. }
  105. logits, trans = sess.run([self.logits, self.trans], feed_dict=feed_dict)
  106. path = decode(logits, trans, lengths, self.num_tags)
  107. labels = [[self.id_to_tag.get(l) for l in p] for p in path]
  108. return labels
  109. def ner(self, text_list,ignore=False):
  110. text_list = _check_input(text_list,ignore)
  111. ner_labels = self.predict(self.sess,text_list)
  112. #print(ner_labels)
  113. all_entitys = []
  114. for ti, text in enumerate(text_list):
  115. ens = []
  116. entity = ""
  117. i = 0
  118. ner_label = ner_labels[ti]
  119. chars = list(text)
  120. for label, word in zip(ner_label, chars):
  121. i += 1
  122. if label == "O":
  123. continue
  124. lt = label.split("_")[1]
  125. lb = label.split("_")[0]
  126. if lb == "S":
  127. ens.append((i, i + 1, lt, word))
  128. elif lb == "B":
  129. entity = ""
  130. entity += word
  131. elif lb == "M":
  132. entity += word
  133. elif lb == "E":
  134. entity += word
  135. ens.append((i - len(entity), i + 1, lt, entity))
  136. entity = ""
  137. if entity:
  138. ens.append((i - len(entity), i + 1, lt, entity))
  139. all_entitys.append(ens)
  140. return all_entitys
  141. def _check_input(text, ignore=False):
  142. if not text:
  143. return []
  144. if not isinstance(text, list):
  145. text = [text]
  146. null_index = [i for i, t in enumerate(text) if not t]
  147. if null_index and not ignore:
  148. raise Exception("null text in input ")
  149. return text