preprocess_data.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. """
  2. 数据预处理脚本
  3. """
  4. import json
  5. import argparse
  6. import os
  7. import sys
  8. sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
  9. from finetunex.data import preprocess_data, validate_dataset
  10. from finetunex.data.dataset import load_dataset
  11. def convert_to_alpaca_format(input_file, output_file):
  12. """转换为 Alpaca 格式"""
  13. with open(input_file, "r", encoding="utf-8") as f:
  14. data = json.load(f)
  15. alpaca_data = []
  16. for item in data:
  17. alpaca_item = {
  18. "instruction": item.get("instruction", ""),
  19. "input": item.get("input", ""),
  20. "output": item.get("output", ""),
  21. }
  22. alpaca_data.append(alpaca_item)
  23. with open(output_file, "w", encoding="utf-8") as f:
  24. json.dump(alpaca_data, f, ensure_ascii=False, indent=2)
  25. print(f"已转换 {len(alpaca_data)} 条数据到 {output_file}")
  26. def main():
  27. parser = argparse.ArgumentParser(description="数据预处理工具")
  28. parser.add_argument("--input", type=str, required=True, help="输入文件路径")
  29. parser.add_argument("--output", type=str, help="输出文件路径")
  30. parser.add_argument("--template", type=str, default="default",
  31. choices=["default", "alpaca", "chat"], help="模板类型")
  32. parser.add_argument("--validate", action="store_true", help="验证数据集")
  33. args = parser.parse_args()
  34. # 加载数据集
  35. dataset = load_dataset(args.input, format="json")
  36. # 验证数据集
  37. if args.validate:
  38. if not validate_dataset(dataset):
  39. print("数据集验证失败!")
  40. return
  41. # 预处理数据
  42. data_list = list(dataset)
  43. processed = preprocess_data(data_list, template=args.template)
  44. # 保存结果
  45. if args.output:
  46. with open(args.output, "w", encoding="utf-8") as f:
  47. json.dump(processed, f, ensure_ascii=False, indent=2)
  48. print(f"已保存 {len(processed)} 条数据到 {args.output}")
  49. else:
  50. print(f"处理完成,共 {len(processed)} 条数据")
  51. if __name__ == "__main__":
  52. main()