from typing import Optional, List import torch import torchsnooper import torch.nn as nn from transformers import ElectraPreTrainedModel, ElectraModel from transformers.modeling_outputs import NextSentencePredictorOutput, BaseModelOutputWithPoolingAndCrossAttentions from transformers.models.bert.modeling_bert import BertOnlyNSPHead, BertPooler class ElectraNSPModel(ElectraPreTrainedModel): def __init__(self, config): super(ElectraNSPModel, self).__init__(config) self.electra = ElectraModel(config) self.cls = BertOnlyNSPHead(config) self.pooler = BertPooler(config) self.init_weights() # @torchsnooper.snoop() def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): outputs = self.electra(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) last_output = outputs[0] pooled_output = self.pooler(last_output) seq_relationship_scores = self.cls(pooled_output) next_sentence_loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) # next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels) return NextSentencePredictorOutput( loss=next_sentence_loss, logits=seq_relationship_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )