base.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. """
  2. 基础模型配置类
  3. """
  4. from dataclasses import dataclass, field
  5. from typing import List, Optional, Dict, Any
  6. import torch
  7. @dataclass
  8. class BaseModelConfig:
  9. """基础模型配置"""
  10. # 模型相关
  11. model_name: str = "Qwen/Qwen3.5-0.5B"
  12. model_revision: str = "main"
  13. trust_remote_code: bool = True
  14. # LoRA 配置
  15. lora_r: int = 16
  16. lora_alpha: int = 32
  17. lora_dropout: float = 0.05
  18. target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
  19. # 训练配置
  20. per_device_train_batch_size: int = 1
  21. gradient_accumulation_steps: int = 4
  22. learning_rate: float = 2e-4
  23. num_train_epochs: float = 3.0
  24. max_seq_length: int = 512
  25. # 优化器配置
  26. warmup_ratio: float = 0.03
  27. weight_decay: float = 0.01
  28. lr_scheduler_type: str = "cosine"
  29. # 量化配置
  30. use_4bit: bool = True
  31. bnb_4bit_compute_dtype: str = "float16"
  32. bnb_4bit_quant_type: str = "nf4"
  33. use_nested_quant: bool = False
  34. # 其他配置
  35. output_dir: str = "./outputs"
  36. logging_steps: int = 10
  37. save_steps: int = 100
  38. evaluation_strategy: str = "no"
  39. def get_compute_dtype(self) -> torch.dtype:
  40. """获取计算精度"""
  41. if self.bnb_4bit_compute_dtype == "float16":
  42. return torch.float16
  43. elif self.bnb_4bit_compute_dtype == "bfloat16":
  44. return torch.bfloat16
  45. else:
  46. return torch.float32
  47. def to_dict(self) -> Dict[str, Any]:
  48. """转换为字典"""
  49. return {
  50. "model_name": self.model_name,
  51. "lora_r": self.lora_r,
  52. "lora_alpha": self.lora_alpha,
  53. "lora_dropout": self.lora_dropout,
  54. "target_modules": self.target_modules,
  55. "per_device_train_batch_size": self.per_device_train_batch_size,
  56. "gradient_accumulation_steps": self.gradient_accumulation_steps,
  57. "learning_rate": self.learning_rate,
  58. "num_train_epochs": self.num_train_epochs,
  59. "max_seq_length": self.max_seq_length,
  60. "output_dir": self.output_dir,
  61. }
  62. def load_model(config: BaseModelConfig):
  63. """加载模型的通用接口"""
  64. if "qwen" in config.model_name.lower():
  65. from finetunex.models.qwen import load_qwen_model
  66. return load_qwen_model(config)
  67. else:
  68. raise ValueError(f"不支持的模型:{config.model_name}")
  69. def get_model_config(model_name: str, **kwargs) -> BaseModelConfig:
  70. """获取模型配置"""
  71. if "qwen" in model_name.lower():
  72. from finetunex.models.qwen import QwenConfig
  73. return QwenConfig(model_name=model_name, **kwargs)
  74. else:
  75. return BaseModelConfig(model_name=model_name, **kwargs)