utils.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. """
  2. 量化相关工具函数
  3. """
  4. import os
  5. import json
  6. from typing import Dict, Any
  7. from pathlib import Path
  8. def get_model_size(model_path: str) -> Dict[str, Any]:
  9. """
  10. 获取模型大小
  11. Args:
  12. model_path: 模型路径
  13. Returns:
  14. 模型大小信息
  15. """
  16. total_size = 0
  17. file_count = 0
  18. file_sizes = {}
  19. for root, dirs, files in os.walk(model_path):
  20. for file in files:
  21. if file.endswith((".bin", ".safetensors", ".pt", ".pth")):
  22. file_path = os.path.join(root, file)
  23. size = os.path.getsize(file_path)
  24. total_size += size
  25. file_count += 1
  26. file_sizes[file] = size
  27. # 转换为人类可读格式
  28. def format_size(size_bytes):
  29. for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
  30. if size_bytes < 1024.0:
  31. return f"{size_bytes:.2f} {unit}"
  32. size_bytes /= 1024.0
  33. return f"{size_bytes:.2f} PB"
  34. return {
  35. "total_size": total_size,
  36. "total_size_formatted": format_size(total_size),
  37. "file_count": file_count,
  38. "file_sizes": file_sizes,
  39. }
  40. def estimate_quantized_size(
  41. model_path: str,
  42. quantization_bits: int = 4,
  43. overhead: float = 0.1
  44. ) -> Dict[str, Any]:
  45. """
  46. 估算量化后的模型大小
  47. Args:
  48. model_path: 原始模型路径
  49. quantization_bits: 量化位数 (4, 8 等)
  50. overhead: 额外开销比例(元数据、tokenizer 等)
  51. Returns:
  52. 估算的大小信息
  53. """
  54. original_size = get_model_size(model_path)
  55. # 计算压缩比
  56. compression_ratio = quantization_bits / 16 # 假设原始是 FP16 (16bit)
  57. # 估算量化后大小
  58. estimated_size = original_size["total_size"] * compression_ratio * (1 + overhead)
  59. def format_size(size_bytes):
  60. for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
  61. if size_bytes < 1024.0:
  62. return f"{size_bytes:.2f} {unit}"
  63. size_bytes /= 1024.0
  64. return f"{size_bytes:.2f} PB"
  65. return {
  66. "original_size": original_size["total_size_formatted"],
  67. "quantization_bits": quantization_bits,
  68. "compression_ratio": f"{1/compression_ratio:.1f}x",
  69. "estimated_size": format_size(estimated_size),
  70. "estimated_size_bytes": estimated_size,
  71. "space_saved": format_size(original_size["total_size"] - estimated_size),
  72. "space_saved_percent": f"{(1 - compression_ratio * (1 + overhead)) * 100:.1f}%",
  73. }
  74. def compare_models(
  75. model_path_1: str,
  76. model_path_2: str,
  77. label_1: str = "原始模型",
  78. label_2: str = "量化模型"
  79. ) -> Dict[str, Any]:
  80. """
  81. 比较两个模型的大小
  82. Args:
  83. model_path_1: 第一个模型路径
  84. model_path_2: 第二个模型路径
  85. label_1: 第一个模型标签
  86. label_2: 第二个模型标签
  87. Returns:
  88. 比较结果
  89. """
  90. size_1 = get_model_size(model_path_1)
  91. size_2 = get_model_size(model_path_2)
  92. size_diff = size_1["total_size"] - size_2["total_size"]
  93. size_diff_percent = (size_diff / size_1["total_size"]) * 100 if size_1["total_size"] > 0 else 0
  94. def format_size(size_bytes):
  95. for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
  96. if size_bytes < 1024.0:
  97. return f"{size_bytes:.2f} {unit}"
  98. size_bytes /= 1024.0
  99. return f"{size_bytes:.2f} PB"
  100. return {
  101. label_1: {
  102. "size": size_1["total_size_formatted"],
  103. "files": size_1["file_count"],
  104. },
  105. label_2: {
  106. "size": size_2["total_size_formatted"],
  107. "files": size_2["file_count"],
  108. },
  109. "difference": format_size(abs(size_diff)),
  110. "difference_percent": f"{size_diff_percent:.1f}%",
  111. "smaller": label_1 if size_diff > 0 else label_2,
  112. }
  113. def print_model_info(model_path: str, title: str = "模型信息"):
  114. """
  115. 打印模型信息
  116. Args:
  117. model_path: 模型路径
  118. title: 标题
  119. """
  120. print(f"\n{'=' * 60}")
  121. print(f"{title}")
  122. print(f"{'=' * 60}")
  123. size_info = get_model_size(model_path)
  124. print(f"模型路径:{model_path}")
  125. print(f"总大小:{size_info['total_size_formatted']}")
  126. print(f"文件数:{size_info['file_count']}")
  127. if size_info['file_sizes']:
  128. print("\n文件列表:")
  129. for filename, size in size_info['file_sizes'].items():
  130. size_mb = size / (1024 * 1024)
  131. print(f" {filename}: {size_mb:.2f} MB")
  132. print(f"{'=' * 60}\n")
  133. def save_quantization_report(
  134. output_path: str,
  135. quantization_info: Dict[str, Any],
  136. size_info: Dict[str, Any],
  137. report_path: Optional[str] = None
  138. ):
  139. """
  140. 保存量化报告
  141. Args:
  142. output_path: 输出目录
  143. quantization_info: 量化信息
  144. size_info: 大小信息
  145. report_path: 报告路径(可选)
  146. """
  147. if report_path is None:
  148. report_path = os.path.join(output_path, "quantization_report.json")
  149. report = {
  150. "quantization": quantization_info,
  151. "size_info": size_info,
  152. "timestamp": __import__('datetime').datetime.now().isoformat(),
  153. }
  154. with open(report_path, "w", encoding="utf-8") as f:
  155. json.dump(report, f, indent=2, ensure_ascii=False)
  156. print(f"量化报告已保存到:{report_path}")
  157. return report_path