save_load.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import errno
  18. import os
  19. import pickle
  20. import six
  21. import paddle
  22. __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
  23. def _mkdir_if_not_exist(path, logger):
  24. """
  25. mkdir if not exists, ignore the exception when multiprocess mkdir together
  26. """
  27. if not os.path.exists(path):
  28. try:
  29. os.makedirs(path)
  30. except OSError as e:
  31. if e.errno == errno.EEXIST and os.path.isdir(path):
  32. logger.warning(
  33. 'be happy if some process has already created {}'.format(
  34. path))
  35. else:
  36. raise OSError('Failed to mkdir {}'.format(path))
  37. def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
  38. if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
  39. raise ValueError("Model pretrain path {} does not "
  40. "exists.".format(path))
  41. if load_static_weights:
  42. pre_state_dict = paddle.static.load_program_state(path)
  43. param_state_dict = {}
  44. model_dict = model.state_dict()
  45. for key in model_dict.keys():
  46. weight_name = model_dict[key].name
  47. weight_name = weight_name.replace('binarize', '').replace(
  48. 'thresh', '') # for DB
  49. if weight_name in pre_state_dict.keys():
  50. # logger.info('Load weight: {}, shape: {}'.format(
  51. # weight_name, pre_state_dict[weight_name].shape))
  52. if 'encoder_rnn' in key:
  53. # delete axis which is 1
  54. pre_state_dict[weight_name] = pre_state_dict[
  55. weight_name].squeeze()
  56. # change axis
  57. if len(pre_state_dict[weight_name].shape) > 1:
  58. pre_state_dict[weight_name] = pre_state_dict[
  59. weight_name].transpose((1, 0))
  60. param_state_dict[key] = pre_state_dict[weight_name]
  61. else:
  62. param_state_dict[key] = model_dict[key]
  63. model.set_state_dict(param_state_dict)
  64. return param_state_dict
  65. param_state_dict = paddle.load(path + '.pdparams')
  66. # print("param_state_dict", param_state_dict)
  67. model.set_state_dict(param_state_dict)
  68. return param_state_dict
  69. def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
  70. """
  71. load model from checkpoint or pretrained_model
  72. """
  73. gloabl_config = config['Global']
  74. checkpoints = gloabl_config.get('checkpoints')
  75. pretrained_model = gloabl_config.get('pretrained_model')
  76. best_model_dict = {}
  77. if checkpoints:
  78. assert os.path.exists(checkpoints + ".pdparams"), \
  79. "Given dir {}.pdparams not exist.".format(checkpoints)
  80. assert os.path.exists(checkpoints + ".pdopt"), \
  81. "Given dir {}.pdopt not exist.".format(checkpoints)
  82. para_dict = paddle.load(checkpoints + '.pdparams')
  83. opti_dict = paddle.load(checkpoints + '.pdopt')
  84. model.set_state_dict(para_dict)
  85. if optimizer is not None:
  86. optimizer.set_state_dict(opti_dict)
  87. if os.path.exists(checkpoints + '.states'):
  88. with open(checkpoints + '.states', 'rb') as f:
  89. states_dict = pickle.load(f) if six.PY2 else pickle.load(
  90. f, encoding='latin1')
  91. best_model_dict = states_dict.get('best_model_dict', {})
  92. if 'epoch' in states_dict:
  93. best_model_dict['start_epoch'] = states_dict['epoch'] + 1
  94. logger.info("resume from {}".format(checkpoints))
  95. # 加载预训练模型
  96. elif pretrained_model:
  97. load_static_weights = gloabl_config.get('load_static_weights', False)
  98. if not isinstance(pretrained_model, list):
  99. pretrained_model = [pretrained_model]
  100. if not isinstance(load_static_weights, list):
  101. load_static_weights = [load_static_weights] * len(pretrained_model)
  102. for idx, pretrained in enumerate(pretrained_model):
  103. load_static = load_static_weights[idx]
  104. best_model_dict = load_dygraph_pretrain(
  105. model, logger, path=pretrained, load_static_weights=load_static)
  106. logger.info("load pretrained model from {}".format(
  107. pretrained_model))
  108. else:
  109. logger.info('train from scratch')
  110. return best_model_dict
  111. def save_model(net,
  112. optimizer,
  113. model_path,
  114. logger,
  115. is_best=False,
  116. prefix='ppocr',
  117. **kwargs):
  118. """
  119. save model to the target path
  120. """
  121. _mkdir_if_not_exist(model_path, logger)
  122. model_prefix = os.path.join(model_path, prefix)
  123. paddle.save(net.state_dict(), model_prefix + '.pdparams')
  124. paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
  125. # save metric and config
  126. with open(model_prefix + '.states', 'wb') as f:
  127. pickle.dump(kwargs, f, protocol=2)
  128. # if is_best:
  129. # logger.info('save best model is to {}'.format(model_prefix))
  130. # else:
  131. # logger.info("save model in {}".format(model_prefix))