retrain.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import sys
  2. sys.path.insert(0,'../..')
  3. from transformers import ElectraForMaskedLM
  4. from transformers import ElectraModel,ElectraTokenizer,ElectraConfig
  5. base_model_name = "chinese-electra-180g-small-discriminator"
  6. model_name = "chinese-electra-180g-small-discriminator"
  7. # eleconfig = ElectraConfig.from_pretrained(model_name)
  8. # model = ElectraModel.from_pretrained(model_name)
  9. tokenizer = ElectraTokenizer.from_pretrained(base_model_name)
  10. lm_model = ElectraForMaskedLM.from_pretrained(model_name)
  11. #freeze the backbone
  12. for param in lm_model.electra.parameters():
  13. param.requires_grad = False
  14. print(lm_model.electra.num_parameters())
  15. from transformers import LineByLineTextDataset
  16. # inputfile = r"G:\NLPDatasets\bidi_industry\bidi_industry_knowledge.txt"
  17. inputfile = r"/data2/luojiehua/BIDINLTK/dev/data/bidi_industry_knowledge.txt"
  18. inputfile = r'/data2/luojiehua/BIDINLTK/dev/data/wikiouta.txt'
  19. dataset = LineByLineTextDataset(
  20. tokenizer=tokenizer,
  21. file_path=inputfile,
  22. block_size=512,
  23. )
  24. print("dataset loaded")
  25. from transformers import DataCollatorForLanguageModeling
  26. data_collator = DataCollatorForLanguageModeling(
  27. tokenizer=tokenizer, mlm=True, mlm_probability=0.15
  28. )
  29. from transformers import Trainer, TrainingArguments
  30. training_args = TrainingArguments(
  31. output_dir="./%s-bidi"%model_name,
  32. overwrite_output_dir=True,
  33. num_train_epochs=20,
  34. per_gpu_train_batch_size=3,
  35. save_steps=2000,
  36. save_total_limit=2,
  37. )
  38. trainer = Trainer(
  39. model=lm_model,
  40. args=training_args,
  41. data_collator=data_collator,
  42. train_dataset=dataset,
  43. )
  44. trainer.train()