ElectraCRF.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import codecs
  5. import re
  6. import numpy as np
  7. import os
  8. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  9. class CRF(nn.Module):
  10. def __init__(self,num_tags,use_mask=False,id_to_tag=None):
  11. super(CRF,self).__init__()
  12. self.start_transition = nn.Parameter(torch.empty(num_tags))
  13. self.transition = nn.Parameter(torch.empty(num_tags,num_tags))
  14. self.reset_parameters()
  15. self.decode_transition = None
  16. self.use_mask = use_mask
  17. if use_mask:
  18. self.transition_mask = torch.as_tensor(self.init_transition_mask(id_to_tag,(num_tags,num_tags),_type="BMESO"),dtype=torch.float32).to(device)
  19. def get_transition(self,):
  20. if self.use_mask:
  21. trans = torch.minimum(self.transition,self.transition_mask)
  22. return trans
  23. return self.transition
  24. def init_transition_mask(self,id2tag,_shape,_type,_split="_"):
  25. list_legal_trans = []
  26. if _type=="BMESO":
  27. list_legal_trans.append(("B%s"%(_split),"M%s"%(_split)))
  28. list_legal_trans.append(("B%s"%(_split),"E%s"%(_split)))
  29. list_legal_trans.append(("M%s"%(_split),"M%s"%(_split)))
  30. list_legal_trans.append(("M%s"%(_split),"E%s"%(_split)))
  31. list_legal_trans.append(("E","S"))
  32. list_legal_trans.append(("E","O"))
  33. list_legal_trans.append(("E","B"))
  34. list_legal_trans.append(("S","O"))
  35. list_legal_trans.append(("S","B"))
  36. list_legal_trans.append(("O","B"))
  37. list_legal_trans.append(("O","S"))
  38. list_legal_trans.append(("O","O"))
  39. if _type=="BISO":
  40. list_legal_trans.append(("B%s"%(_split),"I%s"%(_split)))
  41. list_legal_trans.append(("B","S"))
  42. list_legal_trans.append(("B","O"))
  43. list_legal_trans.append(("I%s"%_split,"S%s"%_split))
  44. list_legal_trans.append(("I","S"))
  45. list_legal_trans.append(("I","O"))
  46. list_legal_trans.append(("I","B"))
  47. list_legal_trans.append(("S","O"))
  48. list_legal_trans.append(("S","B"))
  49. list_legal_trans.append(("O","B"))
  50. list_legal_trans.append(("O","S"))
  51. list_legal_trans.append(("O","O"))
  52. mask = np.ones(_shape)
  53. for col_index,col_tag in id2tag.items():
  54. for row_index,row_tag in id2tag.items():
  55. is_legal = False
  56. for legal_a,legal_b in list_legal_trans:
  57. if col_tag.startswith(legal_a) and row_tag.startswith(legal_b):
  58. if legal_a[-1]!=_split or legal_b[-1]!=_split:
  59. is_legal = True
  60. break
  61. else:
  62. if len(col_tag.split(_split))>1 and len(row_tag.split(_split))>1:
  63. if col_tag.split(_split)[-1]==row_tag.split(_split)[-1]:
  64. is_legal = True
  65. break
  66. if not is_legal:
  67. mask[col_index,row_index] = -1.0
  68. mask *= 100
  69. return mask
  70. def reset_parameters(self):
  71. '''
  72. 要将初始化的概率限制在一个小的范围,否则loss将会很大导致nan
  73. :return:
  74. '''
  75. nn.init.xavier_normal_(self.transition)
  76. nn.init.uniform_(self.start_transition,-0.1,0.1)
  77. def forward(self,logits,labels,seq_length):
  78. return self.forward_path(logits,labels,seq_length)
  79. #按照batch和path计算loss,可以理解内部计算,但速度较慢
  80. # return self.forward_batch_path(logits,labels,seq_length)
  81. def forward_batch_path(self,logits,labels,seq_length):
  82. '''
  83. :param logits:the observation probabilities shape[batch,max_length,num_tags]
  84. :param labels:the labels shape[batch,max_length
  85. :param seq_length: the seq_length shape[batch]
  86. :return:the Z(x) loss
  87. '''
  88. # generate mask
  89. batch,max_length,num_tags = logits.shape
  90. mask = (torch.arange(max_length)<=(seq_length-1).unsqueeze(1)).to(torch.float).unsqueeze(2)
  91. logits = logits*mask
  92. whole_loss = 0
  93. for i in range(batch):
  94. _loss = self._compute_path_loss(logits[i,...],labels[i,...],seq_length[i])
  95. whole_loss += _loss
  96. return whole_loss
  97. def _compute_path_loss(self,logits,labels,seq_length):
  98. '''
  99. compute the path loss
  100. :param logits:[max_length,num_tags]
  101. :param labels:[max_length
  102. :return:
  103. '''
  104. max_length,num_tags = logits.shape
  105. mask = (torch.arange(max_length)<=(seq_length-1))
  106. #compute the Z(x)
  107. prob = logits[0] +self.start_transition
  108. for _i in range(1,max_length):
  109. prob = torch.where(mask[_i],torch.logsumexp(prob+self.get_transition(),dim=1)+logits[_i],prob)
  110. # 以下形式是错误的,0不是概率为0
  111. # prob = torch.logsumexp(prob+self.get_transition()*mask[_i],dim=1)+logits[_i]*mask[_i]
  112. # 写成以下形式是一样的
  113. # prob = torch.where(mask[_i],torch.logsumexp(prob+self.get_transition()+logits[_i],dim=1),prob)
  114. prob = torch.logsumexp(prob,dim=0)
  115. #generate label [max_length,num_tags]
  116. path_prob = logits[0,labels[0]]+self.start_transition[labels[0]]
  117. last_i = labels[0]
  118. for _i in range(1,max_length):
  119. path_prob += (self.get_transition()[last_i,labels[_i]]+logits[_i,labels[_i]])*mask[_i]
  120. # path_prob = torch.clamp_max(path_prob,1000)
  121. return prob-path_prob
  122. def forward_path(self,logits,labels,seq_length):
  123. """计算给定的标签序列tags的负对数似然
  124. Args:
  125. emissions (torch.Tensor): 发射分数P 形状 (seq_len, batch_size, num_tags), 代表序列中每个单词产生每个标签的得分
  126. tags (torch.LongTensor): 标签序列 如果batch_first=False 形状 (seq_len, batch_size) ,否则 形状为 (batch_size, seq_len)
  127. mask (torch.ByteTensor | None, optional): 表明哪些元素为填充符,和tags的形状一致。 如果batch_first=False 形状 (seq_len, batch_size) ,否则 形状为 (batch_size, seq_len)
  128. 默认为None,表示没有填充符。
  129. reduction (str): 汇聚函数: none|sum|mean|token_mean 。 none:不应用汇聚函数; 默认为sum。
  130. Returns:
  131. torch.Tensor: 输入tags的负对数似然
  132. """
  133. max_length = labels.size(1)
  134. mask = (torch.arange(max_length).to(device)<=(seq_length-1).unsqueeze(1)).to(torch.float)
  135. # 转换为seq维度在前的形式
  136. emissions = logits.transpose(0, 1) # (seq_len, batch_size, num_tags)
  137. tags = labels.transpose(0, 1) # (seq_len, batch_size)
  138. mask = mask.transpose(0, 1) # (seq_len, batch_size)
  139. # 计算标签序列tags的得分
  140. score = self._compute_score(emissions, tags, mask)
  141. # 计算配分函数 partition Z(x)
  142. partition = self._compute_partition(emissions, mask)
  143. return torch.mean(partition - score)
  144. def _compute_score(
  145. self,
  146. emissions: torch.Tensor,
  147. tags: torch.LongTensor,
  148. mask: torch.ByteTensor,
  149. ) -> torch.Tensor:
  150. """计算标签序列tags的得分
  151. Args:
  152. emissions (torch.Tensor): 发射分数P 形状 (seq_len, batch_size, num_tags)
  153. tags (torch.LongTensor): 标签序列 形状 (seq_len, batch_size)
  154. mask (torch.ByteTensor): 表明哪些元素为填充符 形状 (seq_len, batch_size)
  155. Returns:
  156. torch.Tensor: 批次内标签tags的得分, 形状(batch_size,)
  157. """
  158. seq_len, batch_size = tags.shape
  159. # first_tags (batch_size,)
  160. first_tags = tags[0]
  161. # 由start标签转移到批次内所有标签序列第一个标签的得分
  162. score = self.start_transition[first_tags]
  163. # 加上 批次内第一个(index=0)发射得分,即批次内第0个输入产生批次内对应第0个标签的得分
  164. score += emissions[0, torch.arange(batch_size), first_tags]
  165. mask = mask.type_as(emissions) # 类型保持一致
  166. # 这里的index从1开始,也就是第二个时间步开始
  167. for i in range(1, seq_len):
  168. # 第i-1个标签转移到第i个标签的得分 + 第i个单词产生第i个标签的得分
  169. # 乘以mask[i]不需要计算填充单词的得分
  170. score += (
  171. self.get_transition()[tags[i - 1], tags[i]]
  172. + emissions[i, torch.arange(batch_size), tags[i]]
  173. ) * mask[i]
  174. # last_tags = tags[-1] × 这是错误的!,因为可能包含填充单词
  175. valid_last_idx = mask.long().sum(dim=0) - 1 # 有效的最后一个索引
  176. last_tags = tags[valid_last_idx, torch.arange(batch_size)]
  177. # 最后加上最后一个标签转移到end标签的转移得分
  178. return score
  179. def _compute_partition(
  180. self, emissions: torch.Tensor, mask: torch.ByteTensor
  181. ) -> torch.Tensor:
  182. """利用CRF的前向算法计算partition的分数
  183. Args:
  184. emissions (torch.Tensor): 发射分数P 形状 (seq_len, batch_size, num_tags)
  185. mask (torch.ByteTensor): 表明哪些元素为填充符 (seq_len, batch_size)
  186. Returns:
  187. torch.Tensor: 批次内的partition分数 形状(batch_size,)
  188. """
  189. seq_len = emissions.shape[0]
  190. # score (batch_size, num_tags) 对于每个批次来说,第i个元素保存到目前为止以i结尾的所有可能序列的得分
  191. score = self.start_transition.unsqueeze(0) + emissions[0]
  192. for i in range(1, seq_len):
  193. # broadcast_score: (batch_size, num_tags, 1) = (batch_size, pre_tag, current_tag)
  194. # 所有可能的当前标签current_tag广播
  195. broadcast_score = score.unsqueeze(2)
  196. # 广播成 (batch_size, 1, num_tags)
  197. # shape: (batch_size, 1, num_tags)
  198. broadcast_emissions = emissions[i].unsqueeze(1)
  199. # (batch_size, num_tags, num_tags) = (batch_size, num_tags, 1) + (num_tags, num_tags) + (batch_size, 1, num_tags)
  200. current_score = broadcast_score + self.get_transition() + broadcast_emissions
  201. # 在前一时间步标签上求和 -> (batch_size, num_tags)
  202. # 对于每个批次来说,第i个元素保存到目前为止以i结尾的所有可能标签序列的得分
  203. current_score = torch.logsumexp(current_score, dim=1)
  204. # mask[i].unsqueeze(1) -> (batch_size, 1)
  205. # 只有mask[i]是有效标签的current_score才将值设置到score中,否则保持原来的score
  206. score = torch.where(mask[i].bool().unsqueeze(1), current_score, score)
  207. # 加上到end标签的转移得分 end_transitions本身保存的是所有的标签到end标签的得分
  208. # score (batch_size, num_tags)
  209. # 在所有的标签上求(logsumexp)和
  210. # return (batch_size,)
  211. return torch.logsumexp(score, dim=1)
  212. def viterbi_decode(self,score, transition_params):
  213. """Decode the highest scoring sequence of tags outside of TensorFlow.
  214. This should only be used at test time.
  215. Args:
  216. score: A [seq_len, num_tags] matrix of unary potentials.
  217. transition_params: A [num_tags, num_tags] matrix of binary potentials.
  218. Returns:
  219. viterbi: A [seq_len] list of integers containing the highest scoring tag
  220. indices.
  221. viterbi_score: A float containing the score for the Viterbi sequence.
  222. """
  223. trellis = np.zeros_like(score)
  224. backpointers = np.zeros_like(score, dtype=np.int32)
  225. trellis[0] = score[0]
  226. for t in range(1, score.shape[0]):
  227. v = np.expand_dims(trellis[t - 1], 1) + transition_params
  228. trellis[t] = score[t] + np.max(v, 0)
  229. backpointers[t] = np.argmax(v, 0)
  230. viterbi = [np.argmax(trellis[-1])]
  231. for bp in reversed(backpointers[1:]):
  232. viterbi.append(bp[viterbi[-1]])
  233. viterbi.reverse()
  234. viterbi_score = np.max(trellis[-1])
  235. return viterbi, viterbi_score
  236. def get_decode_transition(self):
  237. if self.decode_transition is None:
  238. transition = self.get_transition().detach().cpu().numpy()
  239. start_transition = self.start_transition.detach().cpu().numpy()
  240. num_tags = transition.shape[0]
  241. trans = np.ones((num_tags+1,num_tags+1))
  242. trans *= -1000
  243. trans[:num_tags,:num_tags] = transition
  244. trans[num_tags,:num_tags] = start_transition
  245. self.decode_transition = trans
  246. return self.decode_transition
  247. def decode(self,logits,seq_length):
  248. '''
  249. :param logits:the observation probabilities shape[batch,max_length,num_tags]
  250. :param seq_length:the sequence length [batch,]
  251. :return:the best path and its probability
  252. '''
  253. batch_size,max_length,num_tags = logits.shape
  254. batch_path = []
  255. logits = logits.detach().cpu().numpy()
  256. small = -1000
  257. for _i in range(batch_size):
  258. seq_len = seq_length[_i].item()
  259. scores = logits[_i,:seq_len,...]
  260. start = np.asarray([[small]*num_tags+[0]])
  261. pad = small*np.ones((seq_len,1))
  262. scores = np.concatenate([scores,pad],1)
  263. scores = np.concatenate([start,scores],0)
  264. path,path_score = self.viterbi_decode(scores,self.get_decode_transition())
  265. batch_path.append(path[1:])
  266. return batch_path
  267. def getAcc(self,y_batch,logits,lengths):
  268. index = 0
  269. preds = []
  270. true_tags = []
  271. for path in self.decode(logits,lengths):
  272. preds.append(path)
  273. index += 1
  274. for y, length in zip(y_batch, lengths):
  275. y = y.tolist()
  276. true_tags.append(y[: length])
  277. acc_char = 0
  278. acc_seq = 0
  279. char_len = 0
  280. for pred,y in zip(preds,true_tags):
  281. is_same = True
  282. for a,b in zip(pred,y):
  283. char_len += 1
  284. if a!=b:
  285. is_same = False
  286. else:
  287. acc_char += 1
  288. if is_same:
  289. acc_seq += 1
  290. return acc_char/char_len,acc_seq/len(preds)
  291. class BiLSTMCRF(nn.Module):
  292. def __init__(self,):
  293. super(BiLSTMCRF,self).__init__()
  294. config = {'lstm_dim':100,
  295. 'num_chars':6591,
  296. 'num_tags':25,
  297. 'char_dim':100,
  298. 'lr':0.00002,
  299. 'input_dropout_keep':1.0,
  300. 'optimizer':'adam',
  301. 'clip':5}
  302. # (self.char_to_id,embedding,self.id_to_tag,self.tag_to_id) = load_fool_map()
  303. # config["num_chars"] = len(self.char_to_id.keys())
  304. config["lstm_dim"] = 128
  305. # config["num_tags"] = len(self.id_to_tag.keys())
  306. config["char_dim"] = 128
  307. config["lr"] = 1e-2
  308. self.embedding = nn.Embedding(num_embeddings=config["num_chars"],embedding_dim=config["char_dim"])
  309. self.bilstm = nn.LSTM(input_size=config["char_dim"],hidden_size=config["lstm_dim"],num_layers=2, bidirectional=True)
  310. self.linear = nn.Linear(config["lstm_dim"]*2,config["lstm_dim"])
  311. self.out = nn.Linear(config["lstm_dim"],config["num_tags"])
  312. self.crf = CRF(config["num_tags"])
  313. def forward(self,text_ids):
  314. B,N = text_ids.shape
  315. seq_length = (text_ids!=0).sum(1).squeeze()
  316. x = self.embedding(text_ids)
  317. #pack为lstm所需的变长类型
  318. _seq_length = seq_length.cpu()
  319. x = torch.nn.utils.rnn.pack_padded_sequence(x,_seq_length,batch_first=True,enforce_sorted=False)
  320. seq_length.to(device)
  321. x,(hidden,cell) = self.bilstm(x)
  322. #pad初始形状
  323. x,output_length = torch.nn.utils.rnn.pad_packed_sequence(x,batch_first=True)
  324. # x = self.norm(x.transpose(1,2)).transpose(1,2)
  325. x = self.linear(x)
  326. # x = torch.relu(self.norm(x.transpose(1,2)).transpose(1,2))
  327. x = self.out(x)
  328. return x
  329. from transformers import ElectraModel,ElectraTokenizer
  330. class ElectraCRF(nn.Module):
  331. def __init__(self,hidden_dim= 256,num_tags = 26):
  332. super(ElectraCRF,self).__init__()
  333. model_path = "../bert/chinese-electra-180g-small-discriminator"
  334. self.tokenizers = ElectraTokenizer.from_pretrained(model_path)
  335. self.eletrac = ElectraModel.from_pretrained(model_path)
  336. self.pad_id = self.tokenizers.pad_token_id
  337. self.crf = CRF(num_tags)
  338. self.out = nn.Linear(hidden_dim,num_tags)
  339. def forward(self,inputs):
  340. '''
  341. :param inputs: #[B,N]
  342. :return:
  343. '''
  344. B,N = inputs.shape
  345. sequence_length = (inputs!=self.pad_id).sum(1).squeeze()
  346. mask = (torch.arange(N).unsqueeze(dim=0).to(device)<(sequence_length.unsqueeze(dim=1))).to(torch.float32).unsqueeze(dim=-1)
  347. x = self.eletrac(inputs,attention_mask=mask,output_hidden_states=True).__getitem__("hidden_states")[-1]
  348. out = self.out(x)
  349. return out
  350. from dev.bert.BIDIBert import BIDIBert
  351. class BIDIBertCRF(nn.Module):
  352. def __init__(self,n_src_vocab=6500,d_word_vec=128, n_layers=3, n_head=3, d_k=128, d_v=128,
  353. d_model=128, d_inner=128,num_tags = 26):
  354. super(BIDIBertCRF,self).__init__()
  355. model_path = "../bert/chinese-electra-180g-small-discriminator"
  356. self.tokenizers = ElectraTokenizer.from_pretrained(model_path)
  357. self.pad_id = self.tokenizers.pad_token_id
  358. self.eletrac = BIDIBert(n_src_vocab=len(self.tokenizers.get_vocab()),d_word_vec=128, n_layers=4, n_head=2, d_k=128, d_v=128,
  359. d_model=128, d_inner=128, pad_idx=self.pad_id)
  360. self.crf = CRF(num_tags)
  361. self.out = nn.Linear(d_model,num_tags)
  362. def forward(self,inputs):
  363. '''
  364. :param inputs: #[B,N]
  365. :return:
  366. '''
  367. B,N = inputs.shape
  368. sequence_length = (inputs!=self.pad_id).sum(1)
  369. mask = (torch.arange(N).unsqueeze(dim=0).to(device)<(sequence_length.unsqueeze(dim=1))).to(torch.float32).unsqueeze(dim=-1)
  370. x, = self.eletrac(inputs,attention_mask=mask)
  371. out = self.out(x)
  372. return out
  373. def test_speed():
  374. # model = ElectraCRF()
  375. model = BIDIBertCRF()
  376. a = [[i for i in range(500)]]*10
  377. a = np.array(a,dtype=np.int32)
  378. print(a.shape)
  379. import time
  380. start_time = time.time()
  381. for _i in range(100):
  382. model(torch.from_numpy(a))
  383. print("cost",time.time()-start_time)
  384. model = BiLSTMCRF()
  385. start_time = time.time()
  386. for _i in range(100):
  387. model(torch.from_numpy(a))
  388. print("cost",time.time()-start_time)
  389. if __name__ == '__main__':
  390. test_speed()