helpers.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. """
  2. 辅助工具函数
  3. """
  4. import os
  5. import time
  6. import torch
  7. from typing import Dict, Any
  8. def setup_environment(seed: int = 42):
  9. """
  10. 设置随机种子和环境变量
  11. Args:
  12. seed: 随机种子
  13. """
  14. import random
  15. import numpy as np
  16. random.seed(seed)
  17. np.random.seed(seed)
  18. torch.manual_seed(seed)
  19. torch.cuda.manual_seed_all(seed)
  20. # 设置环境变量
  21. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  22. os.environ["PYTHONHASHSEED"] = str(seed)
  23. print(f"环境设置完成,随机种子:{seed}")
  24. def get_gpu_info() -> Dict[str, Any]:
  25. """
  26. 获取 GPU 信息
  27. Returns:
  28. GPU 信息字典
  29. """
  30. if not torch.cuda.is_available():
  31. return {"available": False}
  32. info = {
  33. "available": True,
  34. "device_count": torch.cuda.device_count(),
  35. "devices": [],
  36. }
  37. for i in range(torch.cuda.device_count()):
  38. device_info = {
  39. "name": torch.cuda.get_device_name(i),
  40. "memory_allocated": torch.cuda.memory_allocated(i) / 1e9,
  41. "memory_reserved": torch.cuda.memory_reserved(i) / 1e9,
  42. "max_memory": torch.cuda.get_device_properties(i).total_memory / 1e9,
  43. }
  44. info["devices"].append(device_info)
  45. return info
  46. def count_parameters(model) -> Dict[str, int]:
  47. """
  48. 统计模型参数
  49. Args:
  50. model: 模型对象
  51. Returns:
  52. 参数字典
  53. """
  54. total_params = sum(p.numel() for p in model.parameters())
  55. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  56. return {
  57. "total": total_params,
  58. "trainable": trainable_params,
  59. "frozen": total_params - trainable_params,
  60. }
  61. def format_time(seconds: float) -> str:
  62. """
  63. 格式化时间为可读字符串
  64. Args:
  65. seconds: 秒数
  66. Returns:
  67. 格式化后的时间字符串
  68. """
  69. if seconds < 60:
  70. return f"{seconds:.2f}s"
  71. elif seconds < 3600:
  72. minutes = seconds / 60
  73. return f"{minutes:.2f}m"
  74. else:
  75. hours = seconds / 3600
  76. return f"{hours:.2f}h"
  77. def print_memory_usage():
  78. """打印内存使用情况"""
  79. if torch.cuda.is_available():
  80. allocated = torch.cuda.memory_allocated() / 1e9
  81. reserved = torch.cuda.memory_reserved() / 1e9
  82. print(f"GPU 内存 - 已分配:{allocated:.2f}GB, 已保留:{reserved:.2f}GB")