qwen3.5_0.8b_local_finetune.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. """
  2. Qwen3.5-0.8B 本地模型微调示例
  3. 这个脚本用于微调本地已下载的 Qwen3.5-0.8B 模型。
  4. 使用方法:
  5. python examples/qwen3.5_0.8b_local_finetune.py
  6. 前提条件:
  7. - 已在本地下载 Qwen3.5-0.8B 模型
  8. - 模型路径配置在 local_model_path 变量中
  9. """
  10. import os
  11. import sys
  12. # 添加项目根目录到 Python 路径
  13. sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
  14. from finetunex.models import QwenConfig, load_qwen_model
  15. from finetunex.data import load_dataset, format_dataset, InstructionDataset
  16. from finetunex.trainer import FineTuneTrainer
  17. from finetunex.utils import setup_environment, get_gpu_info, setup_logger
  18. def main():
  19. # 设置环境和日志
  20. setup_environment(seed=42)
  21. logger = setup_logger("Qwen3.5-0.8B_Local_FineTuning")
  22. logger.info("=" * 60)
  23. logger.info("Qwen3.5-0.8B 本地模型微调")
  24. logger.info("=" * 60)
  25. # ==================== 配置区域 ====================
  26. # 本地模型路径(请根据实际情况修改)
  27. # 可以是绝对路径或相对于项目根目录的路径
  28. local_model_path = "./Qwen3.5-0.8B"
  29. # 如果模型在其他位置,请修改这里
  30. # 例如:
  31. # local_model_path = "/path/to/your/models/Qwen3.5-0.8B"
  32. # local_model_path = "D:\\AI_Models\\Qwen3.5-0.8B"
  33. # 检查模型路径是否存在
  34. if not os.path.exists(local_model_path):
  35. logger.error(f"模型路径不存在:{local_model_path}")
  36. logger.error("请修改脚本中的 local_model_path 变量为正确的模型路径")
  37. sys.exit(1)
  38. logger.info(f"使用本地模型:{local_model_path}")
  39. # 数据集路径
  40. dataset_path = os.path.join(os.path.dirname(__file__), "..", "data", "sample_dataset.json")
  41. # 输出目录
  42. output_dir = "./outputs/qwen3.5-0.8b-finetuned"
  43. # ==================== 模型配置 ====================
  44. config = QwenConfig(
  45. model_name=local_model_path, # 使用本地模型路径
  46. lora_r=16,
  47. lora_alpha=32,
  48. lora_dropout=0.05,
  49. target_modules=[
  50. "q_proj",
  51. "k_proj",
  52. "v_proj",
  53. "o_proj",
  54. "gate_proj",
  55. "up_proj",
  56. "down_proj",
  57. ],
  58. per_device_train_batch_size=1,
  59. gradient_accumulation_steps=4,
  60. learning_rate=2e-4,
  61. num_train_epochs=3,
  62. max_seq_length=512,
  63. output_dir=output_dir,
  64. use_4bit=True, # 使用 4bit 量化以节省显存
  65. trust_remote_code=True,
  66. )
  67. logger.info(f"模型配置:{config.model_name}")
  68. logger.info(f"LoRA 配置:r={config.lora_r}, alpha={config.lora_alpha}")
  69. logger.info(f"训练配置:epochs={config.num_train_epochs}, lr={config.learning_rate}")
  70. logger.info(f"输出目录:{output_dir}")
  71. # ==================== GPU 信息 ====================
  72. gpu_info = get_gpu_info()
  73. if gpu_info["available"]:
  74. logger.info(f"GPU 可用:{gpu_info['device_count']} 个设备")
  75. for i, dev in enumerate(gpu_info["devices"]):
  76. logger.info(f" GPU {i}: {dev['name']} ({dev['max_memory']:.2f} GB)")
  77. else:
  78. logger.warning("GPU 不可用,将使用 CPU 训练(不推荐)")
  79. # ==================== 加载数据 ====================
  80. logger.info("\n加载数据集...")
  81. dataset = load_dataset(dataset_path, format="json")
  82. # 格式化数据集
  83. formatted_dataset = format_dataset(
  84. dataset,
  85. instruction_column="instruction",
  86. input_column="input",
  87. output_column="output",
  88. )
  89. logger.info(f"数据集大小:{len(formatted_dataset)} 样本")
  90. # ==================== 加载模型 ====================
  91. logger.info("\n加载本地模型...")
  92. model, tokenizer, peft_config = load_qwen_model(config)
  93. # ==================== 创建训练数据集 ====================
  94. logger.info("\n创建训练数据集...")
  95. train_dataset = InstructionDataset(
  96. formatted_dataset,
  97. tokenizer,
  98. max_length=config.max_seq_length,
  99. )
  100. # ==================== 创建训练器 ====================
  101. logger.info("\n创建训练器...")
  102. trainer = FineTuneTrainer(
  103. model=model,
  104. tokenizer=tokenizer,
  105. config=config,
  106. train_dataset=train_dataset,
  107. )
  108. # ==================== 设置训练参数 ====================
  109. trainer.setup_training(
  110. output_dir=config.output_dir,
  111. num_train_epochs=config.num_train_epochs,
  112. per_device_train_batch_size=config.per_device_train_batch_size,
  113. gradient_accumulation_steps=config.gradient_accumulation_steps,
  114. learning_rate=config.learning_rate,
  115. warmup_ratio=0.03,
  116. weight_decay=0.01,
  117. logging_steps=10,
  118. save_steps=50,
  119. fp16=True,
  120. )
  121. # ==================== 开始训练 ====================
  122. logger.info("\n" + "=" * 60)
  123. logger.info("开始训练...")
  124. logger.info("=" * 60)
  125. trainer.train()
  126. # ==================== 保存模型 ====================
  127. logger.info("\n保存模型...")
  128. trainer.save_model()
  129. logger.info("=" * 60)
  130. logger.info("训练完成!")
  131. logger.info(f"模型已保存到:{config.output_dir}")
  132. logger.info("=" * 60)
  133. # ==================== 测试推理 ====================
  134. logger.info("\n测试推理...")
  135. test_prompts = [
  136. "请解释什么是机器学习",
  137. "写一首关于春天的诗",
  138. ]
  139. for test_prompt in test_prompts:
  140. logger.info(f"\n输入:{test_prompt}")
  141. inputs = tokenizer(test_prompt, return_tensors="pt")
  142. if gpu_info["available"]:
  143. inputs = inputs.to("cuda")
  144. with torch.no_grad():
  145. outputs = model.generate(
  146. **inputs,
  147. max_new_tokens=150,
  148. temperature=0.7,
  149. do_sample=True,
  150. top_p=0.9,
  151. )
  152. response = tokenizer.decode(outputs[0], skip_special_tokens=True)
  153. logger.info(f"输出:{response}")
  154. logger.info("\n" + "=" * 60)
  155. logger.info("所有任务完成!")
  156. logger.info("使用以下命令进行推理:")
  157. logger.info(f" python scripts/inference.py --model_path {output_dir} --interactive")
  158. logger.info("=" * 60)
  159. if __name__ == "__main__":
  160. import torch
  161. main()