123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- import os
- import re
- import cv2
- import numpy as np
- from PIL import Image
- import torch.nn.functional as F
- from torch.utils.data import Dataset
- from pyctcdecode import build_ctcdecoder
- equation_char_dict_path = os.path.dirname(os.path.abspath(__file__)) + "/equation_torch.txt"
- with open(equation_char_dict_path, "r", encoding='utf-8') as f:
- map_list = f.readlines()
- map_str = "".join(map_list)
- map_str = re.sub("\n", "", map_str)
- decoder = build_ctcdecoder([x for x in map_str])
- class EquationDataset(Dataset):
- def __init__(self, paths, image_shape, input_len, label_len, mode=0, channel=1):
- self.image_shape = image_shape
- self.batch_size = image_shape[0]
- self.label_len = label_len
- self.input_len = input_len
- self.mode = mode
- self.channel = channel
- with open(equation_char_dict_path, "r", encoding='utf-8') as f:
- map_list = f.readlines()
- map_str = "".join(map_list)
- self.map_str = re.sub("\n", "", map_str)
- self.char_map_dict = {
- "星": '*',
- "斜": "/",
- "问": "?",
- 'x': '×',
- '?': '?'
- }
- self.data, self.targets, self.data_len, self.targets_len, self.texts = self.gen_data(paths)
- def gen_data(self, paths):
- data_x = []
- data_y = []
- data_x_len = []
- data_y_len = []
- text_list = []
- inference_flag = 0
- if paths and type(paths[0]) == np.ndarray:
- inference_flag = 1
- for p in paths:
- if not inference_flag:
- path = p.split(os.sep)[-1]
- char_index_list = []
- char_len = self.label_len
- if self.mode == 0:
- chars = path.split(".")[0].split('_')[1:]
- text_list.append(chars)
- char_len = 0
- for c in chars:
- if c in self.char_map_dict.keys():
- c = self.char_map_dict.get(c)
- if not c:
- continue
- char_index_list.append(self.map_str.index(c))
- char_len += 1
- char_index_list.extend([0] * (self.label_len - len(char_index_list)))
- label = np.array(char_index_list)
- img1 = cv2.imread(p)
- else:
- label = []
- char_len = 0
- img1 = p
- if img1 is None:
- img_pil = Image.open(p)
- img1 = pil2np(img_pil)
- img1 = pil_resize(img1, self.image_shape[0], self.image_shape[1])
- # cv2.imshow("gen_char", img1)
- # cv2.waitKey(0)
- if self.channel == 1:
- img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
- img1 = np.expand_dims(img1, axis=-1)
- img1 = img1 / 255.
- data_x.append(img1)
- data_y.append(label)
- data_x_len.append(self.input_len)
- data_y_len.append(char_len)
- return data_x, data_y, data_x_len, data_y_len, text_list
- def __len__(self):
- return len(self.data)
- def __getitem__(self, idx):
- x = self.data[idx]
- y = self.targets[idx]
- x_len = self.data_len[idx]
- y_len = self.targets_len[idx]
- # print(self.texts[idx], x_len, y_len)
- return x, y, x_len, y_len
- def py_ctc_decode(logits):
- # labels = [
- # " ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l",
- # "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z",
- # ]
- batch_result_list = []
- logits = F.pad(logits, (0, 1), value=-1000)
- # print('logits.shape', logits.shape)
- for i in range(logits.shape[0]):
- # prepare decoder and decode logits via shallow fusion
- sub_logits = logits[i, :, :].detach().numpy()
- # start_time = time.time()
- text = decoder.decode_beams(sub_logits, beam_width=10)
- # print('logits.numpy() cost', time.time()-start_time)
- text = [x[0] for x in text]
- text = [re.sub('@', '', x) for x in text]
- batch_result_list.append(text)
- return batch_result_list
- def pil_resize(image_np, height, width):
- image_pil = Image.fromarray(cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB))
- image_pil = image_pil.resize((int(width), int(height)), Image.BICUBIC)
- image_np = cv2.cvtColor(np.asarray(image_pil), cv2.COLOR_RGB2BGR)
- return image_np
- def pil2np(image_pil):
- image_np = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
- return image_np
|