det_train_pse.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. import os
  2. import sys
  3. import pathlib
  4. # 将 torchocr路径加到python路径里
  5. __dir__ = pathlib.Path(os.path.abspath(__file__))
  6. sys.path.append(str(__dir__))
  7. sys.path.append(str(__dir__.parent.parent))
  8. import os
  9. # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  10. # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  11. import random
  12. import time
  13. import shutil
  14. import traceback
  15. from importlib import import_module
  16. import numpy as np
  17. import torch
  18. from tqdm import tqdm
  19. from torch import nn
  20. from torchocr.networks import build_model, build_loss
  21. from torchocr.postprocess import build_post_process
  22. from torchocr.datasets import build_dataloader
  23. from torchocr.utils import get_logger, weight_init, load_checkpoint, save_checkpoint
  24. from torchocr.metrics import DetMetric
  25. def parse_args():
  26. import argparse
  27. parser = argparse.ArgumentParser(description='train')
  28. parser.add_argument('--config', type=str, default='config/cfg_det_pse.py', help='train config file path')
  29. args = parser.parse_args()
  30. # 解析.py文件
  31. config_path = os.path.abspath(os.path.expanduser(args.config))
  32. assert os.path.isfile(config_path)
  33. if config_path.endswith('.py'):
  34. module_name = os.path.basename(config_path)[:-3]
  35. config_dir = os.path.dirname(config_path)
  36. sys.path.insert(0, config_dir)
  37. mod = import_module(module_name)
  38. sys.path.pop(0)
  39. return mod.config
  40. # cfg_dict = {
  41. # name: value
  42. # for name, value in mod.__dict__.items()
  43. # if not name.startswith('__')
  44. # }
  45. # return cfg_dict
  46. else:
  47. raise IOError('Only py type are supported now!')
  48. def set_random_seed(seed, use_cuda=True, deterministic=False):
  49. """Set random seed.
  50. Args:
  51. seed (int): Seed to be used.
  52. use_cuda: whether depend on cuda
  53. deterministic (bool): Whether to set the deterministic option for
  54. CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
  55. to True and `torch.backends.cudnn.benchmark` to False.
  56. Default: False.
  57. """
  58. random.seed(seed)
  59. np.random.seed(seed)
  60. if use_cuda:
  61. torch.manual_seed(seed)
  62. torch.cuda.manual_seed_all(seed)
  63. if deterministic:
  64. torch.backends.cudnn.deterministic = True
  65. torch.backends.cudnn.benchmark = False
  66. def build_optimizer(params, config):
  67. """
  68. 优化器
  69. Returns:
  70. """
  71. from torch import optim
  72. opt_type = config.pop('type')
  73. opt = getattr(optim, opt_type)(params, **config)
  74. return opt
  75. def adjust_learning_rate(optimizer, base_lr, iter, all_iters, factor, warmup_iters=300, warmup_factor=1.0 / 3):
  76. """
  77. 带 warmup 的学习率衰减
  78. :param optimizer: 优化器
  79. :param base_lr: 开始的学习率
  80. :param iter: 当前迭代次数
  81. :param all_iters: 总的迭代次数
  82. :param factor: 学习率衰减系数
  83. :param warmup_iters: warmup 迭代数
  84. :param warmup_factor: warmup 系数
  85. :return:
  86. """
  87. """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
  88. if iter < warmup_iters:
  89. alpha = float(iter) / warmup_iters
  90. rate = warmup_factor * (1 - alpha) + alpha
  91. else:
  92. rate = np.power(1.0 - iter / float(all_iters + 1), factor)
  93. lr = rate * base_lr
  94. for param_group in optimizer.param_groups:
  95. param_group['lr'] = lr
  96. return lr
  97. def get_fine_tune_params(net, finetune_stage):
  98. """
  99. 获取需要优化的参数
  100. Args:
  101. net:
  102. Returns: 需要优化的参数
  103. """
  104. to_return_parameters = []
  105. for stage in finetune_stage:
  106. attr = getattr(net.module, stage, None)
  107. for element in attr.parameters():
  108. to_return_parameters.append(element)
  109. return to_return_parameters
  110. def evaluate(net, val_loader, to_use_device, logger, post_process, metric):
  111. """
  112. 在验证集上评估模型
  113. :param net: 网络
  114. :param val_loader: 验证集 dataloader
  115. :param to_use_device: device
  116. :param logger: logger类对象
  117. :param post_process: 后处理类对象
  118. :param metric: 根据网络输出和 label 计算 acc 等指标的类对象
  119. :return: 一个包含 eval_loss,eval_acc和 norm_edit_dis 的 dict,
  120. 例子: {
  121. 'recall':0,
  122. 'precision': 0.99,
  123. 'hmean': 0.9999,
  124. }
  125. """
  126. logger.info('start evaluate')
  127. net.eval()
  128. raw_metrics = []
  129. total_frame = 0.0
  130. total_time = 0.0
  131. with torch.no_grad():
  132. idx = 0
  133. for batch_data in tqdm(val_loader):
  134. start = time.time()
  135. output = net.forward(batch_data['img'].to(to_use_device))
  136. h, w = batch_data['shape'][0].item(), batch_data['shape'][1].item()
  137. preds, boxes_list = post_process(output[0], 1)
  138. scale = (preds.shape[1] * 1.0 / w, preds.shape[0] * 1.0 / h)
  139. if len(boxes_list):
  140. boxes_list = boxes_list / scale
  141. scores = [1] * len(boxes_list)
  142. # x = output.detach().cpu().numpy().squeeze()
  143. # x = x > 0.7
  144. # x = x * 255
  145. # x = x.astype(np.uint8)
  146. # import cv2
  147. # cv2.imwrite(f'mask{idx}.png', x)
  148. # img = cv2.imread(batch_data['img_path'][0])
  149. # cv2.imwrite(f'gt{idx}.png', img)
  150. # idx += 1
  151. total_frame += batch_data['img'].size()[0]
  152. total_time += time.time() - start
  153. raw_metric = metric(batch_data, ([boxes_list], [scores]))
  154. raw_metrics.append(raw_metric)
  155. metrics = metric.gather_measure(raw_metrics)
  156. net.train()
  157. result_dict = {'recall': metrics['recall'].avg, 'precision': metrics['precision'].avg, 'hmean': metrics['fmeasure'].avg}
  158. for k, v in result_dict.items():
  159. logger.info(f'{k}:{v}')
  160. logger.info('FPS:{}'.format(total_frame / total_time))
  161. return result_dict
  162. def train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device,
  163. cfg, global_state, logger, post_process):
  164. """
  165. 训练函数
  166. :param net: 网络
  167. :param optimizer: 优化器
  168. :param scheduler: 学习率更新器
  169. :param loss_func: loss函数
  170. :param train_loader: 训练数据集 dataloader
  171. :param eval_loader: 验证数据集 dataloader
  172. :param to_use_device: device
  173. :param cfg: 当前训练所使用的配置
  174. :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
  175. :param logger: logger 对象
  176. :param post_process: 后处理类对象
  177. :return: None
  178. """
  179. train_options = cfg.train_options
  180. metric = DetMetric()
  181. # ===>
  182. logger.info('Training...')
  183. # ===> print loss信息的参数
  184. all_step = len(train_loader)
  185. logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
  186. logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
  187. if len(global_state) > 0:
  188. best_model = global_state['best_model']
  189. start_epoch = global_state['start_epoch']
  190. global_step = global_state['global_step']
  191. else:
  192. best_model = {'recall': 0, 'precision': 0, 'hmean': 0, 'best_model_epoch': 0}
  193. start_epoch = 0
  194. global_step = 0
  195. # 开始训练
  196. base_lr = cfg['optimizer']['lr']
  197. all_iters = all_step * train_options['epochs']
  198. warmup_iters = 3 * all_step
  199. # eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric)
  200. try:
  201. for epoch in range(start_epoch, train_options['epochs']): # traverse each epoch
  202. net.train() # train mode
  203. train_loss = 0.
  204. start = time.time()
  205. for i, batch_data in enumerate(train_loader): # traverse each batch in the epoch
  206. current_lr = adjust_learning_rate(optimizer, base_lr, global_step, all_iters, 0.9, warmup_iters=warmup_iters)
  207. # 数据进行转换和丢到gpu
  208. # for key, value in batch_data.items():
  209. # if value is not None:
  210. # if isinstance(value, torch.Tensor):
  211. # batch_data[key] = value.to(to_use_device)
  212. # 清零梯度及反向传播
  213. optimizer.zero_grad()
  214. output = net.forward(batch_data['img'].to(to_use_device))
  215. labels, training_mask = batch_data['score_maps'].to(to_use_device), batch_data['training_mask'].to(to_use_device)
  216. loss_c, loss_s, loss = loss_func(output, labels, training_mask)
  217. loss.backward()
  218. optimizer.step()
  219. # statistic loss for print
  220. train_loss += loss.item()
  221. loss_str = 'loss: {:.4f} - '.format(loss.item())
  222. if (i + 1) % train_options['print_interval'] == 0:
  223. interval_batch_time = time.time() - start
  224. logger.info(f"[{epoch}/{train_options['epochs']}] - "
  225. f"[{i + 1}/{all_step}] - "
  226. f"lr:{current_lr} - "
  227. f"{loss_str} - "
  228. f"time:{interval_batch_time:.4f}")
  229. start = time.time()
  230. global_step += 1
  231. logger.info(f'train_loss: {train_loss / len(train_loader)}')
  232. if (epoch + 1) % train_options['val_interval'] == 0:
  233. global_state['start_epoch'] = epoch
  234. global_state['best_model'] = best_model
  235. global_state['global_step'] = global_step
  236. net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
  237. save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
  238. if train_options['ckpt_save_type'] == 'HighestAcc':
  239. # val
  240. eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric)
  241. if eval_dict['hmean'] > best_model['hmean']:
  242. best_model.update(eval_dict)
  243. best_model['best_model_epoch'] = epoch
  244. best_model['models'] = net_save_path
  245. global_state['start_epoch'] = epoch
  246. global_state['best_model'] = best_model
  247. global_state['global_step'] = global_step
  248. net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
  249. save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
  250. elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0:
  251. shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
  252. best_str = 'current best, '
  253. for k, v in best_model.items():
  254. best_str += '{}: {}, '.format(k, v)
  255. logger.info(best_str)
  256. except KeyboardInterrupt:
  257. import os
  258. save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg, global_state=global_state)
  259. except:
  260. error_msg = traceback.format_exc()
  261. logger.error(error_msg)
  262. finally:
  263. for k, v in best_model.items():
  264. logger.info(f'{k}: {v}')
  265. def main():
  266. # ===> 获取配置文件参数
  267. cfg = parse_args()
  268. os.makedirs(cfg.train_options['checkpoint_save_dir'], exist_ok=True)
  269. logger = get_logger('torchocr', log_file=os.path.join(cfg.train_options['checkpoint_save_dir'], 'train.log'))
  270. # ===> 训练信息的打印
  271. train_options = cfg.train_options
  272. logger.info(cfg)
  273. # ===>
  274. to_use_device = torch.device(
  275. train_options['device'] if torch.cuda.is_available() and ('cuda' in train_options['device']) else 'cpu')
  276. set_random_seed(cfg['SEED'], 'cuda' in train_options['device'], deterministic=True)
  277. # ===> build network
  278. net = build_model(cfg['model'])
  279. # ===> 模型初始化及模型部署到对应的设备
  280. # net.apply(weight_init) # 使用 pretrained时,注释掉这句话
  281. # if torch.cuda.device_count() > 1:
  282. net = nn.DataParallel(net)
  283. net = net.to(to_use_device)
  284. net.train()
  285. # ===> get fine tune layers
  286. params_to_train = get_fine_tune_params(net, train_options['fine_tune_stage'])
  287. # ===> solver and lr scheduler
  288. optimizer = build_optimizer(net.parameters(), cfg['optimizer'])
  289. # ===> whether to resume from checkpoint
  290. resume_from = train_options['resume_from']
  291. if resume_from:
  292. net, _resumed_optimizer, global_state = load_checkpoint(net, resume_from, to_use_device, optimizer,
  293. third_name=train_options['third_party_name'])
  294. if _resumed_optimizer:
  295. optimizer = _resumed_optimizer
  296. logger.info(f'net resume from {resume_from}')
  297. else:
  298. global_state = {}
  299. logger.info(f'net resume from scratch.')
  300. # ===> loss function
  301. loss_func = build_loss(cfg['loss'])
  302. loss_func = loss_func.to(to_use_device)
  303. # ===> data loader
  304. train_loader = build_dataloader(cfg.dataset.train)
  305. eval_loader = build_dataloader(cfg.dataset.eval)
  306. # post_process
  307. post_process = build_post_process(cfg['post_process'])
  308. # ===> train
  309. train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger, post_process)
  310. if __name__ == '__main__':
  311. main()