load_data.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import json
  2. from datasets import Dataset
  3. def load_bid_data(train_path_list, dev_path_list):
  4. """加载招投标数据"""
  5. # 加载训练集
  6. train_data = []
  7. for train_path in train_path_list:
  8. with open(train_path, 'r', encoding='utf-8') as f:
  9. for line in f:
  10. line = json.loads(line.strip())
  11. # 构造Qwen的输入格式(Chat版格式)
  12. prompt = f"<|im_start|>user\n{line['input']}\n{line['instruction']}<|im_end|>\n<|im_start|>assistant\n{line['output']}<|im_end|>"
  13. train_data.append({"text": prompt})
  14. # 加载验证集
  15. dev_data = []
  16. for dev_path in dev_path_list:
  17. with open(dev_path, 'r', encoding='utf-8') as f:
  18. for line in f:
  19. line = json.loads(line.strip())
  20. prompt = f"<|im_start|>user\n{line['input']}\n{line['instruction']}<|im_end|>\n<|im_start|>assistant\n{line['output']}<|im_end|>"
  21. dev_data.append({"text": prompt})
  22. # dev_data = dev_data[:10]
  23. print('len(train_data)', len(train_data))
  24. print('len(dev_data)', len(dev_data))
  25. # 转为Dataset格式
  26. train_dataset = Dataset.from_list(train_data)
  27. dev_dataset = Dataset.from_list(dev_data)
  28. return train_dataset, dev_dataset
  29. def load_bid_data_dpo(train_path_list, dev_path_list):
  30. """加载招投标数据"""
  31. # 加载训练集
  32. train_data = []
  33. for train_path in train_path_list:
  34. with open(train_path, 'r', encoding='utf-8') as f:
  35. for line in f:
  36. line = json.loads(line.strip())
  37. train_data.append(
  38. {
  39. "text_prompt": line['prompt'], "chosen": line['chosen'],
  40. "rejected": line['rejected'],
  41. }
  42. )
  43. # 加载验证集
  44. dev_data = []
  45. for dev_path in dev_path_list:
  46. with open(dev_path, 'r', encoding='utf-8') as f:
  47. for line in f:
  48. line = json.loads(line.strip())
  49. dev_data.append(
  50. {
  51. "text_prompt": line['prompt'], "chosen": line['chosen'],
  52. "rejected": line['rejected'],
  53. }
  54. )
  55. # dev_data = dev_data[:10]
  56. print('len(train_data)', len(train_data))
  57. print('len(dev_data)', len(dev_data))
  58. # 转为Dataset格式
  59. train_dataset = Dataset.from_list(train_data)
  60. dev_dataset = Dataset.from_list(dev_data)
  61. # # 核心:添加这一列,内容全是 None
  62. # train_dataset = train_dataset.add_column("images", [None] * len(train_dataset))
  63. # dev_dataset = dev_dataset.add_column("images", [None] * len(dev_dataset))
  64. return train_dataset, dev_dataset
  65. def load_bid_data_kto(train_path_list, dev_path_list):
  66. train_data = []
  67. for train_path in train_path_list:
  68. with open(train_path, 'r', encoding='utf-8') as f:
  69. for line in f:
  70. line = json.loads(line.strip())
  71. prompt = line["prompt"]
  72. chosen = line["chosen"]
  73. rejected = line["rejected"]
  74. # 👇 正确:拆成两条 KTO 样本
  75. train_data.append({"prompt": prompt, "completion": chosen, "label": True})
  76. train_data.append({"prompt": prompt, "completion": rejected, "label": False})
  77. dev_data = []
  78. for dev_path in dev_path_list:
  79. with open(dev_path, 'r', encoding='utf-8') as f:
  80. for line in f:
  81. line = json.loads(line.strip())
  82. prompt = line["prompt"]
  83. chosen = line["chosen"]
  84. rejected = line["rejected"]
  85. dev_data.append({"prompt": prompt, "completion": chosen, "label": True})
  86. dev_data.append({"prompt": prompt, "completion": rejected, "label": False})
  87. print('len(train_data)', len(train_data))
  88. print('len(dev_data)', len(dev_data))
  89. train_dataset = Dataset.from_list(train_data)
  90. dev_dataset = Dataset.from_list(dev_data)
  91. return train_dataset, dev_dataset
  92. # 测试加载
  93. if __name__ == "__main__":
  94. train_ds, dev_ds = load_bid_data("data/train_data.jsonl", "data/dev_data.jsonl")
  95. print(f"训练集数量:{len(train_ds)}")
  96. print(f"验证集数量:{len(dev_ds)}")
  97. for ds in train_ds:
  98. print(f"示例数据:{ds['text']}")