""" 数据预处理脚本 """ import json import argparse import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from finetunex.data import preprocess_data, validate_dataset from finetunex.data.dataset import load_dataset def convert_to_alpaca_format(input_file, output_file): """转换为 Alpaca 格式""" with open(input_file, "r", encoding="utf-8") as f: data = json.load(f) alpaca_data = [] for item in data: alpaca_item = { "instruction": item.get("instruction", ""), "input": item.get("input", ""), "output": item.get("output", ""), } alpaca_data.append(alpaca_item) with open(output_file, "w", encoding="utf-8") as f: json.dump(alpaca_data, f, ensure_ascii=False, indent=2) print(f"已转换 {len(alpaca_data)} 条数据到 {output_file}") def main(): parser = argparse.ArgumentParser(description="数据预处理工具") parser.add_argument("--input", type=str, required=True, help="输入文件路径") parser.add_argument("--output", type=str, help="输出文件路径") parser.add_argument("--template", type=str, default="default", choices=["default", "alpaca", "chat"], help="模板类型") parser.add_argument("--validate", action="store_true", help="验证数据集") args = parser.parse_args() # 加载数据集 dataset = load_dataset(args.input, format="json") # 验证数据集 if args.validate: if not validate_dataset(dataset): print("数据集验证失败!") return # 预处理数据 data_list = list(dataset) processed = preprocess_data(data_list, template=args.template) # 保存结果 if args.output: with open(args.output, "w", encoding="utf-8") as f: json.dump(processed, f, ensure_ascii=False, indent=2) print(f"已保存 {len(processed)} 条数据到 {args.output}") else: print(f"处理完成,共 {len(processed)} 条数据") if __name__ == "__main__": main()