""" 数据集加载和格式化 """ import json from typing import List, Dict, Any, Optional from datasets import Dataset, DatasetDict import torch from torch.utils.data import Dataset as TorchDataset def load_dataset( data_path: str, format: str = "json", split: str = "train", **kwargs ) -> Dataset: """ 加载数据集 Args: data_path: 数据文件路径或数据集名称 format: 数据格式 (json, csv, parquet, text) split: 数据集划分 (train, validation, test) Returns: Dataset 对象 """ print(f"正在加载数据集:{data_path}") if format == "json": dataset = Dataset.from_json(data_path, **kwargs) elif format == "csv": dataset = Dataset.from_csv(data_path, **kwargs) elif format == "parquet": dataset = Dataset.from_parquet(data_path, **kwargs) elif format == "text": dataset = Dataset.from_text(data_path, **kwargs) else: # 尝试从 HuggingFace 加载 from datasets import load_dataset as hf_load_dataset dataset = hf_load_dataset(data_path, split=split, **kwargs) print(f"数据集加载完成!样本数:{len(dataset)}") return dataset def format_dataset( dataset: Dataset, instruction_column: str = "instruction", input_column: Optional[str] = "input", output_column: str = "output", ) -> Dataset: """ 格式化数据集为标准指令微调格式 Args: dataset: 原始数据集 instruction_column: 指令列名 input_column: 输入列名(可选) output_column: 输出列名 Returns: 格式化后的数据集 """ def format_example(example): instruction = example.get(instruction_column, "") input_text = example.get(input_column, "") if input_column else "" output_text = example.get(output_column, "") # 合并 instruction 和 input if input_text: text = f"{instruction}\n\n输入:{input_text}" else: text = instruction return { "instruction": instruction, "input": input_text, "output": output_text, "text": text, } formatted_dataset = dataset.map(format_example) return formatted_dataset class InstructionDataset(TorchDataset): """ 指令微调数据集 """ def __init__( self, dataset: Dataset, tokenizer, max_length: int = 512, instruction_column: str = "instruction", input_column: Optional[str] = "input", output_column: str = "output", ): self.dataset = dataset self.tokenizer = tokenizer self.max_length = max_length self.instruction_column = instruction_column self.input_column = input_column self.output_column = output_column def __len__(self): return len(self.dataset) def __getitem__(self, idx): example = self.dataset[idx] # 构建 prompt instruction = example[self.instruction_column] input_text = example.get(self.input_column, "") if self.input_column else "" output_text = example[self.output_column] if input_text: prompt = f"{instruction}\n\n输入:{input_text}\n\n回答:" else: prompt = f"{instruction}\n\n回答:" # 完整文本 full_text = prompt + output_text + self.tokenizer.eos_token # Tokenize encoding = self.tokenizer( full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt", ) input_ids = encoding.input_ids.squeeze() attention_mask = encoding.attention_mask.squeeze() # 创建 labels(将 prompt 部分设为 -100) prompt_encoding = self.tokenizer( prompt, max_length=self.max_length, truncation=True, return_tensors="pt", ) prompt_length = prompt_encoding.input_ids.shape[1] labels = input_ids.clone() labels[:prompt_length] = -100 return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, }