simple_dataset.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import os
  16. import random
  17. from paddle.io import Dataset
  18. from ocr.ppocr.data.text2Image import create_image, delete_image
  19. from .imaug import transform, create_operators
  20. import sys
  21. sys.setrecursionlimit(100000)
  22. class SimpleDataSet(Dataset):
  23. def __init__(self, config, mode, logger, seed=None):
  24. super(SimpleDataSet, self).__init__()
  25. self.logger = logger
  26. global_config = config['Global']
  27. # 读取Train相关参数
  28. dataset_config = config[mode]['dataset']
  29. loader_config = config[mode]['loader']
  30. self.delimiter = dataset_config.get('delimiter', '\t')
  31. # 图片路径对应文字txt
  32. label_file_list = dataset_config.pop('label_file_list')
  33. data_source_num = len(label_file_list)
  34. ratio_list = dataset_config.get("ratio_list", [1.0])
  35. if isinstance(ratio_list, (float, int)):
  36. ratio_list = [float(ratio_list)] * int(data_source_num)
  37. assert len(
  38. ratio_list
  39. ) == data_source_num, "The length of ratio_list should be the same as the file_list."
  40. # 图片路径
  41. self.data_dir = dataset_config['data_dir']
  42. self.do_shuffle = loader_config['shuffle']
  43. self.seed = seed
  44. logger.info("Initialize indexs of datasets:%s" % label_file_list)
  45. self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
  46. self.data_idx_order_list = list(range(len(self.data_lines)))
  47. if mode.lower() == "train":
  48. self.shuffle_data_random()
  49. self.ops = create_operators(dataset_config['transforms'], global_config)
  50. def get_image_info_list(self, file_list, ratio_list):
  51. if isinstance(file_list, str):
  52. file_list = [file_list]
  53. data_lines = []
  54. for idx, file in enumerate(file_list):
  55. with open(file, "rb") as f:
  56. lines = f.readlines()
  57. random.seed(self.seed)
  58. lines = random.sample(lines,
  59. round(len(lines) * ratio_list[idx]))
  60. data_lines.extend(lines)
  61. return data_lines
  62. def shuffle_data_random(self):
  63. if self.do_shuffle:
  64. random.seed(self.seed)
  65. random.shuffle(self.data_lines)
  66. return
  67. def __getitem__(self, idx):
  68. file_idx = self.data_idx_order_list[idx]
  69. data_line = self.data_lines[file_idx]
  70. try:
  71. data_line = data_line.decode('utf-8')
  72. substr = data_line.strip("\n").split(self.delimiter)
  73. # 图片文件路径、图片文字标识
  74. file_name = substr[0]
  75. label = substr[1]
  76. if file_name[:5] != "image":
  77. # 临时按Label创建图片
  78. create_image(self.data_dir, file_name, label)
  79. # 读取图片
  80. img_path = os.path.join(self.data_dir, file_name)
  81. data = {'img_path': img_path, 'label': label}
  82. if not os.path.exists(img_path):
  83. raise Exception("{} does not exist!".format(img_path))
  84. with open(data['img_path'], 'rb') as f:
  85. img = f.read()
  86. data['image'] = img
  87. outs = transform(data, self.ops)
  88. # 删除临时图片文件
  89. delete_image(self.data_dir, file_name)
  90. else:
  91. # 直接读取文件中有的图片
  92. img_path = os.path.join(self.data_dir, file_name)
  93. data = {'img_path': img_path, 'label': label}
  94. if not os.path.exists(img_path):
  95. raise Exception("{} does not exist!".format(img_path))
  96. with open(data['img_path'], 'rb') as f:
  97. img = f.read()
  98. data['image'] = img
  99. outs = transform(data, self.ops)
  100. except Exception as e:
  101. self.logger.error(
  102. "When parsing line {}, error happened with msg: {}".format(
  103. data_line, e))
  104. outs = None
  105. if outs is None:
  106. return self.__getitem__(np.random.randint(self.__len__()))
  107. return outs
  108. def __len__(self):
  109. return len(self.data_idx_order_list)