BIDIBert_train.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import sys
  2. sys.path.insert(0,'../..')
  3. from transformers import ElectraTokenizer
  4. base_model_name = "chinese-electra-180g-small-discriminator"
  5. model_name = "chinese-electra-180g-small-discriminator"
  6. save_name = "BIDIBert"
  7. tokenizer = ElectraTokenizer.from_pretrained(base_model_name)
  8. from dev.bert.BIDIBert import BIDIBert,BIDIBertForMaskedLM
  9. lm_model = BIDIBertForMaskedLM(len(tokenizer.get_vocab()),d_word_vec=128, n_layers=3, n_head=3, d_k=128, d_v=128,
  10. d_model=128, d_inner=128,pad_idx=tokenizer.pad_token_id)
  11. model_pretrain_path = "/data2/luojiehua/BIDINLTK/dev/bert/BIDIBert-bidi/checkpoint-4802000/pytorch_model.bin"
  12. import torch
  13. lm_model.load_state_dict(torch.load(model_pretrain_path))
  14. # #freeze the backbone
  15. # for param in lm_model.parameters():
  16. # param.requires_grad = False
  17. from transformers import LineByLineTextDataset
  18. # inputfile = r"G:\NLPDatasets\bidi_industry\bidi_industry_knowledge.txt"
  19. inputfile = r"/data2/luojiehua/BIDINLTK/dev/data/bidi_industry_knowledge.txt"
  20. # inputfile = r'/data2/luojiehua/BIDINLTK/dev/data/wikiouta.txt'
  21. dataset = LineByLineTextDataset(
  22. tokenizer=tokenizer,
  23. file_path=inputfile,
  24. block_size=512,
  25. )
  26. print("dataset loaded")
  27. from transformers import DataCollatorForLanguageModeling
  28. data_collator = DataCollatorForLanguageModeling(
  29. tokenizer=tokenizer, mlm=True, mlm_probability=0.15
  30. )
  31. from transformers import Trainer, TrainingArguments
  32. training_args = TrainingArguments(
  33. output_dir="./%s-bidi"%save_name,
  34. overwrite_output_dir=True,
  35. num_train_epochs=40,
  36. per_gpu_train_batch_size=3,
  37. save_steps=2000,
  38. learning_rate=5e-6,
  39. save_total_limit=2,
  40. )
  41. trainer = Trainer(
  42. model=lm_model,
  43. args=training_args,
  44. data_collator=data_collator,
  45. train_dataset=dataset,
  46. )
  47. trainer.train()