12345678910111213141516171819202122232425262728 |
- # -*- coding: utf-8 -*-
- # @Time : 2020/6/15 14:07
- # @Author : zhoujun
- import Levenshtein
- class RecMetric:
- def __init__(self, converter):
- """
- 文本识别相关指标计算类
- :param converter: 用于label转换的转换器
- """
- self.converter = converter
- def __call__(self, predictions, labels):
- n_correct = 0
- norm_edit_dis = 0.0
- predictions = predictions.softmax(dim=2).detach().cpu().numpy()
- preds_str = self.converter.decode(predictions)
- show_str = []
- for (pred, pred_conf), target in zip(preds_str, labels):
- norm_edit_dis += Levenshtein.distance(pred, target) / max(len(pred), len(target))
- show_str.append(f'{pred} -> {target}')
- print(f'{pred} -> {target}')
- if pred == target:
- n_correct += 1
- return {'n_correct': n_correct, 'norm_edit_dis': norm_edit_dis, 'show_str': show_str}
|