preprocess.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. """
  2. 数据预处理工具
  3. """
  4. from typing import Dict, Any, Optional, List
  5. def create_prompt(
  6. instruction: str,
  7. input_text: Optional[str] = None,
  8. template: str = "default"
  9. ) -> str:
  10. """
  11. 创建提示模板
  12. Args:
  13. instruction: 指令文本
  14. input_text: 输入文本(可选)
  15. template: 模板类型 (default, chat, alpaca)
  16. Returns:
  17. 格式化后的 prompt
  18. """
  19. if template == "alpaca":
  20. if input_text:
  21. return (
  22. "Below is an instruction that describes a task, paired with an input that provides further context. "
  23. "Write a response that appropriately completes the request.\n\n"
  24. f"### Instruction:\n{instruction}\n\n"
  25. f"### Input:\n{input_text}\n\n"
  26. f"### Response:\n"
  27. )
  28. else:
  29. return (
  30. "Below is an instruction that describes a task. "
  31. "Write a response that appropriately completes the request.\n\n"
  32. f"### Instruction:\n{instruction}\n\n"
  33. f"### Response:\n"
  34. )
  35. elif template == "chat":
  36. if input_text:
  37. return f"User: {instruction}\n输入:{input_text}\n\nAssistant: "
  38. else:
  39. return f"User: {instruction}\n\nAssistant: "
  40. else: # default
  41. if input_text:
  42. return f"{instruction}\n\n输入:{input_text}\n\n回答:"
  43. else:
  44. return f"{instruction}\n\n回答:"
  45. def preprocess_data(
  46. data: List[Dict[str, Any]],
  47. template: str = "default",
  48. ) -> List[Dict[str, Any]]:
  49. """
  50. 预处理数据
  51. Args:
  52. data: 原始数据列表
  53. template: 使用的模板类型
  54. Returns:
  55. 预处理后的数据
  56. """
  57. processed = []
  58. for item in data:
  59. instruction = item.get("instruction", "")
  60. input_text = item.get("input", "")
  61. output_text = item.get("output", "")
  62. # 创建 prompt
  63. prompt = create_prompt(instruction, input_text, template)
  64. processed_item = {
  65. "instruction": instruction,
  66. "input": input_text,
  67. "output": output_text,
  68. "prompt": prompt,
  69. "full_text": prompt + output_text,
  70. }
  71. processed.append(processed_item)
  72. return processed
  73. def validate_dataset(dataset) -> bool:
  74. """
  75. 验证数据集格式
  76. Args:
  77. dataset: 数据集对象
  78. Returns:
  79. 是否有效
  80. """
  81. required_columns = {"instruction", "output"}
  82. if not hasattr(dataset, "column_names"):
  83. print("错误:数据集格式不正确")
  84. return False
  85. columns = set(dataset.column_names)
  86. if not required_columns.issubset(columns):
  87. missing = required_columns - columns
  88. print(f"错误:数据集缺少必要的列:{missing}")
  89. return False
  90. print(f"数据集验证通过!列:{columns}")
  91. return True