1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- import sys
- sys.path.insert(0,'../..')
- from transformers import ElectraTokenizer
- base_model_name = "chinese-electra-180g-small-discriminator"
- model_name = "chinese-electra-180g-small-discriminator"
- save_name = "BIDIBert"
- tokenizer = ElectraTokenizer.from_pretrained(base_model_name)
- from dev.bert.BIDIBert import BIDIBert,BIDIBertForMaskedLM
- lm_model = BIDIBertForMaskedLM(len(tokenizer.get_vocab()),d_word_vec=128, n_layers=3, n_head=3, d_k=128, d_v=128,
- d_model=128, d_inner=128,pad_idx=tokenizer.pad_token_id)
- model_pretrain_path = "/data2/luojiehua/BIDINLTK/dev/bert/BIDIBert-bidi/checkpoint-4802000/pytorch_model.bin"
- import torch
- lm_model.load_state_dict(torch.load(model_pretrain_path))
- # #freeze the backbone
- # for param in lm_model.parameters():
- # param.requires_grad = False
- from transformers import LineByLineTextDataset
- # inputfile = r"G:\NLPDatasets\bidi_industry\bidi_industry_knowledge.txt"
- inputfile = r"/data2/luojiehua/BIDINLTK/dev/data/bidi_industry_knowledge.txt"
- # inputfile = r'/data2/luojiehua/BIDINLTK/dev/data/wikiouta.txt'
- dataset = LineByLineTextDataset(
- tokenizer=tokenizer,
- file_path=inputfile,
- block_size=512,
- )
- print("dataset loaded")
- from transformers import DataCollatorForLanguageModeling
- data_collator = DataCollatorForLanguageModeling(
- tokenizer=tokenizer, mlm=True, mlm_probability=0.15
- )
- from transformers import Trainer, TrainingArguments
- training_args = TrainingArguments(
- output_dir="./%s-bidi"%save_name,
- overwrite_output_dir=True,
- num_train_epochs=40,
- per_gpu_train_batch_size=3,
- save_steps=2000,
- learning_rate=5e-6,
- save_total_limit=2,
- )
- trainer = Trainer(
- model=lm_model,
- args=training_args,
- data_collator=data_collator,
- train_dataset=dataset,
- )
- trainer.train()
|