| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- """
- 训练回调函数
- """
- from typing import Dict, Any
- from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
- class TrainingCallback(TrainerCallback):
- """
- 训练回调类
- """
-
- def __init__(self):
- self.best_loss = float("inf")
- self.training_history = []
-
- def on_log(
- self,
- args: TrainingArguments,
- state: TrainerState,
- control: TrainerControl,
- logs: Dict[str, float],
- **kwargs
- ):
- """日志回调"""
- if logs is not None:
- self.training_history.append(logs)
- print(f"Step {state.global_step}: loss={logs.get('loss', 'N/A'):.4f}")
-
- def on_epoch_end(
- self,
- args: TrainingArguments,
- state: TrainerState,
- control: TrainerControl,
- **kwargs
- ):
- """epoch 结束回调"""
- print(f"Epoch {state.epoch:.2f} 完成")
-
- def on_save(
- self,
- args: TrainingArguments,
- state: TrainerState,
- control: TrainerControl,
- **kwargs
- ):
- """保存回调"""
- print(f"模型已保存到 step {state.global_step}")
-
- def on_train_end(
- self,
- args: TrainingArguments,
- state: TrainerState,
- control: TrainerControl,
- **kwargs
- ):
- """训练结束回调"""
- print("训练完成!")
- print(f"总步数:{state.global_step}")
- print(f"最佳 loss: {self.best_loss:.4f}")
- class EarlyStoppingCallback(TrainerCallback):
- """
- 早停回调
- """
-
- def __init__(self, early_stopping_patience: int = 3, early_stopping_threshold: float = 0.01):
- self.patience = early_stopping_patience
- self.threshold = early_stopping_threshold
- self.best_loss = float("inf")
- self.patience_counter = 0
-
- def on_evaluate(
- self,
- args: TrainingArguments,
- state: TrainerState,
- control: TrainerControl,
- metrics: Dict[str, float],
- **kwargs
- ):
- eval_loss = metrics.get("eval_loss", float("inf"))
-
- if eval_loss < self.best_loss - self.threshold:
- self.best_loss = eval_loss
- self.patience_counter = 0
- else:
- self.patience_counter += 1
-
- if self.patience_counter >= self.patience:
- print(f"早停触发:{self.patience} 个 epoch 没有改善")
- control.should_training_stop = True
|