cfg_rec_crnn.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/5/19 21:44
  3. # @Author : xiangjing
  4. # ####################rec_train_options 参数说明##########################
  5. # 识别训练参数
  6. # base_lr:初始学习率
  7. # fine_tune_stage:
  8. # if you want to freeze some stage, and tune the others.
  9. # ['backbone', 'neck', 'head'], 所有参数都参与调优
  10. # ['backbone'], 只调优backbone部分的参数
  11. # 后续更新: 1、添加bn层freeze的代码
  12. # optimizer 和 optimizer_step:
  13. # 优化器的配置, 成对
  14. # example1: 'optimizer':['SGD'], 'optimizer_step':[],表示一直用SGD优化器
  15. # example2: 'optimizer':['SGD', 'Adam'], 'optimizer_step':[160], 表示前[0,160)个epoch使用SGD优化器,
  16. # [160,~]采用Adam优化器
  17. # lr_scheduler和lr_scheduler_info:
  18. # 学习率scheduler的设置
  19. # ckpt_save_type作用是选择模型保存的方式
  20. # HighestAcc:只保存在验证集上精度最高的模型(还是在训练集上loss最小)
  21. # FixedEpochStep: 按一定间隔保存模型
  22. ###
  23. from addict import Dict
  24. config = Dict()
  25. config.exp_name = 'CRNN'
  26. config.train_options = {
  27. # for train
  28. 'resume_from': '', # 继续训练地址
  29. 'third_party_name': '', # 加载paddle模型可选
  30. 'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint", # 模型保存地址,log文件也保存在这里
  31. 'device': 'cuda:0', # 不建议修改
  32. 'epochs': 20,
  33. 'fine_tune_stage': ['backbone', 'neck', 'head'],
  34. 'print_interval': 10, # step为单位
  35. 'val_interval': 300, # step为单位
  36. 'ckpt_save_type': 'HighestAcc', # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
  37. 'ckpt_save_epoch': 4, # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
  38. }
  39. config.SEED = 927
  40. config.optimizer = {
  41. 'type': 'Adam',
  42. 'lr': 0.001,
  43. 'weight_decay': 1e-4,
  44. }
  45. config.lr_scheduler = {
  46. 'type': 'StepLR',
  47. 'step_size': 60,
  48. 'gamma': 0.5
  49. }
  50. config.model = {
  51. # backbone 可以设置'pretrained': False/True
  52. 'type': "RecModel",
  53. # 'backbone': {"type": "ResNet", 'layers': 34},
  54. # 'neck': {"type": 'PPaddleRNN',"hidden_size": 256},
  55. # 'head': {"type": "CTC", 'n_class': 5990},
  56. # 'in_channels': 3,
  57. 'backbone': {"type": "MobileNetV3", 'model_name': 'small'},
  58. 'neck': {"type": 'PPaddleRNN', "hidden_size": 48},
  59. 'head': {"type": "CTC", 'n_class': 5990},
  60. 'in_channels': 3,
  61. }
  62. config.loss = {
  63. 'type': 'CTCLoss',
  64. 'blank_idx': 0,
  65. }
  66. # for dataset
  67. # ##lable文件
  68. ### 存在问题,gt中str-->label 是放在loss中还是放在dataloader中
  69. config.dataset = {
  70. 'alphabet': r'/data2/znj/CRNN_Chinese_Characters_Rec/lib/dataset/txt/char_std_5990.txt',
  71. 'train': {
  72. 'dataset': {
  73. 'type': 'RecTextLineDataset',
  74. 'file': r'/data2/znj/CRNN_Chinese_Characters_Rec/lib/dataset/txt/train.txt',
  75. 'input_h': 32,
  76. 'mean': 0.5,
  77. 'std': 0.5,
  78. 'augmentation': False,
  79. },
  80. 'loader': {
  81. 'type': 'DataLoader', # 使用torch dataloader只需要改为 DataLoader
  82. 'batch_size': 16,
  83. 'shuffle': True,
  84. 'num_workers': 3,
  85. 'collate_fn': {
  86. 'type': 'RecCollateFn',
  87. 'img_w': 320
  88. }
  89. }
  90. },
  91. 'eval': {
  92. 'dataset': {
  93. 'type': 'RecTextLineDataset',
  94. 'file': r'/data2/znj/CRNN_Chinese_Characters_Rec/lib/dataset/txt/test.txt',
  95. 'input_h': 32,
  96. 'mean': 0.5,
  97. 'std': 0.5,
  98. 'augmentation': False,
  99. },
  100. 'loader': {
  101. 'type': 'RecDataLoader',
  102. 'batch_size': 32,
  103. 'shuffle': False,
  104. 'num_workers': 2,
  105. 'collate_fn': {
  106. 'type': 'RecCollateFn',
  107. 'img_w': 320
  108. }
  109. }
  110. }
  111. }
  112. # 转换为 Dict
  113. for k, v in config.items():
  114. if isinstance(v, dict):
  115. config[k] = Dict(v)