rec_postprocess.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import string
  16. import paddle
  17. from paddle.nn import functional as F
  18. class BaseRecLabelDecode(object):
  19. """ Convert between text-label and text-index """
  20. def __init__(self,
  21. character_dict_path=None,
  22. character_type='ch',
  23. use_space_char=False):
  24. support_character_type = [
  25. 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
  26. 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
  27. 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
  28. 'ne', 'EN'
  29. ]
  30. assert character_type in support_character_type, "Only {} are supported now but get {}".format(
  31. support_character_type, character_type)
  32. self.beg_str = "sos"
  33. self.end_str = "eos"
  34. if character_type == "en":
  35. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  36. dict_character = list(self.character_str)
  37. elif character_type == "EN_symbol":
  38. # same with ASTER setting (use 94 char).
  39. self.character_str = string.printable[:-6]
  40. dict_character = list(self.character_str)
  41. elif character_type in support_character_type:
  42. self.character_str = ""
  43. assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
  44. character_type)
  45. with open(character_dict_path, "rb") as fin:
  46. lines = fin.readlines()
  47. for line in lines:
  48. # 移除字符串头尾指定的字符
  49. line = line.decode('utf-8').strip("\n").strip("\r\n")
  50. self.character_str += line
  51. if use_space_char:
  52. self.character_str += " "
  53. dict_character = list(self.character_str)
  54. else:
  55. raise NotImplementedError
  56. self.character_type = character_type
  57. dict_character = self.add_special_char(dict_character)
  58. self.dict = {}
  59. for i, char in enumerate(dict_character):
  60. self.dict[char] = i
  61. self.character = dict_character
  62. def add_special_char(self, dict_character):
  63. return dict_character
  64. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  65. """ convert text-index into text-label. """
  66. # print(self.character)
  67. result_list = []
  68. ignored_tokens = self.get_ignored_tokens()
  69. batch_size = len(text_index)
  70. # batch内部循环,每一个是一个文字序列,即一张图片
  71. for batch_idx in range(batch_size):
  72. char_list = []
  73. conf_list = []
  74. for idx in range(len(text_index[batch_idx])):
  75. if text_index[batch_idx][idx] in ignored_tokens:
  76. continue
  77. if is_remove_duplicate:
  78. # only for predict
  79. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  80. batch_idx][idx]:
  81. continue
  82. # print(char_list)
  83. char_list.append(self.character[int(text_index[batch_idx][
  84. idx])])
  85. if text_prob is not None:
  86. conf_list.append(text_prob[batch_idx][idx])
  87. else:
  88. conf_list.append(1)
  89. text = ''.join(char_list)
  90. result_list.append((text, np.mean(conf_list)))
  91. return result_list
  92. def get_ignored_tokens(self):
  93. return [0] # for ctc blank
  94. class CTCLabelDecode(BaseRecLabelDecode):
  95. """ Convert between text-label and text-index """
  96. def __init__(self,
  97. character_dict_path=None,
  98. character_type='ch',
  99. use_space_char=False,
  100. **kwargs):
  101. super(CTCLabelDecode, self).__init__(character_dict_path,
  102. character_type, use_space_char)
  103. def __call__(self, preds, label=None, *args, **kwargs):
  104. if isinstance(preds, paddle.Tensor):
  105. preds = preds.numpy()
  106. # 预测值按列取最大值及其id
  107. preds_idx = preds.argmax(axis=2)
  108. preds_prob = preds.max(axis=2)
  109. # print("rec_postprocess")
  110. # print("preds", preds.shape)
  111. # print("preds_idx", preds_idx, len(preds_idx))
  112. # print("preds_prob", preds_prob, len(preds_prob))
  113. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  114. # print("text", text, len(text))
  115. if label is None:
  116. return text
  117. label = self.decode(label)
  118. return text, label
  119. def add_special_char(self, dict_character):
  120. dict_character = ['blank'] + dict_character
  121. return dict_character
  122. class AttnLabelDecode(BaseRecLabelDecode):
  123. """ Convert between text-label and text-index """
  124. def __init__(self,
  125. character_dict_path=None,
  126. character_type='ch',
  127. use_space_char=False,
  128. **kwargs):
  129. super(AttnLabelDecode, self).__init__(character_dict_path,
  130. character_type, use_space_char)
  131. def add_special_char(self, dict_character):
  132. self.beg_str = "sos"
  133. self.end_str = "eos"
  134. dict_character = dict_character
  135. dict_character = [self.beg_str] + dict_character + [self.end_str]
  136. return dict_character
  137. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  138. """ convert text-index into text-label. """
  139. result_list = []
  140. ignored_tokens = self.get_ignored_tokens()
  141. [beg_idx, end_idx] = self.get_ignored_tokens()
  142. batch_size = len(text_index)
  143. for batch_idx in range(batch_size):
  144. char_list = []
  145. conf_list = []
  146. for idx in range(len(text_index[batch_idx])):
  147. if text_index[batch_idx][idx] in ignored_tokens:
  148. continue
  149. if int(text_index[batch_idx][idx]) == int(end_idx):
  150. break
  151. if is_remove_duplicate:
  152. # only for predict
  153. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  154. batch_idx][idx]:
  155. continue
  156. char_list.append(self.character[int(text_index[batch_idx][
  157. idx])])
  158. if text_prob is not None:
  159. conf_list.append(text_prob[batch_idx][idx])
  160. else:
  161. conf_list.append(1)
  162. text = ''.join(char_list)
  163. result_list.append((text, np.mean(conf_list)))
  164. return result_list
  165. def __call__(self, preds, label=None, *args, **kwargs):
  166. """
  167. text = self.decode(text)
  168. if label is None:
  169. return text
  170. else:
  171. label = self.decode(label, is_remove_duplicate=False)
  172. return text, label
  173. """
  174. if isinstance(preds, paddle.Tensor):
  175. preds = preds.numpy()
  176. preds_idx = preds.argmax(axis=2)
  177. preds_prob = preds.max(axis=2)
  178. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  179. if label is None:
  180. return text
  181. label = self.decode(label, is_remove_duplicate=False)
  182. return text, label
  183. def get_ignored_tokens(self):
  184. beg_idx = self.get_beg_end_flag_idx("beg")
  185. end_idx = self.get_beg_end_flag_idx("end")
  186. return [beg_idx, end_idx]
  187. def get_beg_end_flag_idx(self, beg_or_end):
  188. if beg_or_end == "beg":
  189. idx = np.array(self.dict[self.beg_str])
  190. elif beg_or_end == "end":
  191. idx = np.array(self.dict[self.end_str])
  192. else:
  193. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  194. % beg_or_end
  195. return idx
  196. class SRNLabelDecode(BaseRecLabelDecode):
  197. """ Convert between text-label and text-index """
  198. def __init__(self,
  199. character_dict_path=None,
  200. character_type='en',
  201. use_space_char=False,
  202. **kwargs):
  203. super(SRNLabelDecode, self).__init__(character_dict_path,
  204. character_type, use_space_char)
  205. def __call__(self, preds, label=None, *args, **kwargs):
  206. pred = preds['predict']
  207. char_num = len(self.character_str) + 2
  208. if isinstance(pred, paddle.Tensor):
  209. pred = pred.numpy()
  210. pred = np.reshape(pred, [-1, char_num])
  211. preds_idx = np.argmax(pred, axis=1)
  212. preds_prob = np.max(pred, axis=1)
  213. preds_idx = np.reshape(preds_idx, [-1, 25])
  214. preds_prob = np.reshape(preds_prob, [-1, 25])
  215. text = self.decode(preds_idx, preds_prob)
  216. if label is None:
  217. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  218. return text
  219. label = self.decode(label)
  220. return text, label
  221. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  222. """ convert text-index into text-label. """
  223. result_list = []
  224. ignored_tokens = self.get_ignored_tokens()
  225. batch_size = len(text_index)
  226. for batch_idx in range(batch_size):
  227. char_list = []
  228. conf_list = []
  229. for idx in range(len(text_index[batch_idx])):
  230. if text_index[batch_idx][idx] in ignored_tokens:
  231. continue
  232. if is_remove_duplicate:
  233. # only for predict
  234. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  235. batch_idx][idx]:
  236. continue
  237. char_list.append(self.character[int(text_index[batch_idx][
  238. idx])])
  239. if text_prob is not None:
  240. conf_list.append(text_prob[batch_idx][idx])
  241. else:
  242. conf_list.append(1)
  243. text = ''.join(char_list)
  244. result_list.append((text, np.mean(conf_list)))
  245. return result_list
  246. def add_special_char(self, dict_character):
  247. dict_character = dict_character + [self.beg_str, self.end_str]
  248. return dict_character
  249. def get_ignored_tokens(self):
  250. beg_idx = self.get_beg_end_flag_idx("beg")
  251. end_idx = self.get_beg_end_flag_idx("end")
  252. return [beg_idx, end_idx]
  253. def get_beg_end_flag_idx(self, beg_or_end):
  254. if beg_or_end == "beg":
  255. idx = np.array(self.dict[self.beg_str])
  256. elif beg_or_end == "end":
  257. idx = np.array(self.dict[self.end_str])
  258. else:
  259. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  260. % beg_or_end
  261. return idx