| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- """
- 模型推理脚本
- """
- import argparse
- import os
- import sys
- import torch
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
- from transformers import AutoModelForCausalLM, AutoTokenizer
- from peft import PeftModel
- def load_finetuned_model(model_path, base_model_name="Qwen/Qwen3.5-0.5B"):
- """加载微调后的模型"""
- print(f"加载基础模型:{base_model_name}")
-
- # 加载 tokenizer
- tokenizer = AutoTokenizer.from_pretrained(model_path)
-
- # 加载基础模型
- base_model = AutoModelForCausalLM.from_pretrained(
- base_model_name,
- device_map="auto",
- torch_dtype=torch.float16,
- trust_remote_code=True,
- )
-
- # 加载 LoRA 权重
- model = PeftModel.from_pretrained(base_model, model_path)
-
- print("模型加载完成!")
- return model, tokenizer
- def generate_response(model, tokenizer, prompt, max_length=512):
- """生成响应"""
- inputs = tokenizer(prompt, return_tensors="pt")
-
- if torch.cuda.is_available():
- inputs = inputs.to("cuda")
-
- with torch.no_grad():
- outputs = model.generate(
- **inputs,
- max_new_tokens=max_length,
- temperature=0.7,
- do_sample=True,
- top_p=0.9,
- )
-
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
- return response
- def main():
- parser = argparse.ArgumentParser(description="模型推理工具")
- parser.add_argument("--model_path", type=str, required=True, help="微调模型路径")
- parser.add_argument("--base_model", type=str, default="Qwen/Qwen3.5-0.5B",
- help="基础模型名称")
- parser.add_argument("--prompt", type=str, help="输入提示")
- parser.add_argument("--max_length", type=int, default=512, help="最大生成长度")
- parser.add_argument("--interactive", action="store_true", help="交互模式")
-
- args = parser.parse_args()
-
- # 加载模型
- model, tokenizer = load_finetuned_model(args.model_path, args.base_model)
-
- if args.interactive:
- print("\n进入交互模式(输入 'quit' 退出)\n")
- while True:
- try:
- prompt = input("输入:")
- if prompt.lower() == "quit":
- break
-
- response = generate_response(model, tokenizer, prompt, args.max_length)
- print(f"输出:{response}\n")
- except KeyboardInterrupt:
- break
- elif args.prompt:
- response = generate_response(model, tokenizer, args.prompt, args.max_length)
- print(f"输出:{response}")
- else:
- print("请提供 --prompt 或使用 --interactive 模式")
- if __name__ == "__main__":
- main()
|