label_convert.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/6/15 14:29
  3. # @Author : zhoujun
  4. import torch
  5. class CTCLabelConverter(object):
  6. """ Convert between text-label and text-index """
  7. def __init__(self, character):
  8. # character (str): set of the possible characters.
  9. dict_character = []
  10. with open(character, "rb") as fin:
  11. lines = fin.readlines()
  12. for line in lines:
  13. line = line.decode('utf-8').strip("\n").strip("\r\n")
  14. dict_character += list(line)
  15. # dict_character = list(character)
  16. self.dict = {}
  17. for i, char in enumerate(dict_character):
  18. # NOTE: 0 is reserved for 'blank' token required by CTCLoss
  19. self.dict[char] = i + 1
  20. #TODO replace ‘ ’ with special symbol
  21. self.character = ['[blank]'] + dict_character+[' '] # dummy '[blank]' token for CTCLoss (index 0)
  22. def encode(self, text, batch_max_length=None):
  23. """convert text-label into text-index.
  24. input:
  25. text: text labels of each image. [batch_size]
  26. output:
  27. text: concatenated text index for CTCLoss.
  28. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
  29. length: length of each text. [batch_size]
  30. """
  31. length = [len(s) for s in text]
  32. # text = ''.join(text)
  33. # text = [self.dict[char] for char in text]
  34. d = []
  35. batch_max_length = max(length)
  36. for s in text:
  37. t = [self.dict[char] for char in s]
  38. t.extend([0] * (batch_max_length - len(s)))
  39. d.append(t)
  40. return (torch.tensor(d, dtype=torch.long), torch.tensor(length, dtype=torch.long))
  41. def decode(self, preds, raw=False):
  42. """ convert text-index into text-label. """
  43. preds_idx = preds.argmax(axis=2)
  44. preds_prob = preds.max(axis=2)
  45. result_list = []
  46. for word, prob in zip(preds_idx, preds_prob):
  47. if raw:
  48. result_list.append((''.join([self.character[int(i)] for i in word]), prob))
  49. else:
  50. result = []
  51. conf = []
  52. for i, index in enumerate(word):
  53. if word[i] != 0 and (not (i > 0 and word[i - 1] == word[i])):
  54. result.append(self.character[int(index)])
  55. conf.append(prob[i])
  56. result_list.append((''.join(result), conf))
  57. return result_list