check_env.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. """
  2. 检查环境依赖
  3. """
  4. import sys
  5. import subprocess
  6. from packaging import version
  7. def check_package(package_name, min_version=None):
  8. """检查包是否安装及版本"""
  9. try:
  10. import importlib
  11. module = importlib.import_module(package_name)
  12. if hasattr(module, "__version__"):
  13. installed_version = module.__version__
  14. if min_version and version.parse(installed_version) < version.parse(min_version):
  15. print(f"❌ {package_name}: {installed_version} (需要 >= {min_version})")
  16. return False
  17. else:
  18. print(f"✓ {package_name}: {installed_version}")
  19. else:
  20. print(f"✓ {package_name}: 已安装")
  21. return True
  22. except ImportError:
  23. print(f"❌ {package_name}: 未安装")
  24. return False
  25. def check_cuda():
  26. """检查 CUDA 是否可用"""
  27. try:
  28. import torch
  29. if torch.cuda.is_available():
  30. print(f"✓ CUDA: 可用 ({torch.cuda.device_count()} 个 GPU)")
  31. for i in range(torch.cuda.device_count()):
  32. print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
  33. return True
  34. else:
  35. print("⚠ CUDA: 不可用 (将使用 CPU 训练)")
  36. return False
  37. except ImportError:
  38. print("❌ PyTorch: 未安装")
  39. return False
  40. def check_npu():
  41. """检查华为升腾 NPU 是否可用"""
  42. try:
  43. import torch
  44. if hasattr(torch, 'npu') and torch.npu.is_available():
  45. print(f"✓ NPU (华为升腾): 可用 ({torch.npu.device_count()} 个 NPU)")
  46. return True
  47. else:
  48. print("⚠ NPU (华为升腾): 不可用")
  49. return False
  50. except ImportError:
  51. print("⚠ torch-npu: 未安装 (华为升腾需要)")
  52. return False
  53. def main():
  54. print("=" * 60)
  55. print("FineTuneX 环境检查")
  56. print("=" * 60)
  57. print()
  58. required_packages = {
  59. "torch": "2.1.0",
  60. "transformers": "4.40.0",
  61. "datasets": "2.14.0",
  62. "accelerate": "0.25.0",
  63. "peft": "0.7.0",
  64. "trl": "0.7.0",
  65. "fastapi": "0.104.0",
  66. "uvicorn": "0.24.0",
  67. "pydantic": "2.0.0",
  68. }
  69. all_ok = True
  70. print("检查 Python 包:")
  71. print("-" * 60)
  72. for package, min_ver in required_packages.items():
  73. if not check_package(package, min_ver):
  74. all_ok = False
  75. # 检查 bitsandbytes(可选,仅 NVIDIA GPU)
  76. print("\n检查量化工具(可选,仅 NVIDIA GPU):")
  77. print("-" * 60)
  78. check_package("bitsandbytes", "0.41.0")
  79. print("(华为升腾 NPU 不支持 bitsandbytes 量化,可忽略此项)")
  80. print()
  81. print("检查计算设备:")
  82. print("-" * 60)
  83. has_cuda = check_cuda()
  84. has_npu = check_npu()
  85. if not has_cuda and not has_npu:
  86. print("⚠ 警告:未检测到 GPU/NPU,将使用 CPU 训练(速度较慢)")
  87. print()
  88. print("=" * 60)
  89. if all_ok:
  90. print("✓ 所有依赖已安装!")
  91. else:
  92. print("❌ 部分依赖缺失,请运行:pip install -r requirements.txt")
  93. print("=" * 60)
  94. return 0 if all_ok else 1
  95. if __name__ == "__main__":
  96. sys.exit(main())