| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- """
- FineTuneX 测试套件
- """
- import unittest
- import os
- import sys
- import json
- class TestDataLoading(unittest.TestCase):
- """测试数据加载"""
-
- def test_load_json_dataset(self):
- """测试加载 JSON 数据集"""
- from finetunex.data.dataset import load_dataset
-
- # 创建临时测试数据
- test_data = [
- {"instruction": "test1", "output": "output1"},
- {"instruction": "test2", "output": "output2"},
- ]
-
- test_file = "test_data.json"
- with open(test_file, "w", encoding="utf-8") as f:
- json.dump(test_data, f, ensure_ascii=False)
-
- try:
- dataset = load_dataset(test_file, format="json")
- self.assertEqual(len(dataset), 2)
- self.assertEqual(dataset[0]["instruction"], "test1")
- finally:
- if os.path.exists(test_file):
- os.remove(test_file)
-
- def test_format_dataset(self):
- """测试数据集格式化"""
- from finetunex.data.dataset import format_dataset
- from datasets import Dataset
-
- test_data = {
- "instruction": ["test1", "test2"],
- "input": ["", "input2"],
- "output": ["output1", "output2"],
- }
-
- dataset = Dataset.from_dict(test_data)
- formatted = format_dataset(dataset)
-
- self.assertEqual(len(formatted), 2)
- self.assertIn("text", formatted.column_names)
- class TestPreprocessing(unittest.TestCase):
- """测试数据预处理"""
-
- def test_create_prompt_default(self):
- """测试创建默认 prompt"""
- from finetunex.data.preprocess import create_prompt
-
- prompt = create_prompt("instruction", "input")
- self.assertIn("instruction", prompt)
- self.assertIn("输入:input", prompt)
-
- def test_create_prompt_no_input(self):
- """测试创建无输入的 prompt"""
- from finetunex.data.preprocess import create_prompt
-
- prompt = create_prompt("instruction")
- self.assertIn("instruction", prompt)
- self.assertNotIn("输入:", prompt)
-
- def test_create_prompt_alpaca(self):
- """测试创建 Alpaca 格式 prompt"""
- from finetunex.data.preprocess import create_prompt
-
- prompt = create_prompt("instruction", "input", template="alpaca")
- self.assertIn("### Instruction:", prompt)
- self.assertIn("### Input:", prompt)
- class TestModelConfig(unittest.TestCase):
- """测试模型配置"""
-
- def test_qwen_config(self):
- """测试 Qwen 配置"""
- from finetunex.models.qwen import QwenConfig
-
- config = QwenConfig(
- model_name="Qwen/Qwen3.5-0.5B",
- lora_r=16,
- lora_alpha=32,
- )
-
- self.assertEqual(config.lora_r, 16)
- self.assertEqual(config.lora_alpha, 32)
- self.assertIsNotNone(config.target_modules)
-
- def test_base_config(self):
- """测试基础配置"""
- from finetunex.models.base import BaseModelConfig
-
- config = BaseModelConfig()
- self.assertEqual(config.lora_r, 16)
- self.assertEqual(config.learning_rate, 2e-4)
- class TestUtils(unittest.TestCase):
- """测试工具函数"""
-
- def test_format_time(self):
- """测试时间格式化"""
- from finetunex.utils.helpers import format_time
-
- self.assertEqual(format_time(30), "30.00s")
- self.assertEqual(format_time(90), "1.50m")
- self.assertEqual(format_time(3600), "1.00h")
-
- def test_count_parameters(self):
- """测试参数统计"""
- from finetunex.utils.helpers import count_parameters
- import torch.nn as nn
-
- model = nn.Linear(10, 5)
- params = count_parameters(model)
-
- self.assertEqual(params["total"], 55) # 10*5 + 5
- self.assertEqual(params["trainable"], 55)
- class TestAPI(unittest.TestCase):
- """测试 API"""
-
- def test_health_check(self):
- """测试健康检查端点"""
- from fastapi.testclient import TestClient
- from finetunex.api.server import app
-
- client = TestClient(app)
- response = client.get("/health")
-
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.json()["status"], "healthy")
-
- def test_root_endpoint(self):
- """测试根端点"""
- from fastapi.testclient import TestClient
- from finetunex.api.server import app
-
- client = TestClient(app)
- response = client.get("/")
-
- self.assertEqual(response.status_code, 200)
- self.assertIn("message", response.json())
- def run_tests():
- """运行所有测试"""
- # 创建测试套件
- loader = unittest.TestLoader()
- suite = unittest.TestSuite()
-
- # 添加测试
- suite.addTests(loader.loadTestsFromTestCase(TestDataLoading))
- suite.addTests(loader.loadTestsFromTestCase(TestPreprocessing))
- suite.addTests(loader.loadTestsFromTestCase(TestModelConfig))
- suite.addTests(loader.loadTestsFromTestCase(TestUtils))
- suite.addTests(loader.loadTestsFromTestCase(TestAPI))
-
- # 运行测试
- runner = unittest.TextTestRunner(verbosity=2)
- result = runner.run(suite)
-
- return result.wasSuccessful()
- if __name__ == "__main__":
- success = run_tests()
- sys.exit(0 if success else 1)
|