123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- from addict import Dict
- config = Dict()
- config.exp_name = 'DBNet_icdar_distill'
- config.train_options = {
- # for train
- 'resume_from': '', # 继续训练地址
- 'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint", # 模型保存地址,log文件也保存在这里
- 'device': 'cuda:0', # 不建议修改
- 'epochs': 600,
- 'fine_tune_stage': ['backbone', 'neck', 'head'],
- 'print_interval': 5, # step为单位
- 'val_interval': 1, # epoch为单位
- 'ckpt_save_type': 'HighestAcc', # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
- 'ckpt_save_epoch': 4, # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
- }
- config.SEED = 927
- config.optimizer = {
- 'type': 'Adam',
- 'lr': 0.0002,
- 'weight_decay': 0,
- }
- config.model = {
- 'type': 'DistillationModel',
- 'algorithm': 'Distillation',
- 'init_weight': False, # 当不使用任何预训练模型(子网络或任意子网络backbone)时打开
- 'models': {
- 'Teacher': {
- 'type': "DetModel",
- 'freeze_params': True,
- 'backbone': {"type": "ResNet", 'pretrained': False, 'layers': 18},
- 'neck': {"type": 'DB_fpn', 'out_channels': 256},
- 'head': {"type": "DBHead"},
- 'in_channels': 3,
- 'pretrained': '/path/to/your/workspace/work/PytorchOCR/models/dismodels/DBNet_icdar_res18_fast_pre.pth'
- },
- 'Student': {
- 'type': "DetModel",
- 'freeze_params': False,
- 'backbone': {"type": "MobileNetV3", 'pretrained': False, 'disable_se': False},
- 'neck': {"type": 'DB_fpn', 'out_channels': 96},
- 'head': {"type": "DBHead"},
- 'in_channels': 3,
- 'pretrained': '/path/to/your/workspace/work/PytorchOCR/models/dismodels/mbv3.pth'
- },
- 'Student2': {
- 'type': "DetModel",
- 'freeze_params': False,
- 'backbone': {"type": "MobileNetV3", 'pretrained': False, 'disable_se': False},
- 'neck': {"type": 'DB_fpn', 'out_channels': 96},
- 'head': {"type": "DBHead"},
- 'in_channels': 3,
- 'pretrained': '/path/to/your/workspace/work/PytorchOCR/models/dismodels/mbv3.pth'
- }
- }
- }
- config.loss = {
- 'type': 'CombinedLoss',
- 'combine_list': {
- 'DistillationDilaDBLoss': {
- 'weight': 2.0,
- 'model_name_pairs': [("Student", "Teacher"), ("Student2", "Teacher")],
- # 'model_name_pairs': [("Student", "Teacher")],
- 'key': 'maps',
- 'balance_loss': True,
- 'main_loss_type': 'DiceLoss',
- 'alpha': 5,
- 'beta': 10,
- 'ohem_ratio': 3,
- },
- 'DistillationDMLLoss': {
- 'maps_name': "thrink_maps",
- 'weight': 1.0,
- 'model_name_pairs': ["Student", "Student2"],
- 'key': 'maps'
- },
- 'DistillationDBLoss': {
- 'weight': 1.0,
- 'model_name_list': ["Student"],
- 'balance_loss': True,
- 'main_loss_type': 'DiceLoss',
- 'alpha': 5,
- 'beta': 10,
- 'ohem_ratio': 3}
- }
- }
- config.post_process = {
- 'type': 'DistillationDBPostProcess',
- 'model_name': ["Student", "Student2", "Teacher"],
- # 'model_name': ["Student", "Teacher"],
- 'thresh': 0.3, # 二值化输出map的阈值
- 'box_thresh': 0.5, # 低于此阈值的box丢弃
- 'unclip_ratio': 1.5 # 扩大框的比例
- }
- config.metric = {
- 'name': 'DistillationMetric',
- 'base_metric_name': 'DetMetric',
- 'main_indicator': 'hmean',
- 'key': "Student"
- }
- # for dataset
- # ##lable文件
- ### 存在问题,gt中str-->label 是放在loss中还是放在dataloader中
- config.dataset = {
- 'train': {
- 'dataset': {
- 'type': 'JsonDataset',
- 'file': r'/path/to/your/workspace/dataset/icdar15-detection/train.json',
- 'mean': [0.485, 0.456, 0.406],
- 'std': [0.229, 0.224, 0.225],
- # db 预处理,不需要修改
- 'pre_processes': [{'type': 'IaaAugment', 'args': [{'type': 'Fliplr', 'args': {'p': 0.5}},
- {'type': 'Affine', 'args': {'rotate': [-10, 10]}},
- {'type': 'Resize', 'args': {'size': [0.5, 3]}}]},
- {'type': 'EastRandomCropData', 'args': {'size': [640, 640], 'max_tries': 50, 'keep_ratio': True}},
- {'type': 'MakeBorderMap', 'args': {'shrink_ratio': 0.4, 'thresh_min': 0.3, 'thresh_max': 0.7}},
- {'type': 'MakeShrinkMap', 'args': {'shrink_ratio': 0.4, 'min_text_size': 8}}],
- 'filter_keys': ['img_name', 'text_polys', 'texts', 'ignore_tags', 'shape'],
- # 需要从data_dict里过滤掉的key
- 'ignore_tags': ['*', '###', ' '],
- 'img_mode': 'RGB'
- },
- 'loader': {
- 'type': 'DataLoader', # 使用torch dataloader只需要改为 DataLoader
- 'batch_size': 20,
- 'shuffle': True,
- 'num_workers': 20,
- 'collate_fn': {
- 'type': ''
- }
- }
- },
- 'eval': {
- 'dataset': {
- 'type': 'JsonDataset',
- 'file': r'/path/to/your/workspace/dataset/icdar15-detection/test.json',
- 'mean': [0.485, 0.456, 0.406],
- 'std': [0.229, 0.224, 0.225],
- 'pre_processes': [{'type': 'ResizeShortSize', 'args': {'short_size': 736, 'resize_text_polys': False}}],
- 'filter_keys': [], # 需要从data_dict里过滤掉的key
- 'ignore_tags': ['*', '###', ' '],
- 'img_mode': 'RGB'
- },
- 'loader': {
- 'type': 'DataLoader',
- 'batch_size': 1, # 必须为1
- 'shuffle': False,
- 'num_workers': 10,
- 'collate_fn': {
- 'type': 'DetCollectFN'
- }
- }
- }
- }
- # 转换为 Dict
- for k, v in config.items():
- if isinstance(v, dict):
- config[k] = Dict(v)
|