det_train_disti.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/5/19 21:44
  3. # @Author : xiangjing
  4. import os
  5. import sys
  6. import pathlib
  7. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  8. os.environ['CUDA_VISIBLE_DEVICES'] = '3'
  9. # 将 torchocr路径加到python路径里
  10. __dir__ = pathlib.Path(os.path.abspath(__file__))
  11. sys.path.append(str(__dir__))
  12. sys.path.append(str(__dir__.parent.parent))
  13. import random
  14. import time
  15. import shutil
  16. import traceback
  17. from importlib import import_module
  18. import numpy as np
  19. import torch
  20. from tqdm import tqdm
  21. from torch import nn
  22. from torchocr.networks import build_model, build_loss
  23. from torchocr.postprocess import build_post_process
  24. from torchocr.datasets import build_dataloader
  25. from torchocr.utils import get_logger, weight_init, load_checkpoint, save_checkpoint
  26. from torchocr.metrics import build_metric
  27. def parse_args():
  28. import argparse
  29. parser = argparse.ArgumentParser(description='train')
  30. parser.add_argument('--config', type=str, default='config/cfg_det_dis.py', help='train config file path')
  31. args = parser.parse_args()
  32. # 解析.py文件
  33. config_path = os.path.abspath(os.path.expanduser(args.config))
  34. assert os.path.isfile(config_path)
  35. if config_path.endswith('.py'):
  36. module_name = os.path.basename(config_path)[:-3]
  37. config_dir = os.path.dirname(config_path)
  38. sys.path.insert(0, config_dir)
  39. mod = import_module(module_name)
  40. sys.path.pop(0)
  41. return mod.config
  42. else:
  43. raise IOError('Only py type are supported now!')
  44. def set_random_seed(seed, use_cuda=True, deterministic=False):
  45. """Set random seed.
  46. Args:
  47. seed (int): Seed to be used.
  48. use_cuda: whether depend on cuda
  49. deterministic (bool): Whether to set the deterministic option for
  50. CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
  51. to True and `torch.backends.cudnn.benchmark` to False.
  52. Default: False.
  53. """
  54. random.seed(seed)
  55. np.random.seed(seed)
  56. if use_cuda:
  57. torch.manual_seed(seed)
  58. torch.cuda.manual_seed_all(seed)
  59. if deterministic:
  60. torch.backends.cudnn.deterministic = True
  61. torch.backends.cudnn.benchmark = False
  62. def build_optimizer(params, config):
  63. """
  64. 优化器
  65. Returns:
  66. """
  67. from torch import optim
  68. opt_type = config.pop('type')
  69. opt = getattr(optim, opt_type)(filter(lambda p: p.requires_grad,params), **config)
  70. return opt
  71. def adjust_learning_rate(optimizer, base_lr, iter, all_iters, factor, warmup_iters=0, warmup_factor=1.0 / 3):
  72. """
  73. 带 warmup 的学习率衰减
  74. :param optimizer: 优化器
  75. :param base_lr: 开始的学习率
  76. :param iter: 当前迭代次数
  77. :param all_iters: 总的迭代次数
  78. :param factor: 学习率衰减系数
  79. :param warmup_iters: warmup 迭代数
  80. :param warmup_factor: warmup 系数
  81. :return:
  82. """
  83. """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
  84. if iter < warmup_iters:
  85. alpha = float(iter) / warmup_iters
  86. rate = warmup_factor * (1 - alpha) + alpha
  87. else:
  88. rate = np.power(1.0 - iter / float(all_iters + 1), factor)
  89. lr = rate * base_lr
  90. for param_group in optimizer.param_groups:
  91. param_group['lr'] = lr
  92. return lr
  93. def get_fine_tune_params(net, finetune_stage):
  94. """
  95. 获取需要优化的参数
  96. Args:
  97. net:
  98. Returns: 需要优化的参数
  99. """
  100. to_return_parameters = []
  101. for stage in finetune_stage:
  102. attr = getattr(net.module, stage, None)
  103. for element in attr.parameters():
  104. to_return_parameters.append(element)
  105. return to_return_parameters
  106. def evaluate(net, val_loader, to_use_device, logger, post_process, metric):
  107. """
  108. 在验证集上评估模型
  109. :param net: 网络
  110. :param val_loader: 验证集 dataloader
  111. :param to_use_device: device
  112. :param logger: logger类对象
  113. :param post_process: 后处理类对象
  114. :param metric: 根据网络输出和 label 计算 acc 等指标的类对象
  115. :return: 一个包含 eval_loss,eval_acc和 norm_edit_dis 的 dict,
  116. 例子: {
  117. 'recall':0,
  118. 'precision': 0.99,
  119. 'fmeasure': 0.9999,
  120. }
  121. """
  122. logger.info('start evaluate')
  123. net.eval()
  124. total_frame = 0.0
  125. total_time = 0.0
  126. with torch.no_grad():
  127. for batch_data in tqdm(val_loader):
  128. start = time.time()
  129. output = net.forward(batch_data['img'].to(to_use_device))
  130. box_score_tuple = post_process(output, batch_data['shape'])
  131. total_frame += batch_data['img'].size()[0]
  132. total_time += time.time() - start
  133. metric(batch_data, box_score_tuple)
  134. metrics = metric.get_metric()
  135. net.train()
  136. net.module.model_dict['Teacher'].eval()
  137. metrics = {key: val.avg for key, val in metrics.items()}
  138. for k, v in metrics.items():
  139. logger.info(f'{k}:{v}')
  140. logger.info('FPS:{}'.format(total_frame / total_time))
  141. return metrics
  142. def train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device,
  143. cfg, global_state, logger, post_process, metric):
  144. """
  145. 训练函数
  146. :param net: 网络
  147. :param optimizer: 优化器
  148. :param scheduler: 学习率更新器
  149. :param loss_func: loss函数
  150. :param train_loader: 训练数据集 dataloader
  151. :param eval_loader: 验证数据集 dataloader
  152. :param to_use_device: device
  153. :param cfg: 当前训练所使用的配置
  154. :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
  155. :param logger: logger 对象
  156. :param post_process: 后处理类对象
  157. :param metric: 评测方法
  158. :return: None
  159. """
  160. train_options = cfg.train_options
  161. logger.info('Train beginning...')
  162. # ===> print loss信息的参数
  163. all_step = len(train_loader)
  164. logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
  165. logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
  166. if len(global_state) > 0:
  167. best_model = global_state['best_model']
  168. start_epoch = global_state['start_epoch']
  169. global_step = global_state['global_step']
  170. else:
  171. best_model = {'recall': 0, 'precision': 0, 'fmeasure': 0, 'best_model_epoch': 0}
  172. start_epoch = 0
  173. global_step = 0
  174. # 开始训练
  175. base_lr = cfg['optimizer']['lr']
  176. all_iters = all_step * train_options['epochs']
  177. warmup_iters = 3 * all_step
  178. # eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric)
  179. try:
  180. for epoch in range(start_epoch, train_options['epochs']): # traverse each epoch
  181. net.train() # train mode
  182. net.module.model_dict['Teacher'].eval()
  183. train_loss = 0.
  184. start = time.time()
  185. for i, batch_data in enumerate(train_loader): # traverse each batch in the epoch
  186. current_lr = adjust_learning_rate(optimizer, base_lr, global_step, all_iters, 0.9,
  187. warmup_iters=warmup_iters)
  188. # 数据进行转换和丢到gpu
  189. for key, value in batch_data.items():
  190. if value is not None:
  191. if isinstance(value, torch.Tensor):
  192. batch_data[key] = value.to(to_use_device)
  193. # 清零梯度及反向传播
  194. optimizer.zero_grad()
  195. output = net.forward(batch_data['img'].to(to_use_device))
  196. loss_dict = loss_func(output, batch_data)
  197. loss_dict['loss'].backward()
  198. optimizer.step()
  199. # statistic loss for print
  200. train_loss += loss_dict['loss'].item()
  201. loss_str = 'loss: {:.4f} - '.format(loss_dict.pop('loss').item())
  202. for idx, (key, value) in enumerate(loss_dict.items()):
  203. loss_dict[key] = value.item()
  204. loss_str += '{}: {:.4f}'.format(key, loss_dict[key])
  205. if idx < len(loss_dict) - 1:
  206. loss_str += ' - '
  207. if (i + 1) % train_options['print_interval'] == 0:
  208. interval_batch_time = time.time() - start
  209. logger.info(f"[{epoch}/{train_options['epochs']}] - "
  210. f"[{i + 1}/{all_step}] - "
  211. f"lr:{current_lr} - "
  212. f"{loss_str} - "
  213. f"time:{interval_batch_time:.4f}")
  214. start = time.time()
  215. global_step += 1
  216. logger.info(f'train_loss: {train_loss / len(train_loader)}')
  217. if (epoch + 1) % train_options['val_interval'] == 0:
  218. global_state['start_epoch'] = epoch
  219. global_state['best_model'] = best_model
  220. global_state['global_step'] = global_step
  221. net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
  222. save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
  223. if train_options['ckpt_save_type'] == 'HighestAcc':
  224. # val
  225. eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric)
  226. if eval_dict['fmeasure'] > best_model['fmeasure']:
  227. best_model.update(eval_dict)
  228. best_model['best_model_epoch'] = epoch
  229. best_model['models'] = net_save_path
  230. global_state['start_epoch'] = epoch
  231. global_state['best_model'] = best_model
  232. global_state['global_step'] = global_step
  233. net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
  234. save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
  235. elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0:
  236. shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
  237. best_str = 'current best, '
  238. for k, v in best_model.items():
  239. best_str += '{}: {}, '.format(k, v)
  240. logger.info(best_str)
  241. except KeyboardInterrupt:
  242. import os
  243. save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg,
  244. global_state=global_state)
  245. except:
  246. error_msg = traceback.format_exc()
  247. logger.error(error_msg)
  248. finally:
  249. for k, v in best_model.items():
  250. logger.info(f'{k}: {v}')
  251. def main():
  252. # ===> 获取配置文件参数
  253. cfg = parse_args()
  254. os.makedirs(cfg.train_options['checkpoint_save_dir'], exist_ok=True)
  255. # ===> 训练信息的打印
  256. logger = get_logger('torchocr', log_file=os.path.join(cfg.train_options['checkpoint_save_dir'], 'train.log'))
  257. logger.info(cfg)
  258. # ===>
  259. train_options = cfg.train_options
  260. to_use_device = torch.device(
  261. train_options['device'] if torch.cuda.is_available() and ('cuda' in train_options['device']) else 'cpu')
  262. set_random_seed(cfg['SEED'], 'cuda' in train_options['device'], deterministic=True)
  263. # ===> build network
  264. net = build_model(cfg['model'])
  265. # ===> 模型部署到对应的设备
  266. net = nn.DataParallel(net)
  267. net = net.to(to_use_device)
  268. # ===> 创建metric
  269. metric = build_metric(cfg['metric'])
  270. # ===> get fine tune layers
  271. # params_to_train = get_fine_tune_params(net, train_options['fine_tune_stage'])
  272. # ===> solver and lr scheduler
  273. optimizer = build_optimizer(net.parameters(), cfg['optimizer'])
  274. net.train()
  275. net.module.model_dict['Teacher'].eval()
  276. # ===> whether to resume from checkpoint
  277. resume_from = train_options['resume_from']
  278. if resume_from:
  279. net, _resumed_optimizer, global_state = load_checkpoint(net, resume_from, to_use_device, optimizer)
  280. if _resumed_optimizer:
  281. optimizer = _resumed_optimizer
  282. logger.info(f'net resume from {resume_from}')
  283. else:
  284. global_state = {}
  285. logger.info(f'net resume from scratch.')
  286. # ===> loss function
  287. loss_func = build_loss(cfg['loss'])
  288. loss_func = loss_func.to(to_use_device)
  289. # ===> data loader
  290. train_loader = build_dataloader(cfg.dataset.train)
  291. eval_loader = build_dataloader(cfg.dataset.eval)
  292. # ===> post_process
  293. post_process = build_post_process(cfg['post_process'])
  294. # ===> train
  295. train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger, post_process,metric)
  296. if __name__ == '__main__':
  297. main()