dataset.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. """
  2. 数据集加载和格式化
  3. """
  4. import json
  5. from typing import List, Dict, Any, Optional
  6. from datasets import Dataset, DatasetDict
  7. import torch
  8. from torch.utils.data import Dataset as TorchDataset
  9. def load_dataset(
  10. data_path: str,
  11. format: str = "json",
  12. split: str = "train",
  13. **kwargs
  14. ) -> Dataset:
  15. """
  16. 加载数据集
  17. Args:
  18. data_path: 数据文件路径或数据集名称
  19. format: 数据格式 (json, csv, parquet, text)
  20. split: 数据集划分 (train, validation, test)
  21. Returns:
  22. Dataset 对象
  23. """
  24. print(f"正在加载数据集:{data_path}")
  25. if format == "json":
  26. dataset = Dataset.from_json(data_path, **kwargs)
  27. elif format == "csv":
  28. dataset = Dataset.from_csv(data_path, **kwargs)
  29. elif format == "parquet":
  30. dataset = Dataset.from_parquet(data_path, **kwargs)
  31. elif format == "text":
  32. dataset = Dataset.from_text(data_path, **kwargs)
  33. else:
  34. # 尝试从 HuggingFace 加载
  35. from datasets import load_dataset as hf_load_dataset
  36. dataset = hf_load_dataset(data_path, split=split, **kwargs)
  37. print(f"数据集加载完成!样本数:{len(dataset)}")
  38. return dataset
  39. def format_dataset(
  40. dataset: Dataset,
  41. instruction_column: str = "instruction",
  42. input_column: Optional[str] = "input",
  43. output_column: str = "output",
  44. ) -> Dataset:
  45. """
  46. 格式化数据集为标准指令微调格式
  47. Args:
  48. dataset: 原始数据集
  49. instruction_column: 指令列名
  50. input_column: 输入列名(可选)
  51. output_column: 输出列名
  52. Returns:
  53. 格式化后的数据集
  54. """
  55. def format_example(example):
  56. instruction = example.get(instruction_column, "")
  57. input_text = example.get(input_column, "") if input_column else ""
  58. output_text = example.get(output_column, "")
  59. # 合并 instruction 和 input
  60. if input_text:
  61. text = f"{instruction}\n\n输入:{input_text}"
  62. else:
  63. text = instruction
  64. return {
  65. "instruction": instruction,
  66. "input": input_text,
  67. "output": output_text,
  68. "text": text,
  69. }
  70. formatted_dataset = dataset.map(format_example)
  71. return formatted_dataset
  72. class InstructionDataset(TorchDataset):
  73. """
  74. 指令微调数据集
  75. """
  76. def __init__(
  77. self,
  78. dataset: Dataset,
  79. tokenizer,
  80. max_length: int = 512,
  81. instruction_column: str = "instruction",
  82. input_column: Optional[str] = "input",
  83. output_column: str = "output",
  84. ):
  85. self.dataset = dataset
  86. self.tokenizer = tokenizer
  87. self.max_length = max_length
  88. self.instruction_column = instruction_column
  89. self.input_column = input_column
  90. self.output_column = output_column
  91. def __len__(self):
  92. return len(self.dataset)
  93. def __getitem__(self, idx):
  94. example = self.dataset[idx]
  95. # 构建 prompt
  96. instruction = example[self.instruction_column]
  97. input_text = example.get(self.input_column, "") if self.input_column else ""
  98. output_text = example[self.output_column]
  99. if input_text:
  100. prompt = f"{instruction}\n\n输入:{input_text}\n\n回答:"
  101. else:
  102. prompt = f"{instruction}\n\n回答:"
  103. # 完整文本
  104. full_text = prompt + output_text + self.tokenizer.eos_token
  105. # Tokenize
  106. encoding = self.tokenizer(
  107. full_text,
  108. max_length=self.max_length,
  109. padding="max_length",
  110. truncation=True,
  111. return_tensors="pt",
  112. )
  113. input_ids = encoding.input_ids.squeeze()
  114. attention_mask = encoding.attention_mask.squeeze()
  115. # 创建 labels(将 prompt 部分设为 -100)
  116. prompt_encoding = self.tokenizer(
  117. prompt,
  118. max_length=self.max_length,
  119. truncation=True,
  120. return_tensors="pt",
  121. )
  122. prompt_length = prompt_encoding.input_ids.shape[1]
  123. labels = input_ids.clone()
  124. labels[:prompt_length] = -100
  125. return {
  126. "input_ids": input_ids,
  127. "attention_mask": attention_mask,
  128. "labels": labels,
  129. }