BertCRF.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. '''
  2. Created on 2019年12月31日
  3. @author: User
  4. '''
  5. # from tensorflow.contrib import rnn
  6. # from tensorflow.contrib.crf import crf_log_likelihood
  7. # from tensorflow.contrib.layers.python.layers import initializers
  8. import numpy as np
  9. from BiddingKG.dl.common.Utils import viterbi_decode
  10. from zipfile import ZipFile
  11. import tensorflow as tf
  12. import os
  13. import json
  14. import math
  15. from BiddingKG.dl.BertNer.BertModel import *
  16. from BiddingKG.dl.common.Utils import *
  17. def _load_map_file(path, char_map_name, id_map_name):
  18. with ZipFile(path) as myzip:
  19. with myzip.open('all_map.json') as myfile:
  20. content = myfile.readline()
  21. content = content.decode()
  22. data = json.loads(content)
  23. return data.get(char_map_name), data.get(id_map_name)
  24. def shape_list(x):
  25. """Return list of dims, statically where possible."""
  26. x = tf.convert_to_tensor(x)
  27. # If unknown rank, return dynamic shape
  28. if x.get_shape().dims is None:
  29. return tf.shape(x)
  30. static = x.get_shape().as_list()
  31. shape = tf.shape(x)
  32. ret = []
  33. for i in range(len(static)):
  34. dim = static[i]
  35. if dim is None:
  36. dim = shape[i]
  37. ret.append(dim)
  38. return ret
  39. class BertCRF(object):
  40. def __init__(self):
  41. config = {'lstm_dim':100,
  42. 'num_chars':6591,
  43. 'num_tags':25,
  44. 'char_dim':100,
  45. 'lr':0.00002,
  46. 'input_dropout_keep':1.0,
  47. 'optimizer':'adam',
  48. 'clip':5,
  49. 'bert_hidden':100}
  50. self.config = config
  51. self.lstm_dim = config["lstm_dim"]
  52. self.num_chars = config["num_chars"]
  53. self.num_tags = config["num_tags"]
  54. self.char_dim = config["char_dim"]
  55. self.lr = config["lr"]
  56. self.bert_hidden = config["bert_hidden"]
  57. self.graph = tf.Graph()
  58. with self.graph.as_default():
  59. self.char_to_id, self.id_to_seg = _load_map_file(os.path.dirname(__file__)+"/data/map.zip", "char_map", "ner_map")
  60. self.id_to_tag = {int(k):v for k,v in self.id_to_seg.items()}
  61. self.tag_to_id = {v:int(k) for k,v in self.id_to_seg.items()}
  62. #self.char_embeding = tf.get_variable(name="char_embeding", initializer=embeddings)
  63. self.char_embeding = tf.get_variable(name="char_embeding",shape=(self.num_chars,self.char_dim))
  64. #添加一串全0的坑,fool发行版和源代码不一样
  65. self.const = tf.constant(value=0,dtype=tf.float32,shape=[1,100])
  66. self.char_embeding = tf.concat([self.const,self.char_embeding],0)
  67. self.global_step = tf.Variable(0, trainable=False)
  68. self.initializer = initializers.xavier_initializer()
  69. self.char_inputs = tf.placeholder(dtype=tf.int32, shape=[None, None], name="char_inputs")
  70. self.targets = tf.placeholder(dtype=tf.int32, shape=[None, None], name="targets")
  71. self.dropout = tf.placeholder(dtype=tf.float32, name="dropout")
  72. self.lengths = tf.placeholder(dtype=tf.int32, shape=[None, ], name="lengths")
  73. # self.middle_dropout_keep_prob = tf.placeholder_with_default(1.0, [], name="middle_dropout_keep_prob")
  74. # self.hidden_dropout_keep_prob = tf.placeholder_with_default(1.0, [], name="hidden_dropout_keep_prob")
  75. self.input_dropout_keep_prob = tf.placeholder_with_default(config["input_dropout_keep"], [], name="input_dropout_keep_prob")
  76. self.batch_size = tf.shape(self.char_inputs)[0]
  77. self.num_steps = tf.shape(self.char_inputs)[-1]
  78. # forward
  79. embedding = self.embedding_layer(self.char_inputs)
  80. bert_outputs = self.bert_layer(embedding)
  81. #lstm_inputs = tf.nn.dropout(bert_outputs, self.input_dropout_keep_prob)
  82. ## bi-directional lstm layer
  83. #lstm_outputs = self.bilstm_layer(lstm_inputs)
  84. ## logits for tags
  85. self.project_layer(bert_outputs)
  86. ## loss of the model
  87. self.loss = self.loss_layer(self.logits, self.lengths)
  88. with tf.variable_scope("optimizer"):
  89. optimizer = self.config["optimizer"]
  90. if optimizer == "sgd":
  91. self.opt = tf.train.GradientDescentOptimizer(self.lr)
  92. elif optimizer == "adam":
  93. self.opt = tf.train.AdamOptimizer(self.lr)
  94. elif optimizer == "adgrad":
  95. self.opt = tf.train.AdagradOptimizer(self.lr)
  96. else:
  97. raise KeyError
  98. grads_vars = self.opt.compute_gradients(self.loss)
  99. capped_grads_vars = [[tf.clip_by_value(g, -self.config["clip"], self.config["clip"]), v] for g, v in grads_vars]
  100. self.train_op = self.opt.apply_gradients(capped_grads_vars, self.global_step)
  101. config = tf.ConfigProto()
  102. config.gpu_options.per_process_gpu_memory_fraction = 0.7
  103. config.gpu_options.allow_growth = True
  104. self.sess = tf.Session(graph=self.graph,config=config)
  105. def embedding_layer(self, char_inputs):
  106. with tf.variable_scope("char_embedding"), tf.device('/cpu:0'):
  107. embed = tf.nn.embedding_lookup(self.char_embeding, char_inputs)
  108. return embed
  109. def bilstm_layer(self, lstm_inputs, name=None):
  110. with tf.variable_scope("char_bilstm" if not name else name):
  111. lstm_fw_cell = rnn.BasicLSTMCell(self.lstm_dim, state_is_tuple=True)
  112. lstm_bw_cell = rnn.BasicLSTMCell(self.lstm_dim, state_is_tuple=True)
  113. outputs, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, lstm_inputs, dtype=tf.float32, sequence_length=self.lengths)
  114. return tf.concat(outputs, axis=2)
  115. def get_timing_signal_1d(self,length,
  116. channels,
  117. min_timescale=1.0,
  118. max_timescale=1.0e4,
  119. start_index=0):
  120. """Gets a bunch of sinusoids of different frequencies.
  121. Each channel of the input Tensor is incremented by a sinusoid of a different
  122. frequency and phase.
  123. This allows attention to learn to use absolute and relative positions.
  124. Timing signals should be added to some precursors of both the query and the
  125. memory inputs to attention.
  126. The use of relative position is possible because sin(x+y) and cos(x+y) can be
  127. expressed in terms of y, sin(x) and cos(x).
  128. In particular, we use a geometric sequence of timescales starting with
  129. min_timescale and ending with max_timescale. The number of different
  130. timescales is equal to channels / 2. For each timescale, we
  131. generate the two sinusoidal signals sin(timestep/timescale) and
  132. cos(timestep/timescale). All of these sinusoids are concatenated in
  133. the channels dimension.
  134. Args:
  135. length: scalar, length of timing signal sequence.
  136. channels: scalar, size of timing embeddings to create. The number of
  137. different timescales is equal to channels / 2.
  138. min_timescale: a float
  139. max_timescale: a float
  140. start_index: index of first position
  141. Returns:
  142. a Tensor of timing signals [1, length, channels]
  143. """
  144. position = tf.to_float(tf.range(length) + start_index)
  145. num_timescales = channels // 2
  146. log_timescale_increment = (
  147. math.log(float(max_timescale) / float(min_timescale)) /
  148. (tf.to_float(num_timescales) - 1))
  149. inv_timescales = min_timescale * tf.exp(
  150. tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
  151. scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
  152. signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
  153. signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]])
  154. signal = tf.reshape(signal, [1, length, channels])
  155. return signal
  156. def add_timing_signal_1d(self,x,
  157. min_timescale=1.0,
  158. max_timescale=1.0e4,
  159. start_index=0):
  160. """Adds a bunch of sinusoids of different frequencies to a Tensor.
  161. Each channel of the input Tensor is incremented by a sinusoid of a different
  162. frequency and phase.
  163. This allows attention to learn to use absolute and relative positions.
  164. Timing signals should be added to some precursors of both the query and the
  165. memory inputs to attention.
  166. The use of relative position is possible because sin(x+y) and cos(x+y) can be
  167. experessed in terms of y, sin(x) and cos(x).
  168. In particular, we use a geometric sequence of timescales starting with
  169. min_timescale and ending with max_timescale. The number of different
  170. timescales is equal to channels / 2. For each timescale, we
  171. generate the two sinusoidal signals sin(timestep/timescale) and
  172. cos(timestep/timescale). All of these sinusoids are concatenated in
  173. the channels dimension.
  174. Args:
  175. x: a Tensor with shape [batch, length, channels]
  176. min_timescale: a float
  177. max_timescale: a float
  178. start_index: index of first position
  179. Returns:
  180. a Tensor the same shape as x.
  181. """
  182. length = shape_list(x)[1]
  183. channels = shape_list(x)[2]
  184. signal = self.get_timing_signal_1d(length, channels, min_timescale, max_timescale,
  185. start_index)
  186. return x + signal
  187. def bert_layer(self,tensor_embedding):
  188. #增加位置向量信息
  189. tensor_after_position = self.add_timing_signal_1d(tensor_embedding)
  190. with tf.variable_scope("encoder"):
  191. # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
  192. # mask of shape [batch_size, seq_length, seq_length] which is used
  193. # for the attention scores.
  194. attention_mask = tf.tile(tf.expand_dims(tf.cast(tf.sequence_mask(self.lengths,tf.reduce_max(self.lengths)),"int32"),1),[1,tf.reduce_max(self.lengths),1])#create_attention_mask_from_input_mask(input_ids, input_mask)
  195. tf.Print(attention_mask,[],"attention_mask")
  196. # Run the stacked transformer.
  197. # `sequence_output` shape = [batch_size, seq_length, hidden_size].
  198. self.all_encoder_layers = transformer_model(
  199. input_tensor=tensor_after_position,
  200. attention_mask=attention_mask,
  201. hidden_size=self.bert_hidden,
  202. num_hidden_layers=4,
  203. num_attention_heads=10,
  204. intermediate_size=128,
  205. intermediate_act_fn=get_activation("gelu"),
  206. hidden_dropout_prob=0.1,
  207. attention_probs_dropout_prob=0.1,
  208. initializer_range=0.02,
  209. do_return_all_layers=False)
  210. print(self.all_encoder_layers)
  211. return self.all_encoder_layers
  212. def project_layer(self, lstm_outputs, name=None):
  213. """
  214. """
  215. with tf.variable_scope("project" if not name else name):
  216. with tf.variable_scope("hidden"):
  217. w_tanh = tf.get_variable("w_tanh", shape=[self.bert_hidden, self.lstm_dim],
  218. dtype=tf.float32, initializer=self.initializer, regularizer=tf.contrib.layers.l2_regularizer(0.001))
  219. b_tanh = tf.get_variable("b_tanh", shape=[self.lstm_dim], dtype=tf.float32,
  220. initializer=tf.zeros_initializer())
  221. output = tf.reshape(lstm_outputs, shape=[-1, self.bert_hidden])
  222. hidden = tf.tanh(tf.nn.xw_plus_b(output, w_tanh, b_tanh))
  223. drop_hidden = tf.nn.dropout(hidden, self.dropout)
  224. # project to score of tags
  225. with tf.variable_scope("output"):
  226. w_out = tf.get_variable("w_out", shape=[self.bert_hidden, self.num_tags],
  227. dtype=tf.float32, initializer=self.initializer, regularizer=tf.contrib.layers.l2_regularizer(0.001))
  228. b_out = tf.get_variable("b_out", shape=[self.num_tags], dtype=tf.float32,
  229. initializer=tf.zeros_initializer())
  230. pred = tf.nn.xw_plus_b(drop_hidden, w_out, b_out, name="pred")
  231. self.logits = tf.reshape(pred, [-1, self.num_steps, self.num_tags], name="logits")
  232. def loss_layer(self, project_logits, lengths, name=None):
  233. with tf.variable_scope("crf_loss" if not name else name):
  234. small = -1000.0
  235. start_logits = tf.concat(
  236. [small * tf.ones(shape=[self.batch_size, 1, self.num_tags]), tf.zeros(shape=[self.batch_size, 1, 1])],
  237. axis=-1)
  238. pad_logits = tf.cast(small * tf.ones([self.batch_size, self.num_steps, 1]), tf.float32)
  239. logits = tf.concat([project_logits, pad_logits], axis=-1)
  240. logits = tf.concat([start_logits, logits], axis=1)
  241. targets = tf.concat(
  242. [tf.cast(self.num_tags * tf.ones([self.batch_size, 1]), tf.int32), self.targets], axis=-1)
  243. self.trans = tf.get_variable(
  244. "transitions",
  245. shape=[self.num_tags + 1, self.num_tags + 1],
  246. initializer=self.initializer)
  247. log_likelihood, self.trans = crf_log_likelihood(
  248. inputs=logits,
  249. tag_indices=targets,
  250. transition_params=self.trans,
  251. sequence_lengths=lengths + 1)
  252. return tf.reduce_mean(-log_likelihood)
  253. def initVariables(self):
  254. dict_tensor_values = load(os.path.dirname(__file__)+"/dict_tensor_values.pk")
  255. with self.graph.as_default():
  256. init_op = tf.global_variables_initializer()
  257. self.sess.run(init_op)
  258. '''
  259. trainable_variables = tf.trainable_variables()
  260. for item in trainable_variables:
  261. print(item.name,"prefix/"+item.name in dict_tensor_values.keys())
  262. self.sess.run(tf.assign(item,dict_tensor_values["prefix/"+item.name]))
  263. print((self.sess.run(item)==dict_tensor_values["prefix/"+item.name]).all())
  264. '''
  265. ''''''
  266. for _key in dict_tensor_values.keys():
  267. print("init variable %s"%(_key))
  268. self.sess.run(tf.assign(self.graph.get_tensor_by_name(_key[7:]),dict_tensor_values[_key]))
  269. #print(self.sess.run(tf.nn.embedding_lookup(self.char_embeding, np.array([[1]], dtype=np.int32))))
  270. #print(self.sess.run(self.char_embeding))
  271. return self
  272. def restore(self,path=None):
  273. print("restore weights")
  274. with self.graph.as_default():
  275. saver = tf.train.Saver()
  276. if path is None:
  277. path_add = "0-12/"
  278. saver.restore(self.sess, os.path.dirname(__file__)+'/model/'+path_add+'model.ckpt')
  279. '''
  280. path_add = "0-4/"
  281. saver.restore(self.sess, os.path.dirname(__file__)+'/model-server/'+path_add+'model.ckpt')
  282. '''
  283. else:
  284. saver.restore(self.sess,path)
  285. list_v = [v.name for v in tf.trainable_variables()]
  286. print(list_v)
  287. list_value = bert.sess.run(list_v)
  288. dict_key_value = dict()
  289. for k,v in zip(list_v,list_value):
  290. if re.search("encoder",k) is not None or re.search("char_embeding",k) is not None:
  291. dict_key_value[k] = v
  292. print(k,v.shape)
  293. save(self.char_to_id,"fool_char_to_id.pk")
  294. save(dict_key_value,"dict_key_value.pk")
  295. return self
  296. def getNodes(self):
  297. return self.char_inputs,self.targets,self.lengths,self.dropout,self.logits,self.trans,self.loss,self.train_op
  298. def predict(self,sess,sents):
  299. inputs = []
  300. lengths = [len(text) for text in sents]
  301. max_len = max(lengths)
  302. for sent in sents:
  303. sent_ids = [self.char_to_id.get(w) if w in self.char_to_id else self.char_to_id.get("<OOV>") for w in sent]
  304. padding = [0] * (max_len - len(sent_ids))
  305. sent_ids += padding
  306. inputs.append(sent_ids)
  307. inputs = np.array(inputs, dtype=np.int32)
  308. if USE_PAI_EAS:
  309. request = tf_predict_pb2.PredictRequest()
  310. request.inputs["char_inputs"].dtype = tf_predict_pb2.DT_INT32
  311. request.inputs["char_inputs"].array_shape.dim.extend(np.shape(inputs))
  312. request.inputs["char_inputs"].int_val.extend(np.array(inputs,dtype=np.int32).reshape(-1))
  313. request.inputs["lengths"].dtype = tf_predict_pb2.DT_INT32
  314. request.inputs["lengths"].array_shape.dim.extend(np.shape(lengths))
  315. request.inputs["lengths"].int_val.extend(np.array(lengths,dtype=np.int32).reshape(-1))
  316. request.inputs["dropout"].dtype = tf_predict_pb2.DT_FLOAT
  317. request.inputs["dropout"].float_val.extend([1.0])
  318. request_data = request.SerializeToString()
  319. list_outputs = ["logits","trans"]
  320. result = vpc_requests(selffool_url, selffool_authorization, request_data, list_outputs)
  321. if result is not None:
  322. logits = result["logits"]
  323. trans = result["trans"]
  324. else:
  325. feed_dict = {
  326. self.char_inputs: inputs,
  327. self.lengths: lengths,
  328. self.dropout: 1.0
  329. }
  330. logits, trans = sess.run([self.logits, self.trans], feed_dict=feed_dict)
  331. else:
  332. feed_dict = {
  333. self.char_inputs: inputs,
  334. self.lengths: lengths,
  335. self.dropout: 1.0
  336. }
  337. logits, trans = sess.run([self.logits, self.trans], feed_dict=feed_dict)
  338. path = decode(logits, trans, lengths, self.num_tags)
  339. labels = [[self.id_to_tag.get(l) for l in p] for p in path]
  340. return labels
  341. def ner(self, text_list):
  342. text_list = _check_input(text_list)
  343. ner_labels = self.predict(self.sess,text_list)
  344. #print(ner_labels)
  345. all_entitys = []
  346. for ti, text in enumerate(text_list):
  347. ens = []
  348. entity = ""
  349. i = 0
  350. ner_label = ner_labels[ti]
  351. chars = list(text)
  352. for label, word in zip(ner_label, chars):
  353. i += 1
  354. if label == "O":
  355. continue
  356. lt = label.split("_")[1]
  357. lb = label.split("_")[0]
  358. if lb == "S":
  359. ens.append((i, i + 1, lt, word))
  360. elif lb == "B":
  361. entity = ""
  362. entity += word
  363. elif lb == "M":
  364. entity += word
  365. elif lb == "E":
  366. entity += word
  367. ens.append((i - len(entity), i + 1, lt, entity))
  368. entity = ""
  369. if entity:
  370. ens.append((i - len(entity), i + 1, lt, entity))
  371. all_entitys.append(ens)
  372. return all_entitys
  373. def decode(logits, trans, sequence_lengths, tag_num):
  374. viterbi_sequences = []
  375. small = -1000.0
  376. start = np.asarray([[small] * tag_num + [0]])
  377. for logit, length in zip(logits, sequence_lengths):
  378. score = logit[:length]
  379. pad = small * np.ones([length, 1])
  380. score = np.concatenate([score, pad], axis=1)
  381. score = np.concatenate([start, score], axis=0)
  382. viterbi_seq, viterbi_score = viterbi_decode(score, trans)
  383. viterbi_sequences.append(viterbi_seq[1:])
  384. return viterbi_sequences
  385. def _check_input(text, ignore=False):
  386. if not text:
  387. return []
  388. if not isinstance(text, list):
  389. text = [text]
  390. null_index = [i for i, t in enumerate(text) if not t]
  391. if null_index and not ignore:
  392. raise Exception("null text in input ")
  393. return text
  394. if __name__=="__main__":
  395. bert = BertCRF()
  396. bert.restore()