create_rec_lmdb_dataset.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/11/6 15:31
  3. # @Author : zhoujun
  4. """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """
  5. import os
  6. import lmdb
  7. import cv2
  8. from tqdm import tqdm
  9. import numpy as np
  10. def checkImageIsValid(imageBin):
  11. if imageBin is None:
  12. return False
  13. imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
  14. img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
  15. imgH, imgW = img.shape[0], img.shape[1]
  16. if imgH * imgW == 0:
  17. return False
  18. return True
  19. def writeCache(env, cache):
  20. with env.begin(write=True) as txn:
  21. for k, v in cache.items():
  22. txn.put(k, v)
  23. def createDataset(data_list, lmdb_save_path, checkValid=True):
  24. """
  25. Create LMDB dataset for training and evaluation.
  26. ARGS:
  27. data_list : a list contains img_path\tlabel
  28. lmdb_save_path : LMDB output path
  29. checkValid : if true, check the validity of every image
  30. """
  31. os.makedirs(lmdb_save_path, exist_ok=True)
  32. env = lmdb.open(lmdb_save_path, map_size=109951162)
  33. cache = {}
  34. cnt = 1
  35. for imagePath, label in tqdm(data_list, desc=f'make dataset, save to {lmdb_save_path}'):
  36. with open(imagePath, 'rb') as f:
  37. imageBin = f.read()
  38. if checkValid:
  39. try:
  40. if not checkImageIsValid(imageBin):
  41. print('%s is not a valid image' % imagePath)
  42. continue
  43. except:
  44. continue
  45. imageKey = 'image-%09d'.encode() % cnt
  46. labelKey = 'label-%09d'.encode() % cnt
  47. cache[imageKey] = imageBin
  48. cache[labelKey] = label.encode()
  49. if cnt % 1000 == 0:
  50. writeCache(env, cache)
  51. cache = {}
  52. cnt += 1
  53. nSamples = cnt - 1
  54. cache['num-samples'.encode()] = str(nSamples).encode()
  55. writeCache(env, cache)
  56. print('Created dataset with %d samples' % nSamples)
  57. if __name__ == '__main__':
  58. import pathlib
  59. label_file = r"path/val.txt"
  60. lmdb_save_path = r'path/lmdb/eval'
  61. os.makedirs(lmdb_save_path, exist_ok=True)
  62. data_list = []
  63. with open(label_file, 'r', encoding='utf-8') as f:
  64. for line in tqdm(f.readlines(), desc=f'load data from {label_file}'):
  65. line = line.strip('\n').replace('.jpg ', '.jpg\t').replace('.png ', '.png\t').split('\t')
  66. if len(line) > 1:
  67. img_path = pathlib.Path(line[0].strip(' '))
  68. label = line[1]
  69. if img_path.exists() and img_path.stat().st_size > 0:
  70. data_list.append((str(img_path), label))
  71. createDataset(data_list, lmdb_save_path)