| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- """
- Qwen 模型配置和加载
- """
- from dataclasses import dataclass
- from typing import List
- import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
- from finetunex.models.base import BaseModelConfig
- @dataclass
- class QwenConfig(BaseModelConfig):
- """Qwen 模型专用配置"""
-
- model_name: str = "Qwen/Qwen3.5-0.5B"
- target_modules: List[str] = None
-
- def __post_init__(self):
- # Qwen 模型的默认 target_modules
- if self.target_modules is None:
- self.target_modules = [
- "q_proj",
- "k_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "up_proj",
- "down_proj",
- ]
- def load_qwen_model(config: QwenConfig):
- """加载 Qwen 模型"""
-
- print(f"正在加载模型:{config.model_name}")
-
- # 配置量化
- compute_dtype = config.get_compute_dtype()
-
- bnb_config = BitsAndBytesConfig(
- load_in_4bit=config.use_4bit,
- bnb_4bit_quant_type=config.bnb_4bit_quant_type,
- bnb_4bit_compute_dtype=compute_dtype,
- bnb_4bit_use_double_quant=config.use_nested_quant,
- )
-
- # 加载 tokenizer
- tokenizer = AutoTokenizer.from_pretrained(
- config.model_name,
- trust_remote_code=config.trust_remote_code,
- padding_side="right",
- )
- tokenizer.pad_token = tokenizer.eos_token
-
- # 加载模型
- model = AutoModelForCausalLM.from_pretrained(
- config.model_name,
- quantization_config=bnb_config if config.use_4bit else None,
- device_map="auto",
- trust_remote_code=config.trust_remote_code,
- torch_dtype=compute_dtype,
- )
-
- # 准备模型用于 k-bit 训练
- if config.use_4bit:
- model = prepare_model_for_kbit_training(model)
-
- # 配置 LoRA
- peft_config = LoraConfig(
- lora_alpha=config.lora_alpha,
- lora_dropout=config.lora_dropout,
- r=config.lora_r,
- bias="none",
- task_type="CAUSAL_LM",
- target_modules=config.target_modules,
- )
-
- # 应用 LoRA
- model = get_peft_model(model, peft_config)
-
- print(f"模型加载完成!可训练参数:{model.print_trainable_parameters()}")
-
- return model, tokenizer, peft_config
|