selffool_ner.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. '''
  2. Created on 2019年11月29日
  3. @author: User
  4. '''
  5. import sys
  6. import os
  7. import numpy as np
  8. import tensorflow as tf
  9. from BiddingKG.dl.common.Utils import viterbi_decode
  10. from BiddingKG.dl.common.Utils import *
  11. from zipfile import ZipFile
  12. import json
  13. # ner_model_path = os.path.abspath("../savedModel")
  14. ner_model_path = os.path.abspath("../selffool_savedmodel_before")
  15. def _load_map_file(path, char_map_name, id_map_name):
  16. with ZipFile(path) as myzip:
  17. with myzip.open('all_map.json') as myfile:
  18. content = myfile.readline()
  19. content = content.decode()
  20. data = json.loads(content)
  21. return data.get(char_map_name), data.get(id_map_name)
  22. def decode(logits, trans, sequence_lengths, tag_num):
  23. viterbi_sequences = []
  24. small = -1000.0
  25. start = np.asarray([[small] * tag_num + [0]])
  26. for logit, length in zip(logits, sequence_lengths):
  27. score = logit[:length]
  28. pad = small * np.ones([length, 1])
  29. score = np.concatenate([score, pad], axis=1)
  30. score = np.concatenate([start, score], axis=0)
  31. viterbi_seq, viterbi_score = viterbi_decode(score, trans)
  32. viterbi_sequences.append(viterbi_seq[1:])
  33. return viterbi_sequences
  34. class SelfNer():
  35. def __init__(self):
  36. self.load_ner_model()
  37. def load_ner_model(self):
  38. config = {'lstm_dim':100,
  39. 'num_chars':6591,
  40. 'num_tags':25,
  41. 'char_dim':100,
  42. 'lr':0.00002,
  43. 'input_dropout_keep':1.0,
  44. 'optimizer':'adam',
  45. 'clip':5}
  46. self.config = config
  47. self.lstm_dim = config["lstm_dim"]
  48. self.num_chars = config["num_chars"]
  49. self.num_tags = config["num_tags"]
  50. self.char_dim = config["char_dim"]
  51. self.lr = config["lr"]
  52. self.char_to_id, self.id_to_seg = _load_map_file(os.path.dirname(__file__)+"/../data/map.zip", "char_map", "ner_map")
  53. self.id_to_tag = {int(k):v for k,v in self.id_to_seg.items()}
  54. self.tag_to_id = {v:int(k) for k,v in self.id_to_seg.items()}
  55. self.graph = tf.Graph()
  56. self.sess = tf.Session(graph=self.graph)
  57. with self.sess.as_default() as sess:
  58. meta_graph_def = tf.saved_model.loader.load(sess, ["serve"], os.path.dirname(__file__)+"/../savedModel")
  59. signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
  60. signature = meta_graph_def.signature_def
  61. self.char_inputs = sess.graph.get_tensor_by_name(signature[signature_key].inputs["char_inputs"].name)
  62. self.lengths = sess.graph.get_tensor_by_name(signature[signature_key].inputs["lengths"].name)
  63. self.logits = sess.graph.get_tensor_by_name(signature[signature_key].outputs["logits"].name)
  64. self.trans = sess.graph.get_tensor_by_name(signature[signature_key].outputs["trans"].name)
  65. self.dropout = sess.graph.get_tensor_by_name(signature[signature_key].inputs["dropout"].name)
  66. def predict(self,sess,sents):
  67. inputs = []
  68. lengths = [len(text) for text in sents]
  69. max_len = max(lengths)
  70. for sent in sents:
  71. sent_ids = [self.char_to_id.get(w) if w in self.char_to_id else self.char_to_id.get("<OOV>") for w in sent]
  72. padding = [0] * (max_len - len(sent_ids))
  73. sent_ids += padding
  74. inputs.append(sent_ids)
  75. inputs = np.array(inputs, dtype=np.int32)
  76. if USE_API:
  77. requests_result = requests.post(API_URL + "/predict_selfNer", json={"inputs": inputs.tolist(),'lengths':lengths}, verify=True)
  78. path = json.loads(requests_result.text)['result']
  79. else:
  80. feed_dict = {
  81. self.char_inputs: inputs,
  82. self.lengths: lengths,
  83. self.dropout: 1.0
  84. }
  85. logits, trans = sess.run([self.logits, self.trans], feed_dict=feed_dict)
  86. path = decode(logits, trans, lengths, self.num_tags)
  87. labels = [[self.id_to_tag.get(l) for l in p] for p in path]
  88. return labels
  89. def ner(self, text_list,ignore=False):
  90. text_list = _check_input(text_list,ignore)
  91. ner_labels = self.predict(self.sess,text_list)
  92. #print(ner_labels)
  93. all_entitys = []
  94. for ti, text in enumerate(text_list):
  95. ens = []
  96. entity = ""
  97. i = 0
  98. ner_label = ner_labels[ti]
  99. chars = list(text)
  100. for label, word in zip(ner_label, chars):
  101. # i += 1
  102. if label == "O":
  103. if entity:
  104. # print('不完整实体:', lb, entity)
  105. ens.append((i - len(entity), i, lt, entity))
  106. entity = ""
  107. i += 1
  108. continue
  109. lt = label.split("_")[1]
  110. lb = label.split("_")[0]
  111. if lb == "S":
  112. ens.append((i, i + 1, lt, word))
  113. elif lb == "B":
  114. entity = ""
  115. entity += word
  116. elif lb == "M":
  117. entity += word
  118. elif lb == "E":
  119. entity += word
  120. ens.append((i - len(entity)+1, i + 1, lt, entity))
  121. entity = ""
  122. i += 1
  123. all_entitys.append(ens)
  124. return all_entitys
  125. def _check_input(text, ignore=False):
  126. if not text:
  127. return []
  128. if not isinstance(text, list):
  129. text = [text]
  130. null_index = [i for i, t in enumerate(text) if not t]
  131. if null_index and not ignore:
  132. raise Exception("null text in input ")
  133. return text