pretrain.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import os
  5. from typing import Dict
  6. from transformers import AlbertConfig
  7. import logging
  8. logging.basicConfig(level=logging.INFO,format='%(asctime)s %(name)s %(levelname)s %(message)s')
  9. from transformers import BertTokenizerFast,PreTrainedTokenizer
  10. tokenizer = BertTokenizerFast.from_pretrained(".", max_len=512)
  11. config = AlbertConfig(
  12. vocab_size = len(tokenizer.get_vocab()),
  13. embedding_size = 256,
  14. hidden_size = 768,
  15. num_hidden_layers = 6,
  16. num_attention_heads = 12,
  17. intermediate_size = 3072,
  18. hidden_act = "gelu",
  19. hidden_dropout_prob = 0.1,
  20. attention_probs_dropout_prob = 0.1,
  21. max_position_embeddings = 512,
  22. type_vocab_size = 2,
  23. initializer_range = 0.02,
  24. layer_norm_eps = 1e-12,
  25. )
  26. print("tokenizer vocab length",len(tokenizer.get_vocab()))
  27. from transformers import AlbertForMaskedLM
  28. model = AlbertForMaskedLM(config=config)
  29. print("num_parameters",model.num_parameters())
  30. # => 8554575个参数
  31. from transformers import LineByLineTextDataset
  32. class BertPretrinDataset(LineByLineTextDataset):
  33. """
  34. This will be superseded by a framework-agnostic approach soon.
  35. """
  36. def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, pk_dir: str, pk_size=50000):
  37. if os.path.isfile(file_path) is False:
  38. raise ValueError(f"Input file path {file_path} not found")
  39. # Here, we do not cache the features, operating under the assumption
  40. # that we will soon use fast multithreaded tokenizers from the
  41. # `tokenizers` repo everywhere =)
  42. logging.info(f"Creating features from dataset file at {file_path}")
  43. with open(file_path, encoding="utf-8") as f:
  44. lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
  45. batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)
  46. self.examples = batch_encoding["input_ids"]
  47. self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
  48. def __len__(self):
  49. return len(self.examples)
  50. def __getitem__(self, i) -> Dict[str, torch.tensor]:
  51. return self.examples[i]
  52. dataset = LineByLineTextDataset(
  53. tokenizer=tokenizer,
  54. file_path=r"G:\NLPDatasets\lcsts\train.trg1",
  55. block_size=25,
  56. )
  57. print("dataset loaded")
  58. from transformers import DataCollatorForLanguageModeling
  59. data_collator = DataCollatorForLanguageModeling(
  60. tokenizer=tokenizer, mlm=True, mlm_probability=0.15
  61. )
  62. from transformers import Trainer, TrainingArguments
  63. training_args = TrainingArguments(
  64. output_dir="./lunyuAlbert",
  65. overwrite_output_dir=True,
  66. num_train_epochs=20,
  67. per_gpu_train_batch_size=3,
  68. save_steps=2000,
  69. save_total_limit=2,
  70. )
  71. trainer = Trainer(
  72. model=model,
  73. args=training_args,
  74. data_collator=data_collator,
  75. train_dataset=dataset,
  76. )
  77. trainer.train()