import sys sys.path.insert(0,'../..') from transformers import ElectraTokenizer base_model_name = "chinese-electra-180g-small-discriminator" model_name = "chinese-electra-180g-small-discriminator" save_name = "BIDIBert" tokenizer = ElectraTokenizer.from_pretrained(base_model_name) from dev.bert.BIDIBert import BIDIBert,BIDIBertForMaskedLM lm_model = BIDIBertForMaskedLM(len(tokenizer.get_vocab()),d_word_vec=128, n_layers=3, n_head=3, d_k=128, d_v=128, d_model=128, d_inner=128,pad_idx=tokenizer.pad_token_id) model_pretrain_path = "/data2/luojiehua/BIDINLTK/dev/bert/BIDIBert-bidi/checkpoint-4802000/pytorch_model.bin" import torch lm_model.load_state_dict(torch.load(model_pretrain_path)) # #freeze the backbone # for param in lm_model.parameters(): # param.requires_grad = False from transformers import LineByLineTextDataset # inputfile = r"G:\NLPDatasets\bidi_industry\bidi_industry_knowledge.txt" inputfile = r"/data2/luojiehua/BIDINLTK/dev/data/bidi_industry_knowledge.txt" # inputfile = r'/data2/luojiehua/BIDINLTK/dev/data/wikiouta.txt' dataset = LineByLineTextDataset( tokenizer=tokenizer, file_path=inputfile, block_size=512, ) print("dataset loaded") from transformers import DataCollatorForLanguageModeling data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=0.15 ) from transformers import Trainer, TrainingArguments training_args = TrainingArguments( output_dir="./%s-bidi"%save_name, overwrite_output_dir=True, num_train_epochs=40, per_gpu_train_batch_size=3, save_steps=2000, learning_rate=5e-6, save_total_limit=2, ) trainer = Trainer( model=lm_model, args=training_args, data_collator=data_collator, train_dataset=dataset, ) trainer.train()