inference.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """
  2. 模型推理脚本
  3. """
  4. import argparse
  5. import os
  6. import sys
  7. import torch
  8. sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
  9. from transformers import AutoModelForCausalLM, AutoTokenizer
  10. from peft import PeftModel
  11. def load_finetuned_model(model_path, base_model_name="Qwen/Qwen3.5-0.5B"):
  12. """加载微调后的模型"""
  13. print(f"加载基础模型:{base_model_name}")
  14. # 加载 tokenizer
  15. tokenizer = AutoTokenizer.from_pretrained(model_path)
  16. # 加载基础模型
  17. base_model = AutoModelForCausalLM.from_pretrained(
  18. base_model_name,
  19. device_map="auto",
  20. torch_dtype=torch.float16,
  21. trust_remote_code=True,
  22. )
  23. # 加载 LoRA 权重
  24. model = PeftModel.from_pretrained(base_model, model_path)
  25. print("模型加载完成!")
  26. return model, tokenizer
  27. def generate_response(model, tokenizer, prompt, max_length=512):
  28. """生成响应"""
  29. inputs = tokenizer(prompt, return_tensors="pt")
  30. if torch.cuda.is_available():
  31. inputs = inputs.to("cuda")
  32. with torch.no_grad():
  33. outputs = model.generate(
  34. **inputs,
  35. max_new_tokens=max_length,
  36. temperature=0.7,
  37. do_sample=True,
  38. top_p=0.9,
  39. )
  40. response = tokenizer.decode(outputs[0], skip_special_tokens=True)
  41. return response
  42. def main():
  43. parser = argparse.ArgumentParser(description="模型推理工具")
  44. parser.add_argument("--model_path", type=str, required=True, help="微调模型路径")
  45. parser.add_argument("--base_model", type=str, default="Qwen/Qwen3.5-0.5B",
  46. help="基础模型名称")
  47. parser.add_argument("--prompt", type=str, help="输入提示")
  48. parser.add_argument("--max_length", type=int, default=512, help="最大生成长度")
  49. parser.add_argument("--interactive", action="store_true", help="交互模式")
  50. args = parser.parse_args()
  51. # 加载模型
  52. model, tokenizer = load_finetuned_model(args.model_path, args.base_model)
  53. if args.interactive:
  54. print("\n进入交互模式(输入 'quit' 退出)\n")
  55. while True:
  56. try:
  57. prompt = input("输入:")
  58. if prompt.lower() == "quit":
  59. break
  60. response = generate_response(model, tokenizer, prompt, args.max_length)
  61. print(f"输出:{response}\n")
  62. except KeyboardInterrupt:
  63. break
  64. elif args.prompt:
  65. response = generate_response(model, tokenizer, args.prompt, args.max_length)
  66. print(f"输出:{response}")
  67. else:
  68. print("请提供 --prompt 或使用 --interactive 模式")
  69. if __name__ == "__main__":
  70. main()