cfg_det_pse.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # encoding: utf-8
  2. """
  3. @time: 2021/3/6 19:48
  4. @author: Bourne-M
  5. """
  6. # -*- coding: utf-8 -*-
  7. # @Time : 2020/5/19 21:44
  8. # @Author : xiangjing
  9. # ####################rec_train_options 参数说明##########################
  10. # 识别训练参数
  11. # base_lr:初始学习率
  12. # fine_tune_stage:
  13. # if you want to freeze some stage, and tune the others.
  14. # ['backbone', 'neck', 'head'], 所有参数都参与调优
  15. # ['backbone'], 只调优backbone部分的参数
  16. # 后续更新: 1、添加bn层freeze的代码
  17. # optimizer 和 optimizer_step:
  18. # 优化器的配置, 成对
  19. # example1: 'optimizer':['SGD'], 'optimizer_step':[],表示一直用SGD优化器
  20. # example2: 'optimizer':['SGD', 'Adam'], 'optimizer_step':[160], 表示前[0,160)个epoch使用SGD优化器,
  21. # [160,~]采用Adam优化器
  22. # lr_scheduler和lr_scheduler_info:
  23. # 学习率scheduler的设置
  24. # ckpt_save_type作用是选择模型保存的方式
  25. # HighestAcc:只保存在验证集上精度最高的模型(还是在训练集上loss最小)
  26. # FixedEpochStep: 按一定间隔保存模型
  27. ###
  28. # from addict import Dict
  29. #
  30. # config = Dict()
  31. # config.exp_name = 'DBNet_res18_init'
  32. # config.train_options = {
  33. # # for train
  34. # 'resume_from': '', # 继续训练地址
  35. # 'third_party_name': '', # 加载paddle模型可选
  36. # 'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint", # 模型保存地址,log文件也保存在这里
  37. # 'device': 'cuda:0', # 不建议修改
  38. # 'epochs': 1200,
  39. # 'fine_tune_stage': ['backbone', 'neck', 'head'],
  40. # 'print_interval': 32, # step为单位
  41. # 'val_interval': 10, # epoch为单位
  42. # 'ckpt_save_type': 'HighestAcc', # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
  43. # 'ckpt_save_epoch': 4, # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
  44. # }
  45. #
  46. # config.SEED = 927
  47. # config.optimizer = {
  48. # 'type': 'Adam',
  49. # 'lr': 0.001,
  50. # 'weight_decay': 1e-4,
  51. # }
  52. #
  53. # config.model = {
  54. # 'type': "DetModel",
  55. # 'backbone': {"type": "ResNet", 'layers': 18, 'pretrained': True}, # ResNet or MobileNetV3
  56. # 'neck': {"type": 'DB_fpn', 'out_channels': 256},
  57. # 'head': {"type": "DBHead"},
  58. # 'in_channels': 3,
  59. # }
  60. #
  61. # config.loss = {
  62. # 'type': 'DBLoss',
  63. # 'alpha': 1,
  64. # 'beta': 10
  65. # }
  66. #
  67. # config.post_process = {
  68. # 'type': 'DBPostProcess',
  69. # 'thresh': 0.3, # 二值化输出map的阈值
  70. # 'box_thresh': 0.7, # 低于此阈值的box丢弃
  71. # 'unclip_ratio': 1.5 # 扩大框的比例
  72. # }
  73. # # for dataset
  74. # # ##lable文件
  75. # ### 存在问题,gt中str-->label 是放在loss中还是放在dataloader中
  76. # config.dataset = {
  77. # 'train': {
  78. # 'dataset': {
  79. # 'type': 'JsonDataset',
  80. # 'file': r'/home/zhouyufei/Work/DataSet/icdar2015/detection/train.json',
  81. # 'mean': [0.485, 0.456, 0.406],
  82. # 'std': [0.229, 0.224, 0.225],
  83. # # db 预处理,不需要修改
  84. # 'pre_processes': [{'type': 'IaaAugment', 'args': [{'type': 'Fliplr', 'args': {'p': 0.5}},
  85. # {'type': 'Affine', 'args': {'rotate': [-10, 10]}},
  86. # {'type': 'Resize', 'args': {'size': [0.5, 3]}}]},
  87. # {'type': 'EastRandomCropData', 'args': {'size': [640, 640], 'max_tries': 50, 'keep_ratio': True}},
  88. # {'type': 'MakeBorderMap', 'args': {'shrink_ratio': 0.4, 'thresh_min': 0.3, 'thresh_max': 0.7}},
  89. # {'type': 'MakeShrinkMap', 'args': {'shrink_ratio': 0.4, 'min_text_size': 8}}],
  90. # 'filter_keys': ['img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags', 'shape'], # 需要从data_dict里过滤掉的key
  91. # 'ignore_tags': ['*', '###'],
  92. # 'img_mode': 'RGB'
  93. # },
  94. # 'loader': {
  95. # 'type': 'DataLoader', # 使用torch dataloader只需要改为 DataLoader
  96. # 'batch_size': 32,
  97. # 'shuffle': True,
  98. # 'num_workers': 30,
  99. # 'collate_fn': {
  100. # 'type': ''
  101. # }
  102. # }
  103. # },
  104. # 'eval': {
  105. # 'dataset': {
  106. # 'type': 'JsonDataset',
  107. # 'file': r'/home/zhouyufei/Work/DataSet/icdar2015/detection/test.json',
  108. # 'mean': [0.485, 0.456, 0.406],
  109. # 'std': [0.229, 0.224, 0.225],
  110. # 'pre_processes': [{'type': 'ResizeShortSize', 'args': {'short_size': 736, 'resize_text_polys': False}}],
  111. # 'filter_keys': [], # 需要从data_dict里过滤掉的key
  112. # 'ignore_tags': ['*', '###'],
  113. # 'img_mode': 'RGB'
  114. # },
  115. # 'loader': {
  116. # 'type': 'DataLoader',
  117. # 'batch_size': 1, # 必须为1
  118. # 'shuffle': False,
  119. # 'num_workers': 20,
  120. # 'collate_fn': {
  121. # 'type': 'DetCollectFN'
  122. # }
  123. # }
  124. # }
  125. # }
  126. #
  127. # # 转换为 Dict
  128. # for k, v in config.items():
  129. # if isinstance(v, dict):
  130. # config[k] = Dict(v)
  131. from addict import Dict
  132. config = Dict()
  133. config.exp_name = 'psenet_mbv3'
  134. config.train_options = {
  135. # for train
  136. 'resume_from': '', # 继续训练地址
  137. 'third_party_name': '', # 加载paddle模型可选
  138. 'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint", # 模型保存地址,log文件也保存在这里
  139. 'device': 'cuda:0', # 不建议修改
  140. 'epochs': 1200,
  141. 'fine_tune_stage': ['backbone', 'neck', 'head'],
  142. 'print_interval': 20, # step为单位
  143. 'val_interval': 1, # epoch为单位
  144. 'ckpt_save_type': 'HighestAcc', # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
  145. 'ckpt_save_epoch': 4, # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
  146. }
  147. config.SEED = 927
  148. config.optimizer = {
  149. 'type': 'Adam',
  150. 'lr': 0.001,
  151. 'weight_decay': 0,
  152. }
  153. config.model = {
  154. 'type': "DetModel",
  155. 'backbone': {"type": "MobileNetV3", 'pretrained': True}, # ResNet or MobileNetV3
  156. 'neck': {"type": 'pse_fpn', 'out_channels': 256},
  157. 'head': {"type": "PseHead"},
  158. 'in_channels': 3,
  159. }
  160. config.loss = {
  161. 'type': 'PSELoss',
  162. 'Lambda': 0.7
  163. }
  164. config.post_process = {
  165. 'type': 'pse_postprocess'
  166. }
  167. # for dataset
  168. # ##lable文件
  169. ### 存在问题,gt中str-->label 是放在loss中还是放在dataloader中
  170. config.dataset = {
  171. 'train': {
  172. 'dataset': {
  173. 'type': 'MyDataset',
  174. 'file': r'/DataSet/icdar2015/detection/train.json',
  175. 'data_shape':640,
  176. 'n':6,
  177. 'm':0.5,
  178. 'mean': [0.485, 0.456, 0.406],
  179. 'std': [0.229, 0.224, 0.225],
  180. 'filter_keys': ['text_polys', 'ignore_tags', 'shape','texts'], # 需要从data_dict里过滤掉的key
  181. 'ignore_tags': ['*', '###'],
  182. 'img_mode': 'RGB'
  183. },
  184. 'loader': {
  185. 'type': 'DataLoader', # 使用torch dataloader只需要改为 DataLoader
  186. 'batch_size': 20,
  187. 'shuffle': True,
  188. 'num_workers': 20
  189. }
  190. },
  191. 'eval': {
  192. 'dataset': {
  193. 'type': 'MyDataset',
  194. 'file': r'/DataSet/icdar2015/detection/test.json',
  195. 'mean': [0.485, 0.456, 0.406],
  196. 'std': [0.229, 0.224, 0.225],
  197. 'n':6,
  198. 'm':0.5,
  199. 'data_shape':640,
  200. 'filter_keys': ['score_maps','training_mask'], # 需要从data_dict里过滤掉的key
  201. 'ignore_tags': ['*', '###'],
  202. 'img_mode': 'RGB'
  203. },
  204. 'loader': {
  205. 'type': 'DataLoader',
  206. 'batch_size': 1, # 必须为1
  207. 'shuffle': False,
  208. 'num_workers': 10
  209. }
  210. }
  211. }
  212. # 转换为 Dict
  213. for k, v in config.items():
  214. if isinstance(v, dict):
  215. config[k] = Dict(v)