callbacks.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. """
  2. 训练回调函数
  3. """
  4. from typing import Dict, Any
  5. from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
  6. class TrainingCallback(TrainerCallback):
  7. """
  8. 训练回调类
  9. """
  10. def __init__(self):
  11. self.best_loss = float("inf")
  12. self.training_history = []
  13. def on_log(
  14. self,
  15. args: TrainingArguments,
  16. state: TrainerState,
  17. control: TrainerControl,
  18. logs: Dict[str, float],
  19. **kwargs
  20. ):
  21. """日志回调"""
  22. if logs is not None:
  23. self.training_history.append(logs)
  24. print(f"Step {state.global_step}: loss={logs.get('loss', 'N/A'):.4f}")
  25. def on_epoch_end(
  26. self,
  27. args: TrainingArguments,
  28. state: TrainerState,
  29. control: TrainerControl,
  30. **kwargs
  31. ):
  32. """epoch 结束回调"""
  33. print(f"Epoch {state.epoch:.2f} 完成")
  34. def on_save(
  35. self,
  36. args: TrainingArguments,
  37. state: TrainerState,
  38. control: TrainerControl,
  39. **kwargs
  40. ):
  41. """保存回调"""
  42. print(f"模型已保存到 step {state.global_step}")
  43. def on_train_end(
  44. self,
  45. args: TrainingArguments,
  46. state: TrainerState,
  47. control: TrainerControl,
  48. **kwargs
  49. ):
  50. """训练结束回调"""
  51. print("训练完成!")
  52. print(f"总步数:{state.global_step}")
  53. print(f"最佳 loss: {self.best_loss:.4f}")
  54. class EarlyStoppingCallback(TrainerCallback):
  55. """
  56. 早停回调
  57. """
  58. def __init__(self, early_stopping_patience: int = 3, early_stopping_threshold: float = 0.01):
  59. self.patience = early_stopping_patience
  60. self.threshold = early_stopping_threshold
  61. self.best_loss = float("inf")
  62. self.patience_counter = 0
  63. def on_evaluate(
  64. self,
  65. args: TrainingArguments,
  66. state: TrainerState,
  67. control: TrainerControl,
  68. metrics: Dict[str, float],
  69. **kwargs
  70. ):
  71. eval_loss = metrics.get("eval_loss", float("inf"))
  72. if eval_loss < self.best_loss - self.threshold:
  73. self.best_loss = eval_loss
  74. self.patience_counter = 0
  75. else:
  76. self.patience_counter += 1
  77. if self.patience_counter >= self.patience:
  78. print(f"早停触发:{self.patience} 个 epoch 没有改善")
  79. control.should_training_stop = True