""" API 路由 """ from fastapi import APIRouter, HTTPException from pydantic import BaseModel from typing import Optional, List, Dict, Any import json router = APIRouter() class TrainingConfig(BaseModel): """训练配置""" model_name: str = "Qwen/Qwen3.5-0.5B" dataset_path: str output_dir: str = "./outputs" num_train_epochs: float = 3.0 learning_rate: float = 2e-4 batch_size: int = 1 lora_r: int = 16 lora_alpha: int = 32 class TrainingResponse(BaseModel): """训练响应""" status: str message: str job_id: Optional[str] = None @router.get("/status") async def get_status(): """获取服务状态""" return { "status": "running", "service": "FineTuneX API", } @router.post("/train", response_model=TrainingResponse) async def start_training(config: TrainingConfig): """ 开始训练任务 Args: config: 训练配置 Returns: 训练响应 """ try: # 这里应该启动训练任务 # 实际实现中会使用异步任务队列 job_id = "job_001" return TrainingResponse( status="started", message=f"训练任务已启动:{config.model_name}", job_id=job_id, ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/train/{job_id}") async def get_training_status(job_id: str): """ 获取训练任务状态 Args: job_id: 任务 ID Returns: 任务状态 """ return { "job_id": job_id, "status": "running", "progress": 0.5, "metrics": { "loss": 0.5, "step": 100, }, } @router.post("/inference") async def inference( model_path: str, prompt: str, max_length: int = 512, temperature: float = 0.7, ): """ 模型推理 Args: model_path: 模型路径 prompt: 输入提示 max_length: 最大生成长度 temperature: 温度参数 Returns: 生成结果 """ return { "prompt": prompt, "generation": "这是一个示例响应", "model": model_path, } @router.get("/models") async def list_models(): """获取支持的模型列表""" return { "models": [ {"name": "Qwen/Qwen3.5-0.5B", "type": "causal_lm"}, {"name": "Qwen/Qwen2.5-0.5B", "type": "causal_lm"}, {"name": "meta-llama/Llama-3.2-1B", "type": "causal_lm"}, ], }