import os import re import cv2 import numpy as np from PIL import Image import torch.nn.functional as F from torch.utils.data import Dataset from pyctcdecode import build_ctcdecoder equation_char_dict_path = os.path.dirname(os.path.abspath(__file__)) + "/equation_torch.txt" with open(equation_char_dict_path, "r", encoding='utf-8') as f: map_list = f.readlines() map_str = "".join(map_list) map_str = re.sub("\n", "", map_str) decoder = build_ctcdecoder([x for x in map_str]) class EquationDataset(Dataset): def __init__(self, paths, image_shape, input_len, label_len, mode=0, channel=1): self.image_shape = image_shape self.batch_size = image_shape[0] self.label_len = label_len self.input_len = input_len self.mode = mode self.channel = channel with open(equation_char_dict_path, "r", encoding='utf-8') as f: map_list = f.readlines() map_str = "".join(map_list) self.map_str = re.sub("\n", "", map_str) self.char_map_dict = { "星": '*', "斜": "/", "问": "?", 'x': '×', '?': '?' } self.data, self.targets, self.data_len, self.targets_len, self.texts = self.gen_data(paths) def gen_data(self, paths): data_x = [] data_y = [] data_x_len = [] data_y_len = [] text_list = [] inference_flag = 0 if paths and type(paths[0]) == np.ndarray: inference_flag = 1 for p in paths: if not inference_flag: path = p.split(os.sep)[-1] char_index_list = [] char_len = self.label_len if self.mode == 0: chars = path.split(".")[0].split('_')[1:] text_list.append(chars) char_len = 0 for c in chars: if c in self.char_map_dict.keys(): c = self.char_map_dict.get(c) if not c: continue char_index_list.append(self.map_str.index(c)) char_len += 1 char_index_list.extend([0] * (self.label_len - len(char_index_list))) label = np.array(char_index_list) img1 = cv2.imread(p) else: label = [] char_len = 0 img1 = p if img1 is None: img_pil = Image.open(p) img1 = pil2np(img_pil) img1 = pil_resize(img1, self.image_shape[0], self.image_shape[1]) # cv2.imshow("gen_char", img1) # cv2.waitKey(0) if self.channel == 1: img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) img1 = np.expand_dims(img1, axis=-1) img1 = img1 / 255. data_x.append(img1) data_y.append(label) data_x_len.append(self.input_len) data_y_len.append(char_len) return data_x, data_y, data_x_len, data_y_len, text_list def __len__(self): return len(self.data) def __getitem__(self, idx): x = self.data[idx] y = self.targets[idx] x_len = self.data_len[idx] y_len = self.targets_len[idx] # print(self.texts[idx], x_len, y_len) return x, y, x_len, y_len def py_ctc_decode(logits): # labels = [ # " ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", # "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", # ] batch_result_list = [] logits = F.pad(logits, (0, 1), value=-1000) # print('logits.shape', logits.shape) for i in range(logits.shape[0]): # prepare decoder and decode logits via shallow fusion sub_logits = logits[i, :, :].detach().numpy() # start_time = time.time() text = decoder.decode_beams(sub_logits, beam_width=10) # print('logits.numpy() cost', time.time()-start_time) text = [x[0] for x in text] text = [re.sub('@', '', x) for x in text] batch_result_list.append(text) return batch_result_list def pil_resize(image_np, height, width): image_pil = Image.fromarray(cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)) image_pil = image_pil.resize((int(width), int(height)), Image.BICUBIC) image_np = cv2.cvtColor(np.asarray(image_pil), cv2.COLOR_RGB2BGR) return image_np def pil2np(image_pil): image_np = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) return image_np