routes.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. """
  2. API 路由
  3. """
  4. from fastapi import APIRouter, HTTPException
  5. from pydantic import BaseModel
  6. from typing import Optional, List, Dict, Any
  7. import json
  8. router = APIRouter()
  9. class TrainingConfig(BaseModel):
  10. """训练配置"""
  11. model_name: str = "Qwen/Qwen3.5-0.5B"
  12. dataset_path: str
  13. output_dir: str = "./outputs"
  14. num_train_epochs: float = 3.0
  15. learning_rate: float = 2e-4
  16. batch_size: int = 1
  17. lora_r: int = 16
  18. lora_alpha: int = 32
  19. class TrainingResponse(BaseModel):
  20. """训练响应"""
  21. status: str
  22. message: str
  23. job_id: Optional[str] = None
  24. @router.get("/status")
  25. async def get_status():
  26. """获取服务状态"""
  27. return {
  28. "status": "running",
  29. "service": "FineTuneX API",
  30. }
  31. @router.post("/train", response_model=TrainingResponse)
  32. async def start_training(config: TrainingConfig):
  33. """
  34. 开始训练任务
  35. Args:
  36. config: 训练配置
  37. Returns:
  38. 训练响应
  39. """
  40. try:
  41. # 这里应该启动训练任务
  42. # 实际实现中会使用异步任务队列
  43. job_id = "job_001"
  44. return TrainingResponse(
  45. status="started",
  46. message=f"训练任务已启动:{config.model_name}",
  47. job_id=job_id,
  48. )
  49. except Exception as e:
  50. raise HTTPException(status_code=500, detail=str(e))
  51. @router.get("/train/{job_id}")
  52. async def get_training_status(job_id: str):
  53. """
  54. 获取训练任务状态
  55. Args:
  56. job_id: 任务 ID
  57. Returns:
  58. 任务状态
  59. """
  60. return {
  61. "job_id": job_id,
  62. "status": "running",
  63. "progress": 0.5,
  64. "metrics": {
  65. "loss": 0.5,
  66. "step": 100,
  67. },
  68. }
  69. @router.post("/inference")
  70. async def inference(
  71. model_path: str,
  72. prompt: str,
  73. max_length: int = 512,
  74. temperature: float = 0.7,
  75. ):
  76. """
  77. 模型推理
  78. Args:
  79. model_path: 模型路径
  80. prompt: 输入提示
  81. max_length: 最大生成长度
  82. temperature: 温度参数
  83. Returns:
  84. 生成结果
  85. """
  86. return {
  87. "prompt": prompt,
  88. "generation": "这是一个示例响应",
  89. "model": model_path,
  90. }
  91. @router.get("/models")
  92. async def list_models():
  93. """获取支持的模型列表"""
  94. return {
  95. "models": [
  96. {"name": "Qwen/Qwen3.5-0.5B", "type": "causal_lm"},
  97. {"name": "Qwen/Qwen2.5-0.5B", "type": "causal_lm"},
  98. {"name": "meta-llama/Llama-3.2-1B", "type": "causal_lm"},
  99. ],
  100. }