| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- """
- 量化相关工具函数
- """
- import os
- import json
- from typing import Dict, Any
- from pathlib import Path
- def get_model_size(model_path: str) -> Dict[str, Any]:
- """
- 获取模型大小
-
- Args:
- model_path: 模型路径
-
- Returns:
- 模型大小信息
- """
- total_size = 0
- file_count = 0
- file_sizes = {}
-
- for root, dirs, files in os.walk(model_path):
- for file in files:
- if file.endswith((".bin", ".safetensors", ".pt", ".pth")):
- file_path = os.path.join(root, file)
- size = os.path.getsize(file_path)
- total_size += size
- file_count += 1
- file_sizes[file] = size
-
- # 转换为人类可读格式
- def format_size(size_bytes):
- for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
- if size_bytes < 1024.0:
- return f"{size_bytes:.2f} {unit}"
- size_bytes /= 1024.0
- return f"{size_bytes:.2f} PB"
-
- return {
- "total_size": total_size,
- "total_size_formatted": format_size(total_size),
- "file_count": file_count,
- "file_sizes": file_sizes,
- }
- def estimate_quantized_size(
- model_path: str,
- quantization_bits: int = 4,
- overhead: float = 0.1
- ) -> Dict[str, Any]:
- """
- 估算量化后的模型大小
-
- Args:
- model_path: 原始模型路径
- quantization_bits: 量化位数 (4, 8 等)
- overhead: 额外开销比例(元数据、tokenizer 等)
-
- Returns:
- 估算的大小信息
- """
- original_size = get_model_size(model_path)
-
- # 计算压缩比
- compression_ratio = quantization_bits / 16 # 假设原始是 FP16 (16bit)
-
- # 估算量化后大小
- estimated_size = original_size["total_size"] * compression_ratio * (1 + overhead)
-
- def format_size(size_bytes):
- for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
- if size_bytes < 1024.0:
- return f"{size_bytes:.2f} {unit}"
- size_bytes /= 1024.0
- return f"{size_bytes:.2f} PB"
-
- return {
- "original_size": original_size["total_size_formatted"],
- "quantization_bits": quantization_bits,
- "compression_ratio": f"{1/compression_ratio:.1f}x",
- "estimated_size": format_size(estimated_size),
- "estimated_size_bytes": estimated_size,
- "space_saved": format_size(original_size["total_size"] - estimated_size),
- "space_saved_percent": f"{(1 - compression_ratio * (1 + overhead)) * 100:.1f}%",
- }
- def compare_models(
- model_path_1: str,
- model_path_2: str,
- label_1: str = "原始模型",
- label_2: str = "量化模型"
- ) -> Dict[str, Any]:
- """
- 比较两个模型的大小
-
- Args:
- model_path_1: 第一个模型路径
- model_path_2: 第二个模型路径
- label_1: 第一个模型标签
- label_2: 第二个模型标签
-
- Returns:
- 比较结果
- """
- size_1 = get_model_size(model_path_1)
- size_2 = get_model_size(model_path_2)
-
- size_diff = size_1["total_size"] - size_2["total_size"]
- size_diff_percent = (size_diff / size_1["total_size"]) * 100 if size_1["total_size"] > 0 else 0
-
- def format_size(size_bytes):
- for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
- if size_bytes < 1024.0:
- return f"{size_bytes:.2f} {unit}"
- size_bytes /= 1024.0
- return f"{size_bytes:.2f} PB"
-
- return {
- label_1: {
- "size": size_1["total_size_formatted"],
- "files": size_1["file_count"],
- },
- label_2: {
- "size": size_2["total_size_formatted"],
- "files": size_2["file_count"],
- },
- "difference": format_size(abs(size_diff)),
- "difference_percent": f"{size_diff_percent:.1f}%",
- "smaller": label_1 if size_diff > 0 else label_2,
- }
- def print_model_info(model_path: str, title: str = "模型信息"):
- """
- 打印模型信息
-
- Args:
- model_path: 模型路径
- title: 标题
- """
- print(f"\n{'=' * 60}")
- print(f"{title}")
- print(f"{'=' * 60}")
-
- size_info = get_model_size(model_path)
- print(f"模型路径:{model_path}")
- print(f"总大小:{size_info['total_size_formatted']}")
- print(f"文件数:{size_info['file_count']}")
-
- if size_info['file_sizes']:
- print("\n文件列表:")
- for filename, size in size_info['file_sizes'].items():
- size_mb = size / (1024 * 1024)
- print(f" {filename}: {size_mb:.2f} MB")
-
- print(f"{'=' * 60}\n")
- def save_quantization_report(
- output_path: str,
- quantization_info: Dict[str, Any],
- size_info: Dict[str, Any],
- report_path: Optional[str] = None
- ):
- """
- 保存量化报告
-
- Args:
- output_path: 输出目录
- quantization_info: 量化信息
- size_info: 大小信息
- report_path: 报告路径(可选)
- """
- if report_path is None:
- report_path = os.path.join(output_path, "quantization_report.json")
-
- report = {
- "quantization": quantization_info,
- "size_info": size_info,
- "timestamp": __import__('datetime').datetime.now().isoformat(),
- }
-
- with open(report_path, "w", encoding="utf-8") as f:
- json.dump(report, f, indent=2, ensure_ascii=False)
-
- print(f"量化报告已保存到:{report_path}")
-
- return report_path
|