| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- """
- Qwen3.5 0.8B 模型微调示例
- 这个脚本演示了如何使用 FineTuneX 框架微调 Qwen3.5 模型。
- 注意:Qwen3.5-0.8B 尚未正式发布,这里使用 Qwen3.5-0.5B 作为示例。
- 使用方法:
- python examples/qwen3.5_0.8b_finetune.py
- """
- import os
- import sys
- # 添加项目根目录到 Python 路径
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
- from finetunex.models import QwenConfig, load_qwen_model
- from finetunex.data import load_dataset, format_dataset, InstructionDataset
- from finetunex.trainer import FineTuneTrainer
- from finetunex.utils import setup_environment, get_gpu_info, setup_logger
- def main():
- # 设置环境和日志
- setup_environment(seed=42)
- logger = setup_logger("Qwen3.5_FineTuning")
-
- logger.info("=" * 60)
- logger.info("Qwen3.5 0.8B 微调示例")
- logger.info("=" * 60)
-
- # 显示 GPU 信息
- gpu_info = get_gpu_info()
- if gpu_info["available"]:
- logger.info(f"GPU 可用:{gpu_info['device_count']} 个设备")
- for i, dev in enumerate(gpu_info["devices"]):
- logger.info(f" GPU {i}: {dev['name']} ({dev['max_memory']:.2f} GB)")
- else:
- logger.warning("GPU 不可用,将使用 CPU 训练(不推荐)")
-
- # 1. 配置模型
- # 注意:Qwen3.5-0.8B 尚未发布,使用 Qwen3.5-0.5B 替代
- # 如果 Qwen3.5-0.8B 发布后,可以改为 "Qwen/Qwen3.5-0.8B"
- config = QwenConfig(
- model_name="Qwen/Qwen3.5-0.5B", # 或 "Qwen/Qwen3.5-0.8B" 当可用时
- lora_r=16,
- lora_alpha=32,
- lora_dropout=0.05,
- target_modules=[
- "q_proj",
- "k_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "up_proj",
- "down_proj",
- ],
- per_device_train_batch_size=1,
- gradient_accumulation_steps=4,
- learning_rate=2e-4,
- num_train_epochs=3,
- max_seq_length=512,
- output_dir="./outputs/qwen3.5-0.5b-finetuned",
- use_4bit=False, # 华为升腾 NPU 不支持 4bit 量化
- )
-
- logger.info(f"模型配置:{config.model_name}")
- logger.info(f"LoRA 配置:r={config.lora_r}, alpha={config.lora_alpha}")
- logger.info(f"训练配置:epochs={config.num_train_epochs}, lr={config.learning_rate}")
-
- # 2. 加载数据集
- dataset_path = os.path.join(os.path.dirname(__file__), "..", "data", "sample_dataset.json")
- dataset = load_dataset(dataset_path, format="json")
-
- # 格式化数据集
- formatted_dataset = format_dataset(
- dataset,
- instruction_column="instruction",
- input_column="input",
- output_column="output",
- )
-
- logger.info(f"数据集大小:{len(formatted_dataset)} 样本")
-
- # 3. 加载模型和 tokenizer
- model, tokenizer, peft_config = load_qwen_model(config)
-
- # 4. 创建训练数据集
- train_dataset = InstructionDataset(
- formatted_dataset,
- tokenizer,
- max_length=config.max_seq_length,
- )
-
- # 5. 创建训练器
- trainer = FineTuneTrainer(
- model=model,
- tokenizer=tokenizer,
- config=config,
- train_dataset=train_dataset,
- )
-
- # 6. 设置训练参数
- trainer.setup_training(
- output_dir=config.output_dir,
- num_train_epochs=config.num_train_epochs,
- per_device_train_batch_size=config.per_device_train_batch_size,
- gradient_accumulation_steps=config.gradient_accumulation_steps,
- learning_rate=config.learning_rate,
- warmup_ratio=0.03,
- weight_decay=0.01,
- logging_steps=10,
- save_steps=50,
- bf16=True,
- fp16=False,
- )
-
- # 7. 开始训练
- logger.info("开始训练...")
- trainer.train()
-
- # 8. 保存模型
- trainer.save_model()
-
- logger.info("=" * 60)
- logger.info("训练完成!")
- logger.info(f"模型已保存到:{config.output_dir}")
- logger.info("=" * 60)
-
- # 9. 测试推理(可选)
- logger.info("\n测试推理...")
- test_prompt = "请解释什么是人工智能"
-
- inputs = tokenizer(test_prompt, return_tensors="pt")
- if torch.cuda.is_available():
- inputs = inputs.to("cuda")
- elif hasattr(torch, 'npu') and torch.npu.is_available():
- inputs = inputs.to("npu:0")
-
- with torch.no_grad():
- outputs = model.generate(
- **inputs,
- max_new_tokens=100,
- temperature=0.7,
- do_sample=True,
- )
-
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
- logger.info(f"输入:{test_prompt}")
- logger.info(f"输出:{response}")
- if __name__ == "__main__":
- import torch
- main()
|