qwen.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. """
  2. Qwen 模型配置和加载
  3. """
  4. from dataclasses import dataclass
  5. from typing import List
  6. import torch
  7. from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
  8. from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
  9. from finetunex.models.base import BaseModelConfig
  10. @dataclass
  11. class QwenConfig(BaseModelConfig):
  12. """Qwen 模型专用配置"""
  13. model_name: str = "Qwen/Qwen3.5-0.5B"
  14. target_modules: List[str] = None
  15. def __post_init__(self):
  16. # Qwen 模型的默认 target_modules
  17. if self.target_modules is None:
  18. self.target_modules = [
  19. "q_proj",
  20. "k_proj",
  21. "v_proj",
  22. "o_proj",
  23. "gate_proj",
  24. "up_proj",
  25. "down_proj",
  26. ]
  27. def load_qwen_model(config: QwenConfig):
  28. """加载 Qwen 模型"""
  29. print(f"正在加载模型:{config.model_name}")
  30. # 配置量化
  31. compute_dtype = config.get_compute_dtype()
  32. bnb_config = BitsAndBytesConfig(
  33. load_in_4bit=config.use_4bit,
  34. bnb_4bit_quant_type=config.bnb_4bit_quant_type,
  35. bnb_4bit_compute_dtype=compute_dtype,
  36. bnb_4bit_use_double_quant=config.use_nested_quant,
  37. )
  38. # 加载 tokenizer
  39. tokenizer = AutoTokenizer.from_pretrained(
  40. config.model_name,
  41. trust_remote_code=config.trust_remote_code,
  42. padding_side="right",
  43. )
  44. tokenizer.pad_token = tokenizer.eos_token
  45. # 加载模型
  46. model = AutoModelForCausalLM.from_pretrained(
  47. config.model_name,
  48. quantization_config=bnb_config if config.use_4bit else None,
  49. device_map="auto",
  50. trust_remote_code=config.trust_remote_code,
  51. torch_dtype=compute_dtype,
  52. )
  53. # 准备模型用于 k-bit 训练
  54. if config.use_4bit:
  55. model = prepare_model_for_kbit_training(model)
  56. # 配置 LoRA
  57. peft_config = LoraConfig(
  58. lora_alpha=config.lora_alpha,
  59. lora_dropout=config.lora_dropout,
  60. r=config.lora_r,
  61. bias="none",
  62. task_type="CAUSAL_LM",
  63. target_modules=config.target_modules,
  64. )
  65. # 应用 LoRA
  66. model = get_peft_model(model, peft_config)
  67. print(f"模型加载完成!可训练参数:{model.print_trainable_parameters()}")
  68. return model, tokenizer, peft_config