| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- """
- 数据集加载和格式化
- """
- import json
- from typing import List, Dict, Any, Optional
- from datasets import Dataset, DatasetDict
- import torch
- from torch.utils.data import Dataset as TorchDataset
- def load_dataset(
- data_path: str,
- format: str = "json",
- split: str = "train",
- **kwargs
- ) -> Dataset:
- """
- 加载数据集
-
- Args:
- data_path: 数据文件路径或数据集名称
- format: 数据格式 (json, csv, parquet, text)
- split: 数据集划分 (train, validation, test)
-
- Returns:
- Dataset 对象
- """
- print(f"正在加载数据集:{data_path}")
-
- if format == "json":
- dataset = Dataset.from_json(data_path, **kwargs)
- elif format == "csv":
- dataset = Dataset.from_csv(data_path, **kwargs)
- elif format == "parquet":
- dataset = Dataset.from_parquet(data_path, **kwargs)
- elif format == "text":
- dataset = Dataset.from_text(data_path, **kwargs)
- else:
- # 尝试从 HuggingFace 加载
- from datasets import load_dataset as hf_load_dataset
- dataset = hf_load_dataset(data_path, split=split, **kwargs)
-
- print(f"数据集加载完成!样本数:{len(dataset)}")
- return dataset
- def format_dataset(
- dataset: Dataset,
- instruction_column: str = "instruction",
- input_column: Optional[str] = "input",
- output_column: str = "output",
- ) -> Dataset:
- """
- 格式化数据集为标准指令微调格式
-
- Args:
- dataset: 原始数据集
- instruction_column: 指令列名
- input_column: 输入列名(可选)
- output_column: 输出列名
-
- Returns:
- 格式化后的数据集
- """
- def format_example(example):
- instruction = example.get(instruction_column, "")
- input_text = example.get(input_column, "") if input_column else ""
- output_text = example.get(output_column, "")
-
- # 合并 instruction 和 input
- if input_text:
- text = f"{instruction}\n\n输入:{input_text}"
- else:
- text = instruction
-
- return {
- "instruction": instruction,
- "input": input_text,
- "output": output_text,
- "text": text,
- }
-
- formatted_dataset = dataset.map(format_example)
- return formatted_dataset
- class InstructionDataset(TorchDataset):
- """
- 指令微调数据集
- """
-
- def __init__(
- self,
- dataset: Dataset,
- tokenizer,
- max_length: int = 512,
- instruction_column: str = "instruction",
- input_column: Optional[str] = "input",
- output_column: str = "output",
- ):
- self.dataset = dataset
- self.tokenizer = tokenizer
- self.max_length = max_length
- self.instruction_column = instruction_column
- self.input_column = input_column
- self.output_column = output_column
-
- def __len__(self):
- return len(self.dataset)
-
- def __getitem__(self, idx):
- example = self.dataset[idx]
-
- # 构建 prompt
- instruction = example[self.instruction_column]
- input_text = example.get(self.input_column, "") if self.input_column else ""
- output_text = example[self.output_column]
-
- if input_text:
- prompt = f"{instruction}\n\n输入:{input_text}\n\n回答:"
- else:
- prompt = f"{instruction}\n\n回答:"
-
- # 完整文本
- full_text = prompt + output_text + self.tokenizer.eos_token
-
- # Tokenize
- encoding = self.tokenizer(
- full_text,
- max_length=self.max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt",
- )
-
- input_ids = encoding.input_ids.squeeze()
- attention_mask = encoding.attention_mask.squeeze()
-
- # 创建 labels(将 prompt 部分设为 -100)
- prompt_encoding = self.tokenizer(
- prompt,
- max_length=self.max_length,
- truncation=True,
- return_tensors="pt",
- )
- prompt_length = prompt_encoding.input_ids.shape[1]
-
- labels = input_ids.clone()
- labels[:prompt_length] = -100
-
- return {
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- "labels": labels,
- }
|