| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- """
- 数据预处理脚本
- """
- import json
- import argparse
- import os
- import sys
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
- from finetunex.data import preprocess_data, validate_dataset
- from finetunex.data.dataset import load_dataset
- def convert_to_alpaca_format(input_file, output_file):
- """转换为 Alpaca 格式"""
- with open(input_file, "r", encoding="utf-8") as f:
- data = json.load(f)
-
- alpaca_data = []
- for item in data:
- alpaca_item = {
- "instruction": item.get("instruction", ""),
- "input": item.get("input", ""),
- "output": item.get("output", ""),
- }
- alpaca_data.append(alpaca_item)
-
- with open(output_file, "w", encoding="utf-8") as f:
- json.dump(alpaca_data, f, ensure_ascii=False, indent=2)
-
- print(f"已转换 {len(alpaca_data)} 条数据到 {output_file}")
- def main():
- parser = argparse.ArgumentParser(description="数据预处理工具")
- parser.add_argument("--input", type=str, required=True, help="输入文件路径")
- parser.add_argument("--output", type=str, help="输出文件路径")
- parser.add_argument("--template", type=str, default="default",
- choices=["default", "alpaca", "chat"], help="模板类型")
- parser.add_argument("--validate", action="store_true", help="验证数据集")
-
- args = parser.parse_args()
-
- # 加载数据集
- dataset = load_dataset(args.input, format="json")
-
- # 验证数据集
- if args.validate:
- if not validate_dataset(dataset):
- print("数据集验证失败!")
- return
-
- # 预处理数据
- data_list = list(dataset)
- processed = preprocess_data(data_list, template=args.template)
-
- # 保存结果
- if args.output:
- with open(args.output, "w", encoding="utf-8") as f:
- json.dump(processed, f, ensure_ascii=False, indent=2)
- print(f"已保存 {len(processed)} 条数据到 {args.output}")
- else:
- print(f"处理完成,共 {len(processed)} 条数据")
- if __name__ == "__main__":
- main()
|