| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- """
- 基础模型配置类
- """
- from dataclasses import dataclass, field
- from typing import List, Optional, Dict, Any
- import torch
- @dataclass
- class BaseModelConfig:
- """基础模型配置"""
-
- # 模型相关
- model_name: str = "Qwen/Qwen3.5-0.5B"
- model_revision: str = "main"
- trust_remote_code: bool = True
-
- # LoRA 配置
- lora_r: int = 16
- lora_alpha: int = 32
- lora_dropout: float = 0.05
- target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
-
- # 训练配置
- per_device_train_batch_size: int = 1
- gradient_accumulation_steps: int = 4
- learning_rate: float = 2e-4
- num_train_epochs: float = 3.0
- max_seq_length: int = 512
-
- # 优化器配置
- warmup_ratio: float = 0.03
- weight_decay: float = 0.01
- lr_scheduler_type: str = "cosine"
-
- # 量化配置
- use_4bit: bool = True
- bnb_4bit_compute_dtype: str = "float16"
- bnb_4bit_quant_type: str = "nf4"
- use_nested_quant: bool = False
-
- # 其他配置
- output_dir: str = "./outputs"
- logging_steps: int = 10
- save_steps: int = 100
- evaluation_strategy: str = "no"
-
- def get_compute_dtype(self) -> torch.dtype:
- """获取计算精度"""
- if self.bnb_4bit_compute_dtype == "float16":
- return torch.float16
- elif self.bnb_4bit_compute_dtype == "bfloat16":
- return torch.bfloat16
- else:
- return torch.float32
-
- def to_dict(self) -> Dict[str, Any]:
- """转换为字典"""
- return {
- "model_name": self.model_name,
- "lora_r": self.lora_r,
- "lora_alpha": self.lora_alpha,
- "lora_dropout": self.lora_dropout,
- "target_modules": self.target_modules,
- "per_device_train_batch_size": self.per_device_train_batch_size,
- "gradient_accumulation_steps": self.gradient_accumulation_steps,
- "learning_rate": self.learning_rate,
- "num_train_epochs": self.num_train_epochs,
- "max_seq_length": self.max_seq_length,
- "output_dir": self.output_dir,
- }
- def load_model(config: BaseModelConfig):
- """加载模型的通用接口"""
- if "qwen" in config.model_name.lower():
- from finetunex.models.qwen import load_qwen_model
- return load_qwen_model(config)
- else:
- raise ValueError(f"不支持的模型:{config.model_name}")
- def get_model_config(model_name: str, **kwargs) -> BaseModelConfig:
- """获取模型配置"""
- if "qwen" in model_name.lower():
- from finetunex.models.qwen import QwenConfig
- return QwenConfig(model_name=model_name, **kwargs)
- else:
- return BaseModelConfig(model_name=model_name, **kwargs)
|