save.py 398 B

123456789101112
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/6/15 17:14
  3. # @Author : zhoujun
  4. import torch
  5. def save_checkpoint(checkpoint_path, model, _optimizers, epoch, logger):
  6. state = {'state_dict': model.state_dict(),
  7. 'optimizer': [_.state_dict() for _ in _optimizers],
  8. 'epoch': epoch}
  9. torch.save(state, checkpoint_path)
  10. logger.info('models saved to %s' % checkpoint_path)