__init__.py 697 B

1234567891011121314151617181920212223
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/5/15 17:43
  3. # @Author : zhoujun
  4. import copy
  5. from addict import Dict
  6. from .DBLoss import DBLoss
  7. from .CTCLoss import CTCLoss,EnhancedCTCLoss
  8. from .PSELoss import PSELoss
  9. from .CombinedLoss import CombinedLoss
  10. from .FCELoss import FCELoss
  11. __all__ = ['build_loss']
  12. support_loss = ['DBLoss', 'CTCLoss','PSELoss','CombinedLoss','FCELoss','EnhancedCTCLoss']
  13. def build_loss(config):
  14. copy_config = copy.deepcopy(config)
  15. loss_type = copy_config.pop('type')
  16. assert loss_type in support_loss, f'all support loss is {support_loss}'
  17. criterion = eval(loss_type)(copy_config)
  18. # criterion = eval(loss_type)(Dict(copy_config))
  19. return criterion