1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- import sys
- sys.path.insert(0,'../..')
- from transformers import ElectraForMaskedLM
- from transformers import ElectraModel,ElectraTokenizer,ElectraConfig
- base_model_name = "chinese-electra-180g-small-discriminator"
- model_name = "chinese-electra-180g-small-discriminator"
- # eleconfig = ElectraConfig.from_pretrained(model_name)
- # model = ElectraModel.from_pretrained(model_name)
- tokenizer = ElectraTokenizer.from_pretrained(base_model_name)
- lm_model = ElectraForMaskedLM.from_pretrained(model_name)
- #freeze the backbone
- for param in lm_model.electra.parameters():
- param.requires_grad = False
- print(lm_model.electra.num_parameters())
- from transformers import LineByLineTextDataset
- # inputfile = r"G:\NLPDatasets\bidi_industry\bidi_industry_knowledge.txt"
- inputfile = r"/data2/luojiehua/BIDINLTK/dev/data/bidi_industry_knowledge.txt"
- inputfile = r'/data2/luojiehua/BIDINLTK/dev/data/wikiouta.txt'
- dataset = LineByLineTextDataset(
- tokenizer=tokenizer,
- file_path=inputfile,
- block_size=512,
- )
- print("dataset loaded")
- from transformers import DataCollatorForLanguageModeling
- data_collator = DataCollatorForLanguageModeling(
- tokenizer=tokenizer, mlm=True, mlm_probability=0.15
- )
- from transformers import Trainer, TrainingArguments
- training_args = TrainingArguments(
- output_dir="./%s-bidi"%model_name,
- overwrite_output_dir=True,
- num_train_epochs=20,
- per_gpu_train_batch_size=3,
- save_steps=2000,
- save_total_limit=2,
- )
- trainer = Trainer(
- model=lm_model,
- args=training_args,
- data_collator=data_collator,
- train_dataset=dataset,
- )
- trainer.train()
|