RecMetric.py 957 B

12345678910111213141516171819202122232425262728
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/6/15 14:07
  3. # @Author : zhoujun
  4. import Levenshtein
  5. class RecMetric:
  6. def __init__(self, converter):
  7. """
  8. 文本识别相关指标计算类
  9. :param converter: 用于label转换的转换器
  10. """
  11. self.converter = converter
  12. def __call__(self, predictions, labels):
  13. n_correct = 0
  14. norm_edit_dis = 0.0
  15. predictions = predictions.softmax(dim=2).detach().cpu().numpy()
  16. preds_str = self.converter.decode(predictions)
  17. show_str = []
  18. for (pred, pred_conf), target in zip(preds_str, labels):
  19. norm_edit_dis += Levenshtein.distance(pred, target) / max(len(pred), len(target))
  20. show_str.append(f'{pred} -> {target}')
  21. print(f'{pred} -> {target}')
  22. if pred == target:
  23. n_correct += 1
  24. return {'n_correct': n_correct, 'norm_edit_dis': norm_edit_dis, 'show_str': show_str}