_360cc.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from __future__ import print_function, absolute_import
  2. import torch.utils.data as data
  3. import os
  4. import numpy as np
  5. import cv2
  6. import lib.utils.utils as utils
  7. class _360CC(data.Dataset):
  8. def __init__(self, config, is_train=True):
  9. self.root = config.DATASET.ROOT
  10. self.is_train = is_train
  11. self.inp_h = config.MODEL.IMAGE_SIZE.H
  12. self.inp_w = config.MODEL.IMAGE_SIZE.W
  13. self.dataset_name = config.DATASET.DATASET
  14. self.mean = np.array(config.DATASET.MEAN, dtype=np.float32)
  15. self.std = np.array(config.DATASET.STD, dtype=np.float32)
  16. char_file = config.DATASET.CHAR_FILE
  17. with open(char_file, 'rb') as file:
  18. char_dict = {num: char.strip().decode('gbk', 'ignore') for num, char in enumerate(file.readlines())}
  19. txt_file = config.DATASET.JSON_FILE['train'] if is_train else config.DATASET.JSON_FILE['val']
  20. # convert name:indices to name:string
  21. self.labels = []
  22. with open(txt_file, 'r', encoding='utf-8') as file:
  23. contents = file.readlines()
  24. for c in contents:
  25. imgname = c.split(' ')[0]
  26. indices = c.split(' ')[1:]
  27. string = ''.join([char_dict[int(idx)] for idx in indices])
  28. self.labels.append({imgname: string})
  29. print("load {} images!".format(self.__len__()))
  30. def __len__(self):
  31. return len(self.labels)
  32. def __getitem__(self, idx):
  33. img_name = list(self.labels[idx].keys())[0]
  34. img = cv2.imread(os.path.join(self.root, img_name))
  35. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  36. img_h, img_w = img.shape
  37. img = cv2.resize(img, (0,0), fx=self.inp_w / img_w, fy=self.inp_h / img_h, interpolation=cv2.INTER_CUBIC)
  38. img = np.reshape(img, (self.inp_h, self.inp_w, 1))
  39. img = img.astype(np.float32)
  40. img = (img/255. - self.mean) / self.std
  41. img = img.transpose([2, 0, 1])
  42. return img, idx