utils.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import torch.optim as optim
  2. import time
  3. from pathlib import Path
  4. import os
  5. import torch
  6. def get_optimizer(config, model):
  7. optimizer = None
  8. if config.TRAIN.OPTIMIZER == "sgd":
  9. optimizer = optim.SGD(
  10. filter(lambda p: p.requires_grad, model.parameters()),
  11. lr=config.TRAIN.LR,
  12. momentum=config.TRAIN.MOMENTUM,
  13. weight_decay=config.TRAIN.WD,
  14. nesterov=config.TRAIN.NESTEROV
  15. )
  16. elif config.TRAIN.OPTIMIZER == "adam":
  17. optimizer = optim.Adam(
  18. filter(lambda p: p.requires_grad, model.parameters()),
  19. lr=config.TRAIN.LR,
  20. )
  21. elif config.TRAIN.OPTIMIZER == "rmsprop":
  22. optimizer = optim.RMSprop(
  23. filter(lambda p: p.requires_grad, model.parameters()),
  24. lr=config.TRAIN.LR,
  25. momentum=config.TRAIN.MOMENTUM,
  26. weight_decay=config.TRAIN.WD,
  27. # alpha=config.TRAIN.RMSPROP_ALPHA,
  28. # centered=config.TRAIN.RMSPROP_CENTERED
  29. )
  30. return optimizer
  31. def create_log_folder(cfg, phase='train'):
  32. root_output_dir = Path(cfg.OUTPUT_DIR)
  33. # set up logger
  34. if not root_output_dir.exists():
  35. print('=> creating {}'.format(root_output_dir))
  36. root_output_dir.mkdir()
  37. dataset = cfg.DATASET.DATASET
  38. model = cfg.MODEL.NAME
  39. time_str = time.strftime('%Y-%m-%d-%H-%M')
  40. checkpoints_output_dir = root_output_dir / dataset / model / time_str / 'checkpoints'
  41. print('=> creating {}'.format(checkpoints_output_dir))
  42. checkpoints_output_dir.mkdir(parents=True, exist_ok=True)
  43. tensorboard_log_dir = root_output_dir / dataset / model / time_str / 'log'
  44. print('=> creating {}'.format(tensorboard_log_dir))
  45. tensorboard_log_dir.mkdir(parents=True, exist_ok=True)
  46. return {'chs_dir': str(checkpoints_output_dir), 'tb_dir': str(tensorboard_log_dir)}
  47. def get_batch_label(d, i):
  48. label = []
  49. for idx in i:
  50. label.append(list(d.labels[idx].values())[0])
  51. return label
  52. class strLabelConverter(object):
  53. """Convert between str and label.
  54. NOTE:
  55. Insert `blank` to the alphabet for CTC.
  56. Args:
  57. alphabet (str): set of the possible characters.
  58. ignore_case (bool, default=True): whether or not to ignore all of the case.
  59. """
  60. def __init__(self, alphabet, ignore_case=False):
  61. self._ignore_case = ignore_case
  62. if self._ignore_case:
  63. alphabet = alphabet.lower()
  64. self.alphabet = alphabet + '-' # for `-1` index
  65. self.dict = {}
  66. for i, char in enumerate(alphabet):
  67. # NOTE: 0 is reserved for 'blank' required by wrap_ctc
  68. self.dict[char] = i + 1
  69. def encode(self, text):
  70. """Support batch or single str.
  71. Args:
  72. text (str or list of str): texts to convert.
  73. Returns:
  74. torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
  75. torch.IntTensor [n]: length of each text.
  76. """
  77. length = []
  78. result = []
  79. decode_flag = True if type(text[0])==bytes else False
  80. for item in text:
  81. if decode_flag:
  82. item = item.decode('utf-8','strict')
  83. length.append(len(item))
  84. for char in item:
  85. index = self.dict[char]
  86. result.append(index)
  87. text = result
  88. return (torch.IntTensor(text), torch.IntTensor(length))
  89. def decode(self, t, length, raw=False):
  90. """Decode encoded texts back into strs.
  91. Args:
  92. torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
  93. torch.IntTensor [n]: length of each text.
  94. Raises:
  95. AssertionError: when the texts and its length does not match.
  96. Returns:
  97. text (str or list of str): texts to convert.
  98. """
  99. if length.numel() == 1:
  100. length = length[0]
  101. assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
  102. if raw:
  103. return ''.join([self.alphabet[i - 1] for i in t])
  104. else:
  105. char_list = []
  106. for i in range(length):
  107. if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
  108. char_list.append(self.alphabet[t[i] - 1])
  109. return ''.join(char_list)
  110. else:
  111. # batch mode
  112. assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
  113. texts = []
  114. index = 0
  115. for i in range(length.numel()):
  116. l = length[i]
  117. texts.append(
  118. self.decode(
  119. t[index:index + l], torch.IntTensor([l]), raw=raw))
  120. index += l
  121. return texts
  122. def get_char_dict(path):
  123. with open(path, 'rb') as file:
  124. char_dict = {num: char.strip().decode('gbk', 'ignore') for num, char in enumerate(file.readlines())}
  125. def model_info(model): # Plots a line-by-line description of a PyTorch model
  126. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  127. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  128. print('\n%5s %50s %9s %12s %20s %12s %12s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  129. for i, (name, p) in enumerate(model.named_parameters()):
  130. name = name.replace('module_list.', '')
  131. print('%5g %50s %9s %12g %20s %12.3g %12.3g' % (
  132. i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  133. print('Model Summary: %g layers, %g parameters, %g gradients\n' % (i + 1, n_p, n_g))