test_training_args.py 1013 B

123456789101112131415161718192021222324252627282930313233343536
  1. """
  2. 测试 TrainingArguments 参数
  3. """
  4. from transformers import TrainingArguments
  5. import torch
  6. print("测试 TrainingArguments 参数...")
  7. try:
  8. args = TrainingArguments(
  9. output_dir="./test_output",
  10. num_train_epochs=3,
  11. per_device_train_batch_size=1,
  12. gradient_accumulation_steps=4,
  13. learning_rate=2e-4,
  14. warmup_ratio=0.03,
  15. weight_decay=0.01,
  16. logging_steps=10,
  17. save_steps=50,
  18. eval_strategy="no",
  19. save_total_limit=3,
  20. fp16=torch.cuda.is_available(),
  21. optim="paged_adamw_32bit",
  22. lr_scheduler_type="cosine",
  23. report_to="none",
  24. remove_unused_columns=False,
  25. )
  26. print("✓ TrainingArguments 参数验证通过!")
  27. print(f" 输出目录:{args.output_dir}")
  28. print(f" 训练轮数:{args.num_train_epochs}")
  29. print(f" FP16: {args.fp16}")
  30. except Exception as e:
  31. print(f"✗ TrainingArguments 参数验证失败:{e}")
  32. import traceback
  33. traceback.print_exc()