| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- """
- 数据预处理工具
- """
- from typing import Dict, Any, Optional, List
- def create_prompt(
- instruction: str,
- input_text: Optional[str] = None,
- template: str = "default"
- ) -> str:
- """
- 创建提示模板
-
- Args:
- instruction: 指令文本
- input_text: 输入文本(可选)
- template: 模板类型 (default, chat, alpaca)
-
- Returns:
- 格式化后的 prompt
- """
- if template == "alpaca":
- if input_text:
- return (
- "Below is an instruction that describes a task, paired with an input that provides further context. "
- "Write a response that appropriately completes the request.\n\n"
- f"### Instruction:\n{instruction}\n\n"
- f"### Input:\n{input_text}\n\n"
- f"### Response:\n"
- )
- else:
- return (
- "Below is an instruction that describes a task. "
- "Write a response that appropriately completes the request.\n\n"
- f"### Instruction:\n{instruction}\n\n"
- f"### Response:\n"
- )
-
- elif template == "chat":
- if input_text:
- return f"User: {instruction}\n输入:{input_text}\n\nAssistant: "
- else:
- return f"User: {instruction}\n\nAssistant: "
-
- else: # default
- if input_text:
- return f"{instruction}\n\n输入:{input_text}\n\n回答:"
- else:
- return f"{instruction}\n\n回答:"
- def preprocess_data(
- data: List[Dict[str, Any]],
- template: str = "default",
- ) -> List[Dict[str, Any]]:
- """
- 预处理数据
-
- Args:
- data: 原始数据列表
- template: 使用的模板类型
-
- Returns:
- 预处理后的数据
- """
- processed = []
-
- for item in data:
- instruction = item.get("instruction", "")
- input_text = item.get("input", "")
- output_text = item.get("output", "")
-
- # 创建 prompt
- prompt = create_prompt(instruction, input_text, template)
-
- processed_item = {
- "instruction": instruction,
- "input": input_text,
- "output": output_text,
- "prompt": prompt,
- "full_text": prompt + output_text,
- }
-
- processed.append(processed_item)
-
- return processed
- def validate_dataset(dataset) -> bool:
- """
- 验证数据集格式
-
- Args:
- dataset: 数据集对象
-
- Returns:
- 是否有效
- """
- required_columns = {"instruction", "output"}
-
- if not hasattr(dataset, "column_names"):
- print("错误:数据集格式不正确")
- return False
-
- columns = set(dataset.column_names)
-
- if not required_columns.issubset(columns):
- missing = required_columns - columns
- print(f"错误:数据集缺少必要的列:{missing}")
- return False
-
- print(f"数据集验证通过!列:{columns}")
- return True
|