model.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536
  1. from typing import Optional, List
  2. import torch
  3. import torchsnooper
  4. import torch.nn as nn
  5. from transformers import ElectraPreTrainedModel, ElectraModel
  6. from transformers.modeling_outputs import NextSentencePredictorOutput, BaseModelOutputWithPoolingAndCrossAttentions
  7. from transformers.models.bert.modeling_bert import BertOnlyNSPHead, BertPooler
  8. class ElectraNSPModel(ElectraPreTrainedModel):
  9. def __init__(self, config):
  10. super(ElectraNSPModel, self).__init__(config)
  11. self.electra = ElectraModel(config)
  12. self.cls = BertOnlyNSPHead(config)
  13. self.pooler = BertPooler(config)
  14. self.init_weights()
  15. # @torchsnooper.snoop()
  16. def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
  17. outputs = self.electra(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
  18. last_output = outputs[0]
  19. pooled_output = self.pooler(last_output)
  20. seq_relationship_scores = self.cls(pooled_output)
  21. next_sentence_loss = None
  22. if labels is not None:
  23. loss_fct = nn.CrossEntropyLoss()
  24. next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
  25. # next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels)
  26. return NextSentencePredictorOutput(
  27. loss=next_sentence_loss,
  28. logits=seq_relationship_scores,
  29. hidden_states=outputs.hidden_states,
  30. attentions=outputs.attentions,
  31. )