DetMetric.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import numpy as np
  2. from torchocr.metrics.iou_utils import DetectionIoUEvaluator
  3. class AverageMeter(object):
  4. """Computes and stores the average and current value"""
  5. def __init__(self):
  6. self.reset()
  7. def reset(self):
  8. self.val = 0
  9. self.avg = 0
  10. self.sum = 0
  11. self.count = 0
  12. def update(self, val, n=1):
  13. self.val = val
  14. self.sum += val * n
  15. self.count += n
  16. self.avg = self.sum / self.count
  17. return self
  18. class DetMetric():
  19. def __init__(self, is_output_polygon=False):
  20. self.is_output_polygon = is_output_polygon
  21. self.evaluator = DetectionIoUEvaluator(is_output_polygon=is_output_polygon)
  22. def __call__(self, batch, output, box_thresh=0.6):
  23. '''
  24. batch: (image, polygons, ignore_tags
  25. batch: a dict produced by dataloaders.
  26. image: tensor of shape (N, C, H, W).
  27. polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
  28. ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
  29. shape: the original shape of images.
  30. filename: the original filenames of images.
  31. output: (polygons, ...)
  32. '''
  33. results = []
  34. gt_polyons_batch = batch['text_polys']
  35. ignore_tags_batch = batch['ignore_tags']
  36. pred_polygons_batch = np.array(output[0])
  37. pred_scores_batch = np.array(output[1])
  38. for polygons, pred_polygons, pred_scores, ignore_tags in zip(gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch):
  39. gt = [dict(points=np.int64(polygons[i]), ignore=ignore_tags[i]) for i in range(len(polygons))]
  40. if self.is_output_polygon:
  41. pred = [dict(points=pred_polygons[i]) for i in range(len(pred_polygons))]
  42. else:
  43. pred = []
  44. # print(pred_polygons.shape)
  45. for i in range(pred_polygons.shape[0]):
  46. if pred_scores[i] >= box_thresh:
  47. # print(pred_polygons[i,:,:].tolist())
  48. pred.append(dict(points=pred_polygons[i, :, :].astype(np.int)))
  49. # pred = [dict(points=pred_polygons[i,:,:].tolist()) if pred_scores[i] >= box_thresh for i in range(pred_polygons.shape[0])]
  50. results.append(self.evaluator.evaluate_image(gt, pred))
  51. return results
  52. def gather_measure(self, raw_metrics):
  53. raw_metrics = [image_metrics
  54. for batch_metrics in raw_metrics
  55. for image_metrics in batch_metrics]
  56. result = self.evaluator.combine_results(raw_metrics)
  57. precision = AverageMeter()
  58. recall = AverageMeter()
  59. fmeasure = AverageMeter()
  60. precision.update(result['precision'], n=len(raw_metrics))
  61. recall.update(result['recall'], n=len(raw_metrics))
  62. fmeasure_score = 2 * precision.val * recall.val / (precision.val + recall.val + 1e-8)
  63. fmeasure.update(fmeasure_score)
  64. return {
  65. 'precision': precision,
  66. 'recall': recall,
  67. 'fmeasure': fmeasure
  68. }