123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import codecs
- import re
- import numpy as np
- import os
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- class CRF(nn.Module):
- def __init__(self,num_tags,use_mask=False,id_to_tag=None):
- super(CRF,self).__init__()
- self.start_transition = nn.Parameter(torch.empty(num_tags))
- self.transition = nn.Parameter(torch.empty(num_tags,num_tags))
- self.reset_parameters()
- self.decode_transition = None
- self.use_mask = use_mask
- if use_mask:
- self.transition_mask = torch.as_tensor(self.init_transition_mask(id_to_tag,(num_tags,num_tags),_type="BMESO"),dtype=torch.float32).to(device)
- def get_transition(self,):
- if self.use_mask:
- trans = torch.minimum(self.transition,self.transition_mask)
- return trans
- return self.transition
- def init_transition_mask(self,id2tag,_shape,_type,_split="_"):
- list_legal_trans = []
- if _type=="BMESO":
- list_legal_trans.append(("B%s"%(_split),"M%s"%(_split)))
- list_legal_trans.append(("B%s"%(_split),"E%s"%(_split)))
- list_legal_trans.append(("M%s"%(_split),"M%s"%(_split)))
- list_legal_trans.append(("M%s"%(_split),"E%s"%(_split)))
- list_legal_trans.append(("E","S"))
- list_legal_trans.append(("E","O"))
- list_legal_trans.append(("E","B"))
- list_legal_trans.append(("S","O"))
- list_legal_trans.append(("S","B"))
- list_legal_trans.append(("O","B"))
- list_legal_trans.append(("O","S"))
- list_legal_trans.append(("O","O"))
- if _type=="BISO":
- list_legal_trans.append(("B%s"%(_split),"I%s"%(_split)))
- list_legal_trans.append(("B","S"))
- list_legal_trans.append(("B","O"))
- list_legal_trans.append(("I%s"%_split,"S%s"%_split))
- list_legal_trans.append(("I","S"))
- list_legal_trans.append(("I","O"))
- list_legal_trans.append(("I","B"))
- list_legal_trans.append(("S","O"))
- list_legal_trans.append(("S","B"))
- list_legal_trans.append(("O","B"))
- list_legal_trans.append(("O","S"))
- list_legal_trans.append(("O","O"))
- mask = np.ones(_shape)
- for col_index,col_tag in id2tag.items():
- for row_index,row_tag in id2tag.items():
- is_legal = False
- for legal_a,legal_b in list_legal_trans:
- if col_tag.startswith(legal_a) and row_tag.startswith(legal_b):
- if legal_a[-1]!=_split or legal_b[-1]!=_split:
- is_legal = True
- break
- else:
- if len(col_tag.split(_split))>1 and len(row_tag.split(_split))>1:
- if col_tag.split(_split)[-1]==row_tag.split(_split)[-1]:
- is_legal = True
- break
- if not is_legal:
- mask[col_index,row_index] = -1.0
- mask *= 100
- return mask
- def reset_parameters(self):
- '''
- 要将初始化的概率限制在一个小的范围,否则loss将会很大导致nan
- :return:
- '''
- nn.init.xavier_normal_(self.transition)
- nn.init.uniform_(self.start_transition,-0.1,0.1)
- def forward(self,logits,labels,seq_length):
- return self.forward_path(logits,labels,seq_length)
- #按照batch和path计算loss,可以理解内部计算,但速度较慢
- # return self.forward_batch_path(logits,labels,seq_length)
- def forward_batch_path(self,logits,labels,seq_length):
- '''
- :param logits:the observation probabilities shape[batch,max_length,num_tags]
- :param labels:the labels shape[batch,max_length
- :param seq_length: the seq_length shape[batch]
- :return:the Z(x) loss
- '''
- # generate mask
- batch,max_length,num_tags = logits.shape
- mask = (torch.arange(max_length)<=(seq_length-1).unsqueeze(1)).to(torch.float).unsqueeze(2)
- logits = logits*mask
- whole_loss = 0
- for i in range(batch):
- _loss = self._compute_path_loss(logits[i,...],labels[i,...],seq_length[i])
- whole_loss += _loss
- return whole_loss
- def _compute_path_loss(self,logits,labels,seq_length):
- '''
- compute the path loss
- :param logits:[max_length,num_tags]
- :param labels:[max_length
- :return:
- '''
- max_length,num_tags = logits.shape
- mask = (torch.arange(max_length)<=(seq_length-1))
- #compute the Z(x)
- prob = logits[0] +self.start_transition
- for _i in range(1,max_length):
- prob = torch.where(mask[_i],torch.logsumexp(prob+self.get_transition(),dim=1)+logits[_i],prob)
- # 以下形式是错误的,0不是概率为0
- # prob = torch.logsumexp(prob+self.get_transition()*mask[_i],dim=1)+logits[_i]*mask[_i]
- # 写成以下形式是一样的
- # prob = torch.where(mask[_i],torch.logsumexp(prob+self.get_transition()+logits[_i],dim=1),prob)
- prob = torch.logsumexp(prob,dim=0)
- #generate label [max_length,num_tags]
- path_prob = logits[0,labels[0]]+self.start_transition[labels[0]]
- last_i = labels[0]
- for _i in range(1,max_length):
- path_prob += (self.get_transition()[last_i,labels[_i]]+logits[_i,labels[_i]])*mask[_i]
- # path_prob = torch.clamp_max(path_prob,1000)
- return prob-path_prob
- def forward_path(self,logits,labels,seq_length):
- """计算给定的标签序列tags的负对数似然
- Args:
- emissions (torch.Tensor): 发射分数P 形状 (seq_len, batch_size, num_tags), 代表序列中每个单词产生每个标签的得分
- tags (torch.LongTensor): 标签序列 如果batch_first=False 形状 (seq_len, batch_size) ,否则 形状为 (batch_size, seq_len)
- mask (torch.ByteTensor | None, optional): 表明哪些元素为填充符,和tags的形状一致。 如果batch_first=False 形状 (seq_len, batch_size) ,否则 形状为 (batch_size, seq_len)
- 默认为None,表示没有填充符。
- reduction (str): 汇聚函数: none|sum|mean|token_mean 。 none:不应用汇聚函数; 默认为sum。
- Returns:
- torch.Tensor: 输入tags的负对数似然
- """
- max_length = labels.size(1)
- mask = (torch.arange(max_length).to(device)<=(seq_length-1).unsqueeze(1)).to(torch.float)
- # 转换为seq维度在前的形式
- emissions = logits.transpose(0, 1) # (seq_len, batch_size, num_tags)
- tags = labels.transpose(0, 1) # (seq_len, batch_size)
- mask = mask.transpose(0, 1) # (seq_len, batch_size)
- # 计算标签序列tags的得分
- score = self._compute_score(emissions, tags, mask)
- # 计算配分函数 partition Z(x)
- partition = self._compute_partition(emissions, mask)
- return torch.mean(partition - score)
- def _compute_score(
- self,
- emissions: torch.Tensor,
- tags: torch.LongTensor,
- mask: torch.ByteTensor,
- ) -> torch.Tensor:
- """计算标签序列tags的得分
- Args:
- emissions (torch.Tensor): 发射分数P 形状 (seq_len, batch_size, num_tags)
- tags (torch.LongTensor): 标签序列 形状 (seq_len, batch_size)
- mask (torch.ByteTensor): 表明哪些元素为填充符 形状 (seq_len, batch_size)
- Returns:
- torch.Tensor: 批次内标签tags的得分, 形状(batch_size,)
- """
- seq_len, batch_size = tags.shape
- # first_tags (batch_size,)
- first_tags = tags[0]
- # 由start标签转移到批次内所有标签序列第一个标签的得分
- score = self.start_transition[first_tags]
- # 加上 批次内第一个(index=0)发射得分,即批次内第0个输入产生批次内对应第0个标签的得分
- score += emissions[0, torch.arange(batch_size), first_tags]
- mask = mask.type_as(emissions) # 类型保持一致
- # 这里的index从1开始,也就是第二个时间步开始
- for i in range(1, seq_len):
- # 第i-1个标签转移到第i个标签的得分 + 第i个单词产生第i个标签的得分
- # 乘以mask[i]不需要计算填充单词的得分
- score += (
- self.get_transition()[tags[i - 1], tags[i]]
- + emissions[i, torch.arange(batch_size), tags[i]]
- ) * mask[i]
- # last_tags = tags[-1] × 这是错误的!,因为可能包含填充单词
- valid_last_idx = mask.long().sum(dim=0) - 1 # 有效的最后一个索引
- last_tags = tags[valid_last_idx, torch.arange(batch_size)]
- # 最后加上最后一个标签转移到end标签的转移得分
- return score
- def _compute_partition(
- self, emissions: torch.Tensor, mask: torch.ByteTensor
- ) -> torch.Tensor:
- """利用CRF的前向算法计算partition的分数
- Args:
- emissions (torch.Tensor): 发射分数P 形状 (seq_len, batch_size, num_tags)
- mask (torch.ByteTensor): 表明哪些元素为填充符 (seq_len, batch_size)
- Returns:
- torch.Tensor: 批次内的partition分数 形状(batch_size,)
- """
- seq_len = emissions.shape[0]
- # score (batch_size, num_tags) 对于每个批次来说,第i个元素保存到目前为止以i结尾的所有可能序列的得分
- score = self.start_transition.unsqueeze(0) + emissions[0]
- for i in range(1, seq_len):
- # broadcast_score: (batch_size, num_tags, 1) = (batch_size, pre_tag, current_tag)
- # 所有可能的当前标签current_tag广播
- broadcast_score = score.unsqueeze(2)
- # 广播成 (batch_size, 1, num_tags)
- # shape: (batch_size, 1, num_tags)
- broadcast_emissions = emissions[i].unsqueeze(1)
- # (batch_size, num_tags, num_tags) = (batch_size, num_tags, 1) + (num_tags, num_tags) + (batch_size, 1, num_tags)
- current_score = broadcast_score + self.get_transition() + broadcast_emissions
- # 在前一时间步标签上求和 -> (batch_size, num_tags)
- # 对于每个批次来说,第i个元素保存到目前为止以i结尾的所有可能标签序列的得分
- current_score = torch.logsumexp(current_score, dim=1)
- # mask[i].unsqueeze(1) -> (batch_size, 1)
- # 只有mask[i]是有效标签的current_score才将值设置到score中,否则保持原来的score
- score = torch.where(mask[i].bool().unsqueeze(1), current_score, score)
- # 加上到end标签的转移得分 end_transitions本身保存的是所有的标签到end标签的得分
- # score (batch_size, num_tags)
- # 在所有的标签上求(logsumexp)和
- # return (batch_size,)
- return torch.logsumexp(score, dim=1)
- def viterbi_decode(self,score, transition_params):
- """Decode the highest scoring sequence of tags outside of TensorFlow.
- This should only be used at test time.
- Args:
- score: A [seq_len, num_tags] matrix of unary potentials.
- transition_params: A [num_tags, num_tags] matrix of binary potentials.
- Returns:
- viterbi: A [seq_len] list of integers containing the highest scoring tag
- indices.
- viterbi_score: A float containing the score for the Viterbi sequence.
- """
- trellis = np.zeros_like(score)
- backpointers = np.zeros_like(score, dtype=np.int32)
- trellis[0] = score[0]
- for t in range(1, score.shape[0]):
- v = np.expand_dims(trellis[t - 1], 1) + transition_params
- trellis[t] = score[t] + np.max(v, 0)
- backpointers[t] = np.argmax(v, 0)
- viterbi = [np.argmax(trellis[-1])]
- for bp in reversed(backpointers[1:]):
- viterbi.append(bp[viterbi[-1]])
- viterbi.reverse()
- viterbi_score = np.max(trellis[-1])
- return viterbi, viterbi_score
- def get_decode_transition(self):
- if self.decode_transition is None:
- transition = self.get_transition().detach().cpu().numpy()
- start_transition = self.start_transition.detach().cpu().numpy()
- num_tags = transition.shape[0]
- trans = np.ones((num_tags+1,num_tags+1))
- trans *= -1000
- trans[:num_tags,:num_tags] = transition
- trans[num_tags,:num_tags] = start_transition
- self.decode_transition = trans
- return self.decode_transition
- def decode(self,logits,seq_length):
- '''
- :param logits:the observation probabilities shape[batch,max_length,num_tags]
- :param seq_length:the sequence length [batch,]
- :return:the best path and its probability
- '''
- batch_size,max_length,num_tags = logits.shape
- batch_path = []
- logits = logits.detach().cpu().numpy()
- small = -1000
- for _i in range(batch_size):
- seq_len = seq_length[_i].item()
- scores = logits[_i,:seq_len,...]
- start = np.asarray([[small]*num_tags+[0]])
- pad = small*np.ones((seq_len,1))
- scores = np.concatenate([scores,pad],1)
- scores = np.concatenate([start,scores],0)
- path,path_score = self.viterbi_decode(scores,self.get_decode_transition())
- batch_path.append(path[1:])
- return batch_path
- def getAcc(self,y_batch,logits,lengths):
- index = 0
- preds = []
- true_tags = []
- for path in self.decode(logits,lengths):
- preds.append(path)
- index += 1
- for y, length in zip(y_batch, lengths):
- y = y.tolist()
- true_tags.append(y[: length])
- acc_char = 0
- acc_seq = 0
- char_len = 0
- for pred,y in zip(preds,true_tags):
- is_same = True
- for a,b in zip(pred,y):
- char_len += 1
- if a!=b:
- is_same = False
- else:
- acc_char += 1
- if is_same:
- acc_seq += 1
- return acc_char/char_len,acc_seq/len(preds)
- class BiLSTMCRF(nn.Module):
- def __init__(self,):
- super(BiLSTMCRF,self).__init__()
- config = {'lstm_dim':100,
- 'num_chars':6591,
- 'num_tags':25,
- 'char_dim':100,
- 'lr':0.00002,
- 'input_dropout_keep':1.0,
- 'optimizer':'adam',
- 'clip':5}
- # (self.char_to_id,embedding,self.id_to_tag,self.tag_to_id) = load_fool_map()
- # config["num_chars"] = len(self.char_to_id.keys())
- config["lstm_dim"] = 128
- # config["num_tags"] = len(self.id_to_tag.keys())
- config["char_dim"] = 128
- config["lr"] = 1e-2
- self.embedding = nn.Embedding(num_embeddings=config["num_chars"],embedding_dim=config["char_dim"])
- self.bilstm = nn.LSTM(input_size=config["char_dim"],hidden_size=config["lstm_dim"],num_layers=2, bidirectional=True)
- self.linear = nn.Linear(config["lstm_dim"]*2,config["lstm_dim"])
- self.out = nn.Linear(config["lstm_dim"],config["num_tags"])
- self.crf = CRF(config["num_tags"])
- def forward(self,text_ids):
- B,N = text_ids.shape
- seq_length = (text_ids!=0).sum(1).squeeze()
- x = self.embedding(text_ids)
- #pack为lstm所需的变长类型
- _seq_length = seq_length.cpu()
- x = torch.nn.utils.rnn.pack_padded_sequence(x,_seq_length,batch_first=True,enforce_sorted=False)
- seq_length.to(device)
- x,(hidden,cell) = self.bilstm(x)
- #pad初始形状
- x,output_length = torch.nn.utils.rnn.pad_packed_sequence(x,batch_first=True)
- # x = self.norm(x.transpose(1,2)).transpose(1,2)
- x = self.linear(x)
- # x = torch.relu(self.norm(x.transpose(1,2)).transpose(1,2))
- x = self.out(x)
- return x
- from transformers import ElectraModel,ElectraTokenizer
- class ElectraCRF(nn.Module):
- def __init__(self,hidden_dim= 256,num_tags = 26):
- super(ElectraCRF,self).__init__()
- model_path = "../bert/chinese-electra-180g-small-discriminator"
- self.tokenizers = ElectraTokenizer.from_pretrained(model_path)
- self.eletrac = ElectraModel.from_pretrained(model_path)
- self.pad_id = self.tokenizers.pad_token_id
- self.crf = CRF(num_tags)
- self.out = nn.Linear(hidden_dim,num_tags)
- def forward(self,inputs):
- '''
- :param inputs: #[B,N]
- :return:
- '''
- B,N = inputs.shape
- sequence_length = (inputs!=self.pad_id).sum(1).squeeze()
- mask = (torch.arange(N).unsqueeze(dim=0).to(device)<(sequence_length.unsqueeze(dim=1))).to(torch.float32).unsqueeze(dim=-1)
- x = self.eletrac(inputs,attention_mask=mask,output_hidden_states=True).__getitem__("hidden_states")[-1]
- out = self.out(x)
- return out
- from dev.bert.BIDIBert import BIDIBert
- class BIDIBertCRF(nn.Module):
- def __init__(self,n_src_vocab=6500,d_word_vec=128, n_layers=3, n_head=3, d_k=128, d_v=128,
- d_model=128, d_inner=128,num_tags = 26):
- super(BIDIBertCRF,self).__init__()
- model_path = "../bert/chinese-electra-180g-small-discriminator"
- self.tokenizers = ElectraTokenizer.from_pretrained(model_path)
- self.pad_id = self.tokenizers.pad_token_id
- self.eletrac = BIDIBert(n_src_vocab=len(self.tokenizers.get_vocab()),d_word_vec=128, n_layers=4, n_head=2, d_k=128, d_v=128,
- d_model=128, d_inner=128, pad_idx=self.pad_id)
- self.crf = CRF(num_tags)
- self.out = nn.Linear(d_model,num_tags)
- def forward(self,inputs):
- '''
- :param inputs: #[B,N]
- :return:
- '''
- B,N = inputs.shape
- sequence_length = (inputs!=self.pad_id).sum(1)
- mask = (torch.arange(N).unsqueeze(dim=0).to(device)<(sequence_length.unsqueeze(dim=1))).to(torch.float32).unsqueeze(dim=-1)
- x, = self.eletrac(inputs,attention_mask=mask)
- out = self.out(x)
- return out
- def test_speed():
- # model = ElectraCRF()
- model = BIDIBertCRF()
- a = [[i for i in range(500)]]*10
- a = np.array(a,dtype=np.int32)
- print(a.shape)
- import time
- start_time = time.time()
- for _i in range(100):
- model(torch.from_numpy(a))
- print("cost",time.time()-start_time)
- model = BiLSTMCRF()
- start_time = time.time()
- for _i in range(100):
- model(torch.from_numpy(a))
- print("cost",time.time()-start_time)
- if __name__ == '__main__':
- test_speed()
|