distill_metric.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import importlib
  2. from .DetMetric import DetMetric
  3. class DistillationMetric(object):
  4. def __init__(self, key=None, base_metric_name=None, main_indicator=None, **kwargs):
  5. self.main_indicator = main_indicator
  6. self.key = key
  7. self.main_indicator = main_indicator
  8. self.base_metric_name = base_metric_name
  9. self.kwargs = kwargs
  10. self.metrics = None
  11. self.out = dict()
  12. def _init_metric(self, preds):
  13. self.metrics = dict()
  14. mod = importlib.import_module(__name__)
  15. for key in preds:
  16. self.metrics[key] = getattr(mod, self.base_metric_name)(**self.kwargs)
  17. def __call__(self,batch, preds, **kwargs):
  18. assert isinstance(preds, dict), f'preds should be dict,not {type(preds)}'
  19. if self.metrics is None:
  20. self._init_metric(preds)
  21. for key in preds:
  22. self.out.setdefault(key, []).append(self.metrics[key].__call__( batch,preds[key], **kwargs))
  23. def get_metric(self):
  24. output = dict()
  25. for key, val in self.out.items():
  26. metric = self.metrics[key].gather_measure(val)
  27. if key == self.key:
  28. output.update(metric)
  29. else:
  30. for sub_key in metric:
  31. output['{}_{}'.format(key, sub_key)] = metric[sub_key]
  32. self.out.clear()
  33. return output
  34. def reset(self):
  35. for key in self.metrics:
  36. self.metrics[key].reset()