ckpt.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/6/16 8:56
  3. # @Author : zhoujun
  4. import os
  5. import torch
  6. def load_checkpoint(_model, resume_from, to_use_device, _optimizers=None, third_name=None):
  7. """
  8. 加载预训练模型
  9. Args:
  10. _model: 模型
  11. resume_from: 预训练模型路径
  12. to_use_device: 设备
  13. _optimizers: 如果不为None,则表明采用模型的训练参数
  14. third_name: 第三方预训练模型的名称
  15. Returns:
  16. """
  17. global_state = {}
  18. if not third_name:
  19. state = torch.load(resume_from, map_location=to_use_device)
  20. _model.load_state_dict({'module.' + k: v for k, v in state['state_dict'].items()}, strict=True)
  21. # _model.load_state_dict(state['state_dict'])
  22. if 'optimizer' in state and _optimizers is not None:
  23. _optimizers.load_state_dict(state['optimizer'])
  24. if 'global_state' in state:
  25. global_state = state['global_state']
  26. return _model, _optimizers, global_state
  27. def save_checkpoint(checkpoint_path, model, _optimizers, logger, cfg, **kwargs):
  28. # if isinstance(model, torch.nn.DataParallel()):
  29. # mode_state_dict = model.module.state_dict()
  30. # else:
  31. mode_state_dict = model.module.state_dict()
  32. state = {'state_dict': mode_state_dict,
  33. 'optimizer': _optimizers.state_dict(),
  34. 'cfg': cfg}
  35. state.update(kwargs)
  36. torch.save(state, checkpoint_path)
  37. logger.info('models saved to %s' % checkpoint_path)
  38. def save_checkpoint_logic(total_loss, total_num, min_loss, net, solver, epoch, rec_train_options, logger):
  39. """
  40. 根据配置文件保存模型
  41. Args:
  42. total_loss:
  43. total_num:
  44. min_loss:
  45. net:
  46. epoch:
  47. rec_train_options:
  48. logger:
  49. Returns:
  50. """
  51. # operation for model save as parameter ckpt_save_type is HighestAcc
  52. if rec_train_options['ckpt_save_type'] == 'HighestAcc':
  53. loss_mean = sum([total_loss[idx] * total_num[idx] for idx in range(len(total_loss))]) / sum(total_num)
  54. if loss_mean < min_loss:
  55. min_loss = loss_mean
  56. save_checkpoint(os.path.join(rec_train_options['checkpoint_save_dir'], 'epoch_' + str(epoch) + '.pth'), net,
  57. solver, epoch, logger)
  58. else:
  59. if epoch % rec_train_options['ckpt_save_epoch'] == 0:
  60. save_checkpoint(os.path.join(rec_train_options['checkpoint_save_dir'], 'epoch_' + str(epoch) + '.pth'), net,
  61. solver, epoch, logger)
  62. return min_loss