""" FineTuneX 测试套件 """ import unittest import os import sys import json class TestDataLoading(unittest.TestCase): """测试数据加载""" def test_load_json_dataset(self): """测试加载 JSON 数据集""" from finetunex.data.dataset import load_dataset # 创建临时测试数据 test_data = [ {"instruction": "test1", "output": "output1"}, {"instruction": "test2", "output": "output2"}, ] test_file = "test_data.json" with open(test_file, "w", encoding="utf-8") as f: json.dump(test_data, f, ensure_ascii=False) try: dataset = load_dataset(test_file, format="json") self.assertEqual(len(dataset), 2) self.assertEqual(dataset[0]["instruction"], "test1") finally: if os.path.exists(test_file): os.remove(test_file) def test_format_dataset(self): """测试数据集格式化""" from finetunex.data.dataset import format_dataset from datasets import Dataset test_data = { "instruction": ["test1", "test2"], "input": ["", "input2"], "output": ["output1", "output2"], } dataset = Dataset.from_dict(test_data) formatted = format_dataset(dataset) self.assertEqual(len(formatted), 2) self.assertIn("text", formatted.column_names) class TestPreprocessing(unittest.TestCase): """测试数据预处理""" def test_create_prompt_default(self): """测试创建默认 prompt""" from finetunex.data.preprocess import create_prompt prompt = create_prompt("instruction", "input") self.assertIn("instruction", prompt) self.assertIn("输入:input", prompt) def test_create_prompt_no_input(self): """测试创建无输入的 prompt""" from finetunex.data.preprocess import create_prompt prompt = create_prompt("instruction") self.assertIn("instruction", prompt) self.assertNotIn("输入:", prompt) def test_create_prompt_alpaca(self): """测试创建 Alpaca 格式 prompt""" from finetunex.data.preprocess import create_prompt prompt = create_prompt("instruction", "input", template="alpaca") self.assertIn("### Instruction:", prompt) self.assertIn("### Input:", prompt) class TestModelConfig(unittest.TestCase): """测试模型配置""" def test_qwen_config(self): """测试 Qwen 配置""" from finetunex.models.qwen import QwenConfig config = QwenConfig( model_name="Qwen/Qwen3.5-0.5B", lora_r=16, lora_alpha=32, ) self.assertEqual(config.lora_r, 16) self.assertEqual(config.lora_alpha, 32) self.assertIsNotNone(config.target_modules) def test_base_config(self): """测试基础配置""" from finetunex.models.base import BaseModelConfig config = BaseModelConfig() self.assertEqual(config.lora_r, 16) self.assertEqual(config.learning_rate, 2e-4) class TestUtils(unittest.TestCase): """测试工具函数""" def test_format_time(self): """测试时间格式化""" from finetunex.utils.helpers import format_time self.assertEqual(format_time(30), "30.00s") self.assertEqual(format_time(90), "1.50m") self.assertEqual(format_time(3600), "1.00h") def test_count_parameters(self): """测试参数统计""" from finetunex.utils.helpers import count_parameters import torch.nn as nn model = nn.Linear(10, 5) params = count_parameters(model) self.assertEqual(params["total"], 55) # 10*5 + 5 self.assertEqual(params["trainable"], 55) class TestAPI(unittest.TestCase): """测试 API""" def test_health_check(self): """测试健康检查端点""" from fastapi.testclient import TestClient from finetunex.api.server import app client = TestClient(app) response = client.get("/health") self.assertEqual(response.status_code, 200) self.assertEqual(response.json()["status"], "healthy") def test_root_endpoint(self): """测试根端点""" from fastapi.testclient import TestClient from finetunex.api.server import app client = TestClient(app) response = client.get("/") self.assertEqual(response.status_code, 200) self.assertIn("message", response.json()) def run_tests(): """运行所有测试""" # 创建测试套件 loader = unittest.TestLoader() suite = unittest.TestSuite() # 添加测试 suite.addTests(loader.loadTestsFromTestCase(TestDataLoading)) suite.addTests(loader.loadTestsFromTestCase(TestPreprocessing)) suite.addTests(loader.loadTestsFromTestCase(TestModelConfig)) suite.addTests(loader.loadTestsFromTestCase(TestUtils)) suite.addTests(loader.loadTestsFromTestCase(TestAPI)) # 运行测试 runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) return result.wasSuccessful() if __name__ == "__main__": success = run_tests() sys.exit(0 if success else 1)