12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- # -*- coding: utf-8 -*-
- # @Time : 2020/6/16 8:56
- # @Author : zhoujun
- import os
- import torch
- def load_checkpoint(_model, resume_from, to_use_device, _optimizers=None, third_name=None):
- """
- 加载预训练模型
- Args:
- _model: 模型
- resume_from: 预训练模型路径
- to_use_device: 设备
- _optimizers: 如果不为None,则表明采用模型的训练参数
- third_name: 第三方预训练模型的名称
- Returns:
- """
- global_state = {}
- if not third_name:
- state = torch.load(resume_from, map_location=to_use_device)
- _model.load_state_dict({'module.' + k: v for k, v in state['state_dict'].items()}, strict=True)
- # _model.load_state_dict(state['state_dict'])
- if 'optimizer' in state and _optimizers is not None:
- _optimizers.load_state_dict(state['optimizer'])
- if 'global_state' in state:
- global_state = state['global_state']
- return _model, _optimizers, global_state
- def save_checkpoint(checkpoint_path, model, _optimizers, logger, cfg, **kwargs):
- # if isinstance(model, torch.nn.DataParallel()):
- # mode_state_dict = model.module.state_dict()
- # else:
- mode_state_dict = model.module.state_dict()
- state = {'state_dict': mode_state_dict,
- 'optimizer': _optimizers.state_dict(),
- 'cfg': cfg}
- state.update(kwargs)
- torch.save(state, checkpoint_path)
- logger.info('models saved to %s' % checkpoint_path)
- def save_checkpoint_logic(total_loss, total_num, min_loss, net, solver, epoch, rec_train_options, logger):
- """
- 根据配置文件保存模型
- Args:
- total_loss:
- total_num:
- min_loss:
- net:
- epoch:
- rec_train_options:
- logger:
- Returns:
- """
- # operation for model save as parameter ckpt_save_type is HighestAcc
- if rec_train_options['ckpt_save_type'] == 'HighestAcc':
- loss_mean = sum([total_loss[idx] * total_num[idx] for idx in range(len(total_loss))]) / sum(total_num)
- if loss_mean < min_loss:
- min_loss = loss_mean
- save_checkpoint(os.path.join(rec_train_options['checkpoint_save_dir'], 'epoch_' + str(epoch) + '.pth'), net,
- solver, epoch, logger)
- else:
- if epoch % rec_train_options['ckpt_save_epoch'] == 0:
- save_checkpoint(os.path.join(rec_train_options['checkpoint_save_dir'], 'epoch_' + str(epoch) + '.pth'), net,
- solver, epoch, logger)
- return min_loss
|