quantize_model.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. """
  2. 模型量化脚本
  3. 用于对微调后的模型进行量化,支持 AWQ、GPTQ、GGUF 等方法。
  4. 使用方法:
  5. # AWQ 量化 (推荐)
  6. python scripts/quantize_model.py --model_path ./outputs/qwen3.5-0.8b-finetuned --method awq
  7. # GPTQ 量化
  8. python scripts/quantize_model.py --model_path ./outputs/qwen3.5-0.8b-finetuned --method gptq
  9. # GGUF 量化
  10. python scripts/quantize_model.py --model_path ./outputs/qwen3.5-0.8b-finetuned --method gguf --quant_type Q4_K_M
  11. """
  12. import argparse
  13. import os
  14. import sys
  15. import json
  16. # 添加项目路径
  17. sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
  18. from finetunex.quantization import quantize_model, get_model_size, estimate_quantized_size
  19. def main():
  20. parser = argparse.ArgumentParser(description="模型量化工具")
  21. parser.add_argument(
  22. "--model_path",
  23. type=str,
  24. required=True,
  25. help="微调后的模型路径"
  26. )
  27. parser.add_argument(
  28. "--output_path",
  29. type=str,
  30. default=None,
  31. help="量化模型输出路径(默认:./outputs/quantized/{method})"
  32. )
  33. parser.add_argument(
  34. "--method",
  35. type=str,
  36. choices=["awq", "gptq", "gguf"],
  37. default="awq",
  38. help="量化方法(默认:awq)"
  39. )
  40. parser.add_argument(
  41. "--quant_type",
  42. type=str,
  43. default=None,
  44. help="量化类型(GGUF 专用,如 Q4_K_M)"
  45. )
  46. parser.add_argument(
  47. "--bits",
  48. type=int,
  49. choices=[4, 8],
  50. default=4,
  51. help="量化位数(默认:4)"
  52. )
  53. parser.add_argument(
  54. "--group_size",
  55. type=int,
  56. default=128,
  57. help="量化分组大小(默认:128)"
  58. )
  59. parser.add_argument(
  60. "--estimate_only",
  61. action="store_true",
  62. help="仅估算大小,不执行量化"
  63. )
  64. parser.add_argument(
  65. "--show_info",
  66. action="store_true",
  67. help="显示模型信息"
  68. )
  69. args = parser.parse_args()
  70. # 检查模型路径
  71. if not os.path.exists(args.model_path):
  72. print(f"错误:模型路径不存在:{args.model_path}")
  73. sys.exit(1)
  74. print("=" * 60)
  75. print("模型量化工具")
  76. print("=" * 60)
  77. print(f"模型路径:{args.model_path}")
  78. print(f"量化方法:{args.method}")
  79. # 显示模型信息
  80. if args.show_info:
  81. size_info = get_model_size(args.model_path)
  82. print(f"\n原始模型大小:{size_info['total_size_formatted']}")
  83. print(f"文件数:{size_info['file_count']}")
  84. # 估算量化后大小
  85. print("\n估算量化后大小:")
  86. for bits in [4, 8]:
  87. estimate = estimate_quantized_size(args.model_path, quantization_bits=bits)
  88. print(f"\n{bits}bit 量化:")
  89. print(f" 原始大小:{estimate['original_size']}")
  90. print(f" 压缩比:{estimate['compression_ratio']}")
  91. print(f" 估算大小:{estimate['estimated_size']}")
  92. print(f" 节省空间:{estimate['space_saved']} ({estimate['space_saved_percent']})")
  93. # 如果只估算,直接退出
  94. if args.estimate_only:
  95. print("\n仅估算模式,跳过量化步骤。")
  96. return
  97. # 设置输出路径
  98. if args.output_path is None:
  99. output_dir = os.path.dirname(args.model_path)
  100. model_name = os.path.basename(args.model_path)
  101. args.output_path = os.path.join(output_dir, f"{model_name}-{args.method}-quantized")
  102. print(f"\n输出路径:{args.output_path}")
  103. # 准备量化配置
  104. quant_config = {
  105. "bits": args.bits,
  106. "group_size": args.group_size,
  107. }
  108. if args.method == "gguf" and args.quant_type:
  109. quant_config["quantization_type"] = args.quant_type
  110. # 执行量化
  111. try:
  112. result = quantize_model(
  113. model_path=args.model_path,
  114. output_path=args.output_path,
  115. method=args.method,
  116. **quant_config
  117. )
  118. if result["success"]:
  119. print("\n" + "=" * 60)
  120. print("量化成功!")
  121. print("=" * 60)
  122. print(f"量化方法:{args.method}")
  123. print(f"输出路径:{args.output_path}")
  124. # 显示实际大小
  125. quantized_size = get_model_size(args.output_path)
  126. print(f"量化后大小:{quantized_size['total_size_formatted']}")
  127. print(f"文件数:{quantized_size['file_count']}")
  128. # 使用建议
  129. print("\n使用建议:")
  130. if args.method == "awq":
  131. print("- AWQ 量化模型可用于推理加速")
  132. print("- 使用 transformers + autoawq 加载")
  133. elif args.method == "gptq":
  134. print("- GPTQ 量化模型适用于 NVIDIA GPU")
  135. print("- 使用 auto-gptq 库加载")
  136. elif args.method == "gguf":
  137. print("- GGUF 格式可用于 llama.cpp")
  138. print("- 支持 CPU 推理")
  139. print("=" * 60)
  140. else:
  141. print("\n量化失败!")
  142. sys.exit(1)
  143. except Exception as e:
  144. print(f"\n量化过程出错:{e}")
  145. import traceback
  146. traceback.print_exc()
  147. sys.exit(1)
  148. if __name__ == "__main__":
  149. main()