import sys sys.path.insert(0,'../..') from transformers import ElectraForMaskedLM from transformers import ElectraModel,ElectraTokenizer,ElectraConfig base_model_name = "chinese-electra-180g-small-discriminator" model_name = "chinese-electra-180g-small-discriminator" # eleconfig = ElectraConfig.from_pretrained(model_name) # model = ElectraModel.from_pretrained(model_name) tokenizer = ElectraTokenizer.from_pretrained(base_model_name) lm_model = ElectraForMaskedLM.from_pretrained(model_name) #freeze the backbone for param in lm_model.electra.parameters(): param.requires_grad = False print(lm_model.electra.num_parameters()) from transformers import LineByLineTextDataset # inputfile = r"G:\NLPDatasets\bidi_industry\bidi_industry_knowledge.txt" inputfile = r"/data2/luojiehua/BIDINLTK/dev/data/bidi_industry_knowledge.txt" inputfile = r'/data2/luojiehua/BIDINLTK/dev/data/wikiouta.txt' dataset = LineByLineTextDataset( tokenizer=tokenizer, file_path=inputfile, block_size=512, ) print("dataset loaded") from transformers import DataCollatorForLanguageModeling data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=0.15 ) from transformers import Trainer, TrainingArguments training_args = TrainingArguments( output_dir="./%s-bidi"%model_name, overwrite_output_dir=True, num_train_epochs=20, per_gpu_train_batch_size=3, save_steps=2000, save_total_limit=2, ) trainer = Trainer( model=lm_model, args=training_args, data_collator=data_collator, train_dataset=dataset, ) trainer.train()