123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import numpy as np
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- class PositionalEncoding(nn.Module):
- def __init__(self,dim_hid):
- super(PositionalEncoding,self).__init__()
- base_array = np.array([np.power(10000,2*(hid_j//2)/dim_hid) for hid_j in range(dim_hid)])
- self.base_tensor = torch.from_numpy(base_array).to(torch.float32).to(device) #[1,D]
- def forward(self,x):
- # x(B,N,d)
- B,N,d = x.shape
- pos = torch.arange(N).unsqueeze(-1).to(torch.float32).to(device) #[N,1]
- pos = pos/self.base_tensor
- pos = pos.unsqueeze(0)
- pos[:,:,0::2] = torch.sin(pos[:,:,0::2])
- pos[:,:,1::2] = torch.cos(pos[:,:,1::2])
- return x+pos
- class ScaledDotProductAttention(nn.Module):
- ''' Scaled Dot-Product Attention '''
- def __init__(self, temperature, attn_dropout=0.1):
- super().__init__()
- self.temperature = temperature
- self.dropout = nn.Dropout(attn_dropout)
- def forward(self, q, k, v, mask=None):
- # print(q.shape,k.shape)
- attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
- if mask is not None:
- attn = attn.masked_fill(mask == 0, -1e9)
- attn = self.dropout(torch.softmax(attn, dim=-1))
- output = torch.matmul(attn, v)
- return output, attn
- class MultiHeadAttention(nn.Module):
- ''' Multi-Head Attention module '''
- def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
- super().__init__()
- self.n_head = n_head
- self.d_k = d_k
- self.d_v = d_v
- self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
- self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
- self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
- self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
- self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
- self.dropout = nn.Dropout(dropout)
- self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
- def forward(self, q, k, v, mask=None):
- d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
- sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
- residual = q
- # Pass through the pre-attention projection: b x lq x (n*dv)
- # Separate different heads: b x lq x n x dv
- q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
- k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
- v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
- # Transpose for attention dot product: b x n x lq x dv
- q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
- if mask is not None:
- # mask = mask.unsqueeze(1) # For head axis broadcasting.
- mask = mask.unsqueeze(1).unsqueeze(2) # For head axis broadcasting.
- q, attn = self.attention(q, k, v, mask=mask)
- #q (sz_b,n_head,N=len_q,d_k)
- #k (sz_b,n_head,N=len_k,d_k)
- #v (sz_b,n_head,N=len_v,d_v)
- # Transpose to move the head dimension back: b x lq x n x dv
- # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
- q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
- #q (sz_b,len_q,n_head,N * d_k)
- q = self.dropout(self.fc(q))
- q += residual
- q = self.layer_norm(q)
- return q, attn
- class PositionwiseFeedForward(nn.Module):
- ''' A two-feed-forward-layer module '''
- def __init__(self, d_in, d_hid, dropout=0.1):
- super().__init__()
- self.w_1 = nn.Linear(d_in, d_hid) # position-wise
- self.w_2 = nn.Linear(d_hid, d_in) # position-wise
- self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
- self.dropout = nn.Dropout(dropout)
- def forward(self, x):
- residual = x
- x = self.w_2(torch.relu(self.w_1(x)))
- x = self.dropout(x)
- x += residual
- x = self.layer_norm(x)
- return x
- class EncoderLayer(nn.Module):
- ''' Compose with two layers '''
- def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
- super(EncoderLayer, self).__init__()
- self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
- self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
- def forward(self, enc_input, slf_attn_mask=None):
- enc_output, enc_slf_attn = self.slf_attn(
- enc_input, enc_input, enc_input, mask=slf_attn_mask)
- enc_output = self.pos_ffn(enc_output)
- return enc_output, enc_slf_attn
- class Encoder(nn.Module):
- ''' A encoder model with self attention mechanism. '''
- def __init__(
- self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
- d_model, d_inner, pad_idx, dropout=0.1, n_position=200, scale_emb=False,embedding=None):
- super().__init__()
- if embedding is None:
- self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
- else:
- self.src_word_emb = embedding
- self.position_enc = PositionalEncoding(d_word_vec)
- self.dropout = nn.Dropout(p=dropout)
- self.layer_stack = nn.ModuleList([
- EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
- for _ in range(n_layers)])
- self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
- self.scale_emb = scale_emb
- self.d_model = d_model
- def forward(self, src_seq, src_mask, return_attns=False):
- enc_slf_attn_list = []
- # -- Forward
- enc_output = self.src_word_emb(src_seq)
- if self.scale_emb:
- enc_output *= self.d_model ** 0.5
- enc_output = self.dropout(self.position_enc(enc_output))
- enc_output = self.layer_norm(enc_output)
- for enc_layer in self.layer_stack:
- enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
- enc_slf_attn_list += [enc_slf_attn] if return_attns else []
- if return_attns:
- return enc_output, enc_slf_attn_list
- return enc_output,
- class BIDIBert(nn.Module):
- def __init__(self,n_src_vocab,d_word_vec, n_layers, n_head, d_k, d_v,
- d_model, d_inner, pad_idx):
- super(BIDIBert, self).__init__()
- self.encoder = Encoder(n_src_vocab,d_word_vec, n_layers, n_head, d_k, d_v,
- d_model, d_inner, pad_idx)
- def forward(self,inputs,attention_mask):
- out = self.encoder(inputs,attention_mask)
- return out
- from transformers import PreTrainedModel
- from transformers.modeling_outputs import MaskedLMOutput
- class BIDIBertForMaskedLM(nn.Module):
- def __init__(self, n_src_vocab,d_word_vec, n_layers, n_head, d_k, d_v,
- d_model, d_inner, pad_idx):
- super(BIDIBertForMaskedLM,self).__init__()
- self.electra = BIDIBert(n_src_vocab,d_word_vec, n_layers, n_head, d_k, d_v,
- d_model, d_inner, pad_idx)
- self.n_src_vocab = n_src_vocab
- self.LayerNorm = nn.LayerNorm(d_model)
- self.dense = nn.Linear(d_model, d_model)
- self.generator_lm_head = nn.Linear(d_model, n_src_vocab)
- # Initialize weights and apply final processing
- def forward(
- self,
- input_ids,
- attention_mask,
- labels= None,
- return_dict=None
- ):
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- """
- generator_hidden_states, = self.electra(
- input_ids,
- attention_mask=attention_mask
- )
- hidden_states = self.dense(generator_hidden_states)
- hidden_states = F.gelu(hidden_states)
- hidden_states = self.LayerNorm(hidden_states)
- generator_sequence_output = hidden_states
- prediction_scores = self.generator_lm_head(generator_sequence_output)
- loss = None
- # Masked language modeling softmax layer
- if labels is not None:
- loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
- loss = loss_fct(prediction_scores.view(-1, self.n_src_vocab), labels.view(-1))
- if not return_dict:
- output = (prediction_scores,) + (generator_hidden_states,)
- return ((loss,) + output) if loss is not None else output
- return MaskedLMOutput(
- loss=loss,
- logits=prediction_scores,
- hidden_states=generator_hidden_states,
- attentions=generator_hidden_states.attentions,
- )
|