""" 基础模型配置类 """ from dataclasses import dataclass, field from typing import List, Optional, Dict, Any import torch @dataclass class BaseModelConfig: """基础模型配置""" # 模型相关 model_name: str = "Qwen/Qwen3.5-0.5B" model_revision: str = "main" trust_remote_code: bool = True # LoRA 配置 lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) # 训练配置 per_device_train_batch_size: int = 1 gradient_accumulation_steps: int = 4 learning_rate: float = 2e-4 num_train_epochs: float = 3.0 max_seq_length: int = 512 # 优化器配置 warmup_ratio: float = 0.03 weight_decay: float = 0.01 lr_scheduler_type: str = "cosine" # 量化配置 use_4bit: bool = True bnb_4bit_compute_dtype: str = "float16" bnb_4bit_quant_type: str = "nf4" use_nested_quant: bool = False # 其他配置 output_dir: str = "./outputs" logging_steps: int = 10 save_steps: int = 100 evaluation_strategy: str = "no" def get_compute_dtype(self) -> torch.dtype: """获取计算精度""" if self.bnb_4bit_compute_dtype == "float16": return torch.float16 elif self.bnb_4bit_compute_dtype == "bfloat16": return torch.bfloat16 else: return torch.float32 def to_dict(self) -> Dict[str, Any]: """转换为字典""" return { "model_name": self.model_name, "lora_r": self.lora_r, "lora_alpha": self.lora_alpha, "lora_dropout": self.lora_dropout, "target_modules": self.target_modules, "per_device_train_batch_size": self.per_device_train_batch_size, "gradient_accumulation_steps": self.gradient_accumulation_steps, "learning_rate": self.learning_rate, "num_train_epochs": self.num_train_epochs, "max_seq_length": self.max_seq_length, "output_dir": self.output_dir, } def load_model(config: BaseModelConfig): """加载模型的通用接口""" if "qwen" in config.model_name.lower(): from finetunex.models.qwen import load_qwen_model return load_qwen_model(config) else: raise ValueError(f"不支持的模型:{config.model_name}") def get_model_config(model_name: str, **kwargs) -> BaseModelConfig: """获取模型配置""" if "qwen" in model_name.lower(): from finetunex.models.qwen import QwenConfig return QwenConfig(model_name=model_name, **kwargs) else: return BaseModelConfig(model_name=model_name, **kwargs)