""" 训练回调函数 """ from typing import Dict, Any from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments class TrainingCallback(TrainerCallback): """ 训练回调类 """ def __init__(self): self.best_loss = float("inf") self.training_history = [] def on_log( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs: Dict[str, float], **kwargs ): """日志回调""" if logs is not None: self.training_history.append(logs) print(f"Step {state.global_step}: loss={logs.get('loss', 'N/A'):.4f}") def on_epoch_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs ): """epoch 结束回调""" print(f"Epoch {state.epoch:.2f} 完成") def on_save( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs ): """保存回调""" print(f"模型已保存到 step {state.global_step}") def on_train_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs ): """训练结束回调""" print("训练完成!") print(f"总步数:{state.global_step}") print(f"最佳 loss: {self.best_loss:.4f}") class EarlyStoppingCallback(TrainerCallback): """ 早停回调 """ def __init__(self, early_stopping_patience: int = 3, early_stopping_threshold: float = 0.01): self.patience = early_stopping_patience self.threshold = early_stopping_threshold self.best_loss = float("inf") self.patience_counter = 0 def on_evaluate( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics: Dict[str, float], **kwargs ): eval_loss = metrics.get("eval_loss", float("inf")) if eval_loss < self.best_loss - self.threshold: self.best_loss = eval_loss self.patience_counter = 0 else: self.patience_counter += 1 if self.patience_counter >= self.patience: print(f"早停触发:{self.patience} 个 epoch 没有改善") control.should_training_stop = True