## 添加新算法 ### Dataset #### 检测算法 不同的检测算法会有不同的图片预处理和label制作方式,添加新dataset的步骤如下 1. 在`torchocr/datasets/det_modules`下添加算法的图片预处理和label制作方式, 每个处理步骤(module)用一个文件存储,module的形式如下 ```python class ModuleName: def __init__(self, *args,**kwargs): pass def __call__(self, data: dict) -> dict: im = data['img'] text_polys = data['text_polys'] # 执行你的处理 data['img'] = im data['text_polys'] = text_polys return data ``` 算法的所有处理步骤由不同的module顺序执行而成,在config文件中按照列表的形式组合并执行。如: ```python 'pre_processes': [{'type': 'IaaAugment', 'args': [{'type': 'Fliplr', 'args': {'p': 0.5}}, {'type': 'Affine', 'args': {'rotate': [-10, 10]}}, {'type': 'Resize', 'args': {'size': [0.5, 3]}}]}, {'type': 'EastRandomCropData', 'args': {'size': [640, 640], 'max_tries': 50, 'keep_ratio': True}}, {'type': 'MakeBorderMap', 'args': {'shrink_ratio': 0.4, 'thresh_min': 0.3, 'thresh_max': 0.7}}, {'type': 'MakeShrinkMap', 'args': {'shrink_ratio': 0.4, 'min_text_size': 8}}] ``` #### 识别算法 对于attention和ctc系列算法,我们已经提供了内置的dataset,其他类型的需要在`torchocr/datasets/RecDataSet.py` 文件里添加一个dataset并在config文件中使用 ### 网络 PytorchOCR将网络划分为三部分 * backbone: 从图片中提取特征,如Resnet,MobileNetV3 * neck: 对backbone输出的特征进行强化,如FPN,CRNN的RNN部分 * head: 在neck输出特征的基础上进行完成算法的输出 `backbone`和`neck`均需要`out_channels`属性以便后续组件构造网络。 若PytorchOCR已提供的组件中没有算法所需组件,就需要在对应的文件夹内实现新组件,一个文件夹存放一个组件, 然后将新组建在`torchocr/networks/architectures/DetModel.py`或`torchocr/networks/architectures/RecModel.py`进行导入并添加到对应的dict 各组件对应文件如下: * backbone: `torchocr/networks/backbones` * necks: `torchocr/networks/necks` * heads: `torchocr/networks/heads` ### 损失函数 损失函数的存文件夹为`torchocr/networks/losses`,损失函数的输出应该是一个dict,格式如下 ```python { 'loss':loss_value, # 总的loss,由l1,l2,l3,...,ln加权组成 '其他的loss': value # 组成总loss的子loss } ``` loss module 的形式如下 ```python class ModuleName(nn.Module): def __init__(self, *args,**kwargs): pass def forward(self, pred, batch): """ :param pred: :param batch: bach为一个dict{ '其他计算loss所需的输入':'vaue' } :return: """ # 计算loss loss_dict = {'loss':loss,'other_sub_loss':value} return loss_dict ``` ### 配置文件 将配置文件里的对应地方换成新增的组件,那么新的网络就添加完成了,在测试性能无误后就可推送到PytorchOCR仓库