trainer.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """
  2. 微调训练器
  3. """
  4. import os
  5. import logging
  6. from typing import Optional, Any
  7. from dataclasses import dataclass
  8. import torch
  9. from transformers import (
  10. TrainingArguments,
  11. Trainer,
  12. DataCollatorForLanguageModeling,
  13. AutoTokenizer,
  14. )
  15. from peft import PeftModel
  16. logger = logging.getLogger(__name__)
  17. @dataclass
  18. class FineTuneTrainer:
  19. """
  20. 微调训练器
  21. """
  22. model: Any
  23. tokenizer: AutoTokenizer
  24. config: Any
  25. train_dataset: Optional[Any] = None
  26. eval_dataset: Optional[Any] = None
  27. def __post_init__(self):
  28. self.training_args = None
  29. self.trainer = None
  30. @staticmethod
  31. def check_device():
  32. """检查可用的计算设备"""
  33. if torch.cuda.is_available():
  34. return "cuda"
  35. elif hasattr(torch, 'npu') and torch.npu.is_available():
  36. return "npu"
  37. else:
  38. return "cpu"
  39. def setup_training(
  40. self,
  41. output_dir: str = "./outputs",
  42. num_train_epochs: float = 3.0,
  43. per_device_train_batch_size: int = 1,
  44. gradient_accumulation_steps: int = 4,
  45. learning_rate: float = 2e-4,
  46. warmup_ratio: float = 0.03,
  47. weight_decay: float = 0.01,
  48. logging_steps: int = 10,
  49. save_steps: int = 100,
  50. eval_strategy: str = "no",
  51. save_total_limit: int = 3,
  52. fp16: bool = False,
  53. bf16: bool = True,
  54. **kwargs
  55. ):
  56. device = self.check_device()
  57. print(f"检测到设备:{device}")
  58. if device == "npu":
  59. os.environ.setdefault("ASCEND_LAUNCH_BLOCKING", "1")
  60. print("华为升腾 NPU 设备,使用 bf16 混合精度训练")
  61. use_fp16 = fp16 if device == "cuda" else False
  62. use_bf16 = bf16 if device in ["cuda", "npu"] else False
  63. if device == "cuda":
  64. print("NVIDIA CUDA 设备,使用 bf16/fp16 混合精度训练")
  65. self.training_args = TrainingArguments(
  66. output_dir=output_dir,
  67. num_train_epochs=num_train_epochs,
  68. per_device_train_batch_size=per_device_train_batch_size,
  69. gradient_accumulation_steps=gradient_accumulation_steps,
  70. learning_rate=learning_rate,
  71. warmup_ratio=warmup_ratio,
  72. weight_decay=weight_decay,
  73. logging_steps=logging_steps,
  74. save_steps=save_steps,
  75. eval_strategy=eval_strategy,
  76. save_total_limit=save_total_limit,
  77. fp16=use_fp16,
  78. bf16=use_bf16,
  79. optim="paged_adamw_32bit" if device == "cuda" else "adamw_torch",
  80. lr_scheduler_type="cosine",
  81. report_to="none",
  82. remove_unused_columns=False,
  83. **kwargs
  84. )
  85. # 数据 collator
  86. data_collator = DataCollatorForLanguageModeling(
  87. tokenizer=self.tokenizer,
  88. mlm=False,
  89. )
  90. # 创建 Trainer
  91. self.trainer = Trainer(
  92. model=self.model,
  93. args=self.training_args,
  94. train_dataset=self.train_dataset,
  95. eval_dataset=self.eval_dataset,
  96. data_collator=data_collator,
  97. )
  98. print("训练设置完成!")
  99. def train(self, resume_from_checkpoint: Optional[str] = None):
  100. """
  101. 开始训练
  102. Args:
  103. resume_from_checkpoint: 从检查点恢复训练
  104. """
  105. if self.trainer is None:
  106. raise ValueError("请先调用 setup_training() 设置训练参数")
  107. print("开始训练...")
  108. self.trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  109. print("训练完成!")
  110. def save_model(self, output_dir: Optional[str] = None):
  111. """
  112. 保存模型
  113. Args:
  114. output_dir: 输出目录
  115. """
  116. if output_dir is None:
  117. output_dir = self.training_args.output_dir
  118. print(f"保存模型到:{output_dir}")
  119. # 保存 LoRA 权重
  120. self.model.save_pretrained(output_dir)
  121. # 保存 tokenizer
  122. self.tokenizer.save_pretrained(output_dir)
  123. print("模型保存完成!")
  124. def push_to_hub(self, repo_id: str, **kwargs):
  125. """
  126. 推送模型到 HuggingFace Hub
  127. Args:
  128. repo_id: 仓库 ID
  129. """
  130. print(f"推送模型到 HuggingFace Hub: {repo_id}")
  131. # 保存并推送
  132. self.model.push_to_hub(repo_id, **kwargs)
  133. self.tokenizer.push_to_hub(repo_id, **kwargs)
  134. print("推送完成!")