finetune_sft.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import gc
  2. import os
  3. import sys
  4. # 必须在任何 import 之前
  5. os.environ["TRANSFORMERS_OFFLINE"] = "1"
  6. os.environ["HF_DATASETS_OFFLINE"] = "1"
  7. os.environ["CUDA_VISIBLE_DEVICES"] = "1"
  8. os.environ["TORCHDYNAMO_DISABLE"] = "1"
  9. os.environ["TORCH_COMPILE_DISABLE"] = "1"
  10. os.environ["UNSLOTH_DISABLE_COMPILE"] = "1"
  11. os.environ["TRITON_DISABLE_COMPILE"] = "1"
  12. # 禁用 Triton 的 persistent TMA(针对你的警告)
  13. os.environ["TRITON_ENABLE_PERSISTENT_TMA_MATMUL"] = "0"
  14. os.environ["TORCH_COMPILE_DISABLE"] = "1"
  15. os.environ["UNSLOTH_SKIP_VLLM_CHECK"] = "1"
  16. import torch
  17. import torch.nn as nn
  18. # 现在可以安全导入 Unsloth
  19. from unsloth import FastLanguageModel
  20. from unsloth.chat_templates import train_on_responses_only
  21. from trl import SFTTrainer, SFTConfig
  22. from transformers import DataCollatorForLanguageModeling
  23. from load_data import load_bid_data
  24. # 定义缺少的 set_submodule 方法
  25. def set_submodule(model, target, module):
  26. if "." not in target:
  27. setattr(model, target, module)
  28. else:
  29. parent_name, child_name = target.rsplit(".", 1)
  30. parent = model.get_submodule(parent_name)
  31. setattr(parent, child_name, module)
  32. # 强行把这个方法注入到 nn.Module 基类里,一劳永逸
  33. if not hasattr(nn.Module, "set_submodule"):
  34. nn.Module.set_submodule = set_submodule
  35. project_dir = os.path.abspath(os.path.dirname(__file__)) + '/../'
  36. MAX_SEQ_LENGTH = int(1024 * 16)
  37. train_data_path = project_dir + 'qwen_0.8B_lora_bidding_kg/data7_prefix_aug/train_data.jsonl'
  38. dev_data_path = project_dir + 'qwen_0.8B_lora_bidding_kg/data7_prefix/dev_data.jsonl'
  39. train_data_path1 = project_dir + 'qwen_0.8B_lora_bidding_kg/data4_prefix_aug/train_data.jsonl'
  40. dev_data_path1 = project_dir + 'qwen_0.8B_lora_bidding_kg/data4_prefix/dev_data.jsonl'
  41. train_data_path2 = project_dir + 'qwen_0.8B_lora_bidding_kg/data5_prefix/train_data.jsonl'
  42. dev_data_path2 = project_dir + 'qwen_0.8B_lora_bidding_kg/data5_prefix/dev_data.jsonl'
  43. train_data_path3 = project_dir + 'qwen_0.8B_lora_bidding_kg/data6_prefix_aug/train_data.jsonl'
  44. dev_data_path3 = project_dir + 'qwen_0.8B_lora_bidding_kg/data6_prefix/dev_data.jsonl'
  45. PER_DEVICE_TRAIN_BATCH_SIZE = 2 # 单卡批次大小,16G GPU建议2-4
  46. GRADIENT_ACCUMULATION_STEPS = 2 # 梯度累积,弥补批次小的问题
  47. MODEL_NAME = '/home/user/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17/'
  48. OUTPUT_DIR = project_dir + "qwen_0.8B_lora_bidding_kg/lora-sft" # 微调结果保存路径
  49. checkpoint_path = project_dir + "qwen_0.8B_lora_bidding_kg/trainer_output/checkpoint-380"
  50. print('context length', MAX_SEQ_LENGTH)
  51. # 加载模型时显式禁用编译
  52. model, tokenizer = FastLanguageModel.from_pretrained(
  53. model_name=MODEL_NAME,
  54. max_seq_length=MAX_SEQ_LENGTH,
  55. dtype=torch.bfloat16,
  56. load_in_4bit=True,
  57. )
  58. print('model_type', model.config.model_type)
  59. # 添加 LoRA(不使用任何 compile)
  60. r = 16
  61. model = FastLanguageModel.get_peft_model(
  62. model,
  63. r=r,
  64. # target_modules=["q_proj", "v_proj", ],
  65. target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
  66. lora_alpha=r*2,
  67. lora_dropout=0.1,
  68. use_gradient_checkpointing="unsloth",
  69. )
  70. # 手动确保没有 compile
  71. if hasattr(torch, "compile"):
  72. torch.compile = lambda x, *args, **kwargs: x
  73. print(f"Allocated1: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB")
  74. print('\n 加载数据 \n')
  75. train_dataset, dev_dataset = load_bid_data(
  76. [train_data_path, train_data_path1, train_data_path2, train_data_path3],
  77. [dev_data_path, dev_data_path1, dev_data_path2, dev_data_path3]
  78. )
  79. # 只计算answer部分的loss
  80. response_template = "<|im_start|>assistant\n"
  81. response_ids = tokenizer.encode(response_template, add_special_tokens=False)
  82. print(f"Template IDs: {response_ids}")
  83. def truncate_eval_dataset(example):
  84. # 使用 tokenizer 对文本进行截断
  85. tokens = tokenizer(
  86. example["text"], # 或者是你数据中的 key,如 "prompt" + "answer"
  87. truncation=True,
  88. max_length=2048, # 评估强制限制在 2k
  89. add_special_tokens=True,
  90. )
  91. # 将截断后的 token 重新转回文本,或者直接返回 token
  92. return {"text": tokenizer.decode(tokens["input_ids"], skip_special_tokens=False)}
  93. # 处理验证集
  94. dev_dataset = dev_dataset.map(truncate_eval_dataset)
  95. # 创建 Trainer
  96. # 更改源码 unsloth trainer.py is_vlm = False,才能packing
  97. trainer = SFTTrainer(
  98. model=model,
  99. # output_dir=OUTPUT_DIR,
  100. tokenizer=tokenizer,
  101. train_dataset=train_dataset,
  102. eval_dataset=dev_dataset,
  103. max_seq_length=MAX_SEQ_LENGTH,
  104. args=SFTConfig(
  105. learning_rate=2e-5,
  106. output_dir=OUTPUT_DIR,
  107. warmup_steps=5,
  108. num_train_epochs=4,
  109. per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
  110. gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
  111. fp16=False,
  112. bf16=True,
  113. optim="adamw_8bit",
  114. # 关键:禁用所有编译相关选项
  115. torch_empty_cache_steps=1,
  116. dataloader_num_workers=0, # 避免多进程问题
  117. dataset_num_proc=1,
  118. save_steps=1024,
  119. save_strategy="steps",
  120. seed=3407,
  121. save_total_limit=3, # 最多保存3个检查点
  122. weight_decay=0.01,
  123. lr_scheduler_type="cosine",
  124. logging_steps=1024,
  125. eval_strategy="steps",
  126. eval_steps=1024,
  127. per_device_eval_batch_size=1,
  128. eval_accumulation_steps=1,
  129. do_eval=True,
  130. metric_for_best_model="eval_loss",
  131. greater_is_better=False,
  132. prediction_loss_only=True,
  133. load_best_model_at_end=True,
  134. ),
  135. )
  136. num_train_samples = len(trainer.train_dataset)
  137. print(f"*** Packing 后的总样本数: {num_train_samples} ***")
  138. max_len = max(len(x) for x in trainer.train_dataset["input_ids"])
  139. print(f"训练集最大 Token 长度: {max_len}")
  140. # 简单测试代码
  141. sample_batch = next(iter(trainer.get_train_dataloader()))
  142. # 统计一下非 -100 的 token 占比,确保不是 0
  143. non_ignore = (sample_batch["labels"] != -100).sum().item()
  144. total = sample_batch["labels"].numel()
  145. print(f"有效 Loss Token 占比: {non_ignore / total:.2%}, {non_ignore}, {total}")
  146. # 只计算回答loss
  147. trainer = train_on_responses_only(
  148. trainer,
  149. instruction_part="<|im_start|>user\n",
  150. response_part="<|im_start|>assistant\n",
  151. )
  152. print('确认数据', tokenizer.decode(trainer.train_dataset[10]["input_ids"]))
  153. print('只有回答', tokenizer.decode(
  154. [tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[10]["labels"]]).replace(
  155. tokenizer.pad_token, " "))
  156. model.gradient_checkpointing_enable()
  157. trainer.train()
  158. # 继续训练 由于torch2.5有bug,内核升不了2.6,暂时不支持继续训练
  159. # trainer.train(resume_from_checkpoint=checkpoint_path)
  160. model.save_pretrained(OUTPUT_DIR)
  161. tokenizer.save_pretrained(OUTPUT_DIR)