pre_process_torch.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import os
  2. import re
  3. import cv2
  4. import numpy as np
  5. from PIL import Image
  6. import torch.nn.functional as F
  7. from torch.utils.data import Dataset
  8. from pyctcdecode import build_ctcdecoder
  9. equation_char_dict_path = os.path.dirname(os.path.abspath(__file__)) + "/equation_torch.txt"
  10. with open(equation_char_dict_path, "r", encoding='utf-8') as f:
  11. map_list = f.readlines()
  12. map_str = "".join(map_list)
  13. map_str = re.sub("\n", "", map_str)
  14. decoder = build_ctcdecoder([x for x in map_str])
  15. class EquationDataset(Dataset):
  16. def __init__(self, paths, image_shape, input_len, label_len, mode=0, channel=1):
  17. self.image_shape = image_shape
  18. self.batch_size = image_shape[0]
  19. self.label_len = label_len
  20. self.input_len = input_len
  21. self.mode = mode
  22. self.channel = channel
  23. with open(equation_char_dict_path, "r", encoding='utf-8') as f:
  24. map_list = f.readlines()
  25. map_str = "".join(map_list)
  26. self.map_str = re.sub("\n", "", map_str)
  27. self.char_map_dict = {
  28. "星": '*',
  29. "斜": "/",
  30. "问": "?",
  31. 'x': '×',
  32. '?': '?'
  33. }
  34. self.data, self.targets, self.data_len, self.targets_len, self.texts = self.gen_data(paths)
  35. def gen_data(self, paths):
  36. data_x = []
  37. data_y = []
  38. data_x_len = []
  39. data_y_len = []
  40. text_list = []
  41. inference_flag = 0
  42. if paths and type(paths[0]) == np.ndarray:
  43. inference_flag = 1
  44. for p in paths:
  45. if not inference_flag:
  46. path = p.split(os.sep)[-1]
  47. char_index_list = []
  48. char_len = self.label_len
  49. if self.mode == 0:
  50. chars = path.split(".")[0].split('_')[1:]
  51. text_list.append(chars)
  52. char_len = 0
  53. for c in chars:
  54. if c in self.char_map_dict.keys():
  55. c = self.char_map_dict.get(c)
  56. if not c:
  57. continue
  58. char_index_list.append(self.map_str.index(c))
  59. char_len += 1
  60. char_index_list.extend([0] * (self.label_len - len(char_index_list)))
  61. label = np.array(char_index_list)
  62. img1 = cv2.imread(p)
  63. else:
  64. label = []
  65. char_len = 0
  66. img1 = p
  67. if img1 is None:
  68. img_pil = Image.open(p)
  69. img1 = pil2np(img_pil)
  70. img1 = pil_resize(img1, self.image_shape[0], self.image_shape[1])
  71. # cv2.imshow("gen_char", img1)
  72. # cv2.waitKey(0)
  73. if self.channel == 1:
  74. img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
  75. img1 = np.expand_dims(img1, axis=-1)
  76. img1 = img1 / 255.
  77. data_x.append(img1)
  78. data_y.append(label)
  79. data_x_len.append(self.input_len)
  80. data_y_len.append(char_len)
  81. return data_x, data_y, data_x_len, data_y_len, text_list
  82. def __len__(self):
  83. return len(self.data)
  84. def __getitem__(self, idx):
  85. x = self.data[idx]
  86. y = self.targets[idx]
  87. x_len = self.data_len[idx]
  88. y_len = self.targets_len[idx]
  89. # print(self.texts[idx], x_len, y_len)
  90. return x, y, x_len, y_len
  91. def py_ctc_decode(logits):
  92. # labels = [
  93. # " ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l",
  94. # "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z",
  95. # ]
  96. batch_result_list = []
  97. logits = F.pad(logits, (0, 1), value=-1000)
  98. # print('logits.shape', logits.shape)
  99. for i in range(logits.shape[0]):
  100. # prepare decoder and decode logits via shallow fusion
  101. sub_logits = logits[i, :, :].detach().numpy()
  102. # start_time = time.time()
  103. text = decoder.decode_beams(sub_logits, beam_width=10)
  104. # print('logits.numpy() cost', time.time()-start_time)
  105. text = [x[0] for x in text]
  106. text = [re.sub('@', '', x) for x in text]
  107. batch_result_list.append(text)
  108. return batch_result_list
  109. def pil_resize(image_np, height, width):
  110. image_pil = Image.fromarray(cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB))
  111. image_pil = image_pil.resize((int(width), int(height)), Image.BICUBIC)
  112. image_np = cv2.cvtColor(np.asarray(image_pil), cv2.COLOR_RGB2BGR)
  113. return image_np
  114. def pil2np(image_pil):
  115. image_np = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
  116. return image_np