import torch import torch.nn as nn import torch.nn.functional as F import numpy as np device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class PositionalEncoding(nn.Module): def __init__(self,dim_hid): super(PositionalEncoding,self).__init__() base_array = np.array([np.power(10000,2*(hid_j//2)/dim_hid) for hid_j in range(dim_hid)]) self.base_tensor = torch.from_numpy(base_array).to(torch.float32).to(device) #[1,D] def forward(self,x): # x(B,N,d) B,N,d = x.shape pos = torch.arange(N).unsqueeze(-1).to(torch.float32).to(device) #[N,1] pos = pos/self.base_tensor pos = pos.unsqueeze(0) pos[:,:,0::2] = torch.sin(pos[:,:,0::2]) pos[:,:,1::2] = torch.cos(pos[:,:,1::2]) return x+pos class ScaledDotProductAttention(nn.Module): ''' Scaled Dot-Product Attention ''' def __init__(self, temperature, attn_dropout=0.1): super().__init__() self.temperature = temperature self.dropout = nn.Dropout(attn_dropout) def forward(self, q, k, v, mask=None): # print(q.shape,k.shape) attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) if mask is not None: attn = attn.masked_fill(mask == 0, -1e9) attn = self.dropout(torch.softmax(attn, dim=-1)) output = torch.matmul(attn, v) return output, attn class MultiHeadAttention(nn.Module): ''' Multi-Head Attention module ''' def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): super().__init__() self.n_head = n_head self.d_k = d_k self.d_v = d_v self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) self.fc = nn.Linear(n_head * d_v, d_model, bias=False) self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) def forward(self, q, k, v, mask=None): d_k, d_v, n_head = self.d_k, self.d_v, self.n_head sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) residual = q # Pass through the pre-attention projection: b x lq x (n*dv) # Separate different heads: b x lq x n x dv q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) # Transpose for attention dot product: b x n x lq x dv q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) if mask is not None: # mask = mask.unsqueeze(1) # For head axis broadcasting. mask = mask.unsqueeze(1).unsqueeze(2) # For head axis broadcasting. q, attn = self.attention(q, k, v, mask=mask) #q (sz_b,n_head,N=len_q,d_k) #k (sz_b,n_head,N=len_k,d_k) #v (sz_b,n_head,N=len_v,d_v) # Transpose to move the head dimension back: b x lq x n x dv # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) #q (sz_b,len_q,n_head,N * d_k) q = self.dropout(self.fc(q)) q += residual q = self.layer_norm(q) return q, attn class PositionwiseFeedForward(nn.Module): ''' A two-feed-forward-layer module ''' def __init__(self, d_in, d_hid, dropout=0.1): super().__init__() self.w_1 = nn.Linear(d_in, d_hid) # position-wise self.w_2 = nn.Linear(d_hid, d_in) # position-wise self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) self.dropout = nn.Dropout(dropout) def forward(self, x): residual = x x = self.w_2(torch.relu(self.w_1(x))) x = self.dropout(x) x += residual x = self.layer_norm(x) return x class EncoderLayer(nn.Module): ''' Compose with two layers ''' def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): super(EncoderLayer, self).__init__() self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) def forward(self, enc_input, slf_attn_mask=None): enc_output, enc_slf_attn = self.slf_attn( enc_input, enc_input, enc_input, mask=slf_attn_mask) enc_output = self.pos_ffn(enc_output) return enc_output, enc_slf_attn class Encoder(nn.Module): ''' A encoder model with self attention mechanism. ''' def __init__( self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v, d_model, d_inner, pad_idx, dropout=0.1, n_position=200, scale_emb=False,embedding=None): super().__init__() if embedding is None: self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx) else: self.src_word_emb = embedding self.position_enc = PositionalEncoding(d_word_vec) self.dropout = nn.Dropout(p=dropout) self.layer_stack = nn.ModuleList([ EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)]) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.scale_emb = scale_emb self.d_model = d_model def forward(self, src_seq, src_mask, return_attns=False): enc_slf_attn_list = [] # -- Forward enc_output = self.src_word_emb(src_seq) if self.scale_emb: enc_output *= self.d_model ** 0.5 enc_output = self.dropout(self.position_enc(enc_output)) enc_output = self.layer_norm(enc_output) for enc_layer in self.layer_stack: enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask) enc_slf_attn_list += [enc_slf_attn] if return_attns else [] if return_attns: return enc_output, enc_slf_attn_list return enc_output, class BIDIBert(nn.Module): def __init__(self,n_src_vocab,d_word_vec, n_layers, n_head, d_k, d_v, d_model, d_inner, pad_idx): super(BIDIBert, self).__init__() self.encoder = Encoder(n_src_vocab,d_word_vec, n_layers, n_head, d_k, d_v, d_model, d_inner, pad_idx) def forward(self,inputs,attention_mask): out = self.encoder(inputs,attention_mask) return out from transformers import PreTrainedModel from transformers.modeling_outputs import MaskedLMOutput class BIDIBertForMaskedLM(nn.Module): def __init__(self, n_src_vocab,d_word_vec, n_layers, n_head, d_k, d_v, d_model, d_inner, pad_idx): super(BIDIBertForMaskedLM,self).__init__() self.electra = BIDIBert(n_src_vocab,d_word_vec, n_layers, n_head, d_k, d_v, d_model, d_inner, pad_idx) self.n_src_vocab = n_src_vocab self.LayerNorm = nn.LayerNorm(d_model) self.dense = nn.Linear(d_model, d_model) self.generator_lm_head = nn.Linear(d_model, n_src_vocab) # Initialize weights and apply final processing def forward( self, input_ids, attention_mask, labels= None, return_dict=None ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ generator_hidden_states, = self.electra( input_ids, attention_mask=attention_mask ) hidden_states = self.dense(generator_hidden_states) hidden_states = F.gelu(hidden_states) hidden_states = self.LayerNorm(hidden_states) generator_sequence_output = hidden_states prediction_scores = self.generator_lm_head(generator_sequence_output) loss = None # Masked language modeling softmax layer if labels is not None: loss_fct = nn.CrossEntropyLoss() # -100 index = padding token loss = loss_fct(prediction_scores.view(-1, self.n_src_vocab), labels.view(-1)) if not return_dict: output = (prediction_scores,) + (generator_hidden_states,) return ((loss,) + output) if loss is not None else output return MaskedLMOutput( loss=loss, logits=prediction_scores, hidden_states=generator_hidden_states, attentions=generator_hidden_states.attentions, )