123456789101112131415161718192021222324252627282930313233343536 |
- from typing import Optional, List
- import torch
- import torchsnooper
- import torch.nn as nn
- from transformers import ElectraPreTrainedModel, ElectraModel
- from transformers.modeling_outputs import NextSentencePredictorOutput, BaseModelOutputWithPoolingAndCrossAttentions
- from transformers.models.bert.modeling_bert import BertOnlyNSPHead, BertPooler
- class ElectraNSPModel(ElectraPreTrainedModel):
- def __init__(self, config):
- super(ElectraNSPModel, self).__init__(config)
- self.electra = ElectraModel(config)
- self.cls = BertOnlyNSPHead(config)
- self.pooler = BertPooler(config)
- self.init_weights()
- # @torchsnooper.snoop()
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
- outputs = self.electra(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
- last_output = outputs[0]
- pooled_output = self.pooler(last_output)
- seq_relationship_scores = self.cls(pooled_output)
- next_sentence_loss = None
- if labels is not None:
- loss_fct = nn.CrossEntropyLoss()
- next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
- # next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels)
- return NextSentencePredictorOutput(
- loss=next_sentence_loss,
- logits=seq_relationship_scores,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
|