rec_fineturn.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/5/19 21:44
  3. # @Author : xiangjing
  4. import os
  5. import sys
  6. import pathlib
  7. # 将 torchocr路径加到python路径里
  8. __dir__ = pathlib.Path(os.path.abspath(__file__))
  9. sys.path.append(str(__dir__))
  10. sys.path.append(str(__dir__.parent.parent))
  11. import random
  12. import time
  13. import shutil
  14. import traceback
  15. from importlib import import_module
  16. import numpy as np
  17. import torch
  18. from tqdm import tqdm
  19. from torch import nn
  20. from torch import optim
  21. from torchocr.networks import build_model, build_loss
  22. from torchocr.datasets import build_dataloader
  23. from torchocr.utils import get_logger, weight_init, load_checkpoint, save_checkpoint
  24. def parse_args():
  25. import argparse
  26. parser = argparse.ArgumentParser(description='train')
  27. parser.add_argument('--config', type=str, default='/data2/znj/PytorchOCR/config/cfg_rec_crnn_doc_fineturn.py',
  28. help='train config file path')
  29. args = parser.parse_args()
  30. # 解析.py文件
  31. config_path = os.path.abspath(os.path.expanduser(args.config))
  32. assert os.path.isfile(config_path)
  33. if config_path.endswith('.py'):
  34. module_name = os.path.basename(config_path)[:-3]
  35. config_dir = os.path.dirname(config_path)
  36. sys.path.insert(0, config_dir)
  37. mod = import_module(module_name)
  38. sys.path.pop(0)
  39. return mod.config
  40. # cfg_dict = {
  41. # name: value
  42. # for name, value in mod.__dict__.items()
  43. # if not name.startswith('__')
  44. # }
  45. # return cfg_dict
  46. else:
  47. raise IOError('Only py type are supported now!')
  48. def set_random_seed(seed, use_cuda=True, deterministic=False):
  49. """Set random seed.
  50. Args:
  51. seed (int): Seed to be used.
  52. use_cuda: whether depend on cuda
  53. deterministic (bool): Whether to set the deterministic option for
  54. CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
  55. to True and `torch.backends.cudnn.benchmark` to False.
  56. Default: False.
  57. """
  58. random.seed(seed)
  59. np.random.seed(seed)
  60. if use_cuda:
  61. torch.manual_seed(seed)
  62. torch.cuda.manual_seed_all(seed)
  63. if deterministic:
  64. torch.backends.cudnn.deterministic = True
  65. torch.backends.cudnn.benchmark = False
  66. def build_optimizer(params, config):
  67. """
  68. 优化器
  69. Returns:
  70. """
  71. opt_type = config.pop('type')
  72. opt = getattr(optim, opt_type)(params, **config)
  73. return opt
  74. def build_scheduler(optimizer, config):
  75. """
  76. """
  77. scheduler = None
  78. sch_type = config.pop('type')
  79. if sch_type == 'LambdaLR':
  80. burn_in, steps = config['burn_in'], config['steps']
  81. # Learning rate setup
  82. def burnin_schedule(i):
  83. if i < burn_in:
  84. factor = pow(i / burn_in, 4)
  85. elif i < steps[0]:
  86. factor = 1.0
  87. elif i < steps[1]:
  88. factor = 0.1
  89. else:
  90. factor = 0.01
  91. return factor
  92. scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)
  93. elif sch_type == 'StepLR':
  94. # 等间隔调整学习率, 调整倍数为gamma倍,调整间隔为step_size,间隔单位是step,step通常是指epoch。
  95. step_size, gamma = config['step_size'], config['gamma']
  96. scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
  97. elif sch_type == 'ReduceLROnPlateau':
  98. # 当某指标不再变化(下降或升高),调整学习率,这是非常实用的学习率调整策略。例如,当验证集的loss不再下降时,进行学习率调整;或者监测验证集的accuracy,当accuracy不再上升时,则调整学习率。
  99. scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1,
  100. patience=3, verbose=True, threshold=1e-4)
  101. return scheduler
  102. def get_fine_tune_params(net, finetune_stage):
  103. """
  104. 获取需要优化的参数
  105. Args:
  106. net:
  107. Returns: 需要优化的参数
  108. """
  109. to_return_parameters = []
  110. for stage in finetune_stage:
  111. attr = getattr(net.module, stage, None)
  112. for element in attr.parameters():
  113. to_return_parameters.append(element)
  114. return to_return_parameters
  115. def evaluate(net, val_loader, loss_func, to_use_device, logger, converter, metric):
  116. """
  117. 在验证集上评估模型
  118. :param net: 网络
  119. :param val_loader: 验证集 dataloader
  120. :param loss_func: 损失函数
  121. :param to_use_device: device
  122. :param logger: logger类对象
  123. :param converter: label转换器类对象
  124. :param metric: 根据网络输出和 label 计算 acc 等指标的类对象
  125. :return: 一个包含 eval_loss,eval_acc和 norm_edit_dis 的 dict,
  126. 例子: {
  127. 'eval_loss':0,
  128. 'eval_acc': 0.99,
  129. 'norm_edit_dis': 0.9999,
  130. }
  131. """
  132. logger.info('start evaluate')
  133. net.eval()
  134. nums = 0
  135. result_dict = {'eval_loss': 0., 'eval_acc': 0., 'norm_edit_dis': 0.}
  136. show_str = []
  137. with torch.no_grad():
  138. for batch_data in tqdm(val_loader):
  139. targets, targets_lengths = converter.encode(batch_data['label'])
  140. batch_data['targets'] = targets
  141. batch_data['targets_lengths'] = targets_lengths
  142. output = net.forward(batch_data['img'].to(to_use_device))
  143. loss = loss_func(output, batch_data)
  144. nums += batch_data['img'].shape[0]
  145. acc_dict = metric(output[1], batch_data['label'])
  146. result_dict['eval_loss'] += loss['loss'].item()
  147. result_dict['eval_acc'] += acc_dict['n_correct']
  148. result_dict['norm_edit_dis'] += acc_dict['norm_edit_dis']
  149. show_str.extend(acc_dict['show_str'])
  150. print('nums:', nums, 'right_nums:', result_dict['eval_acc'])
  151. result_dict['eval_loss'] /= len(val_loader)
  152. result_dict['eval_acc'] /= nums
  153. result_dict['norm_edit_dis'] = 1 - result_dict['norm_edit_dis'] / nums
  154. logger.info(f"eval_loss:{result_dict['eval_loss']}")
  155. logger.info(f"eval_acc:{result_dict['eval_acc']}")
  156. logger.info(f"norm_edit_dis:{result_dict['norm_edit_dis']}")
  157. for s in show_str[:10]:
  158. logger.info(s)
  159. net.train()
  160. return result_dict
  161. def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device,
  162. cfg, global_state, logger):
  163. """
  164. 训练函数
  165. :param net: 网络
  166. :param optimizer: 优化器
  167. :param scheduler: 学习率更新器
  168. :param loss_func: loss函数
  169. :param train_loader: 训练数据集 dataloader
  170. :param eval_loader: 验证数据集 dataloader
  171. :param to_use_device: device
  172. :param cfg: 当前训练所使用的配置
  173. :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
  174. :param logger: logger 对象
  175. :return: None
  176. """
  177. from torchocr.metrics import RecMetric
  178. from torchocr.utils import CTCLabelConverter
  179. converter = CTCLabelConverter(cfg.dataset.alphabet)
  180. train_options = cfg.train_options
  181. metric = RecMetric(converter)
  182. # ===>
  183. logger.info('Training...')
  184. # ===> print loss信息的参数
  185. all_step = len(train_loader)
  186. logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
  187. logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
  188. if len(global_state) > 0:
  189. best_model = global_state['best_model']
  190. start_epoch = global_state['start_epoch']
  191. global_step = global_state['global_step']
  192. else:
  193. best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
  194. start_epoch = 0
  195. global_step = 0
  196. # 开始训练
  197. try:
  198. for epoch in range(start_epoch, train_options['epochs']): # traverse each epoch
  199. net.train() # train mode
  200. start = time.time()
  201. for i, batch_data in enumerate(train_loader): # traverse each batch in the epoch
  202. current_lr = optimizer.param_groups[0]['lr']
  203. cur_batch_size = batch_data['img'].shape[0]
  204. targets, targets_lengths = converter.encode(batch_data['label'])
  205. batch_data['targets'] = targets
  206. batch_data['targets_lengths'] = targets_lengths
  207. # 清零梯度及反向传播
  208. optimizer.zero_grad()
  209. output = net.forward(batch_data['img'].to(to_use_device))
  210. loss_dict = loss_func(output, batch_data)
  211. loss_dict['loss'].backward()
  212. torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
  213. optimizer.step()
  214. # statistic loss for print
  215. acc_dict = metric(output[1], batch_data['label'])
  216. acc = acc_dict['n_correct'] / cur_batch_size
  217. norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
  218. if (i + 1) % train_options['print_interval'] == 0:
  219. interval_batch_time = time.time() - start
  220. logger.info(f"[{epoch}/{train_options['epochs']}] - "
  221. f"[{i + 1}/{all_step}] - "
  222. f"lr:{current_lr} - "
  223. f"loss:{loss_dict['loss'].item():.4f} - "
  224. f"acc:{acc:.4f} - "
  225. f"norm_edit_dis:{norm_edit_dis:.4f} - "
  226. f"time:{interval_batch_time:.4f}")
  227. start = time.time()
  228. if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
  229. global_state['start_epoch'] = epoch
  230. global_state['best_model'] = best_model
  231. global_state['global_step'] = global_step
  232. net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
  233. save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
  234. if train_options['ckpt_save_type'] == 'HighestAcc':
  235. # val
  236. eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
  237. if eval_dict['eval_acc'] > best_model['eval_acc']:
  238. best_model.update(eval_dict)
  239. best_model['best_model_epoch'] = epoch
  240. best_model['models'] = net_save_path
  241. global_state['start_epoch'] = epoch
  242. global_state['best_model'] = best_model
  243. global_state['global_step'] = global_step
  244. net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
  245. save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
  246. elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options[
  247. 'ckpt_save_epoch'] == 0:
  248. shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
  249. global_step += 1
  250. scheduler.step()
  251. except KeyboardInterrupt:
  252. import os
  253. save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg,
  254. global_state=global_state)
  255. except:
  256. error_msg = traceback.format_exc()
  257. logger.error(error_msg)
  258. finally:
  259. for k, v in best_model.items():
  260. logger.info(f'{k}: {v}')
  261. def main():
  262. # ===> 获取配置文件参数
  263. cfg = parse_args()
  264. os.makedirs(cfg.train_options['checkpoint_save_dir'], exist_ok=True)
  265. logger = get_logger('torchocr', log_file=os.path.join(cfg.train_options['checkpoint_save_dir'], 'train.log'))
  266. # ===> 训练信息的打印
  267. train_options = cfg.train_options
  268. logger.info(cfg)
  269. # ===>
  270. to_use_device = torch.device(
  271. train_options['device'] if torch.cuda.is_available() and ('cuda' in train_options['device']) else 'cpu')
  272. set_random_seed(cfg['SEED'], 'cuda' in train_options['device'], deterministic=True)
  273. # ===> build network
  274. net = build_model(cfg['model'])
  275. # ===> 模型初始化及模型部署到对应的设备
  276. if not cfg['model']['backbone']['pretrained']: # 使用 pretrained
  277. net.apply(weight_init)
  278. # if torch.cuda.device_count() > 1:
  279. net = nn.DataParallel(net)
  280. net = net.to(to_use_device)
  281. net.train()
  282. # ===> get fine tune layers
  283. params_to_train = get_fine_tune_params(net, train_options['fine_tune_stage'])
  284. # ===> solver and lr scheduler
  285. optimizer = build_optimizer(params_to_train, cfg['optimizer'])
  286. scheduler = build_scheduler(optimizer, cfg['lr_scheduler'])
  287. # ===> whether to resume from checkpoint
  288. resume_from = train_options['resume_from']
  289. if resume_from:
  290. net, _resumed_optimizer, global_state = load_checkpoint(net, resume_from, to_use_device, optimizer,
  291. third_name=train_options['third_party_name'])
  292. if _resumed_optimizer:
  293. optimizer = _resumed_optimizer
  294. logger.info(f'net resume from {resume_from}')
  295. else:
  296. global_state = {}
  297. logger.info(f'net resume from scratch.')
  298. # ===> loss function
  299. loss_func = build_loss(cfg['loss'])
  300. loss_func = loss_func.to(to_use_device)
  301. # ===> data loader
  302. cfg.dataset.train.dataset.alphabet = cfg.dataset.alphabet
  303. train_loader = build_dataloader(cfg.dataset.train)
  304. cfg.dataset.eval.dataset.alphabet = cfg.dataset.alphabet
  305. eval_loader = build_dataloader(cfg.dataset.eval)
  306. # ===> train
  307. train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger)
  308. if __name__ == '__main__':
  309. main()