check_env.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. """
  2. 检查环境依赖
  3. """
  4. import sys
  5. import subprocess
  6. from packaging import version
  7. def check_package(package_name, min_version=None):
  8. """检查包是否安装及版本"""
  9. try:
  10. import importlib
  11. module = importlib.import_module(package_name)
  12. if hasattr(module, "__version__"):
  13. installed_version = module.__version__
  14. if min_version and version.parse(installed_version) < version.parse(min_version):
  15. print(f"❌ {package_name}: {installed_version} (需要 >= {min_version})")
  16. return False
  17. else:
  18. print(f"✓ {package_name}: {installed_version}")
  19. else:
  20. print(f"✓ {package_name}: 已安装")
  21. return True
  22. except ImportError:
  23. print(f"❌ {package_name}: 未安装")
  24. return False
  25. def check_cuda():
  26. """检查 CUDA 是否可用"""
  27. try:
  28. import torch
  29. if torch.cuda.is_available():
  30. print(f"✓ CUDA: 可用 ({torch.cuda.device_count()} 个 GPU)")
  31. for i in range(torch.cuda.device_count()):
  32. print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
  33. return True
  34. else:
  35. print("⚠ CUDA: 不可用 (将使用 CPU 训练)")
  36. return False
  37. except ImportError:
  38. print("❌ PyTorch: 未安装")
  39. return False
  40. def main():
  41. print("=" * 60)
  42. print("FineTuneX 环境检查")
  43. print("=" * 60)
  44. print()
  45. required_packages = {
  46. "torch": "2.0.0",
  47. "transformers": "4.40.0",
  48. "datasets": "2.14.0",
  49. "accelerate": "0.25.0",
  50. "peft": "0.7.0",
  51. "bitsandbytes": "0.41.0",
  52. "trl": "0.7.0",
  53. "fastapi": "0.104.0",
  54. "uvicorn": "0.24.0",
  55. "pydantic": "2.0.0",
  56. }
  57. all_ok = True
  58. print("检查 Python 包:")
  59. print("-" * 60)
  60. for package, min_ver in required_packages.items():
  61. if not check_package(package, min_ver):
  62. all_ok = False
  63. print()
  64. print("检查 CUDA:")
  65. print("-" * 60)
  66. if not check_cuda():
  67. print("⚠ 警告:CPU 训练速度较慢,建议使用 GPU")
  68. print()
  69. print("=" * 60)
  70. if all_ok:
  71. print("✓ 所有依赖已安装!")
  72. else:
  73. print("❌ 部分依赖缺失,请运行:pip install -r requirements.txt")
  74. print("=" * 60)
  75. return 0 if all_ok else 1
  76. if __name__ == "__main__":
  77. sys.exit(main())