save_load.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import errno
  18. import os
  19. import pickle
  20. import six
  21. import paddle
  22. __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
  23. def _mkdir_if_not_exist(path, logger):
  24. """
  25. mkdir if not exists, ignore the exception when multiprocess mkdir together
  26. """
  27. if not os.path.exists(path):
  28. try:
  29. os.makedirs(path)
  30. except OSError as e:
  31. if e.errno == errno.EEXIST and os.path.isdir(path):
  32. logger.warning(
  33. 'be happy if some process has already created {}'.format(
  34. path))
  35. else:
  36. raise OSError('Failed to mkdir {}'.format(path))
  37. def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
  38. # if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
  39. # raise ValueError("Model pretrain path {} does not "
  40. # "exists.".format(path))
  41. if not (os.path.isdir(path) or os.path.exists(path + '.pdiparams')):
  42. raise ValueError("Model pretrain path {} does not "
  43. "exists.".format(path))
  44. if load_static_weights:
  45. pre_state_dict = paddle.static.load_program_state(path)
  46. param_state_dict = {}
  47. model_dict = model.state_dict()
  48. for key in model_dict.keys():
  49. weight_name = model_dict[key].name
  50. weight_name = weight_name.replace('binarize', '').replace(
  51. 'thresh', '') # for DB
  52. if weight_name in pre_state_dict.keys():
  53. # logger.info('Load weight: {}, shape: {}'.format(
  54. # weight_name, pre_state_dict[weight_name].shape))
  55. if 'encoder_rnn' in key:
  56. # delete axis which is 1
  57. pre_state_dict[weight_name] = pre_state_dict[
  58. weight_name].squeeze()
  59. # change axis
  60. if len(pre_state_dict[weight_name].shape) > 1:
  61. pre_state_dict[weight_name] = pre_state_dict[
  62. weight_name].transpose((1, 0))
  63. param_state_dict[key] = pre_state_dict[weight_name]
  64. else:
  65. param_state_dict[key] = model_dict[key]
  66. model.set_state_dict(param_state_dict)
  67. return param_state_dict
  68. # param_state_dict = paddle.load(path + '.pdparams')
  69. param_state_dict = paddle.load(path)
  70. # print("param_state_dict", param_state_dict)
  71. model.set_state_dict(param_state_dict)
  72. return param_state_dict
  73. def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
  74. """
  75. load model from checkpoint or pretrained_model
  76. """
  77. gloabl_config = config['Global']
  78. checkpoints = gloabl_config.get('checkpoints')
  79. pretrained_model = gloabl_config.get('pretrained_model')
  80. best_model_dict = {}
  81. if checkpoints:
  82. assert os.path.exists(checkpoints + ".pdparams"), \
  83. "Given dir {}.pdparams not exist.".format(checkpoints)
  84. assert os.path.exists(checkpoints + ".pdopt"), \
  85. "Given dir {}.pdopt not exist.".format(checkpoints)
  86. para_dict = paddle.load(checkpoints + '.pdparams')
  87. opti_dict = paddle.load(checkpoints + '.pdopt')
  88. print("read .pdparams")
  89. model.set_state_dict(para_dict)
  90. if optimizer is not None:
  91. print("read .pdopt")
  92. optimizer.set_state_dict(opti_dict)
  93. if os.path.exists(checkpoints + '.states'):
  94. print("read .states")
  95. with open(checkpoints + '.states', 'rb') as f:
  96. states_dict = pickle.load(f) if six.PY2 else pickle.load(
  97. f, encoding='latin1')
  98. best_model_dict = states_dict.get('best_model_dict', {})
  99. if 'epoch' in states_dict:
  100. best_model_dict['start_epoch'] = states_dict['epoch'] + 1
  101. logger.info("resume from {}".format(checkpoints))
  102. # 加载预训练模型
  103. elif pretrained_model:
  104. load_static_weights = gloabl_config.get('load_static_weights', False)
  105. if not isinstance(pretrained_model, list):
  106. pretrained_model = [pretrained_model]
  107. if not isinstance(load_static_weights, list):
  108. load_static_weights = [load_static_weights] * len(pretrained_model)
  109. for idx, pretrained in enumerate(pretrained_model):
  110. load_static = load_static_weights[idx]
  111. best_model_dict = load_dygraph_pretrain(
  112. model, logger, path=pretrained, load_static_weights=load_static)
  113. logger.info("load pretrained model from {}".format(
  114. pretrained_model))
  115. else:
  116. logger.info('train from scratch')
  117. return best_model_dict
  118. def save_model(net,
  119. optimizer,
  120. model_path,
  121. logger,
  122. is_best=False,
  123. prefix='ppocr',
  124. **kwargs):
  125. """
  126. save model to the target path
  127. """
  128. _mkdir_if_not_exist(model_path, logger)
  129. model_prefix = os.path.join(model_path, prefix)
  130. paddle.save(net.state_dict(), model_prefix + '.pdparams')
  131. paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
  132. # save metric and config
  133. with open(model_prefix + '.states', 'wb') as f:
  134. pickle.dump(kwargs, f, protocol=2)
  135. # if is_best:
  136. # logger.info('save best model is to {}'.format(model_prefix))
  137. # else:
  138. # logger.info("save model in {}".format(model_prefix))