BertCRF.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. '''
  2. Created on 2019年12月31日
  3. @author: User
  4. '''
  5. # from tensorflow.contrib import rnn
  6. # from tensorflow.contrib.crf import crf_log_likelihood
  7. # from tensorflow.contrib.layers.python.layers import initializers
  8. from zipfile import ZipFile
  9. from BiddingKG.dl_dev.BertNer.BertModel import *
  10. from BiddingKG.dl.common.Utils import *
  11. def _load_map_file(path, char_map_name, id_map_name):
  12. with ZipFile(path) as myzip:
  13. with myzip.open('all_map.json') as myfile:
  14. content = myfile.readline()
  15. content = content.decode()
  16. data = json.loads(content)
  17. return data.get(char_map_name), data.get(id_map_name)
  18. def shape_list(x):
  19. """Return list of dims, statically where possible."""
  20. x = tf.convert_to_tensor(x)
  21. # If unknown rank, return dynamic shape
  22. if x.get_shape().dims is None:
  23. return tf.shape(x)
  24. static = x.get_shape().as_list()
  25. shape = tf.shape(x)
  26. ret = []
  27. for i in range(len(static)):
  28. dim = static[i]
  29. if dim is None:
  30. dim = shape[i]
  31. ret.append(dim)
  32. return ret
  33. class BertCRF(object):
  34. def __init__(self):
  35. config = {'lstm_dim':100,
  36. 'num_chars':6591,
  37. 'num_tags':25,
  38. 'char_dim':100,
  39. 'lr':0.00002,
  40. 'input_dropout_keep':1.0,
  41. 'optimizer':'adam',
  42. 'clip':5,
  43. 'bert_hidden':100}
  44. self.config = config
  45. self.lstm_dim = config["lstm_dim"]
  46. self.num_chars = config["num_chars"]
  47. self.num_tags = config["num_tags"]
  48. self.char_dim = config["char_dim"]
  49. self.lr = config["lr"]
  50. self.bert_hidden = config["bert_hidden"]
  51. self.graph = tf.Graph()
  52. with self.graph.as_default():
  53. self.char_to_id, self.id_to_seg = _load_map_file(os.path.dirname(__file__)+"/data/map.zip", "char_map", "ner_map")
  54. self.id_to_tag = {int(k):v for k,v in self.id_to_seg.items()}
  55. self.tag_to_id = {v:int(k) for k,v in self.id_to_seg.items()}
  56. #self.char_embeding = tf.get_variable(name="char_embeding", initializer=embeddings)
  57. self.char_embeding = tf.get_variable(name="char_embeding",shape=(self.num_chars,self.char_dim))
  58. #添加一串全0的坑,fool发行版和源代码不一样
  59. self.const = tf.constant(value=0,dtype=tf.float32,shape=[1,100])
  60. self.char_embeding = tf.concat([self.const,self.char_embeding],0)
  61. self.global_step = tf.Variable(0, trainable=False)
  62. self.initializer = initializers.xavier_initializer()
  63. self.char_inputs = tf.placeholder(dtype=tf.int32, shape=[None, None], name="char_inputs")
  64. self.targets = tf.placeholder(dtype=tf.int32, shape=[None, None], name="targets")
  65. self.dropout = tf.placeholder(dtype=tf.float32, name="dropout")
  66. self.lengths = tf.placeholder(dtype=tf.int32, shape=[None, ], name="lengths")
  67. # self.middle_dropout_keep_prob = tf.placeholder_with_default(1.0, [], name="middle_dropout_keep_prob")
  68. # self.hidden_dropout_keep_prob = tf.placeholder_with_default(1.0, [], name="hidden_dropout_keep_prob")
  69. self.input_dropout_keep_prob = tf.placeholder_with_default(config["input_dropout_keep"], [], name="input_dropout_keep_prob")
  70. self.batch_size = tf.shape(self.char_inputs)[0]
  71. self.num_steps = tf.shape(self.char_inputs)[-1]
  72. # forward
  73. embedding = self.embedding_layer(self.char_inputs)
  74. bert_outputs = self.bert_layer(embedding)
  75. #lstm_inputs = tf.nn.dropout(bert_outputs, self.input_dropout_keep_prob)
  76. ## bi-directional lstm layer
  77. #lstm_outputs = self.bilstm_layer(lstm_inputs)
  78. ## logits for tags
  79. self.project_layer(bert_outputs)
  80. ## loss of the model
  81. self.loss = self.loss_layer(self.logits, self.lengths)
  82. with tf.variable_scope("optimizer"):
  83. optimizer = self.config["optimizer"]
  84. if optimizer == "sgd":
  85. self.opt = tf.train.GradientDescentOptimizer(self.lr)
  86. elif optimizer == "adam":
  87. self.opt = tf.train.AdamOptimizer(self.lr)
  88. elif optimizer == "adgrad":
  89. self.opt = tf.train.AdagradOptimizer(self.lr)
  90. else:
  91. raise KeyError
  92. grads_vars = self.opt.compute_gradients(self.loss)
  93. capped_grads_vars = [[tf.clip_by_value(g, -self.config["clip"], self.config["clip"]), v] for g, v in grads_vars]
  94. self.train_op = self.opt.apply_gradients(capped_grads_vars, self.global_step)
  95. config = tf.ConfigProto()
  96. config.gpu_options.per_process_gpu_memory_fraction = 0.7
  97. config.gpu_options.allow_growth = True
  98. self.sess = tf.Session(graph=self.graph,config=config)
  99. def embedding_layer(self, char_inputs):
  100. with tf.variable_scope("char_embedding"), tf.device('/cpu:0'):
  101. embed = tf.nn.embedding_lookup(self.char_embeding, char_inputs)
  102. return embed
  103. def bilstm_layer(self, lstm_inputs, name=None):
  104. with tf.variable_scope("char_bilstm" if not name else name):
  105. lstm_fw_cell = rnn.BasicLSTMCell(self.lstm_dim, state_is_tuple=True)
  106. lstm_bw_cell = rnn.BasicLSTMCell(self.lstm_dim, state_is_tuple=True)
  107. outputs, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, lstm_inputs, dtype=tf.float32, sequence_length=self.lengths)
  108. return tf.concat(outputs, axis=2)
  109. def get_timing_signal_1d(self,length,
  110. channels,
  111. min_timescale=1.0,
  112. max_timescale=1.0e4,
  113. start_index=0):
  114. """Gets a bunch of sinusoids of different frequencies.
  115. Each channel of the input Tensor is incremented by a sinusoid of a different
  116. frequency and phase.
  117. This allows attention to learn to use absolute and relative positions.
  118. Timing signals should be added to some precursors of both the query and the
  119. memory inputs to attention.
  120. The use of relative position is possible because sin(x+y) and cos(x+y) can be
  121. expressed in terms of y, sin(x) and cos(x).
  122. In particular, we use a geometric sequence of timescales starting with
  123. min_timescale and ending with max_timescale. The number of different
  124. timescales is equal to channels / 2. For each timescale, we
  125. generate the two sinusoidal signals sin(timestep/timescale) and
  126. cos(timestep/timescale). All of these sinusoids are concatenated in
  127. the channels dimension.
  128. Args:
  129. length: scalar, length of timing signal sequence.
  130. channels: scalar, size of timing embeddings to create. The number of
  131. different timescales is equal to channels / 2.
  132. min_timescale: a float
  133. max_timescale: a float
  134. start_index: index of first position
  135. Returns:
  136. a Tensor of timing signals [1, length, channels]
  137. """
  138. position = tf.to_float(tf.range(length) + start_index)
  139. num_timescales = channels // 2
  140. log_timescale_increment = (
  141. math.log(float(max_timescale) / float(min_timescale)) /
  142. (tf.to_float(num_timescales) - 1))
  143. inv_timescales = min_timescale * tf.exp(
  144. tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
  145. scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
  146. signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
  147. signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]])
  148. signal = tf.reshape(signal, [1, length, channels])
  149. return signal
  150. def add_timing_signal_1d(self,x,
  151. min_timescale=1.0,
  152. max_timescale=1.0e4,
  153. start_index=0):
  154. """Adds a bunch of sinusoids of different frequencies to a Tensor.
  155. Each channel of the input Tensor is incremented by a sinusoid of a different
  156. frequency and phase.
  157. This allows attention to learn to use absolute and relative positions.
  158. Timing signals should be added to some precursors of both the query and the
  159. memory inputs to attention.
  160. The use of relative position is possible because sin(x+y) and cos(x+y) can be
  161. experessed in terms of y, sin(x) and cos(x).
  162. In particular, we use a geometric sequence of timescales starting with
  163. min_timescale and ending with max_timescale. The number of different
  164. timescales is equal to channels / 2. For each timescale, we
  165. generate the two sinusoidal signals sin(timestep/timescale) and
  166. cos(timestep/timescale). All of these sinusoids are concatenated in
  167. the channels dimension.
  168. Args:
  169. x: a Tensor with shape [batch, length, channels]
  170. min_timescale: a float
  171. max_timescale: a float
  172. start_index: index of first position
  173. Returns:
  174. a Tensor the same shape as x.
  175. """
  176. length = shape_list(x)[1]
  177. channels = shape_list(x)[2]
  178. signal = self.get_timing_signal_1d(length, channels, min_timescale, max_timescale,
  179. start_index)
  180. return x + signal
  181. def bert_layer(self,tensor_embedding):
  182. #增加位置向量信息
  183. tensor_after_position = self.add_timing_signal_1d(tensor_embedding)
  184. with tf.variable_scope("encoder"):
  185. # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
  186. # mask of shape [batch_size, seq_length, seq_length] which is used
  187. # for the attention scores.
  188. attention_mask = tf.tile(tf.expand_dims(tf.cast(tf.sequence_mask(self.lengths,tf.reduce_max(self.lengths)),"int32"),1),[1,tf.reduce_max(self.lengths),1])#create_attention_mask_from_input_mask(input_ids, input_mask)
  189. tf.Print(attention_mask,[],"attention_mask")
  190. # Run the stacked transformer.
  191. # `sequence_output` shape = [batch_size, seq_length, hidden_size].
  192. self.all_encoder_layers = transformer_model(
  193. input_tensor=tensor_after_position,
  194. attention_mask=attention_mask,
  195. hidden_size=self.bert_hidden,
  196. num_hidden_layers=4,
  197. num_attention_heads=10,
  198. intermediate_size=128,
  199. intermediate_act_fn=get_activation("gelu"),
  200. hidden_dropout_prob=0.1,
  201. attention_probs_dropout_prob=0.1,
  202. initializer_range=0.02,
  203. do_return_all_layers=False)
  204. print(self.all_encoder_layers)
  205. return self.all_encoder_layers
  206. def project_layer(self, lstm_outputs, name=None):
  207. """
  208. """
  209. with tf.variable_scope("project" if not name else name):
  210. with tf.variable_scope("hidden"):
  211. w_tanh = tf.get_variable("w_tanh", shape=[self.bert_hidden, self.lstm_dim],
  212. dtype=tf.float32, initializer=self.initializer, regularizer=tf.contrib.layers.l2_regularizer(0.001))
  213. b_tanh = tf.get_variable("b_tanh", shape=[self.lstm_dim], dtype=tf.float32,
  214. initializer=tf.zeros_initializer())
  215. output = tf.reshape(lstm_outputs, shape=[-1, self.bert_hidden])
  216. hidden = tf.tanh(tf.nn.xw_plus_b(output, w_tanh, b_tanh))
  217. drop_hidden = tf.nn.dropout(hidden, self.dropout)
  218. # project to score of tags
  219. with tf.variable_scope("output"):
  220. w_out = tf.get_variable("w_out", shape=[self.bert_hidden, self.num_tags],
  221. dtype=tf.float32, initializer=self.initializer, regularizer=tf.contrib.layers.l2_regularizer(0.001))
  222. b_out = tf.get_variable("b_out", shape=[self.num_tags], dtype=tf.float32,
  223. initializer=tf.zeros_initializer())
  224. pred = tf.nn.xw_plus_b(drop_hidden, w_out, b_out, name="pred")
  225. self.logits = tf.reshape(pred, [-1, self.num_steps, self.num_tags], name="logits")
  226. def loss_layer(self, project_logits, lengths, name=None):
  227. with tf.variable_scope("crf_loss" if not name else name):
  228. small = -1000.0
  229. start_logits = tf.concat(
  230. [small * tf.ones(shape=[self.batch_size, 1, self.num_tags]), tf.zeros(shape=[self.batch_size, 1, 1])],
  231. axis=-1)
  232. pad_logits = tf.cast(small * tf.ones([self.batch_size, self.num_steps, 1]), tf.float32)
  233. logits = tf.concat([project_logits, pad_logits], axis=-1)
  234. logits = tf.concat([start_logits, logits], axis=1)
  235. targets = tf.concat(
  236. [tf.cast(self.num_tags * tf.ones([self.batch_size, 1]), tf.int32), self.targets], axis=-1)
  237. self.trans = tf.get_variable(
  238. "transitions",
  239. shape=[self.num_tags + 1, self.num_tags + 1],
  240. initializer=self.initializer)
  241. log_likelihood, self.trans = crf_log_likelihood(
  242. inputs=logits,
  243. tag_indices=targets,
  244. transition_params=self.trans,
  245. sequence_lengths=lengths + 1)
  246. return tf.reduce_mean(-log_likelihood)
  247. def initVariables(self):
  248. dict_tensor_values = load(os.path.dirname(__file__)+"/dict_tensor_values.pk")
  249. with self.graph.as_default():
  250. init_op = tf.global_variables_initializer()
  251. self.sess.run(init_op)
  252. '''
  253. trainable_variables = tf.trainable_variables()
  254. for item in trainable_variables:
  255. print(item.name,"prefix/"+item.name in dict_tensor_values.keys())
  256. self.sess.run(tf.assign(item,dict_tensor_values["prefix/"+item.name]))
  257. print((self.sess.run(item)==dict_tensor_values["prefix/"+item.name]).all())
  258. '''
  259. ''''''
  260. for _key in dict_tensor_values.keys():
  261. print("init variable %s"%(_key))
  262. self.sess.run(tf.assign(self.graph.get_tensor_by_name(_key[7:]),dict_tensor_values[_key]))
  263. #print(self.sess.run(tf.nn.embedding_lookup(self.char_embeding, np.array([[1]], dtype=np.int32))))
  264. #print(self.sess.run(self.char_embeding))
  265. return self
  266. def restore(self,path=None):
  267. print("restore weights")
  268. with self.graph.as_default():
  269. saver = tf.train.Saver()
  270. if path is None:
  271. path_add = "0-12/"
  272. saver.restore(self.sess, os.path.dirname(__file__)+'/model/'+path_add+'model.ckpt')
  273. '''
  274. path_add = "0-4/"
  275. saver.restore(self.sess, os.path.dirname(__file__)+'/model-server/'+path_add+'model.ckpt')
  276. '''
  277. else:
  278. saver.restore(self.sess,path)
  279. list_v = [v.name for v in tf.trainable_variables()]
  280. print(list_v)
  281. list_value = bert.sess.run(list_v)
  282. dict_key_value = dict()
  283. for k,v in zip(list_v,list_value):
  284. if re.search("encoder",k) is not None or re.search("char_embeding",k) is not None:
  285. dict_key_value[k] = v
  286. print(k,v.shape)
  287. save(self.char_to_id, "fool_char_to_id.pk")
  288. save(dict_key_value, "dict_key_value.pk")
  289. return self
  290. def getNodes(self):
  291. return self.char_inputs,self.targets,self.lengths,self.dropout,self.logits,self.trans,self.loss,self.train_op
  292. def predict(self,sess,sents):
  293. inputs = []
  294. lengths = [len(text) for text in sents]
  295. max_len = max(lengths)
  296. for sent in sents:
  297. sent_ids = [self.char_to_id.get(w) if w in self.char_to_id else self.char_to_id.get("<OOV>") for w in sent]
  298. padding = [0] * (max_len - len(sent_ids))
  299. sent_ids += padding
  300. inputs.append(sent_ids)
  301. inputs = np.array(inputs, dtype=np.int32)
  302. if USE_PAI_EAS:
  303. request = tf_predict_pb2.PredictRequest()
  304. request.inputs["char_inputs"].dtype = tf_predict_pb2.DT_INT32
  305. request.inputs["char_inputs"].array_shape.dim.extend(np.shape(inputs))
  306. request.inputs["char_inputs"].int_val.extend(np.array(inputs,dtype=np.int32).reshape(-1))
  307. request.inputs["lengths"].dtype = tf_predict_pb2.DT_INT32
  308. request.inputs["lengths"].array_shape.dim.extend(np.shape(lengths))
  309. request.inputs["lengths"].int_val.extend(np.array(lengths,dtype=np.int32).reshape(-1))
  310. request.inputs["dropout"].dtype = tf_predict_pb2.DT_FLOAT
  311. request.inputs["dropout"].float_val.extend([1.0])
  312. request_data = request.SerializeToString()
  313. list_outputs = ["logits","trans"]
  314. result = vpc_requests(selffool_url, selffool_authorization, request_data, list_outputs)
  315. if result is not None:
  316. logits = result["logits"]
  317. trans = result["trans"]
  318. else:
  319. feed_dict = {
  320. self.char_inputs: inputs,
  321. self.lengths: lengths,
  322. self.dropout: 1.0
  323. }
  324. logits, trans = sess.run([self.logits, self.trans], feed_dict=feed_dict)
  325. else:
  326. feed_dict = {
  327. self.char_inputs: inputs,
  328. self.lengths: lengths,
  329. self.dropout: 1.0
  330. }
  331. logits, trans = sess.run([self.logits, self.trans], feed_dict=feed_dict)
  332. path = decode(logits, trans, lengths, self.num_tags)
  333. labels = [[self.id_to_tag.get(l) for l in p] for p in path]
  334. return labels
  335. def ner(self, text_list):
  336. text_list = _check_input(text_list)
  337. ner_labels = self.predict(self.sess,text_list)
  338. #print(ner_labels)
  339. all_entitys = []
  340. for ti, text in enumerate(text_list):
  341. ens = []
  342. entity = ""
  343. i = 0
  344. ner_label = ner_labels[ti]
  345. chars = list(text)
  346. for label, word in zip(ner_label, chars):
  347. i += 1
  348. if label == "O":
  349. continue
  350. lt = label.split("_")[1]
  351. lb = label.split("_")[0]
  352. if lb == "S":
  353. ens.append((i, i + 1, lt, word))
  354. elif lb == "B":
  355. entity = ""
  356. entity += word
  357. elif lb == "M":
  358. entity += word
  359. elif lb == "E":
  360. entity += word
  361. ens.append((i - len(entity), i + 1, lt, entity))
  362. entity = ""
  363. if entity:
  364. ens.append((i - len(entity), i + 1, lt, entity))
  365. all_entitys.append(ens)
  366. return all_entitys
  367. def decode(logits, trans, sequence_lengths, tag_num):
  368. viterbi_sequences = []
  369. small = -1000.0
  370. start = np.asarray([[small] * tag_num + [0]])
  371. for logit, length in zip(logits, sequence_lengths):
  372. score = logit[:length]
  373. pad = small * np.ones([length, 1])
  374. score = np.concatenate([score, pad], axis=1)
  375. score = np.concatenate([start, score], axis=0)
  376. viterbi_seq, viterbi_score = viterbi_decode(score, trans)
  377. viterbi_sequences.append(viterbi_seq[1:])
  378. return viterbi_sequences
  379. def _check_input(text, ignore=False):
  380. if not text:
  381. return []
  382. if not isinstance(text, list):
  383. text = [text]
  384. null_index = [i for i, t in enumerate(text) if not t]
  385. if null_index and not ignore:
  386. raise Exception("null text in input ")
  387. return text
  388. if __name__=="__main__":
  389. bert = BertCRF()
  390. bert.restore()