quantize.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. """
  2. 模型量化工具
  3. """
  4. import os
  5. import json
  6. import torch
  7. from typing import Dict, Any, Optional
  8. from transformers import AutoModelForCausalLM, AutoTokenizer
  9. from peft import PeftModel
  10. def quantize_to_gguf(
  11. model_path: str,
  12. output_path: str,
  13. quantization_type: str = "Q4_K_M",
  14. **kwargs
  15. ):
  16. """
  17. 将模型量化为 GGUF 格式
  18. Args:
  19. model_path: 模型路径(微调后的模型)
  20. output_path: 输出路径
  21. quantization_type: 量化类型
  22. - Q2_K, Q3_K_S, Q3_K_M, Q3_K_L
  23. - Q4_0, Q4_1, Q4_K_S, Q4_K_M
  24. - Q5_0, Q5_1, Q5_K_S, Q5_K_M
  25. - Q6_K, Q8_0
  26. """
  27. print(f"开始 GGUF 量化:{quantization_type}")
  28. print(f"模型路径:{model_path}")
  29. print(f"输出路径:{output_path}")
  30. # 使用 llama.cpp 的 convert-hf-to-gguf.py 脚本
  31. # 这里提供调用示例
  32. import subprocess
  33. try:
  34. # 首先需要克隆 llama.cpp
  35. llama_cpp_path = kwargs.get("llama_cpp_path", "./llama.cpp")
  36. if not os.path.exists(llama_cpp_path):
  37. print("正在克隆 llama.cpp...")
  38. subprocess.run(
  39. ["git", "clone", "https://github.com/ggerganov/llama.cpp.git", llama_cpp_path],
  40. check=True
  41. )
  42. # 运行转换脚本
  43. convert_script = os.path.join(llama_cpp_path, "convert-hf-to-gguf.py")
  44. cmd = [
  45. "python",
  46. convert_script,
  47. model_path,
  48. "--outfile", output_path,
  49. "--outtype", quantization_type
  50. ]
  51. print(f"执行命令:{' '.join(cmd)}")
  52. subprocess.run(cmd, check=True)
  53. print(f"GGUF 量化完成!输出:{output_path}")
  54. except subprocess.CalledProcessError as e:
  55. print(f"GGUF 量化失败:{e}")
  56. raise
  57. def quantize_to_awq(
  58. model_path: str,
  59. output_path: str,
  60. quantization_config: Optional[Dict[str, Any]] = None,
  61. **kwargs
  62. ):
  63. """
  64. 使用 AWQ (Activation-aware Weight Quantization) 量化
  65. Args:
  66. model_path: 模型路径
  67. output_path: 输出路径
  68. quantization_config: AWQ 量化配置
  69. """
  70. try:
  71. from awq import AutoAWQForCausalLM
  72. except ImportError:
  73. print("错误:需要安装 autoawq")
  74. print("运行:pip install autoawq")
  75. raise
  76. print("开始 AWQ 量化...")
  77. print(f"模型路径:{model_path}")
  78. print(f"输出路径:{output_path}")
  79. # 默认配置
  80. if quantization_config is None:
  81. quantization_config = {
  82. "zero_point": True,
  83. "q_group_size": 128,
  84. "w_bit": 4,
  85. "version": "GEMM",
  86. }
  87. # 加载模型并量化
  88. model = AutoAWQForCausalLM.from_pretrained(
  89. model_path,
  90. device_map="auto",
  91. trust_remote_code=True,
  92. )
  93. # 执行量化
  94. model.quantize(
  95. tokenizer=AutoTokenizer.from_pretrained(model_path),
  96. quant_config=quantization_config,
  97. )
  98. # 保存量化后的模型
  99. model.save_quantized(output_path)
  100. print(f"AWQ 量化完成!输出:{output_path}")
  101. print(f"量化配置:{quantization_config}")
  102. def quantize_to_gptq(
  103. model_path: str,
  104. output_path: str,
  105. quantization_config: Optional[Dict[str, Any]] = None,
  106. **kwargs
  107. ):
  108. """
  109. 使用 GPTQ 量化
  110. Args:
  111. model_path: 模型路径
  112. output_path: 输出路径
  113. quantization_config: GPTQ 量化配置
  114. """
  115. try:
  116. from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
  117. except ImportError:
  118. print("错误:需要安装 auto-gptq")
  119. print("运行:pip install auto-gptq")
  120. raise
  121. print("开始 GPTQ 量化...")
  122. print(f"模型路径:{model_path}")
  123. print(f"输出路径:{output_path}")
  124. # 默认配置
  125. if quantization_config is None:
  126. quantize_config = BaseQuantizeConfig(
  127. bits=4,
  128. group_size=128,
  129. damp_percent=0.01,
  130. desc_act=False,
  131. )
  132. else:
  133. quantize_config = BaseQuantizeConfig(**quantization_config)
  134. # 加载数据用于校准(可选)
  135. calibration_data = kwargs.get("calibration_data", None)
  136. # 加载模型
  137. model = AutoGPTQForCausalLM.from_pretrained(
  138. model_path,
  139. quantize_config=quantize_config,
  140. device_map="auto",
  141. trust_remote_code=True,
  142. )
  143. # 如果有校准数据,执行量化
  144. if calibration_data:
  145. model.quantize(calibration_data)
  146. else:
  147. print("警告:未提供校准数据,将跳过量化步骤")
  148. # 保存量化后的模型
  149. model.save_quantized(output_path)
  150. print(f"GPTQ 量化完成!输出:{output_path}")
  151. def quantize_model(
  152. model_path: str,
  153. output_path: str,
  154. method: str = "awq",
  155. **kwargs
  156. ):
  157. """
  158. 模型量化的统一接口
  159. Args:
  160. model_path: 模型路径
  161. output_path: 输出路径
  162. method: 量化方法 (awq, gptq, gguf)
  163. **kwargs: 其他参数
  164. Returns:
  165. 量化结果信息
  166. """
  167. print("=" * 60)
  168. print("模型量化")
  169. print("=" * 60)
  170. print(f"量化方法:{method}")
  171. print(f"源模型:{model_path}")
  172. print(f"目标路径:{output_path}")
  173. # 创建输出目录
  174. os.makedirs(output_path, exist_ok=True)
  175. if method.lower() == "awq":
  176. quantize_to_awq(model_path, output_path, **kwargs)
  177. elif method.lower() == "gptq":
  178. quantize_to_gptq(model_path, output_path, **kwargs)
  179. elif method.lower() == "gguf":
  180. quant_type = kwargs.get("quantization_type", "Q4_K_M")
  181. quantize_to_gguf(model_path, output_path, quant_type, **kwargs)
  182. else:
  183. raise ValueError(f"不支持的量化方法:{method}")
  184. # 保存量化信息
  185. info_path = os.path.join(output_path, "quantization_info.json")
  186. with open(info_path, "w", encoding="utf-8") as f:
  187. json.dump({
  188. "method": method,
  189. "source_model": model_path,
  190. "output_path": output_path,
  191. "config": kwargs,
  192. }, f, indent=2, ensure_ascii=False)
  193. print("=" * 60)
  194. print("量化完成!")
  195. print("=" * 60)
  196. return {
  197. "success": True,
  198. "method": method,
  199. "output_path": output_path,
  200. }