Pretrain.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. '''
  2. Created on 2019年12月31日
  3. @author: User
  4. '''
  5. # from tensorflow.contrib import rnn
  6. # from tensorflow.contrib.crf import crf_log_likelihood
  7. # from tensorflow.contrib.layers.python.layers import initializers
  8. from zipfile import ZipFile
  9. from BiddingKG.dl_dev.BertNer.BertModel import *
  10. from BiddingKG.dl.common.Utils import *
  11. import codecs
  12. def _load_map_file(path, char_map_name, id_map_name):
  13. with ZipFile(path) as myzip:
  14. with myzip.open('all_map.json') as myfile:
  15. content = myfile.readline()
  16. content = content.decode()
  17. data = json.loads(content)
  18. return data.get(char_map_name), data.get(id_map_name)
  19. def shape_list(x):
  20. """Return list of dims, statically where possible."""
  21. x = tf.convert_to_tensor(x)
  22. # If unknown rank, return dynamic shape
  23. if x.get_shape().dims is None:
  24. return tf.shape(x)
  25. static = x.get_shape().as_list()
  26. shape = tf.shape(x)
  27. ret = []
  28. for i in range(len(static)):
  29. dim = static[i]
  30. if dim is None:
  31. dim = shape[i]
  32. ret.append(dim)
  33. return ret
  34. class BertCRF(object):
  35. def __init__(self):
  36. config = {'lstm_dim':100,
  37. 'num_chars':6591,
  38. 'num_tags':25,
  39. 'char_dim':100,
  40. 'lr':0.00002,
  41. 'input_dropout_keep':1.0,
  42. 'optimizer':'adam',
  43. 'clip':5,
  44. 'bert_hidden':100}
  45. self.config = config
  46. self.lstm_dim = config["lstm_dim"]
  47. self.num_chars = config["num_chars"]
  48. self.num_tags = config["num_tags"]
  49. self.char_dim = config["char_dim"]
  50. self.lr = config["lr"]
  51. self.bert_hidden = config["bert_hidden"]
  52. self.graph = tf.Graph()
  53. with self.graph.as_default():
  54. self.char_to_id, self.id_to_seg = _load_map_file(os.path.dirname(__file__)+"/data/map.zip", "char_map", "ner_map")
  55. self.id_to_tag = {int(k):v for k,v in self.id_to_seg.items()}
  56. self.tag_to_id = {v:int(k) for k,v in self.id_to_seg.items()}
  57. #self.char_embeding = tf.get_variable(name="char_embeding", initializer=embeddings)
  58. #self.char_embeding = tf.get_variable(name="char_embeding",shape=(self.num_chars,self.char_dim))
  59. #添加一串全0的坑,fool发行版和源代码不一样
  60. self.const = tf.constant(value=0,dtype=tf.float32,shape=[1,100])
  61. #self.char_embeding = tf.concat([self.const,self.char_embeding],0)
  62. self.global_step = tf.Variable(0, trainable=False)
  63. self.initializer = initializers.xavier_initializer()
  64. self.char_inputs = tf.placeholder(dtype=tf.int32, shape=[None, None], name="char_inputs")
  65. self.targets = tf.placeholder(dtype=tf.int32, shape=[None, None], name="targets")
  66. self.dropout = tf.placeholder(dtype=tf.float32, name="dropout")
  67. self.lengths = tf.placeholder(dtype=tf.int32, shape=[None, ], name="lengths")
  68. # self.middle_dropout_keep_prob = tf.placeholder_with_default(1.0, [], name="middle_dropout_keep_prob")
  69. # self.hidden_dropout_keep_prob = tf.placeholder_with_default(1.0, [], name="hidden_dropout_keep_prob")
  70. self.input_dropout_keep_prob = tf.placeholder_with_default(config["input_dropout_keep"], [], name="input_dropout_keep_prob")
  71. self.batch_size = tf.shape(self.char_inputs)[0]
  72. self.num_steps = tf.shape(self.char_inputs)[-1]
  73. # forward
  74. #embedding = self.embedding_layer(self.char_inputs)
  75. bert_outputs = self.bert_layer(self.char_inputs)
  76. #lstm_inputs = tf.nn.dropout(bert_outputs, self.input_dropout_keep_prob)
  77. ## bi-directional lstm layer
  78. #lstm_outputs = self.bilstm_layer(lstm_inputs)
  79. ## logits for tags
  80. self.project_layer(bert_outputs)
  81. ## loss of the model
  82. self.loss = self.loss_layer(self.logits, self.lengths)
  83. with tf.variable_scope("optimizer"):
  84. optimizer = self.config["optimizer"]
  85. if optimizer == "sgd":
  86. self.opt = tf.train.GradientDescentOptimizer(self.lr)
  87. elif optimizer == "adam":
  88. self.opt = tf.train.AdamOptimizer(self.lr)
  89. elif optimizer == "adgrad":
  90. self.opt = tf.train.AdagradOptimizer(self.lr)
  91. else:
  92. raise KeyError
  93. grads_vars = self.opt.compute_gradients(self.loss)
  94. print(grads_vars)
  95. capped_grads_vars = []
  96. for g, v in grads_vars:
  97. if g is not None:
  98. capped_grads_vars.append([tf.clip_by_value(g, -self.config["clip"], self.config["clip"]), v])
  99. #capped_grads_vars = [[tf.clip_by_value(g, -self.config["clip"], self.config["clip"]), v] for g, v in grads_vars]
  100. self.train_op = self.opt.apply_gradients(capped_grads_vars, self.global_step)
  101. self.sess = tf.Session(graph=self.graph)
  102. def embedding_layer(self, char_inputs):
  103. with tf.variable_scope("char_embedding"), tf.device('/cpu:0'):
  104. embed = tf.nn.embedding_lookup(self.char_embeding, char_inputs)
  105. return embed
  106. def bilstm_layer(self, lstm_inputs, name=None):
  107. with tf.variable_scope("char_bilstm" if not name else name):
  108. lstm_fw_cell = rnn.BasicLSTMCell(self.lstm_dim, state_is_tuple=True)
  109. lstm_bw_cell = rnn.BasicLSTMCell(self.lstm_dim, state_is_tuple=True)
  110. outputs, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, lstm_inputs, dtype=tf.float32, sequence_length=self.lengths)
  111. return tf.concat(outputs, axis=2)
  112. def get_timing_signal_1d(self,length,
  113. channels,
  114. min_timescale=1.0,
  115. max_timescale=1.0e4,
  116. start_index=0):
  117. """Gets a bunch of sinusoids of different frequencies.
  118. Each channel of the input Tensor is incremented by a sinusoid of a different
  119. frequency and phase.
  120. This allows attention to learn to use absolute and relative positions.
  121. Timing signals should be added to some precursors of both the query and the
  122. memory inputs to attention.
  123. The use of relative position is possible because sin(x+y) and cos(x+y) can be
  124. expressed in terms of y, sin(x) and cos(x).
  125. In particular, we use a geometric sequence of timescales starting with
  126. min_timescale and ending with max_timescale. The number of different
  127. timescales is equal to channels / 2. For each timescale, we
  128. generate the two sinusoidal signals sin(timestep/timescale) and
  129. cos(timestep/timescale). All of these sinusoids are concatenated in
  130. the channels dimension.
  131. Args:
  132. length: scalar, length of timing signal sequence.
  133. channels: scalar, size of timing embeddings to create. The number of
  134. different timescales is equal to channels / 2.
  135. min_timescale: a float
  136. max_timescale: a float
  137. start_index: index of first position
  138. Returns:
  139. a Tensor of timing signals [1, length, channels]
  140. """
  141. position = tf.to_float(tf.range(length) + start_index)
  142. num_timescales = channels // 2
  143. log_timescale_increment = (
  144. math.log(float(max_timescale) / float(min_timescale)) /
  145. (tf.to_float(num_timescales) - 1))
  146. inv_timescales = min_timescale * tf.exp(
  147. tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
  148. scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
  149. signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
  150. signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]])
  151. signal = tf.reshape(signal, [1, length, channels])
  152. return signal
  153. def add_timing_signal_1d(self,x,
  154. min_timescale=1.0,
  155. max_timescale=1.0e4,
  156. start_index=0):
  157. """Adds a bunch of sinusoids of different frequencies to a Tensor.
  158. Each channel of the input Tensor is incremented by a sinusoid of a different
  159. frequency and phase.
  160. This allows attention to learn to use absolute and relative positions.
  161. Timing signals should be added to some precursors of both the query and the
  162. memory inputs to attention.
  163. The use of relative position is possible because sin(x+y) and cos(x+y) can be
  164. experessed in terms of y, sin(x) and cos(x).
  165. In particular, we use a geometric sequence of timescales starting with
  166. min_timescale and ending with max_timescale. The number of different
  167. timescales is equal to channels / 2. For each timescale, we
  168. generate the two sinusoidal signals sin(timestep/timescale) and
  169. cos(timestep/timescale). All of these sinusoids are concatenated in
  170. the channels dimension.
  171. Args:
  172. x: a Tensor with shape [batch, length, channels]
  173. min_timescale: a float
  174. max_timescale: a float
  175. start_index: index of first position
  176. Returns:
  177. a Tensor the same shape as x.
  178. """
  179. length = shape_list(x)[1]
  180. channels = shape_list(x)[2]
  181. signal = self.get_timing_signal_1d(length, channels, min_timescale, max_timescale,
  182. start_index)
  183. return x + signal
  184. def bert_layer(self,tensor_embedding):
  185. #增加位置向量信息
  186. #tensor_after_position = self.add_timing_signal_1d(tensor_embedding)
  187. _config = BertConfig(vocab_size=21128, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act=gelu, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02)
  188. input_mask = tf.cast(tf.sequence_mask(self.lengths,tf.reduce_max(self.lengths)),tf.int32)
  189. _model = BertModel(_config, is_training=True,input_ids=tensor_embedding,input_mask=input_mask)
  190. return _model.get_sequence_output()
  191. with tf.variable_scope("encoder"):
  192. # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
  193. # mask of shape [batch_size, seq_length, seq_length] which is used
  194. # for the attention scores.
  195. attention_mask = tf.tile(tf.expand_dims(tf.cast(tf.sequence_mask(self.lengths,tf.reduce_max(self.lengths)),"int32"),1),[1,tf.reduce_max(self.lengths),1])#create_attention_mask_from_input_mask(input_ids, input_mask)
  196. tf.Print(attention_mask,[],"attention_mask")
  197. # Run the stacked transformer.
  198. # `sequence_output` shape = [batch_size, seq_length, hidden_size].
  199. self.all_encoder_layers = transformer_model(
  200. input_tensor=tensor_after_position,
  201. attention_mask=attention_mask,
  202. hidden_size=self.bert_hidden,
  203. num_hidden_layers=6,
  204. num_attention_heads=10,
  205. intermediate_size=256,
  206. intermediate_act_fn=get_activation("gelu"),
  207. hidden_dropout_prob=0.1,
  208. attention_probs_dropout_prob=0.1,
  209. initializer_range=0.02,
  210. do_return_all_layers=False)
  211. print(self.all_encoder_layers)
  212. return self.all_encoder_layers
  213. def project_layer(self, lstm_outputs, name=None):
  214. """
  215. """
  216. with tf.variable_scope("project" if not name else name):
  217. with tf.variable_scope("hidden"):
  218. w_tanh = tf.get_variable("w_tanh", shape=[self.bert_hidden, self.lstm_dim],
  219. dtype=tf.float32, initializer=self.initializer, regularizer=tf.contrib.layers.l2_regularizer(0.001))
  220. b_tanh = tf.get_variable("b_tanh", shape=[self.lstm_dim], dtype=tf.float32,
  221. initializer=tf.zeros_initializer())
  222. output = tf.reshape(lstm_outputs, shape=[-1, self.bert_hidden])
  223. hidden = tf.tanh(tf.nn.xw_plus_b(output, w_tanh, b_tanh))
  224. drop_hidden = tf.nn.dropout(hidden, self.dropout)
  225. # project to score of tags
  226. with tf.variable_scope("output"):
  227. w_out = tf.get_variable("w_out", shape=[self.bert_hidden, self.num_tags],
  228. dtype=tf.float32, initializer=self.initializer, regularizer=tf.contrib.layers.l2_regularizer(0.001))
  229. b_out = tf.get_variable("b_out", shape=[self.num_tags], dtype=tf.float32,
  230. initializer=tf.zeros_initializer())
  231. pred = tf.nn.xw_plus_b(drop_hidden, w_out, b_out, name="pred")
  232. self.logits = tf.reshape(pred, [-1, self.num_steps, self.num_tags], name="logits")
  233. def loss_layer(self, project_logits, lengths, name=None):
  234. with tf.variable_scope("crf_loss" if not name else name):
  235. small = -1000.0
  236. start_logits = tf.concat(
  237. [small * tf.ones(shape=[self.batch_size, 1, self.num_tags]), tf.zeros(shape=[self.batch_size, 1, 1])],
  238. axis=-1)
  239. pad_logits = tf.cast(small * tf.ones([self.batch_size, self.num_steps, 1]), tf.float32)
  240. logits = tf.concat([project_logits, pad_logits], axis=-1)
  241. logits = tf.concat([start_logits, logits], axis=1)
  242. targets = tf.concat(
  243. [tf.cast(self.num_tags * tf.ones([self.batch_size, 1]), tf.int32), self.targets], axis=-1)
  244. self.trans = tf.get_variable(
  245. "transitions",
  246. shape=[self.num_tags + 1, self.num_tags + 1],
  247. initializer=self.initializer)
  248. log_likelihood, self.trans = crf_log_likelihood(
  249. inputs=logits,
  250. tag_indices=targets,
  251. transition_params=self.trans,
  252. sequence_lengths=lengths + 1)
  253. return tf.reduce_mean(-log_likelihood)
  254. def initVariables(self):
  255. dict_tensor_values = load(os.path.dirname(__file__)+"/dict_tensor_values.pk")
  256. with self.graph.as_default():
  257. init_op = tf.global_variables_initializer()
  258. self.sess.run(init_op)
  259. '''
  260. trainable_variables = tf.trainable_variables()
  261. for item in trainable_variables:
  262. print(item.name,"prefix/"+item.name in dict_tensor_values.keys())
  263. self.sess.run(tf.assign(item,dict_tensor_values["prefix/"+item.name]))
  264. print((self.sess.run(item)==dict_tensor_values["prefix/"+item.name]).all())
  265. '''
  266. ''''''
  267. for _key in dict_tensor_values.keys():
  268. print("init variable %s"%(_key))
  269. self.sess.run(tf.assign(self.graph.get_tensor_by_name(_key[7:]),dict_tensor_values[_key]))
  270. #print(self.sess.run(tf.nn.embedding_lookup(self.char_embeding, np.array([[1]], dtype=np.int32))))
  271. #print(self.sess.run(self.char_embeding))
  272. return self
  273. def restore(self,path=None):
  274. print("restore weights")
  275. with self.graph.as_default():
  276. saver = tf.train.Saver()
  277. if path is None:
  278. path_add = "0-12/"
  279. path_add = "new_model/"
  280. saver.restore(self.sess, os.path.dirname(__file__)+'/model/'+path_add+'model.ckpt')
  281. '''
  282. path_add = "0-4/"
  283. saver.restore(self.sess, os.path.dirname(__file__)+'/model-server/'+path_add+'model.ckpt')
  284. '''
  285. else:
  286. saver.restore(self.sess,path)
  287. return self
  288. def getNodes(self):
  289. return self.char_inputs,self.targets,self.lengths,self.dropout,self.logits,self.trans,self.loss,self.train_op
  290. def load_graph(path = os.path.dirname(__file__)+"/chinese_L-12_H-768_A-12/bert_model.ckpt.meta"):
  291. with tf.gfile.GFile(path, mode='rb') as f:
  292. graph_def = tf.GraphDef()
  293. graph_def.ParseFromString(f.read())
  294. with tf.Graph().as_default() as graph:
  295. tf.import_graph_def(graph_def, name="prefix")
  296. return graph
  297. def getModel():
  298. _config = BertConfig(vocab_size=21128, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act=gelu, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02)
  299. saver = tf.train.import_meta_graph(meta_graph_or_file=os.path.dirname(__file__)+"/chinese_L-12_H-768_A-12/bert_model.ckpt.meta")
  300. sess = tf.Session()
  301. saver.restore(sess, os.path.dirname(__file__)+"/chinese_L-12_H-768_A-12/bert_model.ckpt")
  302. summaryWriter = tf.summary.FileWriter('log/', sess.graph)
  303. variable_names = [v.name for v in tf.trainable_variables()]
  304. values = sess.run(variable_names)
  305. bert_key_values = dict()
  306. for k, v in zip(variable_names, values):
  307. if re.search("bert",k) is not None:
  308. bert_key_values[k] = v
  309. print("Variable: ", k)
  310. print("Shape: ", v.shape)
  311. save(bert_key_values, "bert_key_values.pk")
  312. print(sess.graph.get_all_collection_keys())
  313. #_model = BertModel(_config, is_training=True, input_ids, input_mask, token_type_ids, use_one_hot_embeddings, scope)
  314. def getBertModel(input_tensor,sess):
  315. print("11",input_tensor)
  316. input_tensor = tf.cast(input_tensor,tf.int32)
  317. _config = BertConfig(vocab_size=21128, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act=gelu, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02)
  318. with sess.graph.as_default():
  319. with sess.as_default():
  320. _model = BertModel(_config,True, input_tensor,scope="bert")
  321. return _model.get_sequence_output()
  322. def restore(sess):
  323. bert_key_values = load(os.path.dirname(__file__)+"/bert_key_values.pk")
  324. variable_names = [v.name for v in tf.trainable_variables()]
  325. print(variable_names)
  326. for key,value in bert_key_values.items():
  327. print(key,value.shape)
  328. sess.run(tf.assign(sess.graph.get_tensor_by_name(key),value))
  329. def getVocab():
  330. dict_word_index = dict()
  331. with codecs.open("chinese_L-12_H-768_A-12/vocab.txt", "r", encoding="utf8") as f:
  332. _idx = 0
  333. while(True):
  334. line = re.sub("[\r\n]","",f.readline())
  335. if _idx>=21128:
  336. break
  337. print(_idx)
  338. dict_word_index[line] = _idx
  339. _idx += 1
  340. save(dict_word_index, "dict_word_index.pk")
  341. if __name__=="__main__":
  342. '''
  343. getModel()
  344. '''
  345. getVocab()