train.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import argparse
  2. from easydict import EasyDict as edict
  3. import yaml
  4. import os
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. from torch.utils.data import DataLoader
  8. import lib.models.crnn as crnn
  9. import lib.utils.utils as utils
  10. from lib.dataset import get_dataset
  11. from lib.core import function
  12. import lib.config.alphabets as alphabets
  13. from lib.utils.utils import model_info
  14. from tensorboardX import SummaryWriter
  15. def parse_arg():
  16. parser = argparse.ArgumentParser(description="train crnn")
  17. parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
  18. args = parser.parse_args()
  19. with open(args.cfg, 'r') as f:
  20. # config = yaml.load(f, Loader=yaml.FullLoader)
  21. config = yaml.load(f)
  22. config = edict(config)
  23. config.DATASET.ALPHABETS = alphabets.alphabet
  24. config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
  25. return config
  26. def main():
  27. # load config
  28. config = parse_arg()
  29. # create output folder
  30. output_dict = utils.create_log_folder(config, phase='train')
  31. # cudnn
  32. cudnn.benchmark = config.CUDNN.BENCHMARK
  33. cudnn.deterministic = config.CUDNN.DETERMINISTIC
  34. cudnn.enabled = config.CUDNN.ENABLED
  35. # writer dict
  36. writer_dict = {
  37. 'writer': SummaryWriter(log_dir=output_dict['tb_dir']),
  38. 'train_global_steps': 0,
  39. 'valid_global_steps': 0,
  40. }
  41. # construct face related neural networks
  42. model = crnn.get_crnn(config)
  43. # get device
  44. if torch.cuda.is_available():
  45. device = torch.device("cuda:{}".format(config.GPUID))
  46. else:
  47. device = torch.device("cpu:0")
  48. model = model.to(device)
  49. # define loss function
  50. criterion = torch.nn.CTCLoss()
  51. last_epoch = config.TRAIN.BEGIN_EPOCH
  52. optimizer = utils.get_optimizer(config, model)
  53. if isinstance(config.TRAIN.LR_STEP, list):
  54. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
  55. optimizer, config.TRAIN.LR_STEP,
  56. config.TRAIN.LR_FACTOR, last_epoch-1
  57. )
  58. else:
  59. lr_scheduler = torch.optim.lr_scheduler.StepLR(
  60. optimizer, config.TRAIN.LR_STEP,
  61. config.TRAIN.LR_FACTOR, last_epoch - 1
  62. )
  63. if config.TRAIN.FINETUNE.IS_FINETUNE:
  64. model_state_file = config.TRAIN.FINETUNE.FINETUNE_CHECKPOINIT
  65. if model_state_file == '':
  66. print(" => no checkpoint found")
  67. checkpoint = torch.load(model_state_file, map_location='cpu')
  68. if 'state_dict' in checkpoint.keys():
  69. checkpoint = checkpoint['state_dict']
  70. from collections import OrderedDict
  71. model_dict = OrderedDict()
  72. for k, v in checkpoint.items():
  73. if 'cnn' in k:
  74. model_dict[k[4:]] = v
  75. model.cnn.load_state_dict(model_dict)
  76. if config.TRAIN.FINETUNE.FREEZE:
  77. for p in model.cnn.parameters():
  78. p.requires_grad = False
  79. elif config.TRAIN.RESUME.IS_RESUME:
  80. model_state_file = config.TRAIN.RESUME.FILE
  81. if model_state_file == '':
  82. print(" => no checkpoint found")
  83. checkpoint = torch.load(model_state_file, map_location='cpu')
  84. if 'state_dict' in checkpoint.keys():
  85. model.load_state_dict(checkpoint['state_dict'])
  86. last_epoch = checkpoint['epoch']
  87. # optimizer.load_state_dict(checkpoint['optimizer'])
  88. # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  89. else:
  90. model.load_state_dict(checkpoint)
  91. model_info(model)
  92. train_dataset = get_dataset(config)(config, is_train=True)
  93. train_loader = DataLoader(
  94. dataset=train_dataset,
  95. batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
  96. shuffle=config.TRAIN.SHUFFLE,
  97. num_workers=config.WORKERS,
  98. pin_memory=config.PIN_MEMORY,
  99. )
  100. val_dataset = get_dataset(config)(config, is_train=False)
  101. val_loader = DataLoader(
  102. dataset=val_dataset,
  103. batch_size=config.TEST.BATCH_SIZE_PER_GPU,
  104. shuffle=config.TEST.SHUFFLE,
  105. num_workers=config.WORKERS,
  106. pin_memory=config.PIN_MEMORY,
  107. )
  108. best_acc = 0.5
  109. converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
  110. for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
  111. function.train(config, train_loader, train_dataset, converter, model, criterion, optimizer, device, epoch, writer_dict, output_dict)
  112. lr_scheduler.step()
  113. acc = function.validate(config, val_loader, val_dataset, converter, model, criterion, device, epoch, writer_dict, output_dict)
  114. is_best = acc > best_acc
  115. best_acc = max(acc, best_acc)
  116. print("is best:", is_best)
  117. print("best acc is:", best_acc)
  118. # save checkpoint
  119. torch.save(
  120. {
  121. "state_dict": model.state_dict(),
  122. "epoch": epoch + 1,
  123. # "optimizer": optimizer.state_dict(),
  124. # "lr_scheduler": lr_scheduler.state_dict(),
  125. "best_acc": best_acc,
  126. }, os.path.join(output_dict['chs_dir'], "checkpoint_{}_acc_{:.4f}.pth".format(epoch, acc))
  127. )
  128. writer_dict['writer'].close()
  129. if __name__ == '__main__':
  130. main()