test_all.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. """
  2. FineTuneX 测试套件
  3. """
  4. import unittest
  5. import os
  6. import sys
  7. import json
  8. class TestDataLoading(unittest.TestCase):
  9. """测试数据加载"""
  10. def test_load_json_dataset(self):
  11. """测试加载 JSON 数据集"""
  12. from finetunex.data.dataset import load_dataset
  13. # 创建临时测试数据
  14. test_data = [
  15. {"instruction": "test1", "output": "output1"},
  16. {"instruction": "test2", "output": "output2"},
  17. ]
  18. test_file = "test_data.json"
  19. with open(test_file, "w", encoding="utf-8") as f:
  20. json.dump(test_data, f, ensure_ascii=False)
  21. try:
  22. dataset = load_dataset(test_file, format="json")
  23. self.assertEqual(len(dataset), 2)
  24. self.assertEqual(dataset[0]["instruction"], "test1")
  25. finally:
  26. if os.path.exists(test_file):
  27. os.remove(test_file)
  28. def test_format_dataset(self):
  29. """测试数据集格式化"""
  30. from finetunex.data.dataset import format_dataset
  31. from datasets import Dataset
  32. test_data = {
  33. "instruction": ["test1", "test2"],
  34. "input": ["", "input2"],
  35. "output": ["output1", "output2"],
  36. }
  37. dataset = Dataset.from_dict(test_data)
  38. formatted = format_dataset(dataset)
  39. self.assertEqual(len(formatted), 2)
  40. self.assertIn("text", formatted.column_names)
  41. class TestPreprocessing(unittest.TestCase):
  42. """测试数据预处理"""
  43. def test_create_prompt_default(self):
  44. """测试创建默认 prompt"""
  45. from finetunex.data.preprocess import create_prompt
  46. prompt = create_prompt("instruction", "input")
  47. self.assertIn("instruction", prompt)
  48. self.assertIn("输入:input", prompt)
  49. def test_create_prompt_no_input(self):
  50. """测试创建无输入的 prompt"""
  51. from finetunex.data.preprocess import create_prompt
  52. prompt = create_prompt("instruction")
  53. self.assertIn("instruction", prompt)
  54. self.assertNotIn("输入:", prompt)
  55. def test_create_prompt_alpaca(self):
  56. """测试创建 Alpaca 格式 prompt"""
  57. from finetunex.data.preprocess import create_prompt
  58. prompt = create_prompt("instruction", "input", template="alpaca")
  59. self.assertIn("### Instruction:", prompt)
  60. self.assertIn("### Input:", prompt)
  61. class TestModelConfig(unittest.TestCase):
  62. """测试模型配置"""
  63. def test_qwen_config(self):
  64. """测试 Qwen 配置"""
  65. from finetunex.models.qwen import QwenConfig
  66. config = QwenConfig(
  67. model_name="Qwen/Qwen3.5-0.5B",
  68. lora_r=16,
  69. lora_alpha=32,
  70. )
  71. self.assertEqual(config.lora_r, 16)
  72. self.assertEqual(config.lora_alpha, 32)
  73. self.assertIsNotNone(config.target_modules)
  74. def test_base_config(self):
  75. """测试基础配置"""
  76. from finetunex.models.base import BaseModelConfig
  77. config = BaseModelConfig()
  78. self.assertEqual(config.lora_r, 16)
  79. self.assertEqual(config.learning_rate, 2e-4)
  80. class TestUtils(unittest.TestCase):
  81. """测试工具函数"""
  82. def test_format_time(self):
  83. """测试时间格式化"""
  84. from finetunex.utils.helpers import format_time
  85. self.assertEqual(format_time(30), "30.00s")
  86. self.assertEqual(format_time(90), "1.50m")
  87. self.assertEqual(format_time(3600), "1.00h")
  88. def test_count_parameters(self):
  89. """测试参数统计"""
  90. from finetunex.utils.helpers import count_parameters
  91. import torch.nn as nn
  92. model = nn.Linear(10, 5)
  93. params = count_parameters(model)
  94. self.assertEqual(params["total"], 55) # 10*5 + 5
  95. self.assertEqual(params["trainable"], 55)
  96. class TestAPI(unittest.TestCase):
  97. """测试 API"""
  98. def test_health_check(self):
  99. """测试健康检查端点"""
  100. from fastapi.testclient import TestClient
  101. from finetunex.api.server import app
  102. client = TestClient(app)
  103. response = client.get("/health")
  104. self.assertEqual(response.status_code, 200)
  105. self.assertEqual(response.json()["status"], "healthy")
  106. def test_root_endpoint(self):
  107. """测试根端点"""
  108. from fastapi.testclient import TestClient
  109. from finetunex.api.server import app
  110. client = TestClient(app)
  111. response = client.get("/")
  112. self.assertEqual(response.status_code, 200)
  113. self.assertIn("message", response.json())
  114. def run_tests():
  115. """运行所有测试"""
  116. # 创建测试套件
  117. loader = unittest.TestLoader()
  118. suite = unittest.TestSuite()
  119. # 添加测试
  120. suite.addTests(loader.loadTestsFromTestCase(TestDataLoading))
  121. suite.addTests(loader.loadTestsFromTestCase(TestPreprocessing))
  122. suite.addTests(loader.loadTestsFromTestCase(TestModelConfig))
  123. suite.addTests(loader.loadTestsFromTestCase(TestUtils))
  124. suite.addTests(loader.loadTestsFromTestCase(TestAPI))
  125. # 运行测试
  126. runner = unittest.TextTestRunner(verbosity=2)
  127. result = runner.run(suite)
  128. return result.wasSuccessful()
  129. if __name__ == "__main__":
  130. success = run_tests()
  131. sys.exit(0 if success else 1)