simple_dataset.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import traceback
  15. import numpy as np
  16. import os
  17. import random
  18. from PIL import Image
  19. from paddle.io import Dataset
  20. from ppocr.data.text2Image import create_image, delete_image, my_image_aug
  21. from .imaug import transform, create_operators
  22. import sys
  23. sys.setrecursionlimit(100000)
  24. class SimpleDataSet(Dataset):
  25. def __init__(self, config, mode, logger, seed=None):
  26. super(SimpleDataSet, self).__init__()
  27. self.logger = logger
  28. global_config = config['Global']
  29. # 读取Train相关参数
  30. dataset_config = config[mode]['dataset']
  31. loader_config = config[mode]['loader']
  32. self.delimiter = dataset_config.get('delimiter', '\t')
  33. # 图片路径对应文字txt
  34. label_file_list = dataset_config.pop('label_file_list')
  35. data_source_num = len(label_file_list)
  36. ratio_list = dataset_config.get("ratio_list", [1.0])
  37. if isinstance(ratio_list, (float, int)):
  38. ratio_list = [float(ratio_list)] * int(data_source_num)
  39. assert len(
  40. ratio_list
  41. ) == data_source_num, "The length of ratio_list should be the same as the file_list."
  42. # 图片路径
  43. self.data_dir = dataset_config['data_dir']
  44. self.do_shuffle = loader_config['shuffle']
  45. self.seed = seed
  46. logger.info("Initialize indexs of datasets:%s" % label_file_list)
  47. self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
  48. self.data_idx_order_list = list(range(len(self.data_lines)))
  49. if mode.lower() == "train":
  50. self.shuffle_data_random()
  51. self.ops = create_operators(dataset_config['transforms'], global_config)
  52. def get_image_info_list(self, file_list, ratio_list):
  53. if isinstance(file_list, str):
  54. file_list = [file_list]
  55. data_lines = []
  56. for idx, file in enumerate(file_list):
  57. with open(file, "rb") as f:
  58. lines = f.readlines()
  59. random.seed(self.seed)
  60. lines = random.sample(lines,
  61. round(len(lines) * ratio_list[idx]))
  62. data_lines.extend(lines)
  63. return data_lines
  64. def shuffle_data_random(self):
  65. if self.do_shuffle:
  66. random.seed(self.seed)
  67. random.shuffle(self.data_lines)
  68. return
  69. def __getitem__(self, idx):
  70. file_idx = self.data_idx_order_list[idx]
  71. data_line = self.data_lines[file_idx]
  72. try:
  73. data_line = data_line.decode('utf-8')
  74. substr = data_line.strip("\n").split(self.delimiter)
  75. # 图片文件路径、图片文字标识
  76. file_name = substr[0]
  77. label = substr[1]
  78. if file_name[:5] != "image":
  79. # 临时按Label创建图片
  80. create_image(self.data_dir, file_name, label)
  81. # 读取图片
  82. img_path = os.path.join(self.data_dir, file_name)
  83. data = {'img_path': img_path, 'label': label}
  84. if not os.path.exists(img_path):
  85. raise Exception("{} does not exist!".format(img_path))
  86. with open(data['img_path'], 'rb') as f:
  87. img = f.read()
  88. data['image'] = img
  89. outs = transform(data, self.ops)
  90. # 删除临时图片文件
  91. delete_image(self.data_dir, file_name)
  92. else:
  93. # 直接读取文件中有的图片
  94. img_path = os.path.join(self.data_dir, file_name)
  95. data = {'img_path': img_path, 'label': label}
  96. if not os.path.exists(img_path):
  97. raise Exception("{} does not exist!".format(img_path))
  98. # img_pil = Image.open(img_path)
  99. # img_pil = my_image_aug(img_pil)
  100. # img_pil.save(img_path)
  101. with open(data['img_path'], 'rb') as f:
  102. img = f.read()
  103. data['image'] = img
  104. outs = transform(data, self.ops)
  105. except Exception as e:
  106. traceback.print_exc()
  107. self.logger.error(
  108. "When parsing line {}, error happened with msg: {}".format(
  109. data_line, e))
  110. outs = None
  111. if outs is None:
  112. return self.__getitem__(np.random.randint(self.__len__()))
  113. return outs
  114. def __len__(self):
  115. return len(self.data_idx_order_list)