train.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import argparse
  2. from easydict import EasyDict as edict
  3. import yaml
  4. import os
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. from torch.utils.data import DataLoader
  8. import lib.models.crnn as crnn
  9. import lib.utils.utils as utils
  10. from lib.dataset import get_dataset
  11. from lib.core import function
  12. import lib.config.alphabets as alphabets
  13. from tensorboardX import SummaryWriter
  14. def parse_arg():
  15. parser = argparse.ArgumentParser(description="train crnn")
  16. parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
  17. args = parser.parse_args()
  18. with open(args.cfg, 'r') as f:
  19. # config = yaml.load(f, Loader=yaml.FullLoader)
  20. config = yaml.load(f)
  21. config = edict(config)
  22. config.DATASET.ALPHABETS = alphabets.alphabet
  23. config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
  24. return config
  25. def main():
  26. # load config
  27. config = parse_arg()
  28. # create output folder
  29. output_dict = utils.create_log_folder(config, phase='train')
  30. # cudnn
  31. cudnn.benchmark = config.CUDNN.BENCHMARK
  32. cudnn.deterministic = config.CUDNN.DETERMINISTIC
  33. cudnn.enabled = config.CUDNN.ENABLED
  34. # writer dict
  35. writer_dict = {
  36. 'writer': SummaryWriter(log_dir=output_dict['tb_dir']),
  37. 'train_global_steps': 0,
  38. 'valid_global_steps': 0,
  39. }
  40. # construct face related neural networks
  41. model = crnn.get_crnn(config)
  42. # get device
  43. if torch.cuda.is_available():
  44. device = torch.device("cuda:{}".format(config.GPUID))
  45. else:
  46. device = torch.device("cpu:0")
  47. model = model.to(device)
  48. # define loss function
  49. criterion = torch.nn.CTCLoss()
  50. optimizer = utils.get_optimizer(config, model)
  51. last_epoch = config.TRAIN.BEGIN_EPOCH
  52. if config.TRAIN.RESUME.IS_RESUME:
  53. model_state_file = config.TRAIN.RESUME.FILE
  54. if model_state_file == '':
  55. print(" => no checkpoint found")
  56. checkpoint = torch.load(model_state_file, map_location='cpu')
  57. model.load_state_dict(checkpoint['state_dict'])
  58. last_epoch = checkpoint['epoch']
  59. if isinstance(config.TRAIN.LR_STEP, list):
  60. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
  61. optimizer, config.TRAIN.LR_STEP,
  62. config.TRAIN.LR_FACTOR, last_epoch-1
  63. )
  64. else:
  65. lr_scheduler = torch.optim.lr_scheduler.StepLR(
  66. optimizer, config.TRAIN.LR_STEP,
  67. config.TRAIN.LR_FACTOR, last_epoch - 1
  68. )
  69. train_dataset = get_dataset(config)(config, is_train=True)
  70. train_loader = DataLoader(
  71. dataset=train_dataset,
  72. batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
  73. shuffle=config.TRAIN.SHUFFLE,
  74. num_workers=config.WORKERS,
  75. pin_memory=config.PIN_MEMORY,
  76. )
  77. val_dataset = get_dataset(config)(config, is_train=False)
  78. val_loader = DataLoader(
  79. dataset=val_dataset,
  80. batch_size=config.TEST.BATCH_SIZE_PER_GPU,
  81. shuffle=config.TEST.SHUFFLE,
  82. num_workers=config.WORKERS,
  83. pin_memory=config.PIN_MEMORY,
  84. )
  85. best_acc = 0.5
  86. converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
  87. for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
  88. function.train(config, train_loader, train_dataset, converter, model, criterion, optimizer, device, epoch, writer_dict, output_dict)
  89. lr_scheduler.step()
  90. acc = function.validate(config, val_loader, val_dataset, converter, model, criterion, device, epoch, writer_dict, output_dict)
  91. is_best = acc > best_acc
  92. best_acc = max(acc, best_acc)
  93. print("is best:", is_best)
  94. print("best acc is:", best_acc)
  95. # save checkpoint
  96. torch.save(
  97. {
  98. "state_dict": model.state_dict(),
  99. "epoch": epoch + 1,
  100. "best_acc": best_acc,
  101. }, os.path.join(output_dict['chs_dir'], "checkpoint_{}_acc_{:.4f}.pth".format(epoch, acc))
  102. )
  103. writer_dict['writer'].close()
  104. if __name__ == '__main__':
  105. main()