cfg_det_dis.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from addict import Dict
  2. config = Dict()
  3. config.exp_name = 'DBNet_icdar_distill'
  4. config.train_options = {
  5. # for train
  6. 'resume_from': '', # 继续训练地址
  7. 'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint", # 模型保存地址,log文件也保存在这里
  8. 'device': 'cuda:0', # 不建议修改
  9. 'epochs': 600,
  10. 'fine_tune_stage': ['backbone', 'neck', 'head'],
  11. 'print_interval': 5, # step为单位
  12. 'val_interval': 1, # epoch为单位
  13. 'ckpt_save_type': 'HighestAcc', # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
  14. 'ckpt_save_epoch': 4, # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
  15. }
  16. config.SEED = 927
  17. config.optimizer = {
  18. 'type': 'Adam',
  19. 'lr': 0.0002,
  20. 'weight_decay': 0,
  21. }
  22. config.model = {
  23. 'type': 'DistillationModel',
  24. 'algorithm': 'Distillation',
  25. 'init_weight': False, # 当不使用任何预训练模型(子网络或任意子网络backbone)时打开
  26. 'models': {
  27. 'Teacher': {
  28. 'type': "DetModel",
  29. 'freeze_params': True,
  30. 'backbone': {"type": "ResNet", 'pretrained': False, 'layers': 18},
  31. 'neck': {"type": 'DB_fpn', 'out_channels': 256},
  32. 'head': {"type": "DBHead"},
  33. 'in_channels': 3,
  34. 'pretrained': '/path/to/your/workspace/work/PytorchOCR/models/dismodels/DBNet_icdar_res18_fast_pre.pth'
  35. },
  36. 'Student': {
  37. 'type': "DetModel",
  38. 'freeze_params': False,
  39. 'backbone': {"type": "MobileNetV3", 'pretrained': False, 'disable_se': False},
  40. 'neck': {"type": 'DB_fpn', 'out_channels': 96},
  41. 'head': {"type": "DBHead"},
  42. 'in_channels': 3,
  43. 'pretrained': '/path/to/your/workspace/work/PytorchOCR/models/dismodels/mbv3.pth'
  44. },
  45. 'Student2': {
  46. 'type': "DetModel",
  47. 'freeze_params': False,
  48. 'backbone': {"type": "MobileNetV3", 'pretrained': False, 'disable_se': False},
  49. 'neck': {"type": 'DB_fpn', 'out_channels': 96},
  50. 'head': {"type": "DBHead"},
  51. 'in_channels': 3,
  52. 'pretrained': '/path/to/your/workspace/work/PytorchOCR/models/dismodels/mbv3.pth'
  53. }
  54. }
  55. }
  56. config.loss = {
  57. 'type': 'CombinedLoss',
  58. 'combine_list': {
  59. 'DistillationDilaDBLoss': {
  60. 'weight': 2.0,
  61. 'model_name_pairs': [("Student", "Teacher"), ("Student2", "Teacher")],
  62. # 'model_name_pairs': [("Student", "Teacher")],
  63. 'key': 'maps',
  64. 'balance_loss': True,
  65. 'main_loss_type': 'DiceLoss',
  66. 'alpha': 5,
  67. 'beta': 10,
  68. 'ohem_ratio': 3,
  69. },
  70. 'DistillationDMLLoss': {
  71. 'maps_name': "thrink_maps",
  72. 'weight': 1.0,
  73. 'model_name_pairs': ["Student", "Student2"],
  74. 'key': 'maps'
  75. },
  76. 'DistillationDBLoss': {
  77. 'weight': 1.0,
  78. 'model_name_list': ["Student"],
  79. 'balance_loss': True,
  80. 'main_loss_type': 'DiceLoss',
  81. 'alpha': 5,
  82. 'beta': 10,
  83. 'ohem_ratio': 3}
  84. }
  85. }
  86. config.post_process = {
  87. 'type': 'DistillationDBPostProcess',
  88. 'model_name': ["Student", "Student2", "Teacher"],
  89. # 'model_name': ["Student", "Teacher"],
  90. 'thresh': 0.3, # 二值化输出map的阈值
  91. 'box_thresh': 0.5, # 低于此阈值的box丢弃
  92. 'unclip_ratio': 1.5 # 扩大框的比例
  93. }
  94. config.metric = {
  95. 'name': 'DistillationMetric',
  96. 'base_metric_name': 'DetMetric',
  97. 'main_indicator': 'hmean',
  98. 'key': "Student"
  99. }
  100. # for dataset
  101. # ##lable文件
  102. ### 存在问题,gt中str-->label 是放在loss中还是放在dataloader中
  103. config.dataset = {
  104. 'train': {
  105. 'dataset': {
  106. 'type': 'JsonDataset',
  107. 'file': r'/path/to/your/workspace/dataset/icdar15-detection/train.json',
  108. 'mean': [0.485, 0.456, 0.406],
  109. 'std': [0.229, 0.224, 0.225],
  110. # db 预处理,不需要修改
  111. 'pre_processes': [{'type': 'IaaAugment', 'args': [{'type': 'Fliplr', 'args': {'p': 0.5}},
  112. {'type': 'Affine', 'args': {'rotate': [-10, 10]}},
  113. {'type': 'Resize', 'args': {'size': [0.5, 3]}}]},
  114. {'type': 'EastRandomCropData', 'args': {'size': [640, 640], 'max_tries': 50, 'keep_ratio': True}},
  115. {'type': 'MakeBorderMap', 'args': {'shrink_ratio': 0.4, 'thresh_min': 0.3, 'thresh_max': 0.7}},
  116. {'type': 'MakeShrinkMap', 'args': {'shrink_ratio': 0.4, 'min_text_size': 8}}],
  117. 'filter_keys': ['img_name', 'text_polys', 'texts', 'ignore_tags', 'shape'],
  118. # 需要从data_dict里过滤掉的key
  119. 'ignore_tags': ['*', '###', ' '],
  120. 'img_mode': 'RGB'
  121. },
  122. 'loader': {
  123. 'type': 'DataLoader', # 使用torch dataloader只需要改为 DataLoader
  124. 'batch_size': 20,
  125. 'shuffle': True,
  126. 'num_workers': 20,
  127. 'collate_fn': {
  128. 'type': ''
  129. }
  130. }
  131. },
  132. 'eval': {
  133. 'dataset': {
  134. 'type': 'JsonDataset',
  135. 'file': r'/path/to/your/workspace/dataset/icdar15-detection/test.json',
  136. 'mean': [0.485, 0.456, 0.406],
  137. 'std': [0.229, 0.224, 0.225],
  138. 'pre_processes': [{'type': 'ResizeShortSize', 'args': {'short_size': 736, 'resize_text_polys': False}}],
  139. 'filter_keys': [], # 需要从data_dict里过滤掉的key
  140. 'ignore_tags': ['*', '###', ' '],
  141. 'img_mode': 'RGB'
  142. },
  143. 'loader': {
  144. 'type': 'DataLoader',
  145. 'batch_size': 1, # 必须为1
  146. 'shuffle': False,
  147. 'num_workers': 10,
  148. 'collate_fn': {
  149. 'type': 'DetCollectFN'
  150. }
  151. }
  152. }
  153. }
  154. # 转换为 Dict
  155. for k, v in config.items():
  156. if isinstance(v, dict):
  157. config[k] = Dict(v)