qwen3.5_0.8b_local_finetune.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. """
  2. Qwen3.5-0.8B 本地模型微调示例
  3. 这个脚本用于微调本地已下载的 Qwen3.5-0.8B 模型。
  4. 使用方法:
  5. python examples/qwen3.5_0.8b_local_finetune.py
  6. 前提条件:
  7. - 已在本地下载 Qwen3.5-0.8B 模型
  8. - 模型路径配置在 local_model_path 变量中
  9. """
  10. import os
  11. import sys
  12. if hasattr(os, 'environ'):
  13. os.environ.setdefault("ASCEND_LAUNCH_BLOCKING", "1")
  14. sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
  15. from finetunex.models import QwenConfig, load_qwen_model
  16. from finetunex.data import load_dataset, format_dataset, InstructionDataset
  17. from finetunex.trainer import FineTuneTrainer
  18. from finetunex.utils import setup_environment, get_gpu_info, setup_logger
  19. def main():
  20. # 设置环境和日志
  21. setup_environment(seed=42)
  22. logger = setup_logger("Qwen3.5-0.8B_Local_FineTuning")
  23. logger.info("=" * 60)
  24. logger.info("Qwen3.5-0.8B 本地模型微调")
  25. logger.info("=" * 60)
  26. # ==================== 配置区域 ====================
  27. # 本地模型路径(请根据实际情况修改)
  28. # 可以是绝对路径或相对于项目根目录的路径
  29. local_model_path = "./Qwen3.5-0.8B"
  30. # 如果模型在其他位置,请修改这里
  31. # 例如:
  32. # local_model_path = "/path/to/your/models/Qwen3.5-0.8B"
  33. # local_model_path = "D:\\AI_Models\\Qwen3.5-0.8B"
  34. # 检查模型路径是否存在
  35. if not os.path.exists(local_model_path):
  36. logger.error(f"模型路径不存在:{local_model_path}")
  37. logger.error("请修改脚本中的 local_model_path 变量为正确的模型路径")
  38. sys.exit(1)
  39. logger.info(f"使用本地模型:{local_model_path}")
  40. # 数据集路径
  41. dataset_path = os.path.join(os.path.dirname(__file__), "..", "data", "sample_dataset.json")
  42. # 输出目录
  43. output_dir = "./outputs/qwen3.5-0.8b-finetuned"
  44. # ==================== 模型配置 ====================
  45. config = QwenConfig(
  46. model_name=local_model_path, # 使用本地模型路径
  47. lora_r=16,
  48. lora_alpha=32,
  49. lora_dropout=0.05,
  50. target_modules=[
  51. "q_proj",
  52. "k_proj",
  53. "v_proj",
  54. "o_proj",
  55. "gate_proj",
  56. "up_proj",
  57. "down_proj",
  58. ],
  59. per_device_train_batch_size=1,
  60. gradient_accumulation_steps=4,
  61. learning_rate=2e-4,
  62. num_train_epochs=3,
  63. max_seq_length=512,
  64. output_dir=output_dir,
  65. use_4bit=False, # 华为升腾 NPU 不支持 4bit 量化
  66. trust_remote_code=True,
  67. )
  68. logger.info(f"模型配置:{config.model_name}")
  69. logger.info(f"LoRA 配置:r={config.lora_r}, alpha={config.lora_alpha}")
  70. logger.info(f"训练配置:epochs={config.num_train_epochs}, lr={config.learning_rate}")
  71. logger.info(f"输出目录:{output_dir}")
  72. # ==================== GPU 信息 ====================
  73. gpu_info = get_gpu_info()
  74. if gpu_info["available"]:
  75. logger.info(f"GPU 可用:{gpu_info['device_count']} 个设备")
  76. for i, dev in enumerate(gpu_info["devices"]):
  77. logger.info(f" GPU {i}: {dev['name']} ({dev['max_memory']:.2f} GB)")
  78. else:
  79. logger.warning("GPU 不可用,将使用 CPU 训练(不推荐)")
  80. # ==================== 加载数据 ====================
  81. logger.info("\n加载数据集...")
  82. dataset = load_dataset(dataset_path, format="json")
  83. # 格式化数据集
  84. formatted_dataset = format_dataset(
  85. dataset,
  86. instruction_column="instruction",
  87. input_column="input",
  88. output_column="output",
  89. )
  90. logger.info(f"数据集大小:{len(formatted_dataset)} 样本")
  91. # ==================== 加载模型 ====================
  92. logger.info("\n加载本地模型...")
  93. model, tokenizer, peft_config = load_qwen_model(config)
  94. # ==================== 创建训练数据集 ====================
  95. logger.info("\n创建训练数据集...")
  96. train_dataset = InstructionDataset(
  97. formatted_dataset,
  98. tokenizer,
  99. max_length=config.max_seq_length,
  100. )
  101. # ==================== 创建训练器 ====================
  102. logger.info("\n创建训练器...")
  103. trainer = FineTuneTrainer(
  104. model=model,
  105. tokenizer=tokenizer,
  106. config=config,
  107. train_dataset=train_dataset,
  108. )
  109. # ==================== 设置训练参数 ====================
  110. trainer.setup_training(
  111. output_dir=config.output_dir,
  112. num_train_epochs=config.num_train_epochs,
  113. per_device_train_batch_size=config.per_device_train_batch_size,
  114. gradient_accumulation_steps=config.gradient_accumulation_steps,
  115. learning_rate=config.learning_rate,
  116. warmup_ratio=0.03,
  117. weight_decay=0.01,
  118. logging_steps=10,
  119. save_steps=50,
  120. bf16=True,
  121. fp16=False,
  122. )
  123. # ==================== 开始训练 ====================
  124. logger.info("\n" + "=" * 60)
  125. logger.info("开始训练...")
  126. logger.info("=" * 60)
  127. trainer.train()
  128. # ==================== 保存模型 ====================
  129. logger.info("\n保存模型...")
  130. trainer.save_model()
  131. logger.info("=" * 60)
  132. logger.info("训练完成!")
  133. logger.info(f"模型已保存到:{config.output_dir}")
  134. logger.info("=" * 60)
  135. # ==================== 测试推理 ====================
  136. logger.info("\n测试推理...")
  137. test_prompts = [
  138. "请解释什么是机器学习",
  139. "写一首关于春天的诗",
  140. ]
  141. for test_prompt in test_prompts:
  142. logger.info(f"\n输入:{test_prompt}")
  143. inputs = tokenizer(test_prompt, return_tensors="pt")
  144. if torch.cuda.is_available():
  145. inputs = inputs.to("cuda")
  146. elif hasattr(torch, 'npu') and torch.npu.is_available():
  147. inputs = inputs.to("npu:0")
  148. with torch.no_grad():
  149. outputs = model.generate(
  150. **inputs,
  151. max_new_tokens=150,
  152. temperature=0.7,
  153. do_sample=True,
  154. top_p=0.9,
  155. )
  156. response = tokenizer.decode(outputs[0], skip_special_tokens=True)
  157. logger.info(f"输出:{response}")
  158. logger.info("\n" + "=" * 60)
  159. logger.info("所有任务完成!")
  160. logger.info("使用以下命令进行推理:")
  161. logger.info(f" python scripts/inference.py --model_path {output_dir} --interactive")
  162. logger.info("=" * 60)
  163. if __name__ == "__main__":
  164. import torch
  165. main()