qwen.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. """
  2. Qwen 模型配置和加载
  3. """
  4. from dataclasses import dataclass
  5. from typing import List
  6. import os
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import logging
  11. from transformers import AutoModelForCausalLM, AutoTokenizer
  12. from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
  13. from finetunex.models.base import BaseModelConfig
  14. logger = logging.getLogger(__name__)
  15. @dataclass
  16. class QwenConfig(BaseModelConfig):
  17. """Qwen 模型专用配置"""
  18. model_name: str = "Qwen/Qwen3.5-0.5B"
  19. target_modules: List[str] = None
  20. def __post_init__(self):
  21. if self.target_modules is None:
  22. self.target_modules = [
  23. "q_proj",
  24. "k_proj",
  25. "v_proj",
  26. "o_proj",
  27. "gate_proj",
  28. "up_proj",
  29. "down_proj",
  30. ]
  31. def is_qwen3_5(self) -> bool:
  32. return "qwen3.5" in self.model_name.lower() or "qwen3_5" in self.model_name.lower()
  33. def is_qwen3(self) -> bool:
  34. name = self.model_name.lower()
  35. return "qwen3" in name and "qwen3.5" not in name and "qwen3_5" not in name
  36. class _NPUConv1d(nn.Module):
  37. """华为升腾 NPU 兼容的 Conv1d 实现
  38. 使用 unfold + einsum 替代 F.conv1d,
  39. 避免 NPU 上 Conv2D 算子编译失败的问题。
  40. """
  41. def __init__(self, original_conv1d: nn.Conv1d):
  42. super().__init__()
  43. self.stride = original_conv1d.stride[0]
  44. self.padding = original_conv1d.padding[0]
  45. self.dilation = original_conv1d.dilation[0]
  46. self.groups = original_conv1d.groups
  47. self.kernel_size = original_conv1d.kernel_size[0]
  48. self.in_channels = original_conv1d.in_channels
  49. self.out_channels = original_conv1d.out_channels
  50. self.weight = original_conv1d.weight
  51. self.bias = original_conv1d.bias
  52. def forward(self, input: torch.Tensor) -> torch.Tensor:
  53. if self.padding > 0:
  54. input = F.pad(input, (self.padding, 0))
  55. if self.kernel_size == 1 and self.stride == 1 and self.dilation == 1:
  56. output = F.linear(input.transpose(1, 2),
  57. self.weight.squeeze(-1),
  58. self.bias)
  59. return output.transpose(1, 2)
  60. unfolded = input.unfold(2, self.kernel_size, self.stride)
  61. weight = self.weight
  62. output = torch.einsum('bci,oci->bo', unfolded, weight)
  63. if self.bias is not None:
  64. output = output + self.bias.unsqueeze(0)
  65. return output
  66. def _patch_conv1d_for_npu():
  67. """Monkey-patch Conv1d 使其在华为升腾 NPU 上使用纯 PyTorch 实现"""
  68. original_forward = nn.Conv1d.forward
  69. if hasattr(nn.Conv1d, '_npu_patched'):
  70. return
  71. def npu_conv1d_forward(self, input):
  72. try:
  73. return original_forward(self, input)
  74. except RuntimeError as e:
  75. if "Conv2D" in str(e) or "500001" in str(e):
  76. logger.info(f"Conv1d 在 NPU 上失败,回退到纯 PyTorch 实现: {e}")
  77. fallback = _NPUConv1d(self)
  78. return fallback(input)
  79. raise
  80. nn.Conv1d.forward = npu_conv1d_forward
  81. nn.Conv1d._npu_patched = True
  82. logger.info("已应用 Conv1d NPU 兼容补丁")
  83. def _patch_qwen3_5_for_npu(model_path: str):
  84. """修补 Qwen3.5 模型配置以兼容华为升腾 NPU"""
  85. import json
  86. config_path = os.path.join(model_path, "config.json")
  87. if not os.path.exists(config_path):
  88. logger.warning(f"未找到模型配置文件:{config_path}")
  89. return
  90. with open(config_path, "r", encoding="utf-8") as f:
  91. model_config = json.load(f)
  92. if model_config.get("model_type") != "qwen3_5":
  93. return
  94. changed = False
  95. if "linear_attn" in model_config:
  96. logger.info("检测到 linear_attn 配置,NPU 不支持,将替换为 sdpa attention")
  97. del model_config["linear_attn"]
  98. changed = True
  99. if model_config.get("_attn_implementation", "") == "linear":
  100. logger.info("检测到 _attn_implementation=linear,将替换为 eager")
  101. model_config["_attn_implementation"] = "eager"
  102. changed = True
  103. attn_layers = model_config.get("attention_layers", None)
  104. if attn_layers:
  105. if any(v == "linear_attn" for v in attn_layers.values()):
  106. logger.info("检测到 attention_layers 中包含 linear_attn,将替换为 eager")
  107. model_config["attention_layers"] = {
  108. k: "eager" if v == "linear_attn" else v
  109. for k, v in attn_layers.items()
  110. }
  111. changed = True
  112. if changed:
  113. backup_path = config_path + ".bak"
  114. if not os.path.exists(backup_path):
  115. import shutil
  116. shutil.copy2(config_path, backup_path)
  117. logger.info(f"原始配置已备份到:{backup_path}")
  118. with open(config_path, "w", encoding="utf-8") as f:
  119. json.dump(model_config, f, indent=2, ensure_ascii=False)
  120. logger.info("模型配置已修改,linear attention 已替换为 sdpa attention")
  121. def load_qwen_model(config: QwenConfig):
  122. """加载 Qwen 模型"""
  123. print(f"正在加载模型:{config.model_name}")
  124. is_npu = hasattr(torch, 'npu') and torch.npu.is_available()
  125. is_qwen3_5 = config.is_qwen3_5()
  126. is_qwen3 = config.is_qwen3()
  127. if is_npu:
  128. logger.info("检测到华为升腾 NPU,应用兼容性补丁...")
  129. if is_qwen3_5:
  130. _patch_conv1d_for_npu()
  131. _patch_qwen3_5_for_npu(config.model_name)
  132. elif is_qwen3:
  133. logger.info("Qwen3 模型使用标准 attention,NPU 兼容性良好,无需额外补丁")
  134. compute_dtype = config.get_compute_dtype()
  135. use_quantization = config.use_4bit
  136. if use_quantization:
  137. try:
  138. from transformers import BitsAndBytesConfig
  139. bnb_config = BitsAndBytesConfig(
  140. load_in_4bit=config.use_4bit,
  141. bnb_4bit_quant_type=config.bnb_4bit_quant_type,
  142. bnb_4bit_compute_dtype=compute_dtype,
  143. bnb_4bit_use_double_quant=config.use_nested_quant,
  144. )
  145. print("使用 4bit 量化加载模型(需要 NVIDIA GPU)")
  146. except (ImportError, Exception) as e:
  147. logger.warning(f"无法使用 4bit 量化: {e}")
  148. logger.warning("将使用 bf16/fp16 加载模型")
  149. use_quantization = False
  150. bnb_config = None
  151. else:
  152. bnb_config = None
  153. print(f"使用 {compute_dtype} 精度加载模型")
  154. tokenizer = AutoTokenizer.from_pretrained(
  155. config.model_name,
  156. trust_remote_code=config.trust_remote_code,
  157. padding_side="right",
  158. )
  159. tokenizer.pad_token = tokenizer.eos_token
  160. model_kwargs = {
  161. "quantization_config": bnb_config if use_quantization else None,
  162. "device_map": "auto",
  163. "trust_remote_code": config.trust_remote_code,
  164. "torch_dtype": compute_dtype,
  165. }
  166. if is_npu:
  167. model_kwargs["attn_implementation"] = "eager"
  168. model = AutoModelForCausalLM.from_pretrained(
  169. config.model_name,
  170. **model_kwargs,
  171. )
  172. if use_quantization:
  173. model = prepare_model_for_kbit_training(model)
  174. peft_config = LoraConfig(
  175. lora_alpha=config.lora_alpha,
  176. lora_dropout=config.lora_dropout,
  177. r=config.lora_r,
  178. bias="none",
  179. task_type="CAUSAL_LM",
  180. target_modules=config.target_modules,
  181. )
  182. model = get_peft_model(model, peft_config)
  183. print(f"模型加载完成!可训练参数:{model.print_trainable_parameters()}")
  184. return model, tokenizer, peft_config