test_quantization.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. """
  2. 测试量化模块
  3. """
  4. import os
  5. import sys
  6. sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
  7. from finetunex.quantization import (
  8. get_model_size,
  9. estimate_quantized_size,
  10. compare_models,
  11. )
  12. def test_get_model_size():
  13. """测试获取模型大小"""
  14. print("=" * 60)
  15. print("测试:get_model_size")
  16. print("=" * 60)
  17. # 使用示例数据目录作为测试
  18. test_path = "./data"
  19. if os.path.exists(test_path):
  20. size_info = get_model_size(test_path)
  21. print(f"路径:{test_path}")
  22. print(f"总大小:{size_info['total_size_formatted']}")
  23. print(f"文件数:{size_info['file_count']}")
  24. print("✓ 测试通过\n")
  25. else:
  26. print(f"⚠ 测试路径不存在:{test_path}\n")
  27. def test_estimate_quantized_size():
  28. """测试估算量化后大小"""
  29. print("=" * 60)
  30. print("测试:estimate_quantized_size")
  31. print("=" * 60)
  32. test_path = "./data"
  33. if os.path.exists(test_path):
  34. print(f"路径:{test_path}")
  35. for bits in [4, 8]:
  36. estimate = estimate_quantized_size(test_path, quantization_bits=bits)
  37. print(f"\n{bits}bit 量化估算:")
  38. print(f" 原始大小:{estimate['original_size']}")
  39. print(f" 估算大小:{estimate['estimated_size']}")
  40. print(f" 压缩比:{estimate['compression_ratio']}")
  41. print(f" 节省空间:{estimate['space_saved']} ({estimate['space_saved_percent']})")
  42. print("\n✓ 测试通过\n")
  43. else:
  44. print(f"⚠ 测试路径不存在:{test_path}\n")
  45. def test_compare_models():
  46. """测试比较模型大小"""
  47. print("=" * 60)
  48. print("测试:compare_models")
  49. print("=" * 60)
  50. # 比较两个目录
  51. path1 = "./data"
  52. path2 = "./configs"
  53. if os.path.exists(path1) and os.path.exists(path2):
  54. comparison = compare_models(path1, path2, "数据目录", "配置目录")
  55. print(f"数据目录:{comparison['数据目录']['size']}")
  56. print(f"配置目录:{comparison['配置目录']['size']}")
  57. print(f"差异:{comparison['difference']} ({comparison['difference_percent']})")
  58. print(f"更小:{comparison['smaller']}")
  59. print("\n✓ 测试通过\n")
  60. else:
  61. print(f"⚠ 测试路径不存在\n")
  62. def main():
  63. print("\n" + "=" * 60)
  64. print("量化模块测试")
  65. print("=" * 60 + "\n")
  66. # 运行测试
  67. test_get_model_size()
  68. test_estimate_quantized_size()
  69. test_compare_models()
  70. print("=" * 60)
  71. print("所有测试完成!")
  72. print("=" * 60)
  73. print("\n提示:这些测试使用了示例目录,实际使用时请指定模型路径。")
  74. print("例如:--model_path ./outputs/qwen3.5-0.8b-finetuned\n")
  75. if __name__ == "__main__":
  76. main()