""" 模型推理脚本 """ import argparse import os import sys import torch sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel def load_finetuned_model(model_path, base_model_name="Qwen/Qwen3.5-0.5B"): """加载微调后的模型""" print(f"加载基础模型:{base_model_name}") # 加载 tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path) # 加载基础模型 base_model = AutoModelForCausalLM.from_pretrained( base_model_name, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True, ) # 加载 LoRA 权重 model = PeftModel.from_pretrained(base_model, model_path) print("模型加载完成!") return model, tokenizer def generate_response(model, tokenizer, prompt, max_length=512): """生成响应""" inputs = tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = inputs.to("cuda") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_length, temperature=0.7, do_sample=True, top_p=0.9, ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response def main(): parser = argparse.ArgumentParser(description="模型推理工具") parser.add_argument("--model_path", type=str, required=True, help="微调模型路径") parser.add_argument("--base_model", type=str, default="Qwen/Qwen3.5-0.5B", help="基础模型名称") parser.add_argument("--prompt", type=str, help="输入提示") parser.add_argument("--max_length", type=int, default=512, help="最大生成长度") parser.add_argument("--interactive", action="store_true", help="交互模式") args = parser.parse_args() # 加载模型 model, tokenizer = load_finetuned_model(args.model_path, args.base_model) if args.interactive: print("\n进入交互模式(输入 'quit' 退出)\n") while True: try: prompt = input("输入:") if prompt.lower() == "quit": break response = generate_response(model, tokenizer, prompt, args.max_length) print(f"输出:{response}\n") except KeyboardInterrupt: break elif args.prompt: response = generate_response(model, tokenizer, args.prompt, args.max_length) print(f"输出:{response}") else: print("请提供 --prompt 或使用 --interactive 模式") if __name__ == "__main__": main()