import torch import torch.nn as nn import torch.nn.functional as F import codecs import re import numpy as np import os device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class CRF(nn.Module): def __init__(self,num_tags,use_mask=False,id_to_tag=None): super(CRF,self).__init__() self.start_transition = nn.Parameter(torch.empty(num_tags)) self.transition = nn.Parameter(torch.empty(num_tags,num_tags)) self.reset_parameters() self.decode_transition = None self.use_mask = use_mask if use_mask: self.transition_mask = torch.as_tensor(self.init_transition_mask(id_to_tag,(num_tags,num_tags),_type="BMESO"),dtype=torch.float32).to(device) def get_transition(self,): if self.use_mask: trans = torch.minimum(self.transition,self.transition_mask) return trans return self.transition def init_transition_mask(self,id2tag,_shape,_type,_split="_"): list_legal_trans = [] if _type=="BMESO": list_legal_trans.append(("B%s"%(_split),"M%s"%(_split))) list_legal_trans.append(("B%s"%(_split),"E%s"%(_split))) list_legal_trans.append(("M%s"%(_split),"M%s"%(_split))) list_legal_trans.append(("M%s"%(_split),"E%s"%(_split))) list_legal_trans.append(("E","S")) list_legal_trans.append(("E","O")) list_legal_trans.append(("E","B")) list_legal_trans.append(("S","O")) list_legal_trans.append(("S","B")) list_legal_trans.append(("O","B")) list_legal_trans.append(("O","S")) list_legal_trans.append(("O","O")) if _type=="BISO": list_legal_trans.append(("B%s"%(_split),"I%s"%(_split))) list_legal_trans.append(("B","S")) list_legal_trans.append(("B","O")) list_legal_trans.append(("I%s"%_split,"S%s"%_split)) list_legal_trans.append(("I","S")) list_legal_trans.append(("I","O")) list_legal_trans.append(("I","B")) list_legal_trans.append(("S","O")) list_legal_trans.append(("S","B")) list_legal_trans.append(("O","B")) list_legal_trans.append(("O","S")) list_legal_trans.append(("O","O")) mask = np.ones(_shape) for col_index,col_tag in id2tag.items(): for row_index,row_tag in id2tag.items(): is_legal = False for legal_a,legal_b in list_legal_trans: if col_tag.startswith(legal_a) and row_tag.startswith(legal_b): if legal_a[-1]!=_split or legal_b[-1]!=_split: is_legal = True break else: if len(col_tag.split(_split))>1 and len(row_tag.split(_split))>1: if col_tag.split(_split)[-1]==row_tag.split(_split)[-1]: is_legal = True break if not is_legal: mask[col_index,row_index] = -1.0 mask *= 100 return mask def reset_parameters(self): ''' 要将初始化的概率限制在一个小的范围,否则loss将会很大导致nan :return: ''' nn.init.xavier_normal_(self.transition) nn.init.uniform_(self.start_transition,-0.1,0.1) def forward(self,logits,labels,seq_length): return self.forward_path(logits,labels,seq_length) #按照batch和path计算loss,可以理解内部计算,但速度较慢 # return self.forward_batch_path(logits,labels,seq_length) def forward_batch_path(self,logits,labels,seq_length): ''' :param logits:the observation probabilities shape[batch,max_length,num_tags] :param labels:the labels shape[batch,max_length :param seq_length: the seq_length shape[batch] :return:the Z(x) loss ''' # generate mask batch,max_length,num_tags = logits.shape mask = (torch.arange(max_length)<=(seq_length-1).unsqueeze(1)).to(torch.float).unsqueeze(2) logits = logits*mask whole_loss = 0 for i in range(batch): _loss = self._compute_path_loss(logits[i,...],labels[i,...],seq_length[i]) whole_loss += _loss return whole_loss def _compute_path_loss(self,logits,labels,seq_length): ''' compute the path loss :param logits:[max_length,num_tags] :param labels:[max_length :return: ''' max_length,num_tags = logits.shape mask = (torch.arange(max_length)<=(seq_length-1)) #compute the Z(x) prob = logits[0] +self.start_transition for _i in range(1,max_length): prob = torch.where(mask[_i],torch.logsumexp(prob+self.get_transition(),dim=1)+logits[_i],prob) # 以下形式是错误的,0不是概率为0 # prob = torch.logsumexp(prob+self.get_transition()*mask[_i],dim=1)+logits[_i]*mask[_i] # 写成以下形式是一样的 # prob = torch.where(mask[_i],torch.logsumexp(prob+self.get_transition()+logits[_i],dim=1),prob) prob = torch.logsumexp(prob,dim=0) #generate label [max_length,num_tags] path_prob = logits[0,labels[0]]+self.start_transition[labels[0]] last_i = labels[0] for _i in range(1,max_length): path_prob += (self.get_transition()[last_i,labels[_i]]+logits[_i,labels[_i]])*mask[_i] # path_prob = torch.clamp_max(path_prob,1000) return prob-path_prob def forward_path(self,logits,labels,seq_length): """计算给定的标签序列tags的负对数似然 Args: emissions (torch.Tensor): 发射分数P 形状 (seq_len, batch_size, num_tags), 代表序列中每个单词产生每个标签的得分 tags (torch.LongTensor): 标签序列 如果batch_first=False 形状 (seq_len, batch_size) ,否则 形状为 (batch_size, seq_len) mask (torch.ByteTensor | None, optional): 表明哪些元素为填充符,和tags的形状一致。 如果batch_first=False 形状 (seq_len, batch_size) ,否则 形状为 (batch_size, seq_len) 默认为None,表示没有填充符。 reduction (str): 汇聚函数: none|sum|mean|token_mean 。 none:不应用汇聚函数; 默认为sum。 Returns: torch.Tensor: 输入tags的负对数似然 """ max_length = labels.size(1) mask = (torch.arange(max_length).to(device)<=(seq_length-1).unsqueeze(1)).to(torch.float) # 转换为seq维度在前的形式 emissions = logits.transpose(0, 1) # (seq_len, batch_size, num_tags) tags = labels.transpose(0, 1) # (seq_len, batch_size) mask = mask.transpose(0, 1) # (seq_len, batch_size) # 计算标签序列tags的得分 score = self._compute_score(emissions, tags, mask) # 计算配分函数 partition Z(x) partition = self._compute_partition(emissions, mask) return torch.mean(partition - score) def _compute_score( self, emissions: torch.Tensor, tags: torch.LongTensor, mask: torch.ByteTensor, ) -> torch.Tensor: """计算标签序列tags的得分 Args: emissions (torch.Tensor): 发射分数P 形状 (seq_len, batch_size, num_tags) tags (torch.LongTensor): 标签序列 形状 (seq_len, batch_size) mask (torch.ByteTensor): 表明哪些元素为填充符 形状 (seq_len, batch_size) Returns: torch.Tensor: 批次内标签tags的得分, 形状(batch_size,) """ seq_len, batch_size = tags.shape # first_tags (batch_size,) first_tags = tags[0] # 由start标签转移到批次内所有标签序列第一个标签的得分 score = self.start_transition[first_tags] # 加上 批次内第一个(index=0)发射得分,即批次内第0个输入产生批次内对应第0个标签的得分 score += emissions[0, torch.arange(batch_size), first_tags] mask = mask.type_as(emissions) # 类型保持一致 # 这里的index从1开始,也就是第二个时间步开始 for i in range(1, seq_len): # 第i-1个标签转移到第i个标签的得分 + 第i个单词产生第i个标签的得分 # 乘以mask[i]不需要计算填充单词的得分 score += ( self.get_transition()[tags[i - 1], tags[i]] + emissions[i, torch.arange(batch_size), tags[i]] ) * mask[i] # last_tags = tags[-1] × 这是错误的!,因为可能包含填充单词 valid_last_idx = mask.long().sum(dim=0) - 1 # 有效的最后一个索引 last_tags = tags[valid_last_idx, torch.arange(batch_size)] # 最后加上最后一个标签转移到end标签的转移得分 return score def _compute_partition( self, emissions: torch.Tensor, mask: torch.ByteTensor ) -> torch.Tensor: """利用CRF的前向算法计算partition的分数 Args: emissions (torch.Tensor): 发射分数P 形状 (seq_len, batch_size, num_tags) mask (torch.ByteTensor): 表明哪些元素为填充符 (seq_len, batch_size) Returns: torch.Tensor: 批次内的partition分数 形状(batch_size,) """ seq_len = emissions.shape[0] # score (batch_size, num_tags) 对于每个批次来说,第i个元素保存到目前为止以i结尾的所有可能序列的得分 score = self.start_transition.unsqueeze(0) + emissions[0] for i in range(1, seq_len): # broadcast_score: (batch_size, num_tags, 1) = (batch_size, pre_tag, current_tag) # 所有可能的当前标签current_tag广播 broadcast_score = score.unsqueeze(2) # 广播成 (batch_size, 1, num_tags) # shape: (batch_size, 1, num_tags) broadcast_emissions = emissions[i].unsqueeze(1) # (batch_size, num_tags, num_tags) = (batch_size, num_tags, 1) + (num_tags, num_tags) + (batch_size, 1, num_tags) current_score = broadcast_score + self.get_transition() + broadcast_emissions # 在前一时间步标签上求和 -> (batch_size, num_tags) # 对于每个批次来说,第i个元素保存到目前为止以i结尾的所有可能标签序列的得分 current_score = torch.logsumexp(current_score, dim=1) # mask[i].unsqueeze(1) -> (batch_size, 1) # 只有mask[i]是有效标签的current_score才将值设置到score中,否则保持原来的score score = torch.where(mask[i].bool().unsqueeze(1), current_score, score) # 加上到end标签的转移得分 end_transitions本身保存的是所有的标签到end标签的得分 # score (batch_size, num_tags) # 在所有的标签上求(logsumexp)和 # return (batch_size,) return torch.logsumexp(score, dim=1) def viterbi_decode(self,score, transition_params): """Decode the highest scoring sequence of tags outside of TensorFlow. This should only be used at test time. Args: score: A [seq_len, num_tags] matrix of unary potentials. transition_params: A [num_tags, num_tags] matrix of binary potentials. Returns: viterbi: A [seq_len] list of integers containing the highest scoring tag indices. viterbi_score: A float containing the score for the Viterbi sequence. """ trellis = np.zeros_like(score) backpointers = np.zeros_like(score, dtype=np.int32) trellis[0] = score[0] for t in range(1, score.shape[0]): v = np.expand_dims(trellis[t - 1], 1) + transition_params trellis[t] = score[t] + np.max(v, 0) backpointers[t] = np.argmax(v, 0) viterbi = [np.argmax(trellis[-1])] for bp in reversed(backpointers[1:]): viterbi.append(bp[viterbi[-1]]) viterbi.reverse() viterbi_score = np.max(trellis[-1]) return viterbi, viterbi_score def get_decode_transition(self): if self.decode_transition is None: transition = self.get_transition().detach().cpu().numpy() start_transition = self.start_transition.detach().cpu().numpy() num_tags = transition.shape[0] trans = np.ones((num_tags+1,num_tags+1)) trans *= -1000 trans[:num_tags,:num_tags] = transition trans[num_tags,:num_tags] = start_transition self.decode_transition = trans return self.decode_transition def decode(self,logits,seq_length): ''' :param logits:the observation probabilities shape[batch,max_length,num_tags] :param seq_length:the sequence length [batch,] :return:the best path and its probability ''' batch_size,max_length,num_tags = logits.shape batch_path = [] logits = logits.detach().cpu().numpy() small = -1000 for _i in range(batch_size): seq_len = seq_length[_i].item() scores = logits[_i,:seq_len,...] start = np.asarray([[small]*num_tags+[0]]) pad = small*np.ones((seq_len,1)) scores = np.concatenate([scores,pad],1) scores = np.concatenate([start,scores],0) path,path_score = self.viterbi_decode(scores,self.get_decode_transition()) batch_path.append(path[1:]) return batch_path def getAcc(self,y_batch,logits,lengths): index = 0 preds = [] true_tags = [] for path in self.decode(logits,lengths): preds.append(path) index += 1 for y, length in zip(y_batch, lengths): y = y.tolist() true_tags.append(y[: length]) acc_char = 0 acc_seq = 0 char_len = 0 for pred,y in zip(preds,true_tags): is_same = True for a,b in zip(pred,y): char_len += 1 if a!=b: is_same = False else: acc_char += 1 if is_same: acc_seq += 1 return acc_char/char_len,acc_seq/len(preds) class BiLSTMCRF(nn.Module): def __init__(self,): super(BiLSTMCRF,self).__init__() config = {'lstm_dim':100, 'num_chars':6591, 'num_tags':25, 'char_dim':100, 'lr':0.00002, 'input_dropout_keep':1.0, 'optimizer':'adam', 'clip':5} # (self.char_to_id,embedding,self.id_to_tag,self.tag_to_id) = load_fool_map() # config["num_chars"] = len(self.char_to_id.keys()) config["lstm_dim"] = 128 # config["num_tags"] = len(self.id_to_tag.keys()) config["char_dim"] = 128 config["lr"] = 1e-2 self.embedding = nn.Embedding(num_embeddings=config["num_chars"],embedding_dim=config["char_dim"]) self.bilstm = nn.LSTM(input_size=config["char_dim"],hidden_size=config["lstm_dim"],num_layers=2, bidirectional=True) self.linear = nn.Linear(config["lstm_dim"]*2,config["lstm_dim"]) self.out = nn.Linear(config["lstm_dim"],config["num_tags"]) self.crf = CRF(config["num_tags"]) def forward(self,text_ids): B,N = text_ids.shape seq_length = (text_ids!=0).sum(1).squeeze() x = self.embedding(text_ids) #pack为lstm所需的变长类型 _seq_length = seq_length.cpu() x = torch.nn.utils.rnn.pack_padded_sequence(x,_seq_length,batch_first=True,enforce_sorted=False) seq_length.to(device) x,(hidden,cell) = self.bilstm(x) #pad初始形状 x,output_length = torch.nn.utils.rnn.pad_packed_sequence(x,batch_first=True) # x = self.norm(x.transpose(1,2)).transpose(1,2) x = self.linear(x) # x = torch.relu(self.norm(x.transpose(1,2)).transpose(1,2)) x = self.out(x) return x from transformers import ElectraModel,ElectraTokenizer class ElectraCRF(nn.Module): def __init__(self,hidden_dim= 256,num_tags = 26): super(ElectraCRF,self).__init__() model_path = "../bert/chinese-electra-180g-small-discriminator" self.tokenizers = ElectraTokenizer.from_pretrained(model_path) self.eletrac = ElectraModel.from_pretrained(model_path) self.pad_id = self.tokenizers.pad_token_id self.crf = CRF(num_tags) self.out = nn.Linear(hidden_dim,num_tags) def forward(self,inputs): ''' :param inputs: #[B,N] :return: ''' B,N = inputs.shape sequence_length = (inputs!=self.pad_id).sum(1).squeeze() mask = (torch.arange(N).unsqueeze(dim=0).to(device)<(sequence_length.unsqueeze(dim=1))).to(torch.float32).unsqueeze(dim=-1) x = self.eletrac(inputs,attention_mask=mask,output_hidden_states=True).__getitem__("hidden_states")[-1] out = self.out(x) return out from dev.bert.BIDIBert import BIDIBert class BIDIBertCRF(nn.Module): def __init__(self,n_src_vocab=6500,d_word_vec=128, n_layers=3, n_head=3, d_k=128, d_v=128, d_model=128, d_inner=128,num_tags = 26): super(BIDIBertCRF,self).__init__() model_path = "../bert/chinese-electra-180g-small-discriminator" self.tokenizers = ElectraTokenizer.from_pretrained(model_path) self.pad_id = self.tokenizers.pad_token_id self.eletrac = BIDIBert(n_src_vocab=len(self.tokenizers.get_vocab()),d_word_vec=128, n_layers=4, n_head=2, d_k=128, d_v=128, d_model=128, d_inner=128, pad_idx=self.pad_id) self.crf = CRF(num_tags) self.out = nn.Linear(d_model,num_tags) def forward(self,inputs): ''' :param inputs: #[B,N] :return: ''' B,N = inputs.shape sequence_length = (inputs!=self.pad_id).sum(1) mask = (torch.arange(N).unsqueeze(dim=0).to(device)<(sequence_length.unsqueeze(dim=1))).to(torch.float32).unsqueeze(dim=-1) x, = self.eletrac(inputs,attention_mask=mask) out = self.out(x) return out def test_speed(): # model = ElectraCRF() model = BIDIBertCRF() a = [[i for i in range(500)]]*10 a = np.array(a,dtype=np.int32) print(a.shape) import time start_time = time.time() for _i in range(100): model(torch.from_numpy(a)) print("cost",time.time()-start_time) model = BiLSTMCRF() start_time = time.time() for _i in range(100): model(torch.from_numpy(a)) print("cost",time.time()-start_time) if __name__ == '__main__': test_speed()