base.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. """
  2. 基础模型配置类
  3. """
  4. from dataclasses import dataclass, field
  5. from typing import List, Optional, Dict, Any
  6. import torch
  7. @dataclass
  8. class BaseModelConfig:
  9. """基础模型配置"""
  10. # 模型相关
  11. model_name: str = "Qwen/Qwen3.5-0.5B"
  12. model_revision: str = "main"
  13. trust_remote_code: bool = True
  14. # LoRA 配置
  15. lora_r: int = 16
  16. lora_alpha: int = 32
  17. lora_dropout: float = 0.05
  18. target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
  19. # 训练配置
  20. per_device_train_batch_size: int = 1
  21. gradient_accumulation_steps: int = 4
  22. learning_rate: float = 2e-4
  23. num_train_epochs: float = 3.0
  24. max_seq_length: int = 512
  25. # 优化器配置
  26. warmup_ratio: float = 0.03
  27. weight_decay: float = 0.01
  28. lr_scheduler_type: str = "cosine"
  29. # 量化配置(仅支持 NVIDIA GPU)
  30. # 注意:华为升腾 NPU 不支持 bitsandbytes 量化
  31. use_4bit: bool = False # 华为升腾默认关闭
  32. bnb_4bit_compute_dtype: str = "bfloat16" # 华为升腾推荐 bf16
  33. bnb_4bit_quant_type: str = "nf4"
  34. use_nested_quant: bool = False
  35. # 其他配置
  36. output_dir: str = "./outputs"
  37. logging_steps: int = 10
  38. save_steps: int = 100
  39. evaluation_strategy: str = "no"
  40. def get_compute_dtype(self) -> torch.dtype:
  41. """获取计算精度"""
  42. if self.bnb_4bit_compute_dtype == "float16":
  43. return torch.float16
  44. elif self.bnb_4bit_compute_dtype == "bfloat16":
  45. return torch.bfloat16
  46. else:
  47. return torch.float32
  48. def to_dict(self) -> Dict[str, Any]:
  49. """转换为字典"""
  50. return {
  51. "model_name": self.model_name,
  52. "lora_r": self.lora_r,
  53. "lora_alpha": self.lora_alpha,
  54. "lora_dropout": self.lora_dropout,
  55. "target_modules": self.target_modules,
  56. "per_device_train_batch_size": self.per_device_train_batch_size,
  57. "gradient_accumulation_steps": self.gradient_accumulation_steps,
  58. "learning_rate": self.learning_rate,
  59. "num_train_epochs": self.num_train_epochs,
  60. "max_seq_length": self.max_seq_length,
  61. "output_dir": self.output_dir,
  62. }
  63. def load_model(config: BaseModelConfig):
  64. """加载模型的通用接口"""
  65. if "qwen" in config.model_name.lower():
  66. from finetunex.models.qwen import load_qwen_model
  67. return load_qwen_model(config)
  68. else:
  69. raise ValueError(f"不支持的模型:{config.model_name}")
  70. def get_model_config(model_name: str, **kwargs) -> BaseModelConfig:
  71. """获取模型配置"""
  72. if "qwen" in model_name.lower():
  73. from finetunex.models.qwen import QwenConfig
  74. return QwenConfig(model_name=model_name, **kwargs)
  75. else:
  76. return BaseModelConfig(model_name=model_name, **kwargs)