| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- """
- API 路由
- """
- from fastapi import APIRouter, HTTPException
- from pydantic import BaseModel
- from typing import Optional, List, Dict, Any
- import json
- router = APIRouter()
- class TrainingConfig(BaseModel):
- """训练配置"""
- model_name: str = "Qwen/Qwen3.5-0.5B"
- dataset_path: str
- output_dir: str = "./outputs"
- num_train_epochs: float = 3.0
- learning_rate: float = 2e-4
- batch_size: int = 1
- lora_r: int = 16
- lora_alpha: int = 32
- class TrainingResponse(BaseModel):
- """训练响应"""
- status: str
- message: str
- job_id: Optional[str] = None
- @router.get("/status")
- async def get_status():
- """获取服务状态"""
- return {
- "status": "running",
- "service": "FineTuneX API",
- }
- @router.post("/train", response_model=TrainingResponse)
- async def start_training(config: TrainingConfig):
- """
- 开始训练任务
-
- Args:
- config: 训练配置
-
- Returns:
- 训练响应
- """
- try:
- # 这里应该启动训练任务
- # 实际实现中会使用异步任务队列
- job_id = "job_001"
-
- return TrainingResponse(
- status="started",
- message=f"训练任务已启动:{config.model_name}",
- job_id=job_id,
- )
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
- @router.get("/train/{job_id}")
- async def get_training_status(job_id: str):
- """
- 获取训练任务状态
-
- Args:
- job_id: 任务 ID
-
- Returns:
- 任务状态
- """
- return {
- "job_id": job_id,
- "status": "running",
- "progress": 0.5,
- "metrics": {
- "loss": 0.5,
- "step": 100,
- },
- }
- @router.post("/inference")
- async def inference(
- model_path: str,
- prompt: str,
- max_length: int = 512,
- temperature: float = 0.7,
- ):
- """
- 模型推理
-
- Args:
- model_path: 模型路径
- prompt: 输入提示
- max_length: 最大生成长度
- temperature: 温度参数
-
- Returns:
- 生成结果
- """
- return {
- "prompt": prompt,
- "generation": "这是一个示例响应",
- "model": model_path,
- }
- @router.get("/models")
- async def list_models():
- """获取支持的模型列表"""
- return {
- "models": [
- {"name": "Qwen/Qwen3.5-0.5B", "type": "causal_lm"},
- {"name": "Qwen/Qwen2.5-0.5B", "type": "causal_lm"},
- {"name": "meta-llama/Llama-3.2-1B", "type": "causal_lm"},
- ],
- }
|