BIDIBert.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import numpy as np
  5. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  6. class PositionalEncoding(nn.Module):
  7. def __init__(self,dim_hid):
  8. super(PositionalEncoding,self).__init__()
  9. base_array = np.array([np.power(10000,2*(hid_j//2)/dim_hid) for hid_j in range(dim_hid)])
  10. self.base_tensor = torch.from_numpy(base_array).to(torch.float32).to(device) #[1,D]
  11. def forward(self,x):
  12. # x(B,N,d)
  13. B,N,d = x.shape
  14. pos = torch.arange(N).unsqueeze(-1).to(torch.float32).to(device) #[N,1]
  15. pos = pos/self.base_tensor
  16. pos = pos.unsqueeze(0)
  17. pos[:,:,0::2] = torch.sin(pos[:,:,0::2])
  18. pos[:,:,1::2] = torch.cos(pos[:,:,1::2])
  19. return x+pos
  20. class ScaledDotProductAttention(nn.Module):
  21. ''' Scaled Dot-Product Attention '''
  22. def __init__(self, temperature, attn_dropout=0.1):
  23. super().__init__()
  24. self.temperature = temperature
  25. self.dropout = nn.Dropout(attn_dropout)
  26. def forward(self, q, k, v, mask=None):
  27. # print(q.shape,k.shape)
  28. attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
  29. if mask is not None:
  30. attn = attn.masked_fill(mask == 0, -1e9)
  31. attn = self.dropout(torch.softmax(attn, dim=-1))
  32. output = torch.matmul(attn, v)
  33. return output, attn
  34. class MultiHeadAttention(nn.Module):
  35. ''' Multi-Head Attention module '''
  36. def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
  37. super().__init__()
  38. self.n_head = n_head
  39. self.d_k = d_k
  40. self.d_v = d_v
  41. self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
  42. self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
  43. self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
  44. self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
  45. self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
  46. self.dropout = nn.Dropout(dropout)
  47. self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
  48. def forward(self, q, k, v, mask=None):
  49. d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
  50. sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
  51. residual = q
  52. # Pass through the pre-attention projection: b x lq x (n*dv)
  53. # Separate different heads: b x lq x n x dv
  54. q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
  55. k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
  56. v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
  57. # Transpose for attention dot product: b x n x lq x dv
  58. q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
  59. if mask is not None:
  60. # mask = mask.unsqueeze(1) # For head axis broadcasting.
  61. mask = mask.unsqueeze(1).unsqueeze(2) # For head axis broadcasting.
  62. q, attn = self.attention(q, k, v, mask=mask)
  63. #q (sz_b,n_head,N=len_q,d_k)
  64. #k (sz_b,n_head,N=len_k,d_k)
  65. #v (sz_b,n_head,N=len_v,d_v)
  66. # Transpose to move the head dimension back: b x lq x n x dv
  67. # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
  68. q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
  69. #q (sz_b,len_q,n_head,N * d_k)
  70. q = self.dropout(self.fc(q))
  71. q += residual
  72. q = self.layer_norm(q)
  73. return q, attn
  74. class PositionwiseFeedForward(nn.Module):
  75. ''' A two-feed-forward-layer module '''
  76. def __init__(self, d_in, d_hid, dropout=0.1):
  77. super().__init__()
  78. self.w_1 = nn.Linear(d_in, d_hid) # position-wise
  79. self.w_2 = nn.Linear(d_hid, d_in) # position-wise
  80. self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
  81. self.dropout = nn.Dropout(dropout)
  82. def forward(self, x):
  83. residual = x
  84. x = self.w_2(torch.relu(self.w_1(x)))
  85. x = self.dropout(x)
  86. x += residual
  87. x = self.layer_norm(x)
  88. return x
  89. class EncoderLayer(nn.Module):
  90. ''' Compose with two layers '''
  91. def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
  92. super(EncoderLayer, self).__init__()
  93. self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
  94. self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
  95. def forward(self, enc_input, slf_attn_mask=None):
  96. enc_output, enc_slf_attn = self.slf_attn(
  97. enc_input, enc_input, enc_input, mask=slf_attn_mask)
  98. enc_output = self.pos_ffn(enc_output)
  99. return enc_output, enc_slf_attn
  100. class Encoder(nn.Module):
  101. ''' A encoder model with self attention mechanism. '''
  102. def __init__(
  103. self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
  104. d_model, d_inner, pad_idx, dropout=0.1, n_position=200, scale_emb=False,embedding=None):
  105. super().__init__()
  106. if embedding is None:
  107. self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
  108. else:
  109. self.src_word_emb = embedding
  110. self.position_enc = PositionalEncoding(d_word_vec)
  111. self.dropout = nn.Dropout(p=dropout)
  112. self.layer_stack = nn.ModuleList([
  113. EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
  114. for _ in range(n_layers)])
  115. self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
  116. self.scale_emb = scale_emb
  117. self.d_model = d_model
  118. def forward(self, src_seq, src_mask, return_attns=False):
  119. enc_slf_attn_list = []
  120. # -- Forward
  121. enc_output = self.src_word_emb(src_seq)
  122. if self.scale_emb:
  123. enc_output *= self.d_model ** 0.5
  124. enc_output = self.dropout(self.position_enc(enc_output))
  125. enc_output = self.layer_norm(enc_output)
  126. for enc_layer in self.layer_stack:
  127. enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
  128. enc_slf_attn_list += [enc_slf_attn] if return_attns else []
  129. if return_attns:
  130. return enc_output, enc_slf_attn_list
  131. return enc_output,
  132. class BIDIBert(nn.Module):
  133. def __init__(self,n_src_vocab,d_word_vec, n_layers, n_head, d_k, d_v,
  134. d_model, d_inner, pad_idx):
  135. super(BIDIBert, self).__init__()
  136. self.encoder = Encoder(n_src_vocab,d_word_vec, n_layers, n_head, d_k, d_v,
  137. d_model, d_inner, pad_idx)
  138. def forward(self,inputs,attention_mask):
  139. out = self.encoder(inputs,attention_mask)
  140. return out
  141. from transformers import PreTrainedModel
  142. from transformers.modeling_outputs import MaskedLMOutput
  143. class BIDIBertForMaskedLM(nn.Module):
  144. def __init__(self, n_src_vocab,d_word_vec, n_layers, n_head, d_k, d_v,
  145. d_model, d_inner, pad_idx):
  146. super(BIDIBertForMaskedLM,self).__init__()
  147. self.electra = BIDIBert(n_src_vocab,d_word_vec, n_layers, n_head, d_k, d_v,
  148. d_model, d_inner, pad_idx)
  149. self.n_src_vocab = n_src_vocab
  150. self.LayerNorm = nn.LayerNorm(d_model)
  151. self.dense = nn.Linear(d_model, d_model)
  152. self.generator_lm_head = nn.Linear(d_model, n_src_vocab)
  153. # Initialize weights and apply final processing
  154. def forward(
  155. self,
  156. input_ids,
  157. attention_mask,
  158. labels= None,
  159. return_dict=None
  160. ):
  161. r"""
  162. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  163. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  164. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  165. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  166. """
  167. generator_hidden_states, = self.electra(
  168. input_ids,
  169. attention_mask=attention_mask
  170. )
  171. hidden_states = self.dense(generator_hidden_states)
  172. hidden_states = F.gelu(hidden_states)
  173. hidden_states = self.LayerNorm(hidden_states)
  174. generator_sequence_output = hidden_states
  175. prediction_scores = self.generator_lm_head(generator_sequence_output)
  176. loss = None
  177. # Masked language modeling softmax layer
  178. if labels is not None:
  179. loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
  180. loss = loss_fct(prediction_scores.view(-1, self.n_src_vocab), labels.view(-1))
  181. if not return_dict:
  182. output = (prediction_scores,) + (generator_hidden_states,)
  183. return ((loss,) + output) if loss is not None else output
  184. return MaskedLMOutput(
  185. loss=loss,
  186. logits=prediction_scores,
  187. hidden_states=generator_hidden_states,
  188. attentions=generator_hidden_states.attentions,
  189. )