znj 2 жил өмнө
commit
c4a2410ca7
100 өөрчлөгдсөн 32253 нэмэгдсэн , 0 устгасан
  1. 14 0
      .gitignore
  2. 90 0
      README.md
  3. 7782 0
      char_std_7782.txt
  4. 139 0
      config/cfg_det_db.py
  5. 172 0
      config/cfg_det_dis.py
  6. 227 0
      config/cfg_det_pse.py
  7. 124 0
      config/cfg_rec_crnn.py
  8. 113 0
      config/cfg_rec_crnn_lmdb.py
  9. 123 0
      config/cfg_rec_crnn_test1.py
  10. BIN
      doc/imgs/exampl1.png
  11. BIN
      doc/imgs/exampl2.png
  12. 14 0
      doc/检测+识别推理.md
  13. 62 0
      doc/检测.md
  14. 79 0
      doc/添加新算法.md
  15. BIN
      doc/田氏颜体大字库2.0.ttf
  16. 68 0
      doc/识别.md
  17. 16 0
      requirements.txt
  18. BIN
      test_image/Snipaste_2023-08-21_17-07-58.jpg
  19. BIN
      test_image/Snipaste_2023-08-21_17-07-582.jpg
  20. BIN
      test_image/Snipaste_2023-08-23_15-35-13.jpg
  21. BIN
      test_image/Snipaste_2023-08-25_10-34-45.jpg
  22. BIN
      test_image/Snipaste_2023-08-25_11-32-11.jpg
  23. BIN
      test_image/Snipaste_2023-08-25_11-33-51.jpg
  24. BIN
      test_image/Snipaste_2023-09-08_10-45-17.jpg
  25. 4 0
      test_image/test.txt
  26. 84 0
      tools/create_rec_lmdb_dataset.py
  27. 88 0
      tools/det_infer.py
  28. 339 0
      tools/det_train.py
  29. 333 0
      tools/det_train_disti.py
  30. 350 0
      tools/det_train_pse.py
  31. 364 0
      tools/doc_test.py
  32. 364 0
      tools/doc_test_resnet.py
  33. 393 0
      tools/minimum_2stage_inference.py
  34. 113 0
      tools/ocr_infer.py
  35. 349 0
      tools/rec_fineturn.py
  36. 92 0
      tools/rec_infer.py
  37. 141 0
      tools/rec_infer_att_test.py
  38. 347 0
      tools/rec_train.py
  39. 365 0
      tools/test_one.py
  40. 3 0
      torchocr/__init__.py
  41. 33 0
      torchocr/datasets/DetCollateFN.py
  42. 155 0
      torchocr/datasets/DetDataSet.py
  43. 160 0
      torchocr/datasets/DetDataSetFce.py
  44. 757 0
      torchocr/datasets/DetDateSetPse.py
  45. 86 0
      torchocr/datasets/RecCollateFn.py
  46. 461 0
      torchocr/datasets/RecDataSet.py
  47. 99 0
      torchocr/datasets/__init__.py
  48. 91 0
      torchocr/datasets/alphabets/dict.txt
  49. 3827 0
      torchocr/datasets/alphabets/dict_baidu.txt
  50. 10 0
      torchocr/datasets/alphabets/digit.txt
  51. 92 0
      torchocr/datasets/alphabets/enAlphaNumPunc90.txt
  52. 6623 0
      torchocr/datasets/alphabets/ppocr_keys_v1.txt
  53. 732 0
      torchocr/datasets/det_modules/FCE_aug.py
  54. 658 0
      torchocr/datasets/det_modules/FCE_target.py
  55. 14 0
      torchocr/datasets/det_modules/__init__.py
  56. 385 0
      torchocr/datasets/det_modules/augment.py
  57. 68 0
      torchocr/datasets/det_modules/iaa_augment.py
  58. 122 0
      torchocr/datasets/det_modules/make_border_map.py
  59. 132 0
      torchocr/datasets/det_modules/make_shrink_map.py
  60. 200 0
      torchocr/datasets/det_modules/random_crop_data.py
  61. 84 0
      torchocr/datasets/icdar15/ICDAR15CropSave.py
  62. 5 0
      torchocr/datasets/icdar15/__init__.py
  63. 49 0
      torchocr/datasets/icdar15/convert_icdar2015_rec.py
  64. 104 0
      torchocr/datasets/训练用数据集汇总.md
  65. 110 0
      torchocr/deprecated/FeaturePyramidNetwork.py
  66. 3 0
      torchocr/deprecated/__init__.py
  67. 82 0
      torchocr/metrics/DetMetric.py
  68. 28 0
      torchocr/metrics/RecMetric.py
  69. 18 0
      torchocr/metrics/__init__.py
  70. 42 0
      torchocr/metrics/distill_metric.py
  71. 256 0
      torchocr/metrics/iou_utils.py
  72. 303 0
      torchocr/networks/CommonModules.py
  73. 4 0
      torchocr/networks/__init__.py
  74. 82 0
      torchocr/networks/architectures/DetModel.py
  75. 58 0
      torchocr/networks/architectures/DistillationDetModel.py
  76. 41 0
      torchocr/networks/architectures/RecModel.py
  77. 21 0
      torchocr/networks/architectures/__init__.py
  78. 166 0
      torchocr/networks/backbones/ConvNext.py
  79. 311 0
      torchocr/networks/backbones/DetGhostNet.py
  80. 175 0
      torchocr/networks/backbones/DetMobilenetV3.py
  81. 213 0
      torchocr/networks/backbones/DetResNetvd.py
  82. 254 0
      torchocr/networks/backbones/MobileViT.py
  83. 141 0
      torchocr/networks/backbones/RecMobileNetV3.py
  84. 189 0
      torchocr/networks/backbones/RecResNetvd.py
  85. 646 0
      torchocr/networks/backbones/Transformer.py
  86. 3 0
      torchocr/networks/backbones/__init__.py
  87. 76 0
      torchocr/networks/heads/DetDbHead.py
  88. 26 0
      torchocr/networks/heads/DetPseHead.py
  89. 73 0
      torchocr/networks/heads/FCEHead.py
  90. 18 0
      torchocr/networks/heads/RecCTCHead.py
  91. 3 0
      torchocr/networks/heads/__init__.py
  92. 143 0
      torchocr/networks/losses/CTCLoss.py
  93. 165 0
      torchocr/networks/losses/CTCLoss_test.py
  94. 27 0
      torchocr/networks/losses/CombinedLoss.py
  95. 59 0
      torchocr/networks/losses/DBLoss.py
  96. 182 0
      torchocr/networks/losses/DetBasicLoss.py
  97. 209 0
      torchocr/networks/losses/FCELoss.py
  98. 104 0
      torchocr/networks/losses/PSELoss.py
  99. 23 0
      torchocr/networks/losses/__init__.py
  100. 300 0
      torchocr/networks/losses/distillation_loss.py

+ 14 - 0
.gitignore

@@ -0,0 +1,14 @@
+**/.DS_Store
+*.pth
+*.pyc
+*.pyo
+*.log
+*.tmp
+*.pkl
+*_local*
+__pycache__/
+.idea/
+output/
+test/*.jpg
+local/
+.DS_Store/

+ 90 - 0
README.md

@@ -0,0 +1,90 @@
+# PytorchOCR
+
+## 简介
+PytorchOCR旨在打造一套训练,推理,部署一体的OCR引擎库
+
+**添加微信z572459439或者nsnovio,然后进群讨论。备注ocr进群。**
+
+## 更新日志
+* 2022.02.24 更新:新增convnext作为backbone
+* 2022.01.28 更新:新增transformer作为backbone
+* 2022.01.07 更新:
+1. 检测模型新增backbone类型ghostnet
+2. 新增pse模型
+3. 新增dbnet的蒸馏版本
+4. 新增新版轻量化检测模型
+5. 修复一些bug
+* 2021.02.27 添加移动端识别模型文件、移动端DBNet模型文件
+* 2021.02.25 添加服务器端识别模型文件
+* 2021.02.09 添加DBNet模型,修改DBNet网络结构的fpn,inference时候的缩放及后处理
+* 2020.07.01 添加 添加新算法文档
+* 2020.06.29 添加检测的mb3和resnet50_vd预训练模型
+* 2020.06.25 检测模块的训练和预测ok
+* 2020.06.18 更新README
+* 2020.06.17 识别模块的训练和预测ok
+
+## todo list
+* [x] crnn训练与python版预测
+* [x] DB训练与python版预测
+* [x] imagenet预训练模型
+* [x] 服务器端识别模型文件
+* [x] DB通用模型
+* [ ] 手机端部署
+* [x] With Triton,[推荐使用Savior](https://github.com/novioleo/Savior)
+
+## 环境配置
+
+需要的环境如下
+* pytorch 1.4+
+* torchvision 0.5+
+* gcc 4.9+ (pse,pan会用到)
+
+快速安装环境
+```bash
+pip3 install -r requirements.txt
+```
+
+## 文档教程
+* [文字检测](doc/检测.md)
+* [文字识别](doc/识别.md)
+* [添加新算法](doc/添加新算法.md)
+
+## 文本检测算法
+
+PytorchOCR开源的文本检测算法列表:
+- [x]  DB([paper](https://arxiv.org/abs/1911.08947))
+
+| 模型简介                              | 骨干网络             | 推荐场景 | 大小    | 下载链接                                                     |
+|-----------------------------------|------------------|  ----  |-------|----------------------------------------------------------|
+| 预训练模型                             | ResNet50         | 服务器端| 97.3M | [ 3cmz](https://pan.baidu.com/s/1l4T0KX4W-PFy1EH5Nh9HSA) |
+| 原始超轻量模型,支持中英文、多语种文本检测             | MoblieNet        | 移动端| 2.3M  | [c9ko](https://pan.baidu.com/s/1DpM_HzwYFgAJhjgUtQ7CCw)  |
+| 新版 轻量模型,支持中英文、多语种文本检测             | MoblieNet        | 移动端| 2.3M  | [39ne](https://pan.baidu.com/s/1h52tjRYuWdcFEfXjQVYEFQ)  |
+| 通用模型,支持中英文、多语种文本检测,比超轻量模型更大,但效果更好 | ResNet18         | 服务器端| 47.2M | [r26k](https://pan.baidu.com/s/1Pt1P0Z8b280AAjr9jLMqeg)  |
+| 预训练模型                             | swin_transformer | 服务器端| 240M  | [se32](https://pan.baidu.com/s/1VhoxcjHrOLChwrp03JNwtg)  |
+| 预训练模型                             | convnext         | 服务器端| 113M  | [46is](https://pan.baidu.com/s/1-XylC8SzrolKDp53NGApag)  |
+
+
+## 文本识别算法
+PytorchOCR开源的文本识别算法列表:
+- [x]  CRNN([paper](https://arxiv.org/abs/1507.05717))
+
+| 模型简介 | 骨干网络 | 推荐场景 |大小 |  下载链接 |
+|  ----  | ----  |  ----  | ----  | ----  |
+|原始超轻量模型,支持中英文、数字识别|MoblieNet| 移动端|4.2M|[7x9q](https://pan.baidu.com/s/1l2BhmrjO1ZtmNw5yWCdPZQ)|
+|通用模型,支持中英文、数字识别|ResNet34| 服务器端|106.4M|[sdnc](https://pan.baidu.com/s/1gnFVXHW-nOz1r8c53u-QFQ)|
+
+
+## 预训练模型下载地址
+链接: https://pan.baidu.com/s/1uMWys5lQ5ZfhnaOCPBVqZw  密码: i9du
+
+## 结果展示
+
+![检测](doc/imgs/exampl1.png)
+
+![检测](doc/imgs/exampl2.png)
+
+## 贡献代码
+我们非常欢迎你为PytorchOCR贡献代码,也十分感谢你的反馈。
+
+## 相关仓库
+* https://github.com/WenmuZhou/OCR_DataSet

+ 7782 - 0
char_std_7782.txt

@@ -0,0 +1,7782 @@
+'
+疗
+绚
+诚
+娇
+溜
+题
+贿
+者
+廖
+更
+纳
+加
+奉
+公
+一
+就
+汴
+计
+与
+路
+房
+原
+妇
+2
+0
+8
+-
+7
+其
+>
+:
+]
+,
+,
+骑
+刈
+全
+消
+昏
+傈
+安
+久
+钟
+嗅
+不
+影
+处
+驽
+蜿
+资
+关
+椤
+地
+瘸
+专
+问
+忖
+票
+嫉
+炎
+韵
+要
+月
+田
+节
+陂
+鄙
+捌
+备
+拳
+伺
+眼
+网
+盎
+大
+傍
+心
+东
+愉
+汇
+蹿
+科
+每
+业
+里
+航
+晏
+字
+平
+录
+先
+1
+3
+彤
+鲶
+产
+稍
+督
+腴
+有
+象
+岳
+注
+绍
+在
+泺
+文
+定
+核
+名
+水
+过
+理
+让
+偷
+率
+等
+这
+发
+”
+为
+含
+肥
+酉
+相
+鄱
+七
+编
+猥
+锛
+日
+镀
+蒂
+掰
+倒
+辆
+栾
+栗
+综
+涩
+州
+雌
+滑
+馀
+了
+机
+块
+司
+宰
+甙
+兴
+矽
+抚
+保
+用
+沧
+秩
+如
+收
+息
+滥
+页
+疑
+埠
+!
+!
+姥
+异
+橹
+钇
+向
+下
+跄
+的
+椴
+沫
+国
+绥
+獠
+报
+开
+民
+蜇
+何
+分
+凇
+长
+讥
+藏
+掏
+施
+羽
+中
+讲
+派
+嘟
+人
+提
+浼
+间
+世
+而
+古
+多
+倪
+唇
+饯
+控
+庚
+首
+赛
+蜓
+味
+断
+制
+觉
+技
+替
+艰
+溢
+潮
+夕
+钺
+外
+摘
+枋
+动
+双
+单
+啮
+户
+枇
+确
+锦
+曜
+杜
+或
+能
+效
+霜
+盒
+然
+侗
+电
+晁
+放
+步
+鹃
+新
+杖
+蜂
+吒
+濂
+瞬
+评
+总
+隍
+对
+独
+合
+也
+是
+府
+青
+天
+诲
+墙
+组
+滴
+级
+邀
+帘
+示
+已
+时
+骸
+仄
+泅
+和
+遨
+店
+雇
+疫
+持
+巍
+踮
+境
+只
+亨
+目
+鉴
+崤
+闲
+体
+泄
+杂
+作
+般
+轰
+化
+解
+迂
+诿
+蛭
+璀
+腾
+告
+版
+服
+省
+师
+小
+规
+程
+线
+海
+办
+引
+二
+桧
+牌
+砺
+洄
+裴
+修
+图
+痫
+胡
+许
+犊
+事
+郛
+基
+柴
+呼
+食
+研
+奶
+律
+蛋
+因
+葆
+察
+戏
+褒
+戒
+再
+李
+骁
+工
+貂
+油
+鹅
+章
+啄
+休
+场
+给
+睡
+纷
+豆
+器
+捎
+说
+敏
+学
+会
+浒
+设
+诊
+格
+廓
+查
+来
+霓
+室
+溆
+¢
+诡
+寥
+焕
+舜
+柒
+狐
+回
+戟
+砾
+厄
+实
+翩
+尿
+五
+入
+径
+惭
+喹
+股
+宇
+篝
+|
+;
+美
+期
+云
+九
+祺
+扮
+靠
+锝
+槌
+系
+企
+酰
+阊
+暂
+蚕
+忻
+豁
+本
+羹
+执
+条
+钦
+H
+獒
+限
+进
+季
+楦
+于
+芘
+玖
+铋
+茯
+未
+答
+粘
+括
+样
+精
+欠
+矢
+甥
+帷
+嵩
+扣
+令
+仔
+风
+皈
+行
+支
+部
+蓉
+刮
+站
+蜡
+救
+钊
+汗
+松
+嫌
+成
+可
+.
+鹤
+院
+从
+交
+政
+怕
+活
+调
+球
+局
+验
+髌
+第
+韫
+谗
+串
+到
+圆
+年
+米
+/
+*
+友
+忿
+检
+区
+看
+自
+敢
+刃
+个
+兹
+弄
+流
+留
+同
+没
+齿
+星
+聆
+轼
+湖
+什
+三
+建
+蛔
+儿
+椋
+汕
+震
+颧
+鲤
+跟
+力
+情
+璺
+铨
+陪
+务
+指
+族
+训
+滦
+鄣
+濮
+扒
+商
+箱
+十
+召
+慷
+辗
+所
+莞
+管
+护
+臭
+横
+硒
+嗓
+接
+侦
+六
+露
+党
+馋
+驾
+剖
+高
+侬
+妪
+幂
+猗
+绺
+骐
+央
+酐
+孝
+筝
+课
+徇
+缰
+门
+男
+西
+项
+句
+谙
+瞒
+秃
+篇
+教
+碲
+罚
+声
+呐
+景
+前
+富
+嘴
+鳌
+稀
+免
+朋
+啬
+睐
+去
+赈
+鱼
+住
+肩
+愕
+速
+旁
+波
+厅
+健
+茼
+厥
+鲟
+谅
+投
+攸
+炔
+数
+方
+击
+呋
+谈
+绩
+别
+愫
+僚
+躬
+鹧
+胪
+炳
+招
+喇
+膨
+泵
+蹦
+毛
+结
+5
+4
+谱
+识
+陕
+粽
+婚
+拟
+构
+且
+搜
+任
+潘
+比
+郢
+妨
+醪
+陀
+桔
+碘
+扎
+选
+哈
+骷
+楷
+亿
+明
+缆
+脯
+监
+睫
+逻
+婵
+共
+赴
+淝
+凡
+惦
+及
+达
+揖
+谩
+澹
+减
+焰
+蛹
+番
+祁
+柏
+员
+禄
+怡
+峤
+龙
+白
+叽
+生
+闯
+起
+细
+装
+谕
+竟
+聚
+钙
+上
+导
+渊
+按
+艾
+辘
+挡
+耒
+盹
+饪
+臀
+记
+邮
+蕙
+受
+各
+医
+搂
+普
+滇
+朗
+茸
+带
+翻
+酚
+(
+光
+堤
+墟
+蔷
+万
+幻
+〓
+瑙
+辈
+昧
+盏
+亘
+蛀
+吉
+铰
+请
+子
+假
+闻
+税
+井
+诩
+哨
+嫂
+好
+面
+琐
+校
+馊
+鬣
+缂
+营
+访
+炖
+占
+农
+缀
+否
+经
+钚
+棵
+趟
+张
+亟
+吏
+茶
+谨
+捻
+论
+迸
+堂
+玉
+信
+吧
+瞠
+乡
+姬
+寺
+咬
+溏
+苄
+皿
+意
+赉
+宝
+尔
+钰
+艺
+特
+唳
+踉
+都
+荣
+倚
+登
+荐
+丧
+奇
+涵
+批
+炭
+近
+符
+傩
+感
+道
+着
+菊
+虹
+仲
+众
+懈
+濯
+颞
+眺
+南
+释
+北
+缝
+标
+既
+茗
+整
+撼
+迤
+贲
+挎
+耱
+拒
+某
+妍
+卫
+哇
+英
+矶
+藩
+治
+他
+元
+领
+膜
+遮
+穗
+蛾
+飞
+荒
+棺
+劫
+么
+市
+火
+温
+拈
+棚
+洼
+转
+果
+奕
+卸
+迪
+伸
+泳
+斗
+邡
+侄
+涨
+屯
+萋
+胭
+氡
+崮
+枞
+惧
+冒
+彩
+斜
+手
+豚
+随
+旭
+淑
+妞
+形
+菌
+吲
+沱
+争
+驯
+歹
+挟
+兆
+柱
+传
+至
+包
+内
+响
+临
+红
+功
+弩
+衡
+寂
+禁
+老
+棍
+耆
+渍
+织
+害
+氵
+渑
+布
+载
+靥
+嗬
+虽
+苹
+咨
+娄
+库
+雉
+榜
+帜
+嘲
+套
+瑚
+亲
+簸
+欧
+边
+6
+腿
+旮
+抛
+吹
+瞳
+得
+镓
+梗
+厨
+继
+漾
+愣
+憨
+士
+策
+窑
+抑
+躯
+襟
+脏
+参
+贸
+言
+干
+绸
+鳄
+穷
+藜
+音
+折
+详
+)
+举
+悍
+甸
+癌
+黎
+谴
+死
+罩
+迁
+寒
+驷
+袖
+媒
+蒋
+掘
+模
+纠
+恣
+观
+祖
+蛆
+碍
+位
+稿
+主
+澧
+跌
+筏
+京
+锏
+帝
+贴
+证
+糠
+才
+黄
+鲸
+略
+炯
+饱
+四
+出
+园
+犀
+牧
+容
+汉
+杆
+浈
+汰
+瑷
+造
+虫
+瘩
+怪
+驴
+济
+应
+花
+沣
+谔
+夙
+旅
+价
+矿
+以
+考
+s
+u
+呦
+晒
+巡
+茅
+准
+肟
+瓴
+詹
+仟
+褂
+译
+桌
+混
+宁
+怦
+郑
+抿
+些
+余
+鄂
+饴
+攒
+珑
+群
+阖
+岔
+琨
+藓
+预
+环
+洮
+岌
+宀
+杲
+瀵
+最
+常
+囡
+周
+踊
+女
+鼓
+袭
+喉
+简
+范
+薯
+遐
+疏
+粱
+黜
+禧
+法
+箔
+斤
+遥
+汝
+奥
+直
+贞
+撑
+置
+绱
+集
+她
+馅
+逗
+钧
+橱
+魉
+[
+恙
+躁
+唤
+9
+旺
+膘
+待
+脾
+惫
+购
+吗
+依
+盲
+度
+瘿
+蠖
+俾
+之
+镗
+拇
+鲵
+厝
+簧
+续
+款
+展
+啃
+表
+剔
+品
+钻
+腭
+损
+清
+锶
+统
+涌
+寸
+滨
+贪
+链
+吠
+冈
+伎
+迥
+咏
+吁
+览
+防
+迅
+失
+汾
+阔
+逵
+绀
+蔑
+列
+川
+凭
+努
+熨
+揪
+利
+俱
+绉
+抢
+鸨
+我
+即
+责
+膦
+易
+毓
+鹊
+刹
+玷
+岿
+空
+嘞
+绊
+排
+术
+估
+锷
+违
+们
+苟
+铜
+播
+肘
+件
+烫
+审
+鲂
+广
+像
+铌
+惰
+铟
+巳
+胍
+鲍
+康
+憧
+色
+恢
+想
+拷
+尤
+疳
+知
+S
+Y
+F
+D
+A
+峄
+裕
+帮
+握
+搔
+氐
+氘
+难
+墒
+沮
+雨
+叁
+缥
+悴
+藐
+湫
+娟
+苑
+稠
+颛
+簇
+后
+阕
+闭
+蕤
+缚
+怎
+佞
+码
+嘤
+蔡
+痊
+舱
+螯
+帕
+赫
+昵
+升
+烬
+岫
+、
+疵
+蜻
+髁
+蕨
+隶
+烛
+械
+丑
+盂
+梁
+强
+鲛
+由
+拘
+揉
+劭
+龟
+撤
+钩
+呕
+孛
+费
+妻
+漂
+求
+阑
+崖
+秤
+甘
+通
+深
+补
+赃
+坎
+床
+啪
+承
+吼
+量
+暇
+钼
+烨
+阂
+擎
+脱
+逮
+称
+P
+神
+属
+矗
+华
+届
+狍
+葑
+汹
+育
+患
+窒
+蛰
+佼
+静
+槎
+运
+鳗
+庆
+逝
+曼
+疱
+克
+代
+官
+此
+麸
+耧
+蚌
+晟
+例
+础
+榛
+副
+测
+唰
+缢
+迹
+灬
+霁
+身
+岁
+赭
+扛
+又
+菡
+乜
+雾
+板
+读
+陷
+徉
+贯
+郁
+虑
+变
+钓
+菜
+圾
+现
+琢
+式
+乐
+维
+渔
+浜
+左
+吾
+脑
+钡
+警
+T
+啵
+拴
+偌
+漱
+湿
+硕
+止
+骼
+魄
+积
+燥
+联
+踢
+玛
+则
+窿
+见
+振
+畿
+送
+班
+钽
+您
+赵
+刨
+印
+讨
+踝
+籍
+谡
+舌
+崧
+汽
+蔽
+沪
+酥
+绒
+怖
+财
+帖
+肱
+私
+莎
+勋
+羔
+霸
+励
+哼
+帐
+将
+帅
+渠
+纪
+婴
+娩
+岭
+厘
+滕
+吻
+伤
+坝
+冠
+戊
+隆
+瘁
+介
+涧
+物
+黍
+并
+姗
+奢
+蹑
+掣
+垸
+锴
+命
+箍
+捉
+病
+辖
+琰
+眭
+迩
+艘
+绌
+繁
+寅
+若
+毋
+思
+诉
+类
+诈
+燮
+轲
+酮
+狂
+重
+反
+职
+筱
+县
+委
+磕
+绣
+奖
+晋
+濉
+志
+徽
+肠
+呈
+獐
+坻
+口
+片
+碰
+几
+村
+柿
+劳
+料
+获
+亩
+惕
+晕
+厌
+号
+罢
+池
+正
+鏖
+煨
+家
+棕
+复
+尝
+懋
+蜥
+锅
+岛
+扰
+队
+坠
+瘾
+钬
+@
+卧
+疣
+镇
+譬
+冰
+彷
+频
+黯
+据
+垄
+采
+八
+缪
+瘫
+型
+熹
+砰
+楠
+襁
+箐
+但
+嘶
+绳
+啤
+拍
+盥
+穆
+傲
+洗
+盯
+塘
+怔
+筛
+丿
+台
+恒
+喂
+葛
+永
+¥
+烟
+酒
+桦
+书
+砂
+蚝
+缉
+态
+瀚
+袄
+圳
+轻
+蛛
+超
+榧
+遛
+姒
+奘
+铮
+右
+荽
+望
+偻
+卡
+丶
+氰
+附
+做
+革
+索
+戚
+坨
+桷
+唁
+垅
+榻
+岐
+偎
+坛
+莨
+山
+殊
+微
+骇
+陈
+爨
+推
+嗝
+驹
+澡
+藁
+呤
+卤
+嘻
+糅
+逛
+侵
+郓
+酌
+德
+摇
+※
+鬃
+被
+慨
+殡
+羸
+昌
+泡
+戛
+鞋
+河
+宪
+沿
+玲
+鲨
+翅
+哽
+源
+铅
+语
+照
+邯
+址
+荃
+佬
+顺
+鸳
+町
+霭
+睾
+瓢
+夸
+椁
+晓
+酿
+痈
+咔
+侏
+券
+噎
+湍
+签
+嚷
+离
+午
+尚
+社
+锤
+背
+孟
+使
+浪
+缦
+潍
+鞅
+军
+姹
+驶
+笑
+鳟
+鲁
+》
+孽
+钜
+绿
+洱
+礴
+焯
+椰
+颖
+囔
+乌
+孔
+巴
+互
+性
+椽
+哞
+聘
+昨
+早
+暮
+胶
+炀
+隧
+低
+彗
+昝
+铁
+呓
+氽
+藉
+喔
+癖
+瑗
+姨
+权
+胱
+韦
+堑
+蜜
+酋
+楝
+砝
+毁
+靓
+歙
+锲
+究
+屋
+喳
+骨
+辨
+碑
+武
+鸠
+宫
+辜
+烊
+适
+坡
+殃
+培
+佩
+供
+走
+蜈
+迟
+翼
+况
+姣
+凛
+浔
+吃
+飘
+债
+犟
+金
+促
+苛
+崇
+坂
+莳
+畔
+绂
+兵
+蠕
+斋
+根
+砍
+亢
+欢
+恬
+崔
+剁
+餐
+榫
+快
+扶
+‖
+濒
+缠
+鳜
+当
+彭
+驭
+浦
+篮
+昀
+锆
+秸
+钳
+弋
+娣
+瞑
+夷
+龛
+苫
+拱
+致
+%
+嵊
+障
+隐
+弑
+初
+娓
+抉
+汩
+累
+蓖
+"
+唬
+助
+苓
+昙
+押
+毙
+破
+城
+郧
+逢
+嚏
+獭
+瞻
+溱
+婿
+赊
+跨
+恼
+璧
+萃
+姻
+貉
+灵
+炉
+密
+氛
+陶
+砸
+谬
+衔
+点
+琛
+沛
+枳
+层
+岱
+诺
+脍
+榈
+埂
+征
+冷
+裁
+打
+蹴
+素
+瘘
+逞
+蛐
+聊
+激
+腱
+萘
+踵
+飒
+蓟
+吆
+取
+咙
+簋
+涓
+矩
+曝
+挺
+揣
+座
+你
+史
+舵
+焱
+尘
+苏
+笈
+脚
+溉
+榨
+诵
+樊
+邓
+焊
+义
+庶
+儋
+蟋
+蒲
+赦
+呷
+杞
+诠
+豪
+还
+试
+颓
+茉
+太
+除
+紫
+逃
+痴
+草
+充
+鳕
+珉
+祗
+墨
+渭
+烩
+蘸
+慕
+璇
+镶
+穴
+嵘
+恶
+骂
+险
+绋
+幕
+碉
+肺
+戳
+刘
+潞
+秣
+纾
+潜
+銮
+洛
+须
+罘
+销
+瘪
+汞
+兮
+屉
+r
+林
+厕
+质
+探
+划
+狸
+殚
+善
+煊
+烹
+〒
+锈
+逯
+宸
+辍
+泱
+柚
+袍
+远
+蹋
+嶙
+绝
+峥
+娥
+缍
+雀
+徵
+认
+镱
+谷
+=
+贩
+勉
+撩
+鄯
+斐
+洋
+非
+祚
+泾
+诒
+饿
+撬
+威
+晷
+搭
+芍
+锥
+笺
+蓦
+候
+琊
+档
+礁
+沼
+卵
+荠
+忑
+朝
+凹
+瑞
+头
+仪
+弧
+孵
+畏
+铆
+突
+衲
+车
+浩
+气
+茂
+悖
+厢
+枕
+酝
+戴
+湾
+邹
+飚
+攘
+锂
+写
+宵
+翁
+岷
+无
+喜
+丈
+挑
+嗟
+绛
+殉
+议
+槽
+具
+醇
+淞
+笃
+郴
+阅
+饼
+底
+壕
+砚
+弈
+询
+缕
+庹
+翟
+零
+筷
+暨
+舟
+闺
+甯
+撞
+麂
+茌
+蔼
+很
+珲
+捕
+棠
+角
+阉
+媛
+娲
+诽
+剿
+尉
+爵
+睬
+韩
+诰
+匣
+危
+糍
+镯
+立
+浏
+阳
+少
+盆
+舔
+擘
+匪
+申
+尬
+铣
+旯
+抖
+赘
+瓯
+居
+哮
+游
+锭
+茏
+歌
+坏
+甚
+秒
+舞
+沙
+仗
+劲
+潺
+阿
+燧
+郭
+嗖
+霏
+忠
+材
+奂
+耐
+跺
+砀
+输
+岖
+媳
+氟
+极
+摆
+灿
+今
+扔
+腻
+枝
+奎
+药
+熄
+吨
+话
+q
+额
+慑
+嘌
+协
+喀
+壳
+埭
+视
+著
+於
+愧
+陲
+翌
+峁
+颅
+佛
+腹
+聋
+侯
+咎
+叟
+秀
+颇
+存
+较
+罪
+哄
+岗
+扫
+栏
+钾
+羌
+己
+璨
+枭
+霉
+煌
+涸
+衿
+键
+镝
+益
+岢
+奏
+连
+夯
+睿
+冥
+均
+糖
+狞
+蹊
+稻
+爸
+刿
+胥
+煜
+丽
+肿
+璃
+掸
+跚
+灾
+垂
+樾
+濑
+乎
+莲
+窄
+犹
+撮
+战
+馄
+软
+络
+显
+鸢
+胸
+宾
+妲
+恕
+埔
+蝌
+份
+遇
+巧
+瞟
+粒
+恰
+剥
+桡
+博
+讯
+凯
+堇
+阶
+滤
+卖
+斌
+骚
+彬
+兑
+磺
+樱
+舷
+两
+娱
+福
+仃
+差
+找
+桁
+净
+把
+阴
+污
+戬
+雷
+碓
+蕲
+楚
+罡
+焖
+抽
+妫
+咒
+仑
+闱
+尽
+邑
+菁
+爱
+贷
+沥
+鞑
+牡
+嗉
+崴
+骤
+塌
+嗦
+订
+拮
+滓
+捡
+锻
+次
+坪
+杩
+臃
+箬
+融
+珂
+鹗
+宗
+枚
+降
+鸬
+妯
+阄
+堰
+盐
+毅
+必
+杨
+崃
+俺
+甬
+状
+莘
+货
+耸
+菱
+腼
+铸
+唏
+痤
+孚
+澳
+懒
+溅
+翘
+疙
+杷
+淼
+缙
+骰
+喊
+悉
+砻
+坷
+艇
+赁
+界
+谤
+纣
+宴
+晃
+茹
+归
+饭
+梢
+铡
+街
+抄
+肼
+鬟
+苯
+颂
+撷
+戈
+炒
+咆
+茭
+瘙
+负
+仰
+客
+琉
+铢
+封
+卑
+珥
+椿
+镧
+窨
+鬲
+寿
+御
+袤
+铃
+萎
+砖
+餮
+脒
+裳
+肪
+孕
+嫣
+馗
+嵇
+恳
+氯
+江
+石
+褶
+冢
+祸
+阻
+狈
+羞
+银
+靳
+透
+咳
+叼
+敷
+芷
+啥
+它
+瓤
+兰
+痘
+懊
+逑
+肌
+往
+捺
+坊
+甩
+呻
+〃
+沦
+忘
+膻
+祟
+菅
+剧
+崆
+智
+坯
+臧
+霍
+墅
+攻
+眯
+倘
+拢
+骠
+铐
+庭
+岙
+瓠
+′
+缺
+泥
+迢
+捶
+?
+?
+郏
+喙
+掷
+沌
+纯
+秘
+种
+听
+绘
+固
+螨
+团
+香
+盗
+妒
+埚
+蓝
+拖
+旱
+荞
+铀
+血
+遏
+汲
+辰
+叩
+拽
+幅
+硬
+惶
+桀
+漠
+措
+泼
+唑
+齐
+肾
+念
+酱
+虚
+屁
+耶
+旗
+砦
+闵
+婉
+馆
+拭
+绅
+韧
+忏
+窝
+醋
+葺
+顾
+辞
+倜
+堆
+辋
+逆
+玟
+贱
+疾
+董
+惘
+倌
+锕
+淘
+嘀
+莽
+俭
+笏
+绑
+鲷
+杈
+择
+蟀
+粥
+嗯
+驰
+逾
+案
+谪
+褓
+胫
+哩
+昕
+颚
+鲢
+绠
+躺
+鹄
+崂
+儒
+俨
+丝
+尕
+泌
+啊
+萸
+彰
+幺
+吟
+骄
+苣
+弦
+脊
+瑰
+〈
+诛
+镁
+析
+闪
+剪
+侧
+哟
+框
+螃
+守
+嬗
+燕
+狭
+铈
+缮
+概
+迳
+痧
+鲲
+俯
+售
+笼
+痣
+扉
+挖
+满
+咋
+援
+邱
+扇
+歪
+便
+玑
+绦
+峡
+蛇
+叨
+〖
+泽
+胃
+斓
+喋
+怂
+坟
+猪
+该
+蚬
+炕
+弥
+赞
+棣
+晔
+娠
+挲
+狡
+创
+疖
+铕
+镭
+稷
+挫
+弭
+啾
+翔
+粉
+履
+苘
+哦
+楼
+秕
+铂
+土
+锣
+瘟
+挣
+栉
+习
+享
+桢
+袅
+磨
+桂
+谦
+延
+坚
+蔚
+噗
+署
+谟
+猬
+钎
+恐
+嬉
+雒
+倦
+衅
+亏
+璩
+睹
+刻
+殿
+王
+算
+雕
+麻
+丘
+柯
+骆
+丸
+塍
+谚
+添
+鲈
+垓
+桎
+蚯
+芥
+予
+飕
+镦
+谌
+窗
+醚
+菀
+亮
+搪
+莺
+蒿
+羁
+足
+J
+真
+轶
+悬
+衷
+靛
+翊
+掩
+哒
+炅
+掐
+冼
+妮
+l
+谐
+稚
+荆
+擒
+犯
+陵
+虏
+浓
+崽
+刍
+陌
+傻
+孜
+千
+靖
+演
+矜
+钕
+煽
+杰
+酗
+渗
+伞
+栋
+俗
+泫
+戍
+罕
+沾
+疽
+灏
+煦
+芬
+磴
+叱
+阱
+榉
+湃
+蜀
+叉
+醒
+彪
+租
+郡
+篷
+屎
+良
+垢
+隗
+弱
+陨
+峪
+砷
+掴
+颁
+胎
+雯
+绵
+贬
+沐
+撵
+隘
+篙
+暖
+曹
+陡
+栓
+填
+臼
+彦
+瓶
+琪
+潼
+哪
+鸡
+摩
+啦
+俟
+锋
+域
+耻
+蔫
+疯
+纹
+撇
+毒
+绶
+痛
+酯
+忍
+爪
+赳
+歆
+嘹
+辕
+烈
+册
+朴
+钱
+吮
+毯
+癜
+娃
+谀
+邵
+厮
+炽
+璞
+邃
+丐
+追
+词
+瓒
+忆
+轧
+芫
+谯
+喷
+弟
+半
+冕
+裙
+掖
+墉
+绮
+寝
+苔
+势
+顷
+褥
+切
+衮
+君
+佳
+嫒
+蚩
+霞
+佚
+洙
+逊
+镖
+暹
+唛
+&
+殒
+顶
+碗
+獗
+轭
+铺
+蛊
+废
+恹
+汨
+崩
+珍
+那
+杵
+曲
+纺
+夏
+薰
+傀
+闳
+淬
+姘
+舀
+拧
+卷
+楂
+恍
+讪
+厩
+寮
+篪
+赓
+乘
+灭
+盅
+鞣
+沟
+慎
+挂
+饺
+鼾
+杳
+树
+缨
+丛
+絮
+娌
+臻
+嗳
+篡
+侩
+述
+衰
+矛
+圈
+蚜
+匕
+筹
+匿
+濞
+晨
+叶
+骋
+郝
+挚
+蚴
+滞
+增
+侍
+描
+瓣
+吖
+嫦
+蟒
+匾
+圣
+赌
+毡
+癞
+恺
+百
+曳
+需
+篓
+肮
+庖
+帏
+卿
+驿
+遗
+蹬
+鬓
+骡
+歉
+芎
+胳
+屐
+禽
+烦
+晌
+寄
+媾
+狄
+翡
+苒
+船
+廉
+终
+痞
+殇
+々
+畦
+饶
+改
+拆
+悻
+萄
+£
+瓿
+乃
+訾
+桅
+匮
+溧
+拥
+纱
+铍
+骗
+蕃
+龋
+缬
+父
+佐
+疚
+栎
+醍
+掳
+蓄
+x
+惆
+颜
+鲆
+榆
+〔
+猎
+敌
+暴
+谥
+鲫
+贾
+罗
+玻
+缄
+扦
+芪
+癣
+落
+徒
+臾
+恿
+猩
+托
+邴
+肄
+牵
+春
+陛
+耀
+刊
+拓
+蓓
+邳
+堕
+寇
+枉
+淌
+啡
+湄
+兽
+酷
+萼
+碚
+濠
+萤
+夹
+旬
+戮
+梭
+琥
+椭
+昔
+勺
+蜊
+绐
+晚
+孺
+僵
+宣
+摄
+冽
+旨
+萌
+忙
+蚤
+眉
+噼
+蟑
+付
+契
+瓜
+悼
+颡
+壁
+曾
+窕
+颢
+澎
+仿
+俑
+浑
+嵌
+浣
+乍
+碌
+褪
+乱
+蔟
+隙
+玩
+剐
+葫
+箫
+纲
+围
+伐
+决
+伙
+漩
+瑟
+刑
+肓
+镳
+缓
+蹭
+氨
+皓
+典
+畲
+坍
+铑
+檐
+塑
+洞
+倬
+储
+胴
+淳
+戾
+吐
+灼
+惺
+妙
+毕
+珐
+缈
+虱
+盖
+羰
+鸿
+磅
+谓
+髅
+娴
+苴
+唷
+蚣
+霹
+抨
+贤
+唠
+犬
+誓
+逍
+庠
+逼
+麓
+籼
+釉
+呜
+碧
+秧
+氩
+摔
+霄
+穸
+纨
+辟
+妈
+映
+完
+牛
+缴
+嗷
+炊
+恩
+荔
+茆
+掉
+紊
+慌
+莓
+羟
+阙
+萁
+磐
+另
+蕹
+辱
+鳐
+湮
+吡
+吩
+唐
+睦
+垠
+舒
+圜
+冗
+瞿
+溺
+芾
+囱
+匠
+僳
+汐
+菩
+饬
+漓
+黑
+霰
+浸
+濡
+窥
+毂
+蒡
+兢
+驻
+鹉
+芮
+诙
+迫
+雳
+厂
+忐
+臆
+猴
+鸣
+蚪
+栈
+箕
+羡
+渐
+莆
+捍
+眈
+哓
+趴
+蹼
+埕
+嚣
+骛
+宏
+淄
+斑
+噜
+严
+瑛
+垃
+椎
+诱
+压
+庾
+绞
+焘
+廿
+抡
+迄
+棘
+夫
+纬
+锹
+眨
+瞌
+侠
+脐
+竞
+瀑
+孳
+骧
+遁
+姜
+颦
+荪
+滚
+萦
+伪
+逸
+粳
+爬
+锁
+矣
+役
+趣
+洒
+颔
+诏
+逐
+奸
+甭
+惠
+攀
+蹄
+泛
+尼
+拼
+阮
+鹰
+亚
+颈
+惑
+勒
+〉
+际
+肛
+爷
+刚
+钨
+丰
+养
+冶
+鲽
+辉
+蔻
+画
+覆
+皴
+妊
+麦
+返
+醉
+皂
+擀
+〗
+酶
+凑
+粹
+悟
+诀
+硖
+港
+卜
+z
+杀
+涕
+舍
+铠
+抵
+弛
+段
+敝
+镐
+奠
+拂
+轴
+跛
+袱
+e
+t
+沉
+菇
+俎
+薪
+峦
+秭
+蟹
+历
+盟
+菠
+寡
+液
+肢
+喻
+染
+裱
+悱
+抱
+氙
+赤
+捅
+猛
+跑
+氮
+谣
+仁
+尺
+辊
+窍
+烙
+衍
+架
+擦
+倏
+璐
+瑁
+币
+楞
+胖
+夔
+趸
+邛
+惴
+饕
+虔
+蝎
+哉
+贝
+宽
+辫
+炮
+扩
+饲
+籽
+魏
+菟
+锰
+伍
+猝
+末
+琳
+哚
+蛎
+邂
+呀
+姿
+鄞
+却
+歧
+仙
+恸
+椐
+森
+牒
+寤
+袒
+婆
+虢
+雅
+钉
+朵
+贼
+欲
+苞
+寰
+故
+龚
+坭
+嘘
+咫
+礼
+硷
+兀
+睢
+汶
+’
+铲
+烧
+绕
+诃
+浃
+钿
+哺
+柜
+讼
+颊
+璁
+腔
+洽
+咐
+脲
+簌
+筠
+镣
+玮
+鞠
+谁
+兼
+姆
+挥
+梯
+蝴
+谘
+漕
+刷
+躏
+宦
+弼
+b
+垌
+劈
+麟
+莉
+揭
+笙
+渎
+仕
+嗤
+仓
+配
+怏
+抬
+错
+泯
+镊
+孰
+猿
+邪
+仍
+秋
+鼬
+壹
+歇
+吵
+炼
+<
+尧
+射
+柬
+廷
+胧
+霾
+凳
+隋
+肚
+浮
+梦
+祥
+株
+堵
+退
+L
+鹫
+跎
+凶
+毽
+荟
+炫
+栩
+玳
+甜
+沂
+鹿
+顽
+伯
+爹
+赔
+蛴
+徐
+匡
+欣
+狰
+缸
+雹
+蟆
+疤
+默
+沤
+啜
+痂
+衣
+禅
+w
+i
+h
+辽
+葳
+黝
+钗
+停
+沽
+棒
+馨
+颌
+肉
+吴
+硫
+悯
+劾
+娈
+马
+啧
+吊
+悌
+镑
+峭
+帆
+瀣
+涉
+咸
+疸
+滋
+泣
+翦
+拙
+癸
+钥
+蜒
++
+尾
+庄
+凝
+泉
+婢
+渴
+谊
+乞
+陆
+锉
+糊
+鸦
+淮
+I
+B
+N
+晦
+弗
+乔
+庥
+葡
+尻
+席
+橡
+傣
+渣
+拿
+惩
+麋
+斛
+缃
+矮
+蛏
+岘
+鸽
+姐
+膏
+催
+奔
+镒
+喱
+蠡
+摧
+钯
+胤
+柠
+拐
+璋
+鸥
+卢
+荡
+倾
+^
+_
+珀
+逄
+萧
+塾
+掇
+贮
+笆
+聂
+圃
+冲
+嵬
+M
+滔
+笕
+值
+炙
+偶
+蜱
+搐
+梆
+汪
+蔬
+腑
+鸯
+蹇
+敞
+绯
+仨
+祯
+谆
+梧
+糗
+鑫
+啸
+豺
+囹
+猾
+巢
+柄
+瀛
+筑
+踌
+沭
+暗
+苁
+鱿
+蹉
+脂
+蘖
+牢
+热
+木
+吸
+溃
+宠
+序
+泞
+偿
+拜
+檩
+厚
+朐
+毗
+螳
+吞
+媚
+朽
+担
+蝗
+橘
+畴
+祈
+糟
+盱
+隼
+郜
+惜
+珠
+裨
+铵
+焙
+琚
+唯
+咚
+噪
+骊
+丫
+滢
+勤
+棉
+呸
+咣
+淀
+隔
+蕾
+窈
+饨
+挨
+煅
+短
+匙
+粕
+镜
+赣
+撕
+墩
+酬
+馁
+豌
+颐
+抗
+酣
+氓
+佑
+搁
+哭
+递
+耷
+涡
+桃
+贻
+碣
+截
+瘦
+昭
+镌
+蔓
+氚
+甲
+猕
+蕴
+蓬
+散
+拾
+纛
+狼
+猷
+铎
+埋
+旖
+矾
+讳
+囊
+糜
+迈
+粟
+蚂
+紧
+鲳
+瘢
+栽
+稼
+羊
+锄
+斟
+睁
+桥
+瓮
+蹙
+祉
+醺
+鼻
+昱
+剃
+跳
+篱
+跷
+蒜
+翎
+宅
+晖
+嗑
+壑
+峻
+癫
+屏
+狠
+陋
+袜
+途
+憎
+祀
+莹
+滟
+佶
+溥
+臣
+约
+盛
+峰
+磁
+慵
+婪
+拦
+莅
+朕
+鹦
+粲
+裤
+哎
+疡
+嫖
+琵
+窟
+堪
+谛
+嘉
+儡
+鳝
+斩
+郾
+驸
+酊
+妄
+胜
+贺
+徙
+傅
+噌
+钢
+栅
+庇
+恋
+匝
+巯
+邈
+尸
+锚
+粗
+佟
+蛟
+薹
+纵
+蚊
+郅
+绢
+锐
+苗
+俞
+篆
+淆
+膀
+鲜
+煎
+诶
+秽
+寻
+涮
+刺
+怀
+噶
+巨
+褰
+魅
+灶
+灌
+桉
+藕
+谜
+舸
+薄
+搀
+恽
+借
+牯
+痉
+渥
+愿
+亓
+耘
+杠
+柩
+锔
+蚶
+钣
+珈
+喘
+蹒
+幽
+赐
+稗
+晤
+莱
+泔
+扯
+肯
+菪
+裆
+腩
+豉
+疆
+骜
+腐
+倭
+珏
+唔
+粮
+亡
+润
+慰
+伽
+橄
+玄
+誉
+醐
+胆
+龊
+粼
+塬
+陇
+彼
+削
+嗣
+绾
+芽
+妗
+垭
+瘴
+爽
+薏
+寨
+龈
+泠
+弹
+赢
+漪
+猫
+嘧
+涂
+恤
+圭
+茧
+烽
+屑
+痕
+巾
+赖
+荸
+凰
+腮
+畈
+亵
+蹲
+偃
+苇
+澜
+艮
+换
+骺
+烘
+苕
+梓
+颉
+肇
+哗
+悄
+氤
+涠
+葬
+屠
+鹭
+植
+竺
+佯
+诣
+鲇
+瘀
+鲅
+邦
+移
+滁
+冯
+耕
+癔
+戌
+茬
+沁
+巩
+悠
+湘
+洪
+痹
+锟
+循
+谋
+腕
+鳃
+钠
+捞
+焉
+迎
+碱
+伫
+急
+榷
+奈
+邝
+卯
+辄
+皲
+卟
+醛
+畹
+忧
+稳
+雄
+昼
+缩
+阈
+睑
+扌
+耗
+曦
+涅
+捏
+瞧
+邕
+淖
+漉
+铝
+耦
+禹
+湛
+喽
+莼
+琅
+诸
+苎
+纂
+硅
+始
+嗨
+傥
+燃
+臂
+赅
+嘈
+呆
+贵
+屹
+壮
+肋
+亍
+蚀
+卅
+豹
+腆
+邬
+迭
+浊
+}
+童
+螂
+捐
+圩
+勐
+触
+寞
+汊
+壤
+荫
+膺
+渌
+芳
+懿
+遴
+螈
+泰
+蓼
+蛤
+茜
+舅
+枫
+朔
+膝
+眙
+避
+梅
+判
+鹜
+璜
+牍
+缅
+垫
+藻
+黔
+侥
+惚
+懂
+踩
+腰
+腈
+札
+丞
+唾
+慈
+顿
+摹
+荻
+琬
+~
+斧
+沈
+滂
+胁
+胀
+幄
+莜
+Z
+匀
+鄄
+掌
+绰
+茎
+焚
+赋
+萱
+谑
+汁
+铒
+瞎
+夺
+蜗
+野
+娆
+冀
+弯
+篁
+懵
+灞
+隽
+芡
+脘
+俐
+辩
+芯
+掺
+喏
+膈
+蝈
+觐
+悚
+踹
+蔗
+熠
+鼠
+呵
+抓
+橼
+峨
+畜
+缔
+禾
+崭
+弃
+熊
+摒
+凸
+拗
+穹
+蒙
+抒
+祛
+劝
+闫
+扳
+阵
+醌
+踪
+喵
+侣
+搬
+仅
+荧
+赎
+蝾
+琦
+买
+婧
+瞄
+寓
+皎
+冻
+赝
+箩
+莫
+瞰
+郊
+笫
+姝
+筒
+枪
+遣
+煸
+袋
+舆
+痱
+涛
+母
+〇
+启
+践
+耙
+绲
+盘
+遂
+昊
+搞
+槿
+诬
+纰
+泓
+惨
+檬
+亻
+越
+C
+o
+憩
+熵
+祷
+钒
+暧
+塔
+阗
+胰
+咄
+娶
+魔
+琶
+钞
+邻
+扬
+杉
+殴
+咽
+弓
+〆
+髻
+】
+吭
+揽
+霆
+拄
+殖
+脆
+彻
+岩
+芝
+勃
+辣
+剌
+钝
+嘎
+甄
+佘
+皖
+伦
+授
+徕
+憔
+挪
+皇
+庞
+稔
+芜
+踏
+溴
+兖
+卒
+擢
+饥
+鳞
+煲
+‰
+账
+颗
+叻
+斯
+捧
+鳍
+琮
+讹
+蛙
+纽
+谭
+酸
+兔
+莒
+睇
+伟
+觑
+羲
+嗜
+宜
+褐
+旎
+辛
+卦
+诘
+筋
+鎏
+溪
+挛
+熔
+阜
+晰
+鳅
+丢
+奚
+灸
+呱
+献
+陉
+黛
+鸪
+甾
+萨
+疮
+拯
+洲
+疹
+辑
+叙
+恻
+谒
+允
+柔
+烂
+氏
+逅
+漆
+拎
+惋
+扈
+湟
+纭
+啕
+掬
+擞
+哥
+忽
+涤
+鸵
+靡
+郗
+瓷
+扁
+廊
+怨
+雏
+钮
+敦
+E
+懦
+憋
+汀
+拚
+啉
+腌
+岸
+f
+痼
+瞅
+尊
+咀
+眩
+飙
+忌
+仝
+迦
+熬
+毫
+胯
+篑
+茄
+腺
+凄
+舛
+碴
+锵
+诧
+羯
+後
+漏
+汤
+宓
+仞
+蚁
+壶
+谰
+皑
+铄
+棰
+罔
+辅
+晶
+苦
+牟
+闽
+\
+烃
+饮
+聿
+丙
+蛳
+朱
+煤
+涔
+鳖
+犁
+罐
+荼
+砒
+淦
+妤
+黏
+戎
+孑
+婕
+瑾
+戢
+钵
+枣
+捋
+砥
+衩
+狙
+桠
+稣
+阎
+肃
+梏
+诫
+孪
+昶
+婊
+衫
+嗔
+侃
+塞
+蜃
+樵
+峒
+貌
+屿
+欺
+缫
+阐
+栖
+诟
+珞
+荭
+吝
+萍
+嗽
+恂
+啻
+蜴
+磬
+峋
+俸
+豫
+谎
+徊
+镍
+韬
+魇
+晴
+U
+囟
+猜
+蛮
+坐
+囿
+伴
+亭
+肝
+佗
+蝠
+妃
+胞
+滩
+榴
+氖
+垩
+苋
+砣
+扪
+馏
+姓
+轩
+厉
+夥
+侈
+禀
+垒
+岑
+赏
+钛
+辐
+痔
+披
+纸
+碳
+“
+坞
+蠓
+挤
+荥
+沅
+悔
+铧
+帼
+蒌
+蝇
+a
+p
+y
+n
+g
+哀
+浆
+瑶
+凿
+桶
+馈
+皮
+奴
+苜
+佤
+伶
+晗
+铱
+炬
+优
+弊
+氢
+恃
+甫
+攥
+端
+锌
+灰
+稹
+炝
+曙
+邋
+亥
+眶
+碾
+拉
+萝
+绔
+捷
+浍
+腋
+姑
+菖
+凌
+涞
+麽
+锢
+桨
+潢
+绎
+镰
+殆
+锑
+渝
+铬
+困
+绽
+觎
+匈
+糙
+暑
+裹
+鸟
+盔
+肽
+迷
+綦
+『
+亳
+佝
+俘
+钴
+觇
+骥
+仆
+疝
+跪
+婶
+郯
+瀹
+唉
+脖
+踞
+针
+晾
+忒
+扼
+瞩
+叛
+椒
+疟
+嗡
+邗
+肆
+跆
+玫
+忡
+捣
+咧
+唆
+艄
+蘑
+潦
+笛
+阚
+沸
+泻
+掊
+菽
+贫
+斥
+髂
+孢
+镂
+赂
+麝
+鸾
+屡
+衬
+苷
+恪
+叠
+希
+粤
+爻
+喝
+茫
+惬
+郸
+绻
+庸
+撅
+碟
+宄
+妹
+膛
+叮
+饵
+崛
+嗲
+椅
+冤
+搅
+咕
+敛
+尹
+垦
+闷
+蝉
+霎
+勰
+败
+蓑
+泸
+肤
+鹌
+幌
+焦
+浠
+鞍
+刁
+舰
+乙
+竿
+裔
+。
+茵
+函
+伊
+兄
+丨
+娜
+匍
+謇
+莪
+宥
+似
+蝽
+翳
+酪
+翠
+粑
+薇
+祢
+骏
+赠
+叫
+Q
+噤
+噻
+竖
+芗
+莠
+潭
+俊
+羿
+耜
+O
+郫
+趁
+嗪
+囚
+蹶
+芒
+洁
+笋
+鹑
+敲
+硝
+啶
+堡
+渲
+揩
+』
+携
+宿
+遒
+颍
+扭
+棱
+割
+萜
+蔸
+葵
+琴
+捂
+饰
+衙
+耿
+掠
+募
+岂
+窖
+涟
+蔺
+瘤
+柞
+瞪
+怜
+匹
+距
+楔
+炜
+哆
+秦
+缎
+幼
+茁
+绪
+痨
+恨
+楸
+娅
+瓦
+桩
+雪
+嬴
+伏
+榔
+妥
+铿
+拌
+眠
+雍
+缇
+‘
+卓
+搓
+哌
+觞
+噩
+屈
+哧
+髓
+咦
+巅
+娑
+侑
+淫
+膳
+祝
+勾
+姊
+莴
+胄
+疃
+薛
+蜷
+胛
+巷
+芙
+芋
+熙
+闰
+勿
+窃
+狱
+剩
+钏
+幢
+陟
+铛
+慧
+靴
+耍
+k
+浙
+浇
+飨
+惟
+绗
+祜
+澈
+啼
+咪
+磷
+摞
+诅
+郦
+抹
+跃
+壬
+吕
+肖
+琏
+颤
+尴
+剡
+抠
+凋
+赚
+泊
+津
+宕
+殷
+倔
+氲
+漫
+邺
+涎
+怠
+$
+垮
+荬
+遵
+俏
+叹
+噢
+饽
+蜘
+孙
+筵
+疼
+鞭
+羧
+牦
+箭
+潴
+c
+眸
+祭
+髯
+啖
+坳
+愁
+芩
+驮
+倡
+巽
+穰
+沃
+胚
+怒
+凤
+槛
+剂
+趵
+嫁
+v
+邢
+灯
+鄢
+桐
+睽
+檗
+锯
+槟
+婷
+嵋
+圻
+诗
+蕈
+颠
+遭
+痢
+芸
+怯
+馥
+竭
+锗
+徜
+恭
+遍
+籁
+剑
+嘱
+苡
+龄
+僧
+桑
+潸
+弘
+澶
+楹
+悲
+讫
+愤
+腥
+悸
+谍
+椹
+呢
+桓
+葭
+攫
+阀
+翰
+躲
+敖
+柑
+郎
+笨
+橇
+呃
+魁
+燎
+脓
+葩
+磋
+垛
+玺
+狮
+沓
+砜
+蕊
+锺
+罹
+蕉
+翱
+虐
+闾
+巫
+旦
+茱
+嬷
+枯
+鹏
+贡
+芹
+汛
+矫
+绁
+拣
+禺
+佃
+讣
+舫
+惯
+乳
+趋
+疲
+挽
+岚
+虾
+衾
+蠹
+蹂
+飓
+氦
+铖
+孩
+稞
+瑜
+壅
+掀
+勘
+妓
+畅
+髋
+W
+庐
+牲
+蓿
+榕
+练
+垣
+唱
+邸
+菲
+昆
+婺
+穿
+绡
+麒
+蚱
+掂
+愚
+泷
+涪
+漳
+妩
+娉
+榄
+讷
+觅
+旧
+藤
+煮
+呛
+柳
+腓
+叭
+庵
+烷
+阡
+罂
+蜕
+擂
+猖
+咿
+媲
+脉
+【
+沏
+貅
+黠
+熏
+哲
+烁
+坦
+酵
+兜
+潇
+撒
+剽
+珩
+圹
+乾
+摸
+樟
+帽
+嗒
+襄
+魂
+轿
+憬
+锡
+〕
+喃
+皆
+咖
+隅
+脸
+残
+泮
+袂
+鹂
+珊
+囤
+捆
+咤
+误
+徨
+闹
+淙
+芊
+淋
+怆
+囗
+拨
+梳
+渤
+R
+G
+绨
+蚓
+婀
+幡
+狩
+麾
+谢
+唢
+裸
+旌
+伉
+纶
+裂
+驳
+砼
+咛
+澄
+樨
+蹈
+宙
+澍
+倍
+貔
+操
+勇
+蟠
+摈
+砧
+虬
+够
+缁
+悦
+藿
+撸
+艹
+摁
+淹
+豇
+虎
+榭
+吱
+d
+喧
+荀
+踱
+侮
+奋
+偕
+饷
+犍
+惮
+坑
+璎
+徘
+宛
+妆
+袈
+倩
+窦
+昂
+荏
+乖
+K
+怅
+撰
+鳙
+牙
+袁
+酞
+X
+痿
+琼
+闸
+雁
+趾
+荚
+虻
+涝
+《
+杏
+韭
+偈
+烤
+绫
+鞘
+卉
+症
+遢
+蓥
+诋
+杭
+荨
+匆
+竣
+簪
+辙
+敕
+虞
+丹
+缭
+咩
+黟
+m
+淤
+瑕
+咂
+铉
+硼
+茨
+嶂
+痒
+畸
+敬
+涿
+粪
+窘
+熟
+叔
+嫔
+盾
+忱
+裘
+憾
+梵
+赡
+珙
+咯
+娘
+庙
+溯
+胺
+葱
+痪
+摊
+荷
+卞
+乒
+髦
+寐
+铭
+坩
+胗
+枷
+爆
+溟
+嚼
+羚
+砬
+轨
+惊
+挠
+罄
+竽
+菏
+氧
+浅
+楣
+盼
+枢
+炸
+阆
+杯
+谏
+噬
+淇
+渺
+俪
+秆
+墓
+泪
+跻
+砌
+痰
+垡
+渡
+耽
+釜
+讶
+鳎
+煞
+呗
+韶
+舶
+绷
+鹳
+缜
+旷
+铊
+皱
+龌
+檀
+霖
+奄
+槐
+艳
+蝶
+旋
+哝
+赶
+骞
+蚧
+腊
+盈
+丁
+`
+蜚
+矸
+蝙
+睨
+嚓
+僻
+鬼
+醴
+夜
+彝
+磊
+笔
+拔
+栀
+糕
+厦
+邰
+纫
+逭
+纤
+眦
+膊
+馍
+躇
+烯
+蘼
+冬
+诤
+暄
+骶
+哑
+瘠
+」
+臊
+丕
+愈
+咱
+螺
+擅
+跋
+搏
+硪
+谄
+笠
+淡
+嘿
+骅
+谧
+鼎
+皋
+姚
+歼
+蠢
+驼
+耳
+胬
+挝
+涯
+狗
+蒽
+孓
+犷
+凉
+芦
+箴
+铤
+孤
+嘛
+坤
+V
+茴
+朦
+挞
+尖
+橙
+诞
+搴
+碇
+洵
+浚
+帚
+蜍
+漯
+柘
+嚎
+讽
+芭
+荤
+咻
+祠
+秉
+跖
+埃
+吓
+糯
+眷
+馒
+惹
+娼
+鲑
+嫩
+讴
+轮
+瞥
+靶
+褚
+乏
+缤
+宋
+帧
+删
+驱
+碎
+扑
+俩
+俄
+偏
+涣
+竹
+噱
+皙
+佰
+渚
+唧
+斡
+#
+镉
+刀
+崎
+筐
+佣
+夭
+贰
+肴
+峙
+哔
+艿
+匐
+牺
+镛
+缘
+仡
+嫡
+劣
+枸
+堀
+梨
+簿
+鸭
+蒸
+亦
+稽
+浴
+{
+衢
+束
+槲
+j
+阁
+揍
+疥
+棋
+潋
+聪
+窜
+乓
+睛
+插
+冉
+阪
+苍
+搽
+「
+蟾
+螟
+幸
+仇
+樽
+撂
+慢
+跤
+幔
+俚
+淅
+覃
+觊
+溶
+妖
+帛
+侨
+曰
+妾
+泗
+:
+瀘
+風
+(
+)
+∶
+紅
+紗
+瑭
+雲
+頭
+鶏
+財
+許
+•
+樂
+焗
+麗
+—
+;
+滙
+東
+榮
+繪
+興
+…
+門
+業
+楊
+國
+顧
+盤
+寳
+龍
+鳳
+島
+誌
+緣
+結
+銭
+萬
+勝
+祎
+璟
+優
+歡
+臨
+時
+購
+=
+★
+藍
+昇
+鐵
+觀
+勅
+農
+聲
+畫
+兿
+術
+發
+劉
+記
+專
+耑
+園
+書
+壴
+種
+●
+褀
+號
+銀
+匯
+敟
+锘
+葉
+橪
+廣
+進
+蒄
+鑽
+阝
+祙
+貢
+鍋
+豊
+夬
+喆
+團
+閣
+開
+燁
+賓
+館
+酡
+沔
+順
++
+硚
+劵
+饸
+陽
+車
+湓
+復
+萊
+氣
+軒
+華
+堃
+迮
+纟
+戶
+馬
+學
+裡
+電
+嶽
+獨
+マ
+シ
+サ
+ジ
+燘
+袪
+環
+❤
+臺
+灣
+専
+賣
+孖
+聖
+攝
+線
+▪
+傢
+俬
+夢
+達
+莊
+喬
+貝
+薩
+劍
+羅
+壓
+棛
+饦
+尃
+璈
+囍
+醫
+G
+I
+A
+#
+N
+鷄
+髙
+嬰
+啓
+約
+隹
+潔
+賴
+藝
+~
+寶
+籣
+麺
+ 
+嶺
+√
+義
+網
+峩
+長
+∧
+魚
+機
+構
+②
+鳯
+偉
+L
+B
+㙟
+畵
+鴿
+'
+詩
+溝
+嚞
+屌
+藔
+佧
+玥
+蘭
+織
+1
+3
+9
+0
+7
+點
+砭
+鴨
+鋪
+銘
+廳
+弍
+‧
+創
+湯
+坶
+℃
+卩
+骝
+&
+烜
+荘
+當
+潤
+扞
+係
+懷
+碶
+钅
+蚨
+讠
+☆
+叢
+爲
+埗
+涫
+塗
+→
+楽
+現
+鯨
+愛
+瑪
+鈺
+忄
+悶
+藥
+飾
+樓
+視
+孬
+ㆍ
+燚
+苪
+師
+①
+丼
+锽
+│
+韓
+標
+兒
+閏
+匋
+張
+漢
+髪
+會
+閑
+檔
+習
+裝
+の
+峯
+菘
+輝
+雞
+釣
+億
+浐
+K
+O
+R
+8
+H
+E
+P
+T
+W
+D
+S
+C
+M
+F
+姌
+饹
+晞
+廰
+嵯
+鷹
+負
+飲
+絲
+冚
+楗
+澤
+綫
+區
+❋
+←
+質
+靑
+揚
+③
+滬
+統
+産
+協
+﹑
+乸
+畐
+經
+運
+際
+洺
+岽
+為
+粵
+諾
+崋
+豐
+碁
+V
+2
+6
+齋
+誠
+訂
+勑
+雙
+陳
+無
+泩
+媄
+夌
+刂
+i
+c
+t
+o
+r
+a
+嘢
+耄
+燴
+暃
+壽
+媽
+靈
+抻
+體
+唻
+冮
+甹
+鎮
+錦
+蜛
+蠄
+尓
+駕
+戀
+飬
+逹
+倫
+貴
+極
+寬
+磚
+嶪
+郎
+職
+|
+間
+n
+d
+剎
+伈
+課
+飛
+橋
+瘊
+№
+譜
+骓
+圗
+滘
+縣
+粿
+咅
+養
+濤
+彳
+%
+Ⅱ
+啰
+㴪
+見
+矞
+薬
+糁
+邨
+鲮
+顔
+罱
+選
+話
+贏
+氪
+俵
+競
+瑩
+繡
+枱
+綉
+獅
+爾
+™
+麵
+戋
+淩
+徳
+個
+劇
+場
+務
+簡
+寵
+h
+實
+膠
+轱
+圖
+築
+嘣
+樹
+㸃
+營
+耵
+孫
+饃
+鄺
+飯
+麯
+遠
+輸
+坫
+孃
+乚
+閃
+鏢
+㎡
+題
+廠
+關
+↑
+爺
+將
+軍
+連
+篦
+覌
+參
+箸
+-
+窠
+棽
+寕
+夀
+爰
+歐
+呙
+閥
+頡
+熱
+雎
+垟
+裟
+凬
+勁
+帑
+馕
+夆
+疌
+枼
+馮
+貨
+蒤
+樸
+彧
+旸
+靜
+龢
+暢
+㐱
+鳥
+珺
+鏡
+灡
+爭
+堷
+廚
+騰
+診
+┅
+蘇
+褔
+凱
+頂
+豕
+亞
+帥
+嘬
+⊥
+仺
+桖
+複
+饣
+絡
+穂
+顏
+棟
+納
+▏
+濟
+親
+設
+計
+攵
+埌
+烺
+頤
+燦
+蓮
+撻
+節
+講
+濱
+濃
+娽
+洳
+朿
+燈
+鈴
+護
+膚
+铔
+過
+補
+Z
+U
+5
+4
+坋
+闿
+䖝
+餘
+缐
+铞
+貿
+铪
+桼
+趙
+鍊
+[
+㐂
+垚
+菓
+揸
+捲
+鐘
+滏
+𣇉
+爍
+輪
+燜
+鴻
+鮮
+動
+鹞
+鷗
+丄
+慶
+鉌
+翥
+飮
+腸
+⇋
+漁
+覺
+來
+熘
+昴
+翏
+鲱
+圧
+鄉
+萭
+頔
+爐
+嫚
+貭
+類
+聯
+幛
+輕
+訓
+鑒
+夋
+锨
+芃
+珣
+䝉
+扙
+嵐
+銷
+處
+ㄱ
+語
+誘
+苝
+歸
+儀
+燒
+楿
+內
+粢
+葒
+奧
+麥
+礻
+滿
+蠔
+穵
+瞭
+態
+鱬
+榞
+硂
+鄭
+黃
+煙
+祐
+奓
+逺
+*
+瑄
+獲
+聞
+薦
+讀
+這
+樣
+決
+問
+啟
+們
+執
+説
+轉
+單
+隨
+唘
+帶
+倉
+庫
+還
+贈
+尙
+皺
+■
+餅
+產
+○
+∈
+報
+狀
+楓
+賠
+琯
+嗮
+禮
+`
+傳
+>
+≤
+嗞
+≥
+換
+咭
+∣
+↓
+曬
+応
+寫
+″
+終
+様
+純
+費
+療
+聨
+凍
+壐
+郵
+黒
+∫
+製
+塊
+調
+軽
+確
+撃
+級
+馴
+Ⅲ
+涇
+繹
+數
+碼
+證
+狒
+処
+劑
+<
+晧
+賀
+衆
+]
+櫥
+兩
+陰
+絶
+對
+鯉
+憶
+◎
+p
+e
+Y
+蕒
+煖
+頓
+測
+試
+鼽
+僑
+碩
+妝
+帯
+≈
+鐡
+舖
+權
+喫
+倆
+該
+悅
+俫
+.
+f
+s
+b
+m
+k
+g
+u
+j
+貼
+淨
+濕
+針
+適
+備
+l
+/
+給
+謢
+強
+觸
+衛
+與
+⊙
+$
+緯
+變
+⑴
+⑵
+⑶
+㎏
+殺
+∩
+幚
+─
+價
+▲
+離
+飄
+烏
+関
+閟
+﹝
+﹞
+邏
+輯
+鍵
+驗
+訣
+導
+歷
+屆
+層
+▼
+儱
+錄
+熳
+艦
+吋
+錶
+辧
+飼
+顯
+④
+禦
+販
+気
+対
+枰
+閩
+紀
+幹
+瞓
+貊
+淚
+△
+眞
+墊
+獻
+褲
+縫
+緑
+亜
+鉅
+餠
+{
+}
+◆
+蘆
+薈
+█
+◇
+溫
+彈
+晳
+粧
+犸
+穩
+訊
+崬
+凖
+熥
+舊
+條
+紋
+圍
+Ⅳ
+筆
+尷
+難
+雜
+錯
+綁
+識
+頰
+鎖
+艶
+□
+殁
+殼
+⑧
+├
+▕
+鵬
+糝
+綱
+▎
+盜
+饅
+醬
+籤
+蓋
+釀
+鹽
+據
+辦
+◥
+彐
+┌
+婦
+獸
+鲩
+伱
+蒟
+蒻
+齊
+袆
+腦
+寧
+凈
+妳
+煥
+詢
+偽
+謹
+啫
+鯽
+騷
+鱸
+損
+傷
+鎻
+髮
+買
+冏
+儥
+両
+﹢
+∞
+載
+喰
+z
+羙
+悵
+燙
+曉
+員
+組
+徹
+艷
+痠
+鋼
+鼙
+縮
+細
+嚒
+爯
+≠
+維
+"
+鱻
+壇
+厍
+帰
+浥
+犇
+薡
+軎
+應
+醜
+刪
+緻
+鶴
+賜
+噁
+軌
+尨
+镔
+鷺
+槗
+彌
+葚
+濛
+請
+溇
+緹
+賢
+訪
+獴
+瑅
+資
+縤
+陣
+蕟
+栢
+韻
+祼
+恁
+伢
+謝
+劃
+涑
+總
+衖
+踺
+砋
+凉
+籃
+駿
+苼
+瘋
+昽
+紡
+驊
+腎
+﹗
+響
+杋
+剛
+嚴
+禪
+歓
+槍
+傘
+檸
+檫
+炣
+勢
+鏜
+鎢
+銑
+尐
+減
+奪
+惡
+僮
+婭
+臘
+殻
+鉄
+∑
+蛲
+焼
+緖
+續
+紹
+懮
+⑤
+⑥
+⑦
+媪
+―
+韂
+⑨
+﹒
+⑩
+觽
+⑾
+⑿
+髃
+遽
+⒀
+骃
+⒁
+頉
+狎
+⒂
+曩
+苌
+⒃
+弒
+赀
+娡
+赟
+柰
+⒄
+愍
+畤
+菑
+蚡
+⒅
+鲧
+踰
+鬻
+笞
+阏
+⒆
+橐
+哙
+馔
+遑
+︰
+圉
+轸
+彘
+驺
+豨
+扃
+逡
+苻
+曪
+焻
+彀
+恚
+絷
+郄
+赍
+﹐
+薨
+雠
+鴈
+巿
+舁
+–
+聒
+蒯
+廪
+闼
+辔
+诳
+黥
+Ⅰ
+顼
+辇
+筮
+媵
+瞽
+缗
+徼
+鼋
+箧
+龁
+醮
+瘗
+礶
+繇
+
+檄
+僖
+卬
+爇
+髡
+儣
+驩
+━
+捽
+贽
+牝
+杼
+噫
+缯
+赧
+诮
+瘳
+獘
+篃
+絜
+杓
+溍
+笥
+鸱
+觥
+椟
+溲
+鞫
+猢
+笄
+翕
+嗥
+卺
+夡
+奭
+棂
+樗
+狲
+怙
+哂
+抟
+轵
+彊
+嬖
+僦
+裰
+舄
+拊
+旃
+俛
+瘥
+禳
+愆
+陬
+墀
+聩
+僊
+眛
+阍
+毐
+刭
+喾
+唿
+缑
+迨
+愦
+牖
+嚭
+邾
+悒
+殽
+斫
+兕
+铙
+镪
+踣
+胙
+臱
+骈
+旄
+豢
+帔
+僭
+忤
+棹
+诎
+氇
+獾
+殂
+倨
+詈
+頫
+掾
+鸩
+氆
+辎
+罴
+鄜
+珪
+曷
+膑
+牂
+捱
+怵
+怛
+觌
+舂
+廨
+怍
+欷
+汧
+鼍
+喟
+殓
+蓺
+奁
+鄗
+悝
+袴
+僇
+酹
+搒
+跽
+姁
+鞮
+纥
+梃
+卮
+肣
+湎
+揄
+迕
+汜
+髫
+炷
+汭
+挈
+蝄
+噙
+歔
+撺
+欤
+冑
+蹻
+鲠
+傒
+醦
+隰
+掼
+琖
+駆
+暲
+犒
+甑
+楫
+嫪
+裀
+贳
+劬
+龏
+酎
+逋
+眇
+佻
+幞
+鉏
+磔
+殄
+浞
+衽
+裾
+廛
+芈
+燔
+伛
+縠
+虮
+祓
+筰
+喁
+俦
+褫
+僰
+旻
+搢
+茕
+柈
+绖
+畑
+鳏
+溷
+楯
+祇
+怼
+褊
+╱
+缧
+齮
+蓐
+怿
+豳
+犴
+窋
+酆
+谶
+讙
+镬
+襦
+纮
+舐
+黙
+縯
+蹀
+枥
+豸
+揶
+闇
+焒
+匳
+髭
+鲰
+筴
+弁
+揆
+跸
+搠
+缞
+旒
+屣
+孱
+槁
+榼
+夤
+埶
+愠
+欻
+刽
+刎
+骖
+冁
+釂
+麤
+珰
+谮
+埒
+耎
+噉
+蟜
+秏
+呶
+悞
+猱
+镵
+鸮
+趺
+簏
+坼
+凫
+诂
+骀
+谲
+薮
+亶
+黾
+螫
+嶲
+茀
+蓍
+┐
+遘
+乩
+褴
+郈
+踽
+叵
+伋
+襆
+伧
+醳
+鄠
+圄
+楮
+迓
+锱
+腉
+纡
+疋
+愀
+滈
+杪
+椀
+懑
+劓
+囫
+脔
+巉
+缒
+蝼
+醢
+忝
+嗫
+勖
+噭
+猊
+儇
+觳
+缟
+郐
+剜
+徭
+愎
+魋
+殛
+篾
+躞
+纔
+粝
+穑
+钲
+徂
+﹖
+棓
+囵
+怫
+屦
+歘
+缱
+荦
+愬
+嗛
+铩
+馐
+媸
+て
+曛
+蹰
+窭
+亹
+駹
+嫜
+姞
+赇
+樭
+澙
+笮
+孀
+狻
+榇
+侪
+盍
+堙
+毶
+癀
+镞
+酤
+譄
+薜
+郿
+⒉
+埽
+阃
+遶
+酺
+辂
+鷪
+貋
+刳
+恫
+挹
+铳
+蒍
+孥
+纻
+旘
+耨
+翮
+洹
+坌
+捭
+睒
+轺
+崚
+仫
+庑
+邽
+麃
+縻
+
+瞋
+螭
+埤
+啁
+讦
+妁
+桞
+匏
+杌
+魑
+峇
+斄
+缶
+酩
+酢
+潏
+韪
+侔
+郪
+踔
+皁
+蜔
+魍
+祧
+粜
+晡
+蹩
+畎
+啱
+窳
+瞾
+
+
+
+舡
+葴
+耋
+鲐
+踧
+遫
+踟
+溊
+觜
+涒
+茔
+⒈
+谸
+跬
+浿
+な
+轘
+郇
+姮
+奡
+钤
+俅
+獬
+儆
+乂
+餍
+胾
+碛
+魭
+喑
+哏
+嶓
+俳
+蟭
+躅
+羖
+羑
+雩
+焜
+鸷
+箦
+铚
+缳
+酇
+罃
+罅
+庳
+褛
+罥
+蒺
+禨
+戕
+岬
+痍
+窴
+邠
+诨
+狁
+顒
+戆
+窎
+儙
+螾
+镕
+跣
+繻
+赜
+槃
+趄
+嬛
+睚
+跹
+壖
+戗
+沬
+畼
+嚋
+珮
+娀
+畀
+谇
+欃
+龂
+鲋
+鹆
+郕
+疴
+讧
+惇
+跂
+扢
+赪
+鈇
+釐
+槊
+寘
+暾
+莩
+钹
+犨
+刓
+逶
+澝
+嬃
+黡
+沕
+恝
+洟
+緃
+媢
+霣
+⒊
+慝
+炟
+皤
+囐
+
+瞀
+烝
+瓻
+醵
+殪
+樯
+缵
+伻
+玊
+觚
+踯
+噔
+忪
+峣
+搤
+嗾
+鞚
+巂
+蘧
+榱
+锾
+隳
+饟
+馎
+驵
+骘
+髀
+髑
+鮼
+鲔
+鹘
+鹚
+﹔
+刖
+啐
+嘭
+嚬
+嚰
+圯
+嫄
+寖
+嶶
+帇
+幤
+悫
+慙
+揜
+撝
+旳
+昃
+玕
+璆
+玃
+猃
+狃
+祊
+燹
+燠
+熛
+窣
+窬
+糌
+紬
+濩
+飧
+肸
+臬
+荜
+襜
+觖
+豭
+贇
+檠
+檇
+邘
+鄏
+鑙
+氅
+柢
+悭
+鄳
+蒗
+虺
+沇
+薤
+墠
+唶
+骍
+帨
+逖
+鹣
+臛
+鹖
+磛
+弢
+懜
+闟
+遹
+垝
+杅
+笤
+佈
+嚅
+蝮
+谳
+眢
+∵
+枵
+騳
+嗌
+玦
+嗄
+劙
+騠
+蚰
+趱
+珅
+洫
+颀
+趹
+蛩
+馓
+轫
+叡
+蒉
+睪
+漦
+胝
+瘐
+逦
+嶷
+傕
+斲
+瘵
+縢
+渖
+灊
+訇
+歃
+讵
+嫱
+狝
+脁
+堌
+塩
+茞
+嶋
+檑
+佺
+皞
+竩
+@
+暘
+訸
+ぴ
+亷
+皊
+澛
+酂
+壎
+戡
+橦
+嬿
+錡
+﹤
+柵
+蜾
+鉍
+玱
+Ⅴ
+虓
+钖
+Q
+缲
+鵾
+栌
+鞒
+锒
+樑
+赑
+泘
+垱
+貟
+崙
+尢
+沄
+廼
+鲖
+夼
+钌
+擖
+棻
+└
+菂
+淏
+湲
+晙
+鶄
+潆
+箅
+甡
+炘
+溦
+崑
+铓
+芏
+颋
+飏
+俣
+琲
+鎔
+誊
+秈
+筲
+耖
+柟
+玢
+_
+洑
+埇
+琤
+桯
+洧
+湜
+枧
+紘
+伭
+岺
+倥
+郃
+镫
+堉
+埸
+摺
+窺
+捩
+潁
+⒌
+愷
+氫
+卲
+铴
+霂
+阌
+韜
+玓
+茚
+⒚
+仂
+J
+冾
+钸
+褙
+硋
+龑
+蘋
+卻
+甁
+訚
+硊
+矬
+堨
+镙
+炤
+黉
+燀
+捃
+娒
+沚
+X
+増
+铗
+陸
+┕
+v
+轹
+垾
+苈
+絨
+鏮
+茳
+辻
+⑼
+仵
+澔
+胨
+∕
+燉
+浬
+鑑
+轳
+牮
+袷
+炻
+燊
+ⅰ
+霈
+垵
+裢
+倖
+貮
+瑀
+芨
+浉
+闶
+鳢
+砩
+铼
+鏐
+瑸
+筘
+濬
+钭
+範
+琭
+箨
+昫
+耩
+缡
+岵
+殳
+迺
+暎
+蔴
+轾
+糸
+塥
+馇
+圊
+睥
+鈊
+铫
+俤
+砟
+韨
+﹪
+鶯
+塅
+犄
+矼
+骉
+翃
+璠
+鋆
+牤
+湧
+劢
+瑱
+圬
+菉
+镡
+崟
+笪
+廘
+硐
+辚
+囝
+滹
+埏
+俙
+靭
+琇
+聶
+澴
+蘅
+褡
+笳
+桫
+烔
+磙
+诖
+倞
+鞥
+璘
+樘
+苧
+郉
+翀
+焩
+酴
+曌
+夿
+劼
+饫
+掛
+蔵
+枟
+鮦
+麴
+岠
+ⅲ
+焓
+缷
+駜
+漴
+舣
+蛱
+凃
+翚
+婥
+銛
+禇
+圪
+柃
+艽
+镆
+橺
+鞔
+舨
+ぁ
+蝥
+钫
+漈
+鄧
+玎
+洸
+蒨
+驖
+贶
+谠
+舳
+蒔
+僕
+棬
+す
+側
+锞
+頣
+琎
+锍
+鉮
+崾
+浛
+埝
+邙
+甍
+﹥
+硎
+菰
+蓁
+浯
+韡
+苳
+硭
+鎳
+翾
+鷇
+艋
+鹍
+禢
+埴
+昉
+桴
+査
+琍
+垐
+忭
+枨
+釭
+瘛
+淠
+漷
+泃
+蚈
+妉
+舾
+∮
+沺
+撖
+⒒
+菭
+奤
+犭
+竝
+骢
+湔
+锜
+嶝
+挏
+沨
+蕻
+朊
+苖
+枘
+梶
+逷
+ㄖ
+顸
+竑
+踅
+佾
+瑢
+鹁
+朣
+屺
+闩
+槠
+甦
+玠
+玭
+勍
+汎
+迴
+厐
+剀
+胓
+勔
+侉
+澥
+鼐
+嘏
+仉
+柽
+澂
+塚
+陔
+堽
+俢
+玙
+伲
+鋉

+ 139 - 0
config/cfg_det_db.py

@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/19 21:44
+# @Author  : xiangjing
+
+# ####################rec_train_options 参数说明##########################
+# 识别训练参数
+# base_lr:初始学习率
+# fine_tune_stage:
+#     if you want to freeze some stage, and tune the others.
+#     ['backbone', 'neck', 'head'], 所有参数都参与调优
+#     ['backbone'], 只调优backbone部分的参数
+#     后续更新: 1、添加bn层freeze的代码
+# optimizer 和 optimizer_step:
+#     优化器的配置, 成对
+#     example1: 'optimizer':['SGD'], 'optimizer_step':[],表示一直用SGD优化器
+#     example2:  'optimizer':['SGD', 'Adam'], 'optimizer_step':[160], 表示前[0,160)个epoch使用SGD优化器,
+#                [160,~]采用Adam优化器
+# lr_scheduler和lr_scheduler_info:
+#     学习率scheduler的设置
+# ckpt_save_type作用是选择模型保存的方式
+#      HighestAcc:只保存在验证集上精度最高的模型(还是在训练集上loss最小)
+#      FixedEpochStep: 按一定间隔保存模型
+###
+from addict import Dict
+
+config = Dict()
+config.exp_name = 'DBNet'
+config.train_options = {
+    # for train
+    'resume_from': '',  # 继续训练地址
+    'third_party_name': '',  # 加载paddle模型可选
+    'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint",  # 模型保存地址,log文件也保存在这里
+    'device': 'cuda:0',  # 不建议修改
+    'epochs': 1200,
+    'fine_tune_stage': ['backbone', 'neck', 'head'],
+    'print_interval': 1,  # step为单位
+    'val_interval': 1,  # epoch为单位
+    'ckpt_save_type': 'HighestAcc',  # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
+    'ckpt_save_epoch': 4,  # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
+}
+
+config.SEED = 927
+config.optimizer = {
+    'type': 'Adam',
+    'lr': 0.001,
+    'weight_decay': 1e-4,
+}
+# backbone设置为swin_transformer时使用
+# config.optimizer = {
+#     'type': 'AdamW',
+#     'lr': 0.0001,
+#     'betas': (0.9, 0.999),
+#     'weight_decay': 0.05,
+# }
+
+
+config.model = {
+    # backbone 可以设置'pretrained': False/True
+    'type': "DetModel",
+    'backbone': {"type": "ResNet", 'layers': 18, 'pretrained': True},  # ResNet or MobileNetV3
+    # 'backbone': {"type": "SwinTransformer", 'pretrained': True},#swin_transformer
+    # 'backbone': {"type": "ConvNeXt", 'pretrained': True},
+    'neck': {"type": 'DB_fpn', 'out_channels': 256},
+    'head': {"type": "DBHead"},
+    'in_channels': 3,
+}
+
+config.loss = {
+    'type': 'DBLoss',
+    'alpha': 1,
+    'beta': 10
+}
+
+config.post_process = {
+    'type': 'DBPostProcess',
+    'thresh': 0.3,  # 二值化输出map的阈值
+    'box_thresh': 0.7,  # 低于此阈值的box丢弃
+    'unclip_ratio': 1.5  # 扩大框的比例
+}
+
+# for dataset
+# ##lable文件
+### 存在问题,gt中str-->label 是放在loss中还是放在dataloader中
+config.dataset = {
+    'train': {
+        'dataset': {
+            'type': 'JsonDataset',
+            'file': r'train.json',
+            'mean': [0.485, 0.456, 0.406],
+            'std': [0.229, 0.224, 0.225],
+            # db 预处理,不需要修改
+            'pre_processes': [{'type': 'IaaAugment', 'args': [{'type': 'Fliplr', 'args': {'p': 0.5}},
+                                                              {'type': 'Affine', 'args': {'rotate': [-10, 10]}},
+                                                              {'type': 'Resize', 'args': {'size': [0.5, 3]}}]},
+                              {'type': 'EastRandomCropData', 'args': {'size': [640, 640], 'max_tries': 50, 'keep_ratio': True}},
+                              {'type': 'MakeShrinkMap', 'args': {'shrink_ratio': 0.4, 'min_text_size': 8}},
+                              {'type': 'MakeBorderMap', 'args': {'shrink_ratio': 0.4, 'thresh_min': 0.3, 'thresh_max': 0.7}}
+                              ],
+            'filter_keys': ['img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags', 'shape'],  # 需要从data_dict里过滤掉的key
+            'ignore_tags': ['*', '###'],
+            'img_mode': 'RGB'
+        },
+        'loader': {
+            'type': 'DataLoader',  # 使用torch dataloader只需要改为 DataLoader
+            'batch_size': 8,
+            'shuffle': True,
+            'num_workers': 1,
+            'collate_fn': {
+                'type': ''
+            }
+        }
+    },
+    'eval': {
+        'dataset': {
+            'type': 'JsonDataset',
+            'file': r'test.json',
+            'mean': [0.485, 0.456, 0.406],
+            'std': [0.229, 0.224, 0.225],
+            'pre_processes': [{'type': 'ResizeShortSize', 'args': {'short_size': 736, 'resize_text_polys': False}}],
+            'filter_keys': [],  # 需要从data_dict里过滤掉的key
+            'ignore_tags': ['*', '###'],
+            'img_mode': 'RGB'
+        },
+        'loader': {
+            'type': 'DataLoader',
+            'batch_size': 1,  # 必须为1
+            'shuffle': False,
+            'num_workers': 1,
+            'collate_fn': {
+                'type': 'DetCollectFN'
+            }
+        }
+    }
+}
+
+# 转换为 Dict
+for k, v in config.items():
+    if isinstance(v, dict):
+        config[k] = Dict(v)

+ 172 - 0
config/cfg_det_dis.py

@@ -0,0 +1,172 @@
+from addict import Dict
+
+config = Dict()
+config.exp_name = 'DBNet_icdar_distill'
+config.train_options = {
+    # for train
+    'resume_from': '',  # 继续训练地址
+    'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint",  # 模型保存地址,log文件也保存在这里
+    'device': 'cuda:0',  # 不建议修改
+    'epochs': 600,
+    'fine_tune_stage': ['backbone', 'neck', 'head'],
+    'print_interval': 5,  # step为单位
+    'val_interval': 1,  # epoch为单位
+    'ckpt_save_type': 'HighestAcc',  # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
+    'ckpt_save_epoch': 4,  # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
+}
+
+config.SEED = 927
+config.optimizer = {
+    'type': 'Adam',
+    'lr': 0.0002,
+    'weight_decay': 0,
+}
+
+config.model = {
+    'type': 'DistillationModel',
+    'algorithm': 'Distillation',
+    'init_weight': False,  # 当不使用任何预训练模型(子网络或任意子网络backbone)时打开
+    'models': {
+        'Teacher': {
+            'type': "DetModel",
+            'freeze_params': True,
+            'backbone': {"type": "ResNet", 'pretrained': False, 'layers': 18},
+            'neck': {"type": 'DB_fpn', 'out_channels': 256},
+            'head': {"type": "DBHead"},
+            'in_channels': 3,
+            'pretrained': '/path/to/your/workspace/work/PytorchOCR/models/dismodels/DBNet_icdar_res18_fast_pre.pth'
+        },
+        'Student': {
+            'type': "DetModel",
+            'freeze_params': False,
+            'backbone': {"type": "MobileNetV3", 'pretrained': False, 'disable_se': False},
+            'neck': {"type": 'DB_fpn', 'out_channels': 96},
+            'head': {"type": "DBHead"},
+            'in_channels': 3,
+            'pretrained': '/path/to/your/workspace/work/PytorchOCR/models/dismodels/mbv3.pth'
+        },
+        'Student2': {
+            'type': "DetModel",
+            'freeze_params': False,
+            'backbone': {"type": "MobileNetV3", 'pretrained': False, 'disable_se': False},
+            'neck': {"type": 'DB_fpn', 'out_channels': 96},
+            'head': {"type": "DBHead"},
+            'in_channels': 3,
+            'pretrained': '/path/to/your/workspace/work/PytorchOCR/models/dismodels/mbv3.pth'
+        }
+
+    }
+
+}
+
+config.loss = {
+    'type': 'CombinedLoss',
+    'combine_list': {
+        'DistillationDilaDBLoss': {
+            'weight': 2.0,
+            'model_name_pairs': [("Student", "Teacher"), ("Student2", "Teacher")],
+            # 'model_name_pairs': [("Student", "Teacher")],
+            'key': 'maps',
+            'balance_loss': True,
+            'main_loss_type': 'DiceLoss',
+            'alpha': 5,
+            'beta': 10,
+            'ohem_ratio': 3,
+        },
+
+        'DistillationDMLLoss': {
+            'maps_name': "thrink_maps",
+            'weight': 1.0,
+            'model_name_pairs': ["Student", "Student2"],
+            'key': 'maps'
+        },
+
+        'DistillationDBLoss': {
+            'weight': 1.0,
+            'model_name_list': ["Student"],
+            'balance_loss': True,
+            'main_loss_type': 'DiceLoss',
+            'alpha': 5,
+            'beta': 10,
+            'ohem_ratio': 3}
+
+    }
+
+}
+
+config.post_process = {
+    'type': 'DistillationDBPostProcess',
+    'model_name': ["Student", "Student2", "Teacher"],
+    # 'model_name': ["Student", "Teacher"],
+    'thresh': 0.3,  # 二值化输出map的阈值
+    'box_thresh': 0.5,  # 低于此阈值的box丢弃
+    'unclip_ratio': 1.5  # 扩大框的比例
+
+}
+
+config.metric = {
+    'name': 'DistillationMetric',
+    'base_metric_name': 'DetMetric',
+    'main_indicator': 'hmean',
+    'key': "Student"
+}
+
+# for dataset
+# ##lable文件
+### 存在问题,gt中str-->label 是放在loss中还是放在dataloader中
+config.dataset = {
+    'train': {
+        'dataset': {
+            'type': 'JsonDataset',
+            'file': r'/path/to/your/workspace/dataset/icdar15-detection/train.json',
+            'mean': [0.485, 0.456, 0.406],
+            'std': [0.229, 0.224, 0.225],
+            # db 预处理,不需要修改
+            'pre_processes': [{'type': 'IaaAugment', 'args': [{'type': 'Fliplr', 'args': {'p': 0.5}},
+                                                              {'type': 'Affine', 'args': {'rotate': [-10, 10]}},
+                                                              {'type': 'Resize', 'args': {'size': [0.5, 3]}}]},
+                              {'type': 'EastRandomCropData', 'args': {'size': [640, 640], 'max_tries': 50, 'keep_ratio': True}},
+                              {'type': 'MakeBorderMap', 'args': {'shrink_ratio': 0.4, 'thresh_min': 0.3, 'thresh_max': 0.7}},
+                              {'type': 'MakeShrinkMap', 'args': {'shrink_ratio': 0.4, 'min_text_size': 8}}],
+            'filter_keys': ['img_name', 'text_polys', 'texts', 'ignore_tags', 'shape'],
+            # 需要从data_dict里过滤掉的key
+            'ignore_tags': ['*', '###', ' '],
+            'img_mode': 'RGB'
+        },
+        'loader': {
+            'type': 'DataLoader',  # 使用torch dataloader只需要改为 DataLoader
+            'batch_size': 20,
+            'shuffle': True,
+            'num_workers': 20,
+            'collate_fn': {
+                'type': ''
+            }
+        }
+    },
+    'eval': {
+        'dataset': {
+            'type': 'JsonDataset',
+            'file': r'/path/to/your/workspace/dataset/icdar15-detection/test.json',
+            'mean': [0.485, 0.456, 0.406],
+            'std': [0.229, 0.224, 0.225],
+            'pre_processes': [{'type': 'ResizeShortSize', 'args': {'short_size': 736, 'resize_text_polys': False}}],
+            'filter_keys': [],  # 需要从data_dict里过滤掉的key
+            'ignore_tags': ['*', '###', ' '],
+            'img_mode': 'RGB'
+        },
+        'loader': {
+            'type': 'DataLoader',
+            'batch_size': 1,  # 必须为1
+            'shuffle': False,
+            'num_workers': 10,
+            'collate_fn': {
+                'type': 'DetCollectFN'
+            }
+        }
+    }
+}
+
+# 转换为 Dict
+for k, v in config.items():
+    if isinstance(v, dict):
+        config[k] = Dict(v)

+ 227 - 0
config/cfg_det_pse.py

@@ -0,0 +1,227 @@
+# encoding: utf-8
+"""
+@time: 2021/3/6 19:48
+@author: Bourne-M
+"""
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/19 21:44
+# @Author  : xiangjing
+
+# ####################rec_train_options 参数说明##########################
+# 识别训练参数
+# base_lr:初始学习率
+# fine_tune_stage:
+#     if you want to freeze some stage, and tune the others.
+#     ['backbone', 'neck', 'head'], 所有参数都参与调优
+#     ['backbone'], 只调优backbone部分的参数
+#     后续更新: 1、添加bn层freeze的代码
+# optimizer 和 optimizer_step:
+#     优化器的配置, 成对
+#     example1: 'optimizer':['SGD'], 'optimizer_step':[],表示一直用SGD优化器
+#     example2:  'optimizer':['SGD', 'Adam'], 'optimizer_step':[160], 表示前[0,160)个epoch使用SGD优化器,
+#                [160,~]采用Adam优化器
+# lr_scheduler和lr_scheduler_info:
+#     学习率scheduler的设置
+# ckpt_save_type作用是选择模型保存的方式
+#      HighestAcc:只保存在验证集上精度最高的模型(还是在训练集上loss最小)
+#      FixedEpochStep: 按一定间隔保存模型
+###
+# from addict import Dict
+#
+# config = Dict()
+# config.exp_name = 'DBNet_res18_init'
+# config.train_options = {
+#     # for train
+#     'resume_from': '',  # 继续训练地址
+#     'third_party_name': '',  # 加载paddle模型可选
+#     'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint",  # 模型保存地址,log文件也保存在这里
+#     'device': 'cuda:0',  # 不建议修改
+#     'epochs': 1200,
+#     'fine_tune_stage': ['backbone', 'neck', 'head'],
+#     'print_interval': 32,  # step为单位
+#     'val_interval': 10,  # epoch为单位
+#     'ckpt_save_type': 'HighestAcc',  # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
+#     'ckpt_save_epoch': 4,  # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
+# }
+#
+# config.SEED = 927
+# config.optimizer = {
+#     'type': 'Adam',
+#     'lr': 0.001,
+#     'weight_decay': 1e-4,
+# }
+#
+# config.model = {
+#     'type': "DetModel",
+#     'backbone': {"type": "ResNet", 'layers': 18, 'pretrained': True},  # ResNet or MobileNetV3
+#     'neck': {"type": 'DB_fpn', 'out_channels': 256},
+#     'head': {"type": "DBHead"},
+#     'in_channels': 3,
+# }
+#
+# config.loss = {
+#     'type': 'DBLoss',
+#     'alpha': 1,
+#     'beta': 10
+# }
+#
+# config.post_process = {
+#     'type': 'DBPostProcess',
+#     'thresh': 0.3,  # 二值化输出map的阈值
+#     'box_thresh': 0.7,  # 低于此阈值的box丢弃
+#     'unclip_ratio': 1.5  # 扩大框的比例
+# }
+# # for dataset
+# # ##lable文件
+# ### 存在问题,gt中str-->label 是放在loss中还是放在dataloader中
+# config.dataset = {
+#     'train': {
+#         'dataset': {
+#             'type': 'JsonDataset',
+#             'file': r'/home/zhouyufei/Work/DataSet/icdar2015/detection/train.json',
+#             'mean': [0.485, 0.456, 0.406],
+#             'std': [0.229, 0.224, 0.225],
+#             # db 预处理,不需要修改
+#             'pre_processes': [{'type': 'IaaAugment', 'args': [{'type': 'Fliplr', 'args': {'p': 0.5}},
+#                                                               {'type': 'Affine', 'args': {'rotate': [-10, 10]}},
+#                                                               {'type': 'Resize', 'args': {'size': [0.5, 3]}}]},
+#                               {'type': 'EastRandomCropData', 'args': {'size': [640, 640], 'max_tries': 50, 'keep_ratio': True}},
+#                               {'type': 'MakeBorderMap', 'args': {'shrink_ratio': 0.4, 'thresh_min': 0.3, 'thresh_max': 0.7}},
+#                               {'type': 'MakeShrinkMap', 'args': {'shrink_ratio': 0.4, 'min_text_size': 8}}],
+#             'filter_keys': ['img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags', 'shape'],  # 需要从data_dict里过滤掉的key
+#             'ignore_tags': ['*', '###'],
+#             'img_mode': 'RGB'
+#         },
+#         'loader': {
+#             'type': 'DataLoader',  # 使用torch dataloader只需要改为 DataLoader
+#             'batch_size': 32,
+#             'shuffle': True,
+#             'num_workers': 30,
+#             'collate_fn': {
+#                 'type': ''
+#             }
+#         }
+#     },
+#     'eval': {
+#         'dataset': {
+#             'type': 'JsonDataset',
+#             'file': r'/home/zhouyufei/Work/DataSet/icdar2015/detection/test.json',
+#             'mean': [0.485, 0.456, 0.406],
+#             'std': [0.229, 0.224, 0.225],
+#             'pre_processes': [{'type': 'ResizeShortSize', 'args': {'short_size': 736, 'resize_text_polys': False}}],
+#             'filter_keys': [],  # 需要从data_dict里过滤掉的key
+#             'ignore_tags': ['*', '###'],
+#             'img_mode': 'RGB'
+#         },
+#         'loader': {
+#             'type': 'DataLoader',
+#             'batch_size': 1,  # 必须为1
+#             'shuffle': False,
+#             'num_workers': 20,
+#             'collate_fn': {
+#                 'type': 'DetCollectFN'
+#             }
+#         }
+#     }
+# }
+#
+# # 转换为 Dict
+# for k, v in config.items():
+#     if isinstance(v, dict):
+#         config[k] = Dict(v)
+
+from addict import Dict
+config = Dict()
+config.exp_name = 'psenet_mbv3'
+config.train_options = {
+    # for train
+    'resume_from': '',  # 继续训练地址
+    'third_party_name': '',  # 加载paddle模型可选
+    'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint",  # 模型保存地址,log文件也保存在这里
+    'device': 'cuda:0',  # 不建议修改
+    'epochs': 1200,
+    'fine_tune_stage': ['backbone', 'neck', 'head'],
+    'print_interval': 20,  # step为单位
+    'val_interval': 1,  # epoch为单位
+    'ckpt_save_type': 'HighestAcc',  # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
+    'ckpt_save_epoch': 4,  # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
+}
+
+config.SEED = 927
+config.optimizer = {
+    'type': 'Adam',
+    'lr': 0.001,
+    'weight_decay': 0,
+}
+
+config.model = {
+    'type': "DetModel",
+    'backbone': {"type": "MobileNetV3", 'pretrained': True},  # ResNet or MobileNetV3
+    'neck': {"type": 'pse_fpn', 'out_channels': 256},
+    'head': {"type": "PseHead"},
+    'in_channels': 3,
+}
+
+config.loss = {
+    'type': 'PSELoss',
+    'Lambda': 0.7
+}
+
+config.post_process = {
+    'type': 'pse_postprocess'
+}
+
+# for dataset
+# ##lable文件
+### 存在问题,gt中str-->label 是放在loss中还是放在dataloader中
+config.dataset = {
+    'train': {
+        'dataset': {
+            'type': 'MyDataset',
+            'file': r'/DataSet/icdar2015/detection/train.json',
+            'data_shape':640,
+            'n':6,
+            'm':0.5,
+            'mean': [0.485, 0.456, 0.406],
+            'std': [0.229, 0.224, 0.225],
+
+            'filter_keys': ['text_polys', 'ignore_tags', 'shape','texts'],  # 需要从data_dict里过滤掉的key
+            'ignore_tags': ['*', '###'],
+            'img_mode': 'RGB'
+        },
+        'loader': {
+            'type': 'DataLoader',  # 使用torch dataloader只需要改为 DataLoader
+            'batch_size': 20,
+            'shuffle': True,
+            'num_workers': 20
+
+        }
+    },
+    'eval': {
+        'dataset': {
+            'type': 'MyDataset',
+            'file': r'/DataSet/icdar2015/detection/test.json',
+            'mean': [0.485, 0.456, 0.406],
+            'std': [0.229, 0.224, 0.225],
+            'n':6,
+            'm':0.5,
+            'data_shape':640,
+            'filter_keys': ['score_maps','training_mask'],  # 需要从data_dict里过滤掉的key
+            'ignore_tags': ['*', '###'],
+            'img_mode': 'RGB'
+        },
+        'loader': {
+            'type': 'DataLoader',
+            'batch_size': 1,  # 必须为1
+            'shuffle': False,
+            'num_workers': 10
+        }
+    }
+}
+
+# 转换为 Dict
+for k, v in config.items():
+    if isinstance(v, dict):
+        config[k] = Dict(v)
+
+

+ 124 - 0
config/cfg_rec_crnn.py

@@ -0,0 +1,124 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/19 21:44
+# @Author  : xiangjing
+
+# ####################rec_train_options 参数说明##########################
+# 识别训练参数
+# base_lr:初始学习率
+# fine_tune_stage:
+#     if you want to freeze some stage, and tune the others.
+#     ['backbone', 'neck', 'head'], 所有参数都参与调优
+#     ['backbone'], 只调优backbone部分的参数
+#     后续更新: 1、添加bn层freeze的代码
+# optimizer 和 optimizer_step:
+#     优化器的配置, 成对
+#     example1: 'optimizer':['SGD'], 'optimizer_step':[],表示一直用SGD优化器
+#     example2:  'optimizer':['SGD', 'Adam'], 'optimizer_step':[160], 表示前[0,160)个epoch使用SGD优化器,
+#                [160,~]采用Adam优化器
+# lr_scheduler和lr_scheduler_info:
+#     学习率scheduler的设置
+# ckpt_save_type作用是选择模型保存的方式
+#      HighestAcc:只保存在验证集上精度最高的模型(还是在训练集上loss最小)
+#      FixedEpochStep: 按一定间隔保存模型
+###
+from addict import Dict
+
+config = Dict()
+config.exp_name = 'CRNN'
+config.train_options = {
+    # for train
+    'resume_from': '',  # 继续训练地址
+    'third_party_name': '',  # 加载paddle模型可选
+    'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint",  # 模型保存地址,log文件也保存在这里
+    'device': 'cuda:0',  # 不建议修改
+    'epochs': 20,
+    'fine_tune_stage': ['backbone', 'neck', 'head'],
+    'print_interval': 10,  # step为单位
+    'val_interval': 300,  # step为单位
+    'ckpt_save_type': 'HighestAcc',  # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
+    'ckpt_save_epoch': 4,  # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
+}
+
+config.SEED = 927
+config.optimizer = {
+    'type': 'Adam',
+    'lr': 0.001,
+    'weight_decay': 1e-4,
+}
+
+config.lr_scheduler = {
+    'type': 'StepLR',
+    'step_size': 60,
+    'gamma': 0.5
+}
+config.model = {
+    # backbone 可以设置'pretrained': False/True
+    'type': "RecModel",
+
+    # 'backbone': {"type": "ResNet", 'layers': 34},
+    # 'neck': {"type": 'PPaddleRNN',"hidden_size": 256},
+    # 'head': {"type": "CTC", 'n_class': 5990},
+    # 'in_channels': 3,
+
+    'backbone': {"type": "MobileNetV3", 'model_name': 'small'},
+    'neck': {"type": 'PPaddleRNN', "hidden_size": 48},
+    'head': {"type": "CTC", 'n_class': 5990},
+    'in_channels': 3,
+}
+
+config.loss = {
+    'type': 'CTCLoss',
+    'blank_idx': 0,
+}
+
+# for dataset
+# ##lable文件
+### 存在问题,gt中str-->label 是放在loss中还是放在dataloader中
+config.dataset = {
+    'alphabet': r'/data2/znj/CRNN_Chinese_Characters_Rec/lib/dataset/txt/char_std_5990.txt',
+    'train': {
+        'dataset': {
+            'type': 'RecTextLineDataset',
+            'file': r'/data2/znj/CRNN_Chinese_Characters_Rec/lib/dataset/txt/train.txt',
+            'input_h': 32,
+            'mean': 0.5,
+            'std': 0.5,
+            'augmentation': False,
+        },
+        'loader': {
+            'type': 'DataLoader',  # 使用torch dataloader只需要改为 DataLoader
+            'batch_size': 16,
+            'shuffle': True,
+            'num_workers': 3,
+            'collate_fn': {
+                'type': 'RecCollateFn',
+                'img_w': 320
+            }
+        }
+    },
+    'eval': {
+        'dataset': {
+            'type': 'RecTextLineDataset',
+            'file': r'/data2/znj/CRNN_Chinese_Characters_Rec/lib/dataset/txt/test.txt',
+            'input_h': 32,
+            'mean': 0.5,
+            'std': 0.5,
+            'augmentation': False,
+        },
+        'loader': {
+            'type': 'RecDataLoader',
+            'batch_size': 32,
+            'shuffle': False,
+            'num_workers': 2,
+            'collate_fn': {
+                'type': 'RecCollateFn',
+                'img_w': 320
+            }
+        }
+    }
+}
+
+# 转换为 Dict
+for k, v in config.items():
+    if isinstance(v, dict):
+        config[k] = Dict(v)

+ 113 - 0
config/cfg_rec_crnn_lmdb.py

@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/19 21:44
+# @Author  : xiangjing
+
+# ####################rec_train_options 参数说明##########################
+# 识别训练参数
+# base_lr:初始学习率
+# fine_tune_stage:
+#     if you want to freeze some stage, and tune the others.
+#     ['backbone', 'neck', 'head'], 所有参数都参与调优
+#     ['backbone'], 只调优backbone部分的参数
+#     后续更新: 1、添加bn层freeze的代码
+# optimizer 和 optimizer_step:
+#     优化器的配置, 成对
+#     example1: 'optimizer':['SGD'], 'optimizer_step':[],表示一直用SGD优化器
+#     example2:  'optimizer':['SGD', 'Adam'], 'optimizer_step':[160], 表示前[0,160)个epoch使用SGD优化器,
+#                [160,~]采用Adam优化器
+# lr_scheduler和lr_scheduler_info:
+#     学习率scheduler的设置
+# ckpt_save_type作用是选择模型保存的方式
+#      HighestAcc:只保存在验证集上精度最高的模型(还是在训练集上loss最小)
+#      FixedEpochStep: 按一定间隔保存模型
+###
+from addict import Dict
+
+config = Dict()
+config.exp_name = 'CRNN'
+config.train_options = {
+    # for train
+    'resume_from': '',  # 继续训练地址
+    'third_party_name': '',  # 加载paddle模型可选
+    'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint",  # 模型保存地址,log文件也保存在这里
+    'device': 'cuda:0',# 不建议修改
+    'epochs': 200,
+    'fine_tune_stage': ['backbone', 'neck', 'head'],
+    'print_interval': 10,  # step为单位
+    'val_interval': 625,  # step为单位
+    'ckpt_save_type': 'HighestAcc',  # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
+    'ckpt_save_epoch': 4,  # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
+}
+
+config.SEED = 927
+config.optimizer = {
+    'type': 'Adam',
+    'lr': 0.001,
+    'weight_decay': 1e-4,
+}
+
+config.lr_scheduler = {
+    'type': 'StepLR',
+    'step_size': 60,
+    'gamma': 0.5
+}
+config.model = {
+    'type': "RecModel",
+    'backbone': {"type": "ResNet", 'layers': 18},
+    'neck': {"type": 'PPaddleRNN'},
+    'head': {"type": "CTC", 'n_class': 11},
+    'in_channels': 3,
+}
+
+config.loss = {
+    'type': 'CTCLoss',
+    'blank_idx': 0,
+}
+
+# for dataset
+# ##lable文件
+### 存在问题,gt中str-->label 是放在loss中还是放在dataloader中
+config.dataset = {
+    'alphabet': r'torchocr/datasets/alphabets/digit.txt',
+    'train': {
+        'dataset': {
+            'type': 'RecLmdbDataset',
+            'file': r'path/lmdb/train',  # LMDB 数据集路径
+            'input_h': 32,
+            'mean': 0.5,
+            'std': 0.5,
+            'augmentation': False,
+        },
+        'loader': {
+            'type': 'DataLoader',  # 使用torch dataloader只需要改为 DataLoader
+            'batch_size': 16,
+            'shuffle': True,
+            'num_workers': 1,
+            'collate_fn': {
+                'type': 'RecCollateFn',
+                'img_w': 120
+            }
+        }
+    },
+    'eval': {
+        'dataset': {
+            'type': 'RecLmdbDataset',
+            'file': r'path/lmdb/eval',  # LMDB 数据集路径
+            'input_h': 32,
+            'mean': 0.5,
+            'std': 0.5,
+            'augmentation': False,
+        },
+        'loader': {
+            'type': 'RecDataLoader',
+            'batch_size': 4,
+            'shuffle': False,
+            'num_workers': 1,
+        }
+    }
+}
+
+# 转换为 Dict
+for k, v in config.items():
+    if isinstance(v, dict):
+        config[k] = Dict(v)

+ 123 - 0
config/cfg_rec_crnn_test1.py

@@ -0,0 +1,123 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/19 21:44
+# @Author  : xiangjing
+
+# ####################rec_train_options 参数说明##########################
+# 识别训练参数
+# base_lr:初始学习率
+# fine_tune_stage:
+#     if you want to freeze some stage, and tune the others.
+#     ['backbone', 'neck', 'head'], 所有参数都参与调优
+#     ['backbone'], 只调优backbone部分的参数
+#     后续更新: 1、添加bn层freeze的代码
+# optimizer 和 optimizer_step:
+#     优化器的配置, 成对
+#     example1: 'optimizer':['SGD'], 'optimizer_step':[],表示一直用SGD优化器
+#     example2:  'optimizer':['SGD', 'Adam'], 'optimizer_step':[160], 表示前[0,160)个epoch使用SGD优化器,
+#                [160,~]采用Adam优化器
+# lr_scheduler和lr_scheduler_info:
+#     学习率scheduler的设置
+# ckpt_save_type作用是选择模型保存的方式
+#      HighestAcc:只保存在验证集上精度最高的模型(还是在训练集上loss最小)
+#      FixedEpochStep: 按一定间隔保存模型
+###
+from addict import Dict
+
+config = Dict()
+config.exp_name = 'CRNN'
+config.train_options = {
+    # for train
+    'resume_from': 'C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\best.pth',  # 继续训练地址
+    'third_party_name': '',  # 加载paddle模型可选
+    'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint_resnet3",  # 模型保存地址,log文件也保存在这里
+    'device': 'cuda:0',  # 不建议修改
+    'epochs': 20,
+    'fine_tune_stage': ['backbone', 'neck', 'head'],
+    'print_interval': 500,  # step为单位
+    'val_interval': 100000,  # step为单位
+    'ckpt_save_type': 'HighestAcc',  # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
+    'ckpt_save_epoch': 4,  # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
+}
+
+config.SEED = 927
+config.optimizer = {
+    'type': 'Adam',
+    'lr': 0.0001,
+    'weight_decay': 1e-4,
+}
+
+config.lr_scheduler = {
+    'type': 'StepLR',
+    'step_size': 2,
+    'gamma': 0.5
+}
+config.model = {
+    # backbone 可以设置'pretrained': False/True
+    'type': "RecModel",
+    # 'backbone': {"type": "ResNet", 'layers': 34},
+    # 'neck': {"type": 'PPaddleRNN',"hidden_size": 256},
+    # 'head': {"type": "CTC", 'n_class': 93},
+    # 'in_channels': 3,
+
+    'backbone': {"type": "ResNet", 'layers':34},
+    'neck': {"type": 'PPaddleRNN', "hidden_size": 256},
+    'head': {"type": "CTC", 'n_class': 7782},
+    'in_channels': 3,
+}
+
+config.loss = {
+    'type': 'CTCLoss',
+    'blank_idx': 0,
+}
+
+# for dataset
+# ##lable文件
+### 存在问题,gt中str-->label 是放在loss中还是放在dataloader中
+config.dataset = {
+    'alphabet': r'C:/Users/Administrator/Desktop/OCR_pytorch/PytorchOCR1/char_std_7782.txt',
+    'train': {
+        'dataset': {
+            'type': 'RecTextLineDataset',
+            'file': r'/data2/znj/CRNN_Chinese_Characters_Rec/lib/dataset/txt/train.txt',
+            'input_h': 32,
+            'mean': 0.5,
+            'std': 0.5,
+            'augmentation': False,
+        },
+        'loader': {
+            'type': 'DataLoader',  # 使用torch dataloader只需要改为 DataLoader
+            'batch_size': 16,
+            'shuffle': True,
+            'num_workers': 3,
+            'collate_fn': {
+                'type': 'RecCollateFn',
+                'img_w': 640
+            }
+        }
+    },
+    'eval': {
+        'dataset': {
+            'type': 'RecTextLineDataset',
+            'file': r'C:/Users/Administrator/Desktop/OCR_pytorch/PytorchOCR1/test_image/test.txt',
+            'input_h': 32,
+            'mean': 0.5,
+            'std': 0.5,
+            'augmentation': False,
+        },
+        'loader': {
+            'type': 'RecDataLoader',
+            'batch_size': 2,
+            'shuffle': False,
+            'num_workers': 2,
+            'collate_fn': {
+                'type': 'RecCollateFn',
+                'img_w': 640
+            }
+        }
+    }
+}
+
+# 转换为 Dict
+for k, v in config.items():
+    if isinstance(v, dict):
+        config[k] = Dict(v)

BIN
doc/imgs/exampl1.png


BIN
doc/imgs/exampl2.png


+ 14 - 0
doc/检测+识别推理.md

@@ -0,0 +1,14 @@
+## 检测和识别
+
+1. 检测和识别
+```shell script
+CUDA_VISIBLE_DEVICES=0  --det_path "" --rec_path ""     --img_path  "" 
+```
+2.模型运行时间分析
+```shell script
+CUDA_VISIBLE_DEVICES=0  --det_path "" --rec_path ""     --img_path  "" -time_profile
+```
+3.模型运行内存分析
+```shell script
+CUDA_VISIBLE_DEVICES=0  --det_path "" --rec_path ""     --img_path  "" -mem_profile
+```

+ 62 - 0
doc/检测.md

@@ -0,0 +1,62 @@
+## 文字检测
+
+### 数据准备
+
+PytorchOCR的检测模块只支持`JsonDataset` 形式的数据格式
+
+* 构造数据集
+   `JsonDataset` 使用 json 格式来存储标注信息,具体格式为
+   ```json
+    "data_root": "存放图片文件的目录",
+    "data_list": [
+          {
+            "img_name": "relative/path/xxx.jpg",# 图片相对于 data_root 的相对路径
+            "annotations": [ # 当前图片的所有标注
+                "polygon": [[x1,y1],[x2,y2],...,[xn,yn]], # 文本框的多点标注
+                "text": "label", # 文本框内容
+                "illegibility":false, # 是否模糊
+                "language":"Latin", # 文本语言类型
+                "chars": [  # 当前文本框的字符集标注,标注含义同上
+                      "polygon": [[x1,y1],[x2,y2],...,[xn,yn]],
+                      "char": "c",
+                      "illegibility": false,
+                      "language":"Latin"
+                ]
+             ]
+          } 
+      ]
+    ```
+    我们提供了 [转换工具](https://github.com/WenmuZhou/OCR_DataSet/tree/master/convert/det)  以方便开发者将现有的公开数据集进行转换
+
+    除此之外,我们也为你准备了一批准换好的公开数据集,具体请参考 https://github.com/WenmuZhou/OCR_DataSet
+
+### 启动训练
+PytorchOCR提供了训练脚本和预测脚本,本节将以 DB 检测模型为例,按照如下步骤启动训练:
+1. 从百度网盘`pytorchocr/det/imagenet`下载预训练模型并放于`PytorchOCR/weights`目录下
+2. 拷贝`config/det_train_db_config.py` 为自己的配置文件
+3. 修改配置文件
+    * 必须修改的字段说明
+     1. `config.dataset.train.dataset.file`: 训练集json文件路径
+     2. `config.dataset.eval.dataset.file`: 验证集集json文件路径
+    
+    * 可选修改字段说明
+    1. `config.train_options.checkpoint_save_dir`: 模型和日志文件保存地址
+    
+    其他字段可根据需要修改
+    
+4. 通过如下命令启动训练
+```shell script
+CUDA_VISIBLE_DEVICES=0 python3 tools/det_train.py --config '你的配置文件路径'
+```
+
+PytorchOCR支持训练和评估交替进行, 可以在 `config.train_options`中修改 `val_interval` 设置评估频率,
+评估过程中默认将最佳hmean模型,保存为 best.pth。
+
+### 恢复训练
+只需修改 `config.train_options.resume_from` 为模型地址,即可从该模型断掉的地方继续训练
+
+### 预测
+通过以下命令启动预测
+```shell script
+CUDA_VISIBLE_DEVICES=0 python3 tools/det_infer.py --model_path '' --img_path ''
+```

+ 79 - 0
doc/添加新算法.md

@@ -0,0 +1,79 @@
+## 添加新算法
+
+### Dataset
+
+#### 检测算法
+不同的检测算法会有不同的图片预处理和label制作方式,添加新dataset的步骤如下
+1. 在`torchocr/datasets/det_modules`下添加算法的图片预处理和label制作方式,
+每个处理步骤(module)用一个文件存储,module的形式如下
+```python
+class ModuleName:
+    def __init__(self, *args,**kwargs):
+        pass
+    def __call__(self, data: dict) -> dict:
+        im = data['img']
+        text_polys = data['text_polys']
+        # 执行你的处理
+        data['img'] = im
+        data['text_polys'] = text_polys
+        return data
+```
+算法的所有处理步骤由不同的module顺序执行而成,在config文件中按照列表的形式组合并执行。如:
+```python
+'pre_processes': [{'type': 'IaaAugment', 'args': [{'type': 'Fliplr', 'args': {'p': 0.5}},
+                                                  {'type': 'Affine', 'args': {'rotate': [-10, 10]}},
+                                                  {'type': 'Resize', 'args': {'size': [0.5, 3]}}]},
+                  {'type': 'EastRandomCropData', 'args': {'size': [640, 640], 'max_tries': 50, 'keep_ratio': True}},
+                  {'type': 'MakeBorderMap', 'args': {'shrink_ratio': 0.4, 'thresh_min': 0.3, 'thresh_max': 0.7}},
+                  {'type': 'MakeShrinkMap', 'args': {'shrink_ratio': 0.4, 'min_text_size': 8}}]
+```
+
+#### 识别算法
+对于attention和ctc系列算法,我们已经提供了内置的dataset,其他类型的需要在`torchocr/datasets/RecDataSet.py`
+文件里添加一个dataset并在config文件中使用
+
+### 网络
+PytorchOCR将网络划分为三部分
+* backbone: 从图片中提取特征,如Resnet,MobileNetV3
+* neck: 对backbone输出的特征进行强化,如FPN,CRNN的RNN部分
+* head: 在neck输出特征的基础上进行完成算法的输出
+`backbone`和`neck`均需要`out_channels`属性以便后续组件构造网络。
+若PytorchOCR已提供的组件中没有算法所需组件,就需要在对应的文件夹内实现新组件,一个文件夹存放一个组件,
+然后将新组建在`torchocr/networks/architectures/DetModel.py`或`torchocr/networks/architectures/RecModel.py`进行导入并添加到对应的dict
+
+各组件对应文件如下: 
+* backbone: `torchocr/networks/backbones`
+* necks: `torchocr/networks/necks`
+* heads: `torchocr/networks/heads`
+
+### 损失函数
+损失函数的存文件夹为`torchocr/networks/losses`,损失函数的输出应该是一个dict,格式如下
+```python
+{
+    'loss':loss_value, # 总的loss,由l1,l2,l3,...,ln加权组成
+    '其他的loss': value # 组成总loss的子loss
+}
+```
+loss module 的形式如下
+```python
+class ModuleName(nn.Module):
+    def __init__(self, *args,**kwargs):
+        pass
+
+    def forward(self, pred, batch):
+        """
+
+        :param pred:
+        :param batch: bach为一个dict{
+                                    '其他计算loss所需的输入':'vaue'
+                                    }
+        :return:
+        """
+        # 计算loss
+        loss_dict = {'loss':loss,'other_sub_loss':value}
+        return loss_dict
+```
+
+### 配置文件
+
+将配置文件里的对应地方换成新增的组件,那么新的网络就添加完成了,在测试性能无误后就可推送到PytorchOCR仓库

BIN
doc/田氏颜体大字库2.0.ttf


+ 68 - 0
doc/识别.md

@@ -0,0 +1,68 @@
+## 文字识别
+
+### 数据准备
+
+PytorchOCR的识别模块支持`TextLine` 和 `LMDB` 形式的数据格式
+
+* 构造数据集
+    *  `TextLine`
+    
+    训练集和验证集的格式一致,准备一个txt文件,里面每一行记录了图片路径和对应的标注,使用`\t`作为分隔符
+    如用其他方式分割将造成训练报错
+    
+    ```shell script
+    " 图像文件名                 图像标注信息 "
+    
+    train_data/train_0001.jpg   简单可依赖
+    train_data/train_0002.jpg   用科技让复杂的世界更简单
+    ```
+  
+    * `LMDB`
+    
+    准备好 `TextLine` 的txt文件之后,使用 [转换工具](../tools/create_rec_lmdb_dataset.py) 可以生成 `LMDB` 格式的数据集
+
+我们也提供准备好一批转换好的训练数据,具体请参考 https://github.com/WenmuZhou/OCR_DataSet
+    
+* 字典
+数据集准备完成之后,需要提前准备好一个包含训练验证集里全部字符的字典,应为如下格式,utf8编码保存
+```shell script
+0
+1
+2
+3
+4
+5
+```
+
+### 启动训练
+PytorchOCR提供了训练脚本和预测脚本,本节将以 CRNN 识别模型为例,按照如下步骤启动训练:
+1. 下载预训练模型(待提供)
+2. 拷贝`config/rec_train_config.py` 为自己的配置文件,LMDB 数据集使用 `config/rec_train_lmdb_config.py`
+3. 修改配置文件
+    * 必须修改的字段说明
+     1. `config.dataset.alphabet`: 字典文件
+     2. `config.dataset.train.dataset.file`: 训练集txt文件 or LMDB 数据集路径
+     3. `config.dataset.eval.dataset.file`: 验证集集txt文件 or LMDB 数据集路径
+     4. `config.model.head.n_class`: 分类字符数+背景
+    
+    * 可选修改字段说明
+    1. `config.train_options.checkpoint_save_dir`: 模型和日志文件保存地址
+    
+    其他字段可根据需要修改
+    
+4. 通过如下命令启动训练
+```shell script
+CUDA_VISIBLE_DEVICES=0 python3 tools/rec_train.py --config '你的配置文件路径'
+```
+
+PytorchOCR支持训练和评估交替进行, 可以在 `config.train_options`中修改 `val_interval` 设置评估频率,
+评估过程中默认将最佳acc模型,保存为 best.pth。
+
+### 恢复训练
+只需修改 `config.train_options.resume_from` 为模型地址,即可从该模型断掉的地方继续训练
+
+### 预测
+通过以下命令启动预测
+```shell script
+CUDA_VISIBLE_DEVICES=0 python3 tools/rec_infer.py --model_path '' --img_path ''
+```

+ 16 - 0
requirements.txt

@@ -0,0 +1,16 @@
+numpy==1.18.4
+Pillow==8.2.0
+tqdm==4.46.0
+opencv-python==4.2.0.34
+addict==2.2.1
+termcolor==1.1.0
+pyclipper==1.1.0.post3
+shapely==1.7.0
+torch>=1.7.0
+torchvision>=0.8.0
+python-Levenshtein>=0.12.0
+lmdb>=0.98
+imgaug>=0.4.0
+line-profiler==3.2.6
+memory-profiler==0.58.0
+

BIN
test_image/Snipaste_2023-08-21_17-07-58.jpg


BIN
test_image/Snipaste_2023-08-21_17-07-582.jpg


BIN
test_image/Snipaste_2023-08-23_15-35-13.jpg


BIN
test_image/Snipaste_2023-08-25_10-34-45.jpg


BIN
test_image/Snipaste_2023-08-25_11-32-11.jpg


BIN
test_image/Snipaste_2023-08-25_11-33-51.jpg


BIN
test_image/Snipaste_2023-09-08_10-45-17.jpg


+ 4 - 0
test_image/test.txt

@@ -0,0 +1,4 @@
+C:\Users\Administrator\Desktop\OCR_pytorch\PytorchOCR1\test_image\Snipaste_2023-08-21_17-07-58.jpg 123
+C:\Users\Administrator\Desktop\OCR_pytorch\PytorchOCR1\test_image\Snipaste_2023-08-21_17-07-58.jpg 123
+C:\Users\Administrator\Desktop\OCR_pytorch\PytorchOCR1\test_image\Snipaste_2023-08-21_17-07-58.jpg 123
+

+ 84 - 0
tools/create_rec_lmdb_dataset.py

@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2019/11/6 15:31
+# @Author  : zhoujun
+
+""" a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """
+
+import os
+import lmdb
+import cv2
+from tqdm import tqdm
+import numpy as np
+
+def checkImageIsValid(imageBin):
+    if imageBin is None:
+        return False
+    imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
+    img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
+    imgH, imgW = img.shape[0], img.shape[1]
+    if imgH * imgW == 0:
+        return False
+    return True
+
+
+def writeCache(env, cache):
+    with env.begin(write=True) as txn:
+        for k, v in cache.items():
+            txn.put(k, v)
+
+
+def createDataset(data_list, lmdb_save_path, checkValid=True):
+    """
+    Create LMDB dataset for training and evaluation.
+    ARGS:
+        data_list  : a list contains img_path\tlabel
+        lmdb_save_path : LMDB output path
+        checkValid : if true, check the validity of every image
+    """
+    os.makedirs(lmdb_save_path, exist_ok=True)
+    env = lmdb.open(lmdb_save_path, map_size=109951162)
+    cache = {}
+    cnt = 1
+    for imagePath, label in tqdm(data_list, desc=f'make dataset, save to {lmdb_save_path}'):
+        with open(imagePath, 'rb') as f:
+            imageBin = f.read()
+        if checkValid:
+            try:
+                if not checkImageIsValid(imageBin):
+                    print('%s is not a valid image' % imagePath)
+                    continue
+            except:
+                continue
+
+        imageKey = 'image-%09d'.encode() % cnt
+        labelKey = 'label-%09d'.encode() % cnt
+        cache[imageKey] = imageBin
+        cache[labelKey] = label.encode()
+
+        if cnt % 1000 == 0:
+            writeCache(env, cache)
+            cache = {}
+        cnt += 1
+    nSamples = cnt - 1
+    cache['num-samples'.encode()] = str(nSamples).encode()
+    writeCache(env, cache)
+    print('Created dataset with %d samples' % nSamples)
+
+
+if __name__ == '__main__':
+    import pathlib
+    label_file = r"path/val.txt"
+    lmdb_save_path = r'path/lmdb/eval'
+    os.makedirs(lmdb_save_path, exist_ok=True)
+
+    data_list = []
+    with open(label_file, 'r', encoding='utf-8') as f:
+        for line in tqdm(f.readlines(), desc=f'load data from {label_file}'):
+            line = line.strip('\n').replace('.jpg ', '.jpg\t').replace('.png ', '.png\t').split('\t')
+            if len(line) > 1:
+                img_path = pathlib.Path(line[0].strip(' '))
+                label = line[1]
+                if img_path.exists() and img_path.stat().st_size > 0:
+                    data_list.append((str(img_path), label))
+
+    createDataset(data_list, lmdb_save_path)

+ 88 - 0
tools/det_infer.py

@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/6/16 10:57
+# @Author  : zhoujun
+import os
+import sys
+import pathlib
+
+# 将 torchocr路径加到python陆经里
+__dir__ = pathlib.Path(os.path.abspath(__file__))
+sys.path.append(str(__dir__))
+sys.path.append(str(__dir__.parent.parent))
+
+import torch
+from torch import nn
+from torchvision import transforms
+from torchocr.networks import build_model
+from torchocr.datasets.det_modules import ResizeShortSize, ResizeFixedSize
+from torchocr.postprocess import build_post_process
+
+
+class DetInfer:
+    def __init__(self, model_path):
+        ckpt = torch.load(model_path, map_location='cpu')
+        cfg = ckpt['cfg']
+        self.model = build_model(cfg['model'])
+        state_dict = {}
+        for k, v in ckpt['state_dict'].items():
+            state_dict[k.replace('module.', '')] = v
+        self.model.load_state_dict(state_dict)
+        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+        self.model.to(self.device)
+        self.model.eval()
+        self.resize = ResizeFixedSize(736, False)
+        self.post_process = build_post_process(cfg['post_process'])
+        self.transform = transforms.Compose([
+            transforms.ToTensor(),
+            transforms.Normalize(mean=cfg['dataset']['train']['dataset']['mean'], std=cfg['dataset']['train']['dataset']['std'])
+        ])
+
+    def predict(self, img):
+        # 预处理根据训练来
+        data = {'img': img, 'shape': [img.shape[:2]], 'text_polys': []}
+        data = self.resize(data)
+        tensor = self.transform(data['img'])
+        tensor = tensor.unsqueeze(dim=0)
+        tensor = tensor.to(self.device)
+        with torch.no_grad():
+            out = self.model(tensor)
+        out = out.cpu().numpy()
+        box_list, score_list = self.post_process(out, data['shape'])
+        box_list, score_list = box_list[0], score_list[0]
+        if len(box_list) > 0:
+            idx = [x.sum() > 0 for x in box_list]
+            box_list = [box_list[i] for i, v in enumerate(idx) if v]
+            score_list = [score_list[i] for i, v in enumerate(idx) if v]
+        else:
+            box_list, score_list = [], []
+        return box_list, score_list
+
+
+def init_args():
+    import argparse
+    parser = argparse.ArgumentParser(description='PytorchOCR infer')
+    parser.add_argument('--model_path', required=True, type=str, help='rec model path')
+    parser.add_argument('--img_path', required=True, type=str, help='img dir for predict')
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == '__main__':
+    import cv2
+    import time
+    from matplotlib import pyplot as plt
+    from torchocr.utils import draw_bbox
+
+    args = init_args()
+
+    model = DetInfer(args.model_path)
+    names = next(os.walk(args.img_path))[2]
+    st = time.time()
+    for name in names:
+        path = os.path.join(args.img_path, name)
+        img = cv2.imread(path)
+        box_list, score_list = model.predict(img)
+        out_path = os.path.join(args.img_path, 'res', name)
+        img = draw_bbox(img, box_list)
+        cv2.imwrite(out_path[:-4] + '_res.jpg', img)
+    print((time.time() - st) / len(names))

+ 339 - 0
tools/det_train.py

@@ -0,0 +1,339 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/19 21:44
+# @Author  : xiangjing
+
+import os
+import sys
+import pathlib
+
+# 将 torchocr路径加到python路径里
+__dir__ = pathlib.Path(os.path.abspath(__file__))
+sys.path.append(str(__dir__))
+sys.path.append(str(__dir__.parent.parent))
+import random
+import time
+import shutil
+import traceback
+from importlib import import_module
+
+import numpy as np
+import torch
+from tqdm import tqdm
+from torch import nn
+
+from torchocr.networks import build_model, build_loss
+from torchocr.postprocess import build_post_process
+from torchocr.datasets import build_dataloader
+from torchocr.utils import get_logger, weight_init, load_checkpoint, save_checkpoint
+from torchocr.metrics import DetMetric
+
+
+def parse_args():
+    import argparse
+    parser = argparse.ArgumentParser(description='train')
+    parser.add_argument('--config', type=str, default='config/cfg_det_db.py', help='train config file path')
+    args = parser.parse_args()
+    # 解析.py文件
+    config_path = os.path.abspath(os.path.expanduser(args.config))
+    assert os.path.isfile(config_path)
+    if config_path.endswith('.py'):
+        module_name = os.path.basename(config_path)[:-3]
+        config_dir = os.path.dirname(config_path)
+        sys.path.insert(0, config_dir)
+        mod = import_module(module_name)
+        sys.path.pop(0)
+        return mod.config
+        # cfg_dict = {
+        #     name: value
+        #     for name, value in mod.__dict__.items()
+        #     if not name.startswith('__')
+        # }
+        # return cfg_dict
+    else:
+        raise IOError('Only py type are supported now!')
+
+
+def set_random_seed(seed, use_cuda=True, deterministic=False):
+    """Set random seed.
+
+    Args:
+        seed (int): Seed to be used.
+        use_cuda: whether depend on cuda
+        deterministic (bool): Whether to set the deterministic option for
+            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+            to True and `torch.backends.cudnn.benchmark` to False.
+            Default: False.
+    """
+    random.seed(seed)
+    np.random.seed(seed)
+    if use_cuda:
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        if deterministic:
+            torch.backends.cudnn.deterministic = True
+            torch.backends.cudnn.benchmark = False
+
+
+def build_optimizer(params, config):
+    """
+    优化器
+    Returns:
+    """
+    from torch import optim
+    opt_type = config.pop('type')
+    opt = getattr(optim, opt_type)(params, **config)
+    return opt
+
+
+def adjust_learning_rate(optimizer, base_lr, iter, all_iters, factor, warmup_iters=0, warmup_factor=1.0 / 3):
+    """
+    带 warmup 的学习率衰减
+    :param optimizer: 优化器
+    :param base_lr: 开始的学习率
+    :param iter: 当前迭代次数
+    :param all_iters: 总的迭代次数
+    :param factor: 学习率衰减系数
+    :param warmup_iters: warmup 迭代数
+    :param warmup_factor: warmup 系数
+    :return:
+    """
+    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
+    if iter < warmup_iters:
+        alpha = float(iter) / warmup_iters
+        rate = warmup_factor * (1 - alpha) + alpha
+    else:
+        rate = np.power(1.0 - iter / float(all_iters + 1), factor)
+    lr = rate * base_lr
+    for param_group in optimizer.param_groups:
+        param_group['lr'] = lr
+    return lr
+
+
+def get_fine_tune_params(net, finetune_stage):
+    """
+    获取需要优化的参数
+    Args:
+        net:
+    Returns: 需要优化的参数
+    """
+    to_return_parameters = []
+    for stage in finetune_stage:
+        attr = getattr(net.module, stage, None)
+        for element in attr.parameters():
+            to_return_parameters.append(element)
+    return to_return_parameters
+
+
+def evaluate(net, val_loader, to_use_device, logger, post_process, metric):
+    """
+    在验证集上评估模型
+
+    :param net: 网络
+    :param val_loader: 验证集 dataloader
+    :param to_use_device: device
+    :param logger: logger类对象
+    :param post_process: 后处理类对象
+    :param metric: 根据网络输出和 label 计算 acc 等指标的类对象
+    :return:  一个包含 eval_loss,eval_acc和 norm_edit_dis 的 dict,
+        例子: {
+                'recall':0,
+                'precision': 0.99,
+                'hmean': 0.9999,
+                }
+    """
+    logger.info('start evaluate')
+    net.eval()
+    raw_metrics = []
+    total_frame = 0.0
+    total_time = 0.0
+    with torch.no_grad():
+        for batch_data in tqdm(val_loader):
+            start = time.time()
+            output = net.forward(batch_data['img'].to(to_use_device))
+            boxes, scores = post_process(output.cpu().numpy(), batch_data['shape'])
+            total_frame += batch_data['img'].size()[0]
+            total_time += time.time() - start
+            raw_metric = metric(batch_data, (boxes, scores))
+            raw_metrics.append(raw_metric)
+    metrics = metric.gather_measure(raw_metrics)
+    net.train()
+    result_dict = {'recall': metrics['recall'].avg, 'precision': metrics['precision'].avg,
+                   'hmean': metrics['fmeasure'].avg}
+    for k, v in result_dict.items():
+        logger.info(f'{k}:{v}')
+    logger.info('FPS:{}'.format(total_frame / total_time))
+    return result_dict
+
+
+def train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device,
+          cfg, global_state, logger, post_process):
+    """
+    训练函数
+
+    :param net: 网络
+    :param optimizer: 优化器
+    :param scheduler: 学习率更新器
+    :param loss_func: loss函数
+    :param train_loader: 训练数据集 dataloader
+    :param eval_loader: 验证数据集 dataloader
+    :param to_use_device: device
+    :param cfg: 当前训练所使用的配置
+    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
+    :param logger: logger 对象
+    :param post_process: 后处理类对象
+    :return: None
+    """
+
+    train_options = cfg.train_options
+    metric = DetMetric()
+    # ===>
+    logger.info('Training...')
+    # ===> print loss信息的参数
+    all_step = len(train_loader)
+    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
+    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
+    if len(global_state) > 0:
+        best_model = global_state['best_model']
+        start_epoch = global_state['start_epoch']
+        global_step = global_state['global_step']
+    else:
+        best_model = {'recall': 0, 'precision': 0, 'hmean': 0, 'best_model_epoch': 0}
+        start_epoch = 0
+        global_step = 0
+    # 开始训练
+    base_lr = cfg['optimizer']['lr']
+    all_iters = all_step * train_options['epochs']
+    warmup_iters = 3 * all_step
+    try:
+        for epoch in range(start_epoch, train_options['epochs']):  # traverse each epoch
+            net.train()  # train mode
+            train_loss = 0.
+            start = time.time()
+            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
+                current_lr = adjust_learning_rate(optimizer, base_lr, global_step, all_iters, 0.9,
+                                                  warmup_iters=warmup_iters)
+                # 数据进行转换和丢到gpu
+                for key, value in batch_data.items():
+                    if value is not None:
+                        if isinstance(value, torch.Tensor):
+                            batch_data[key] = value.to(to_use_device)
+                # 清零梯度及反向传播
+                optimizer.zero_grad()
+                output = net.forward(batch_data['img'].to(to_use_device))
+                loss_dict = loss_func(output, batch_data)
+                loss_dict['loss'].backward()
+                optimizer.step()
+                # statistic loss for print
+                train_loss += loss_dict['loss'].item()
+                loss_str = 'loss: {:.4f} - '.format(loss_dict.pop('loss').item())
+                for idx, (key, value) in enumerate(loss_dict.items()):
+                    loss_dict[key] = value.item()
+                    loss_str += '{}: {:.4f}'.format(key, loss_dict[key])
+                    if idx < len(loss_dict) - 1:
+                        loss_str += ' - '
+                if (i + 1) % train_options['print_interval'] == 0:
+                    interval_batch_time = time.time() - start
+                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
+                                f"[{i + 1}/{all_step}] - "
+                                f"lr:{current_lr} - "
+                                f"{loss_str} - "
+                                f"time:{interval_batch_time:.4f}")
+                    start = time.time()
+                global_step += 1
+            logger.info(f'train_loss: {train_loss / len(train_loader)}')
+            if (epoch + 1) % train_options['val_interval'] == 0:
+                global_state['start_epoch'] = epoch
+                global_state['best_model'] = best_model
+                global_state['global_step'] = global_step
+                net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
+                save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                if train_options['ckpt_save_type'] == 'HighestAcc':
+                    # val
+                    eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric)
+                    if eval_dict['hmean'] > best_model['hmean']:
+                        best_model.update(eval_dict)
+                        best_model['best_model_epoch'] = epoch
+                        best_model['models'] = net_save_path
+
+                        global_state['start_epoch'] = epoch
+                        global_state['best_model'] = best_model
+                        global_state['global_step'] = global_step
+                        net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
+                        save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options[
+                    'ckpt_save_epoch'] == 0:
+                    shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
+                best_str = 'current best, '
+                for k, v in best_model.items():
+                    best_str += '{}: {}, '.format(k, v)
+                logger.info(best_str)
+    except KeyboardInterrupt:
+        import os
+        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg,
+                        global_state=global_state)
+    except:
+        error_msg = traceback.format_exc()
+        logger.error(error_msg)
+    finally:
+        for k, v in best_model.items():
+            logger.info(f'{k}: {v}')
+
+
+def main():
+    # ===> 获取配置文件参数
+    cfg = parse_args()
+    os.makedirs(cfg.train_options['checkpoint_save_dir'], exist_ok=True)
+    logger = get_logger('torchocr', log_file=os.path.join(cfg.train_options['checkpoint_save_dir'], 'train.log'))
+
+    # ===> 训练信息的打印
+    train_options = cfg.train_options
+    logger.info(cfg)
+    # ===>
+    to_use_device = torch.device(
+        train_options['device'] if torch.cuda.is_available() and ('cuda' in train_options['device']) else 'cpu')
+    set_random_seed(cfg['SEED'], 'cuda' in train_options['device'], deterministic=True)
+
+    # ===> build network
+    net = build_model(cfg['model'])
+
+    # ===> 模型初始化及模型部署到对应的设备
+    if not cfg['model']['backbone']['pretrained']:  # 使用 pretrained
+        net.apply(weight_init)
+    # if torch.cuda.device_count() > 1:
+    net = nn.DataParallel(net)
+    net = net.to(to_use_device)
+    net.train()
+
+    # ===> get fine tune layers
+    params_to_train = get_fine_tune_params(net, train_options['fine_tune_stage'])
+    # ===> solver and lr scheduler
+    optimizer = build_optimizer(net.parameters(), cfg['optimizer'])
+
+    # ===> whether to resume from checkpoint
+    resume_from = train_options['resume_from']
+    if resume_from:
+        net, _resumed_optimizer, global_state = load_checkpoint(net, resume_from, to_use_device, optimizer,
+                                                                third_name=train_options['third_party_name'])
+        if _resumed_optimizer:
+            optimizer = _resumed_optimizer
+        logger.info(f'net resume from {resume_from}')
+    else:
+        global_state = {}
+        logger.info(f'net resume from scratch.')
+
+    # ===> loss function
+    loss_func = build_loss(cfg['loss'])
+    loss_func = loss_func.to(to_use_device)
+
+    # ===> data loader
+    train_loader = build_dataloader(cfg.dataset.train)
+    eval_loader = build_dataloader(cfg.dataset.eval)
+
+    # post_process
+    post_process = build_post_process(cfg['post_process'])
+    # ===> train
+    train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger, post_process)
+
+
+if __name__ == '__main__':
+    main()

+ 333 - 0
tools/det_train_disti.py

@@ -0,0 +1,333 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/19 21:44
+# @Author  : xiangjing
+
+import os
+import sys
+import pathlib
+
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+os.environ['CUDA_VISIBLE_DEVICES'] = '3'
+
+# 将 torchocr路径加到python路径里
+__dir__ = pathlib.Path(os.path.abspath(__file__))
+sys.path.append(str(__dir__))
+sys.path.append(str(__dir__.parent.parent))
+import random
+import time
+import shutil
+import traceback
+from importlib import import_module
+
+import numpy as np
+import torch
+from tqdm import tqdm
+from torch import nn
+
+from torchocr.networks import build_model, build_loss
+from torchocr.postprocess import build_post_process
+from torchocr.datasets import build_dataloader
+from torchocr.utils import get_logger, weight_init, load_checkpoint, save_checkpoint
+from torchocr.metrics import build_metric
+
+
+def parse_args():
+    import argparse
+    parser = argparse.ArgumentParser(description='train')
+    parser.add_argument('--config', type=str, default='config/cfg_det_dis.py', help='train config file path')
+    args = parser.parse_args()
+    # 解析.py文件
+    config_path = os.path.abspath(os.path.expanduser(args.config))
+    assert os.path.isfile(config_path)
+    if config_path.endswith('.py'):
+        module_name = os.path.basename(config_path)[:-3]
+        config_dir = os.path.dirname(config_path)
+        sys.path.insert(0, config_dir)
+        mod = import_module(module_name)
+        sys.path.pop(0)
+        return mod.config
+    else:
+        raise IOError('Only py type are supported now!')
+
+
+def set_random_seed(seed, use_cuda=True, deterministic=False):
+    """Set random seed.
+
+    Args:
+        seed (int): Seed to be used.
+        use_cuda: whether depend on cuda
+        deterministic (bool): Whether to set the deterministic option for
+            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+            to True and `torch.backends.cudnn.benchmark` to False.
+            Default: False.
+    """
+    random.seed(seed)
+    np.random.seed(seed)
+    if use_cuda:
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        if deterministic:
+            torch.backends.cudnn.deterministic = True
+            torch.backends.cudnn.benchmark = False
+
+
+def build_optimizer(params, config):
+    """
+    优化器
+    Returns:
+    """
+    from torch import optim
+    opt_type = config.pop('type')
+    opt = getattr(optim, opt_type)(filter(lambda p: p.requires_grad,params), **config)
+    return opt
+
+
+def adjust_learning_rate(optimizer, base_lr, iter, all_iters, factor, warmup_iters=0, warmup_factor=1.0 / 3):
+    """
+    带 warmup 的学习率衰减
+    :param optimizer: 优化器
+    :param base_lr: 开始的学习率
+    :param iter: 当前迭代次数
+    :param all_iters: 总的迭代次数
+    :param factor: 学习率衰减系数
+    :param warmup_iters: warmup 迭代数
+    :param warmup_factor: warmup 系数
+    :return:
+    """
+    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
+    if iter < warmup_iters:
+        alpha = float(iter) / warmup_iters
+        rate = warmup_factor * (1 - alpha) + alpha
+    else:
+        rate = np.power(1.0 - iter / float(all_iters + 1), factor)
+    lr = rate * base_lr
+    for param_group in optimizer.param_groups:
+        param_group['lr'] = lr
+    return lr
+
+
+def get_fine_tune_params(net, finetune_stage):
+    """
+    获取需要优化的参数
+    Args:
+        net:
+    Returns: 需要优化的参数
+    """
+    to_return_parameters = []
+    for stage in finetune_stage:
+        attr = getattr(net.module, stage, None)
+        for element in attr.parameters():
+            to_return_parameters.append(element)
+    return to_return_parameters
+
+
+def evaluate(net, val_loader, to_use_device, logger, post_process, metric):
+    """
+    在验证集上评估模型
+
+    :param net: 网络
+    :param val_loader: 验证集 dataloader
+    :param to_use_device: device
+    :param logger: logger类对象
+    :param post_process: 后处理类对象
+    :param metric: 根据网络输出和 label 计算 acc 等指标的类对象
+    :return:  一个包含 eval_loss,eval_acc和 norm_edit_dis 的 dict,
+        例子: {
+                'recall':0,
+                'precision': 0.99,
+                'fmeasure': 0.9999,
+                }
+    """
+    logger.info('start evaluate')
+    net.eval()
+    total_frame = 0.0
+    total_time = 0.0
+    with torch.no_grad():
+        for batch_data in tqdm(val_loader):
+            start = time.time()
+            output = net.forward(batch_data['img'].to(to_use_device))
+
+            box_score_tuple = post_process(output, batch_data['shape'])
+            total_frame += batch_data['img'].size()[0]
+            total_time += time.time() - start
+            metric(batch_data, box_score_tuple)
+    metrics = metric.get_metric()
+    net.train()
+    net.module.model_dict['Teacher'].eval()
+    metrics = {key: val.avg for key, val in metrics.items()}
+    for k, v in metrics.items():
+        logger.info(f'{k}:{v}')
+    logger.info('FPS:{}'.format(total_frame / total_time))
+    return metrics
+
+def train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device,
+          cfg, global_state, logger, post_process, metric):
+    """
+    训练函数
+    :param net: 网络
+    :param optimizer: 优化器
+    :param scheduler: 学习率更新器
+    :param loss_func: loss函数
+    :param train_loader: 训练数据集 dataloader
+    :param eval_loader: 验证数据集 dataloader
+    :param to_use_device: device
+    :param cfg: 当前训练所使用的配置
+    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
+    :param logger: logger 对象
+    :param post_process: 后处理类对象
+    :param metric: 评测方法
+    :return: None
+    """
+
+    train_options = cfg.train_options
+    logger.info('Train beginning...')
+    # ===> print loss信息的参数
+    all_step = len(train_loader)
+    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
+    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
+    if len(global_state) > 0:
+        best_model = global_state['best_model']
+        start_epoch = global_state['start_epoch']
+        global_step = global_state['global_step']
+    else:
+        best_model = {'recall': 0, 'precision': 0, 'fmeasure': 0, 'best_model_epoch': 0}
+        start_epoch = 0
+        global_step = 0
+    # 开始训练
+    base_lr = cfg['optimizer']['lr']
+    all_iters = all_step * train_options['epochs']
+    warmup_iters = 3 * all_step
+    # eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric)
+    try:
+        for epoch in range(start_epoch, train_options['epochs']):  # traverse each epoch
+            net.train()  # train mode
+            net.module.model_dict['Teacher'].eval()
+            train_loss = 0.
+            start = time.time()
+            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
+                current_lr = adjust_learning_rate(optimizer, base_lr, global_step, all_iters, 0.9,
+                                                  warmup_iters=warmup_iters)
+                # 数据进行转换和丢到gpu
+                for key, value in batch_data.items():
+                    if value is not None:
+                        if isinstance(value, torch.Tensor):
+                            batch_data[key] = value.to(to_use_device)
+                # 清零梯度及反向传播
+                optimizer.zero_grad()
+                output = net.forward(batch_data['img'].to(to_use_device))
+                loss_dict = loss_func(output, batch_data)
+                loss_dict['loss'].backward()
+                optimizer.step()
+                # statistic loss for print
+                train_loss += loss_dict['loss'].item()
+                loss_str = 'loss: {:.4f} - '.format(loss_dict.pop('loss').item())
+                for idx, (key, value) in enumerate(loss_dict.items()):
+                    loss_dict[key] = value.item()
+                    loss_str += '{}: {:.4f}'.format(key, loss_dict[key])
+                    if idx < len(loss_dict) - 1:
+                        loss_str += ' - '
+                if (i + 1) % train_options['print_interval'] == 0:
+                    interval_batch_time = time.time() - start
+                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
+                                f"[{i + 1}/{all_step}] - "
+                                f"lr:{current_lr} - "
+                                f"{loss_str} - "
+                                f"time:{interval_batch_time:.4f}")
+                    start = time.time()
+                global_step += 1
+            logger.info(f'train_loss: {train_loss / len(train_loader)}')
+            if (epoch + 1) % train_options['val_interval'] == 0:
+                global_state['start_epoch'] = epoch
+                global_state['best_model'] = best_model
+                global_state['global_step'] = global_step
+                net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
+                save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                if train_options['ckpt_save_type'] == 'HighestAcc':
+                    # val
+                    eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric)
+                    if eval_dict['fmeasure'] > best_model['fmeasure']:
+                        best_model.update(eval_dict)
+                        best_model['best_model_epoch'] = epoch
+                        best_model['models'] = net_save_path
+
+                        global_state['start_epoch'] = epoch
+                        global_state['best_model'] = best_model
+                        global_state['global_step'] = global_step
+                        net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
+                        save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0:
+                    shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
+                best_str = 'current best, '
+                for k, v in best_model.items():
+                    best_str += '{}: {}, '.format(k, v)
+                logger.info(best_str)
+    except KeyboardInterrupt:
+        import os
+        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg,
+                        global_state=global_state)
+    except:
+        error_msg = traceback.format_exc()
+        logger.error(error_msg)
+    finally:
+        for k, v in best_model.items():
+            logger.info(f'{k}: {v}')
+
+
+def main():
+    # ===> 获取配置文件参数
+    cfg = parse_args()
+    os.makedirs(cfg.train_options['checkpoint_save_dir'], exist_ok=True)
+    # ===> 训练信息的打印
+    logger = get_logger('torchocr', log_file=os.path.join(cfg.train_options['checkpoint_save_dir'], 'train.log'))
+    logger.info(cfg)
+
+    # ===>
+    train_options = cfg.train_options
+    to_use_device = torch.device(
+        train_options['device'] if torch.cuda.is_available() and ('cuda' in train_options['device']) else 'cpu')
+    set_random_seed(cfg['SEED'], 'cuda' in train_options['device'], deterministic=True)
+
+    # ===> build network
+    net = build_model(cfg['model'])
+
+    # ===> 模型部署到对应的设备
+    net = nn.DataParallel(net)
+    net = net.to(to_use_device)
+
+
+    # ===> 创建metric
+    metric = build_metric(cfg['metric'])
+
+    # ===> get fine tune layers
+    # params_to_train = get_fine_tune_params(net, train_options['fine_tune_stage'])
+    # ===> solver and lr scheduler
+    optimizer = build_optimizer(net.parameters(), cfg['optimizer'])
+    net.train()
+    net.module.model_dict['Teacher'].eval()
+    # ===> whether to resume from checkpoint
+    resume_from = train_options['resume_from']
+    if resume_from:
+        net, _resumed_optimizer, global_state = load_checkpoint(net, resume_from, to_use_device, optimizer)
+        if _resumed_optimizer:
+            optimizer = _resumed_optimizer
+        logger.info(f'net resume from {resume_from}')
+    else:
+        global_state = {}
+        logger.info(f'net resume from scratch.')
+
+    # ===> loss function
+    loss_func = build_loss(cfg['loss'])
+    loss_func = loss_func.to(to_use_device)
+
+    # ===> data loader
+    train_loader = build_dataloader(cfg.dataset.train)
+    eval_loader = build_dataloader(cfg.dataset.eval)
+
+    # ===> post_process
+    post_process = build_post_process(cfg['post_process'])
+    # ===> train
+    train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger, post_process,metric)
+
+
+if __name__ == '__main__':
+    main()

+ 350 - 0
tools/det_train_pse.py

@@ -0,0 +1,350 @@
+import os
+import sys
+import pathlib
+
+# 将 torchocr路径加到python路径里
+__dir__ = pathlib.Path(os.path.abspath(__file__))
+sys.path.append(str(__dir__))
+sys.path.append(str(__dir__.parent.parent))
+import os
+
+# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+import random
+import time
+import shutil
+import traceback
+from importlib import import_module
+
+import numpy as np
+import torch
+from tqdm import tqdm
+from torch import nn
+
+from torchocr.networks import build_model, build_loss
+from torchocr.postprocess import build_post_process
+from torchocr.datasets import build_dataloader
+from torchocr.utils import get_logger, weight_init, load_checkpoint, save_checkpoint
+from torchocr.metrics import DetMetric
+
+
+def parse_args():
+    import argparse
+    parser = argparse.ArgumentParser(description='train')
+    parser.add_argument('--config', type=str, default='config/cfg_det_pse.py', help='train config file path')
+    args = parser.parse_args()
+    # 解析.py文件
+    config_path = os.path.abspath(os.path.expanduser(args.config))
+    assert os.path.isfile(config_path)
+    if config_path.endswith('.py'):
+        module_name = os.path.basename(config_path)[:-3]
+        config_dir = os.path.dirname(config_path)
+        sys.path.insert(0, config_dir)
+        mod = import_module(module_name)
+        sys.path.pop(0)
+        return mod.config
+        # cfg_dict = {
+        #     name: value
+        #     for name, value in mod.__dict__.items()
+        #     if not name.startswith('__')
+        # }
+        # return cfg_dict
+    else:
+        raise IOError('Only py type are supported now!')
+
+
+def set_random_seed(seed, use_cuda=True, deterministic=False):
+    """Set random seed.
+
+    Args:
+        seed (int): Seed to be used.
+        use_cuda: whether depend on cuda
+        deterministic (bool): Whether to set the deterministic option for
+            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+            to True and `torch.backends.cudnn.benchmark` to False.
+            Default: False.
+    """
+    random.seed(seed)
+    np.random.seed(seed)
+    if use_cuda:
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        if deterministic:
+            torch.backends.cudnn.deterministic = True
+            torch.backends.cudnn.benchmark = False
+
+
+def build_optimizer(params, config):
+    """
+    优化器
+    Returns:
+    """
+    from torch import optim
+    opt_type = config.pop('type')
+
+    opt = getattr(optim, opt_type)(params, **config)
+    return opt
+
+
+def adjust_learning_rate(optimizer, base_lr, iter, all_iters, factor, warmup_iters=300, warmup_factor=1.0 / 3):
+    """
+    带 warmup 的学习率衰减
+    :param optimizer: 优化器
+    :param base_lr: 开始的学习率
+    :param iter: 当前迭代次数
+    :param all_iters: 总的迭代次数
+    :param factor: 学习率衰减系数
+    :param warmup_iters: warmup 迭代数
+    :param warmup_factor: warmup 系数
+    :return:
+    """
+    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
+    if iter < warmup_iters:
+        alpha = float(iter) / warmup_iters
+        rate = warmup_factor * (1 - alpha) + alpha
+    else:
+        rate = np.power(1.0 - iter / float(all_iters + 1), factor)
+    lr = rate * base_lr
+    for param_group in optimizer.param_groups:
+        param_group['lr'] = lr
+    return lr
+
+
+def get_fine_tune_params(net, finetune_stage):
+    """
+    获取需要优化的参数
+    Args:
+        net:
+    Returns: 需要优化的参数
+    """
+    to_return_parameters = []
+    for stage in finetune_stage:
+        attr = getattr(net.module, stage, None)
+        for element in attr.parameters():
+            to_return_parameters.append(element)
+    return to_return_parameters
+
+
+def evaluate(net, val_loader, to_use_device, logger, post_process, metric):
+    """
+    在验证集上评估模型
+
+    :param net: 网络
+    :param val_loader: 验证集 dataloader
+    :param to_use_device: device
+    :param logger: logger类对象
+    :param post_process: 后处理类对象
+    :param metric: 根据网络输出和 label 计算 acc 等指标的类对象
+    :return:  一个包含 eval_loss,eval_acc和 norm_edit_dis 的 dict,
+        例子: {
+                'recall':0,
+                'precision': 0.99,
+                'hmean': 0.9999,
+                }
+    """
+    logger.info('start evaluate')
+    net.eval()
+    raw_metrics = []
+    total_frame = 0.0
+    total_time = 0.0
+    with torch.no_grad():
+        idx = 0
+        for batch_data in tqdm(val_loader):
+            start = time.time()
+            output = net.forward(batch_data['img'].to(to_use_device))
+            h, w = batch_data['shape'][0].item(), batch_data['shape'][1].item()
+            preds, boxes_list = post_process(output[0], 1)
+            scale = (preds.shape[1] * 1.0 / w, preds.shape[0] * 1.0 / h)
+            if len(boxes_list):
+                boxes_list = boxes_list / scale
+
+            scores = [1] * len(boxes_list)
+            # x = output.detach().cpu().numpy().squeeze()
+            # x = x > 0.7
+            # x = x * 255
+            # x = x.astype(np.uint8)
+            # import cv2
+            # cv2.imwrite(f'mask{idx}.png', x)
+            # img = cv2.imread(batch_data['img_path'][0])
+            # cv2.imwrite(f'gt{idx}.png', img)
+            # idx += 1
+
+            total_frame += batch_data['img'].size()[0]
+            total_time += time.time() - start
+            raw_metric = metric(batch_data, ([boxes_list], [scores]))
+            raw_metrics.append(raw_metric)
+    metrics = metric.gather_measure(raw_metrics)
+    net.train()
+    result_dict = {'recall': metrics['recall'].avg, 'precision': metrics['precision'].avg, 'hmean': metrics['fmeasure'].avg}
+    for k, v in result_dict.items():
+        logger.info(f'{k}:{v}')
+    logger.info('FPS:{}'.format(total_frame / total_time))
+    return result_dict
+
+
+def train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device,
+          cfg, global_state, logger, post_process):
+    """
+    训练函数
+
+    :param net: 网络
+    :param optimizer: 优化器
+    :param scheduler: 学习率更新器
+    :param loss_func: loss函数
+    :param train_loader: 训练数据集 dataloader
+    :param eval_loader: 验证数据集 dataloader
+    :param to_use_device: device
+    :param cfg: 当前训练所使用的配置
+    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
+    :param logger: logger 对象
+    :param post_process: 后处理类对象
+    :return: None
+    """
+
+    train_options = cfg.train_options
+    metric = DetMetric()
+    # ===>
+    logger.info('Training...')
+    # ===> print loss信息的参数
+    all_step = len(train_loader)
+    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
+    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
+    if len(global_state) > 0:
+        best_model = global_state['best_model']
+        start_epoch = global_state['start_epoch']
+        global_step = global_state['global_step']
+    else:
+        best_model = {'recall': 0, 'precision': 0, 'hmean': 0, 'best_model_epoch': 0}
+        start_epoch = 0
+        global_step = 0
+    # 开始训练
+    base_lr = cfg['optimizer']['lr']
+    all_iters = all_step * train_options['epochs']
+    warmup_iters = 3 * all_step
+    # eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric)
+    try:
+        for epoch in range(start_epoch, train_options['epochs']):  # traverse each epoch
+            net.train()  # train mode
+            train_loss = 0.
+            start = time.time()
+            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
+                current_lr = adjust_learning_rate(optimizer, base_lr, global_step, all_iters, 0.9, warmup_iters=warmup_iters)
+                # 数据进行转换和丢到gpu
+                # for key, value in batch_data.items():
+                #     if value is not None:
+                #         if isinstance(value, torch.Tensor):
+                #             batch_data[key] = value.to(to_use_device)
+                # 清零梯度及反向传播
+                optimizer.zero_grad()
+                output = net.forward(batch_data['img'].to(to_use_device))
+                labels, training_mask = batch_data['score_maps'].to(to_use_device), batch_data['training_mask'].to(to_use_device)
+                loss_c, loss_s, loss = loss_func(output, labels, training_mask)
+                loss.backward()
+                optimizer.step()
+                # statistic loss for print
+                train_loss += loss.item()
+                loss_str = 'loss: {:.4f} - '.format(loss.item())
+
+                if (i + 1) % train_options['print_interval'] == 0:
+                    interval_batch_time = time.time() - start
+                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
+                                f"[{i + 1}/{all_step}] - "
+                                f"lr:{current_lr} - "
+                                f"{loss_str} - "
+                                f"time:{interval_batch_time:.4f}")
+                    start = time.time()
+                global_step += 1
+            logger.info(f'train_loss: {train_loss / len(train_loader)}')
+            if (epoch + 1) % train_options['val_interval'] == 0:
+                global_state['start_epoch'] = epoch
+                global_state['best_model'] = best_model
+                global_state['global_step'] = global_step
+                net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
+                save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                if train_options['ckpt_save_type'] == 'HighestAcc':
+                    # val
+                    eval_dict = evaluate(net, eval_loader, to_use_device, logger, post_process, metric)
+                    if eval_dict['hmean'] > best_model['hmean']:
+                        best_model.update(eval_dict)
+                        best_model['best_model_epoch'] = epoch
+                        best_model['models'] = net_save_path
+
+                        global_state['start_epoch'] = epoch
+                        global_state['best_model'] = best_model
+                        global_state['global_step'] = global_step
+                        net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
+                        save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0:
+                    shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
+                best_str = 'current best, '
+                for k, v in best_model.items():
+                    best_str += '{}: {}, '.format(k, v)
+                logger.info(best_str)
+    except KeyboardInterrupt:
+        import os
+        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg, global_state=global_state)
+    except:
+        error_msg = traceback.format_exc()
+        logger.error(error_msg)
+    finally:
+        for k, v in best_model.items():
+            logger.info(f'{k}: {v}')
+
+
+def main():
+    # ===> 获取配置文件参数
+    cfg = parse_args()
+    os.makedirs(cfg.train_options['checkpoint_save_dir'], exist_ok=True)
+    logger = get_logger('torchocr', log_file=os.path.join(cfg.train_options['checkpoint_save_dir'], 'train.log'))
+
+    # ===> 训练信息的打印
+    train_options = cfg.train_options
+    logger.info(cfg)
+    # ===>
+    to_use_device = torch.device(
+        train_options['device'] if torch.cuda.is_available() and ('cuda' in train_options['device']) else 'cpu')
+    set_random_seed(cfg['SEED'], 'cuda' in train_options['device'], deterministic=True)
+
+    # ===> build network
+    net = build_model(cfg['model'])
+
+    # ===> 模型初始化及模型部署到对应的设备
+    # net.apply(weight_init) # 使用 pretrained时,注释掉这句话
+    # if torch.cuda.device_count() > 1:
+    net = nn.DataParallel(net)
+    net = net.to(to_use_device)
+    net.train()
+
+    # ===> get fine tune layers
+    params_to_train = get_fine_tune_params(net, train_options['fine_tune_stage'])
+    # ===> solver and lr scheduler
+    optimizer = build_optimizer(net.parameters(), cfg['optimizer'])
+
+    # ===> whether to resume from checkpoint
+    resume_from = train_options['resume_from']
+    if resume_from:
+        net, _resumed_optimizer, global_state = load_checkpoint(net, resume_from, to_use_device, optimizer,
+                                                                third_name=train_options['third_party_name'])
+        if _resumed_optimizer:
+            optimizer = _resumed_optimizer
+        logger.info(f'net resume from {resume_from}')
+    else:
+        global_state = {}
+        logger.info(f'net resume from scratch.')
+
+    # ===> loss function
+    loss_func = build_loss(cfg['loss'])
+    loss_func = loss_func.to(to_use_device)
+
+    # ===> data loader
+    train_loader = build_dataloader(cfg.dataset.train)
+    eval_loader = build_dataloader(cfg.dataset.eval)
+
+    # post_process
+    post_process = build_post_process(cfg['post_process'])
+    # ===> train
+    train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger, post_process)
+
+
+if __name__ == '__main__':
+    main()

+ 364 - 0
tools/doc_test.py

@@ -0,0 +1,364 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/19 21:44
+# @Author  : xiangjing
+
+import os
+import sys
+import pathlib
+
+# 将 torchocr路径加到python路径里
+__dir__ = pathlib.Path(os.path.abspath(__file__))
+sys.path.append(str(__dir__))
+sys.path.append(str(__dir__.parent.parent))
+import random
+import time
+import shutil
+import traceback
+from importlib import import_module
+
+import numpy as np
+import torch
+from tqdm import tqdm
+from torch import nn
+from torch import optim
+from torchocr.networks import build_model, build_loss
+from torchocr.datasets import build_dataloader
+from torchocr.utils import get_logger, weight_init, load_checkpoint, save_checkpoint
+
+
+def parse_args():
+    import argparse
+    parser = argparse.ArgumentParser(description='train')
+    parser.add_argument('--config', type=str, default='/data2/znj/PytorchOCR/config/cfg_rec_crnn.py',
+                        help='train config file path')
+    args = parser.parse_args()
+    # 解析.py文件
+    config_path = os.path.abspath(os.path.expanduser(args.config))
+    assert os.path.isfile(config_path)
+    if config_path.endswith('.py'):
+        module_name = os.path.basename(config_path)[:-3]
+        config_dir = os.path.dirname(config_path)
+        sys.path.insert(0, config_dir)
+        mod = import_module(module_name)
+        sys.path.pop(0)
+        return mod.config
+        # cfg_dict = {
+        #     name: value
+        #     for name, value in mod.__dict__.items()
+        #     if not name.startswith('__')
+        # }
+        # return cfg_dict
+    else:
+        raise IOError('Only py type are supported now!')
+
+
+def set_random_seed(seed, use_cuda=True, deterministic=False):
+    """Set random seed.
+
+    Args:
+        seed (int): Seed to be used.
+        use_cuda: whether depend on cuda
+        deterministic (bool): Whether to set the deterministic option for
+            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+            to True and `torch.backends.cudnn.benchmark` to False.
+            Default: False.
+    """
+    random.seed(seed)
+    np.random.seed(seed)
+    if use_cuda:
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        if deterministic:
+            torch.backends.cudnn.deterministic = True
+            torch.backends.cudnn.benchmark = False
+
+
+def build_optimizer(params, config):
+    """
+    优化器
+    Returns:
+    """
+
+    opt_type = config.pop('type')
+    opt = getattr(optim, opt_type)(params, **config)
+    return opt
+
+
+def build_scheduler(optimizer, config):
+    """
+    """
+    scheduler = None
+    sch_type = config.pop('type')
+    if sch_type == 'LambdaLR':
+        burn_in, steps = config['burn_in'], config['steps']
+
+        # Learning rate setup
+        def burnin_schedule(i):
+            if i < burn_in:
+                factor = pow(i / burn_in, 4)
+            elif i < steps[0]:
+                factor = 1.0
+            elif i < steps[1]:
+                factor = 0.1
+            else:
+                factor = 0.01
+            return factor
+
+        scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)
+    elif sch_type == 'StepLR':
+        # 等间隔调整学习率, 调整倍数为gamma倍,调整间隔为step_size,间隔单位是step,step通常是指epoch。
+        step_size, gamma = config['step_size'], config['gamma']
+        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
+    elif sch_type == 'ReduceLROnPlateau':
+        # 当某指标不再变化(下降或升高),调整学习率,这是非常实用的学习率调整策略。例如,当验证集的loss不再下降时,进行学习率调整;或者监测验证集的accuracy,当accuracy不再上升时,则调整学习率。
+        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1,
+                                                               patience=3, verbose=True, threshold=1e-4)
+    return scheduler
+
+
+def get_fine_tune_params(net, finetune_stage):
+    """
+    获取需要优化的参数
+    Args:
+        net:
+    Returns: 需要优化的参数
+    """
+    to_return_parameters = []
+    for stage in finetune_stage:
+        attr = getattr(net.module, stage, None)
+        for element in attr.parameters():
+            to_return_parameters.append(element)
+    return to_return_parameters
+
+
+def evaluate(net, val_loader, loss_func, to_use_device, logger, converter, metric):
+    """
+    在验证集上评估模型
+
+    :param net: 网络
+    :param val_loader: 验证集 dataloader
+    :param loss_func: 损失函数
+    :param to_use_device: device
+    :param logger: logger类对象
+    :param converter: label转换器类对象
+    :param metric: 根据网络输出和 label 计算 acc 等指标的类对象
+    :return:  一个包含 eval_loss,eval_acc和 norm_edit_dis 的 dict,
+        例子: {
+                'eval_loss':0,
+                'eval_acc': 0.99,
+                'norm_edit_dis': 0.9999,
+                }
+    """
+    logger.info('start evaluate')
+    net.eval()
+    nums = 0
+    result_dict = {'eval_loss': 0., 'eval_acc': 0., 'norm_edit_dis': 0.}
+    show_str = []
+    with torch.no_grad():
+        # start =time.time()
+        for batch_data in tqdm(val_loader):
+            targets, targets_lengths = converter.encode(batch_data['label'])
+            batch_data['targets'] = targets
+            batch_data['targets_lengths'] = targets_lengths
+            output = net.forward(batch_data['img'].to(to_use_device))
+            loss = loss_func(output, batch_data)
+
+            nums += batch_data['img'].shape[0]
+            acc_dict = metric(output[1], batch_data['label'])
+            result_dict['eval_loss'] += loss['loss'].item()
+            result_dict['eval_acc'] += acc_dict['n_correct']
+            result_dict['norm_edit_dis'] += acc_dict['norm_edit_dis']
+            show_str.extend(acc_dict['show_str'])
+
+            # print("cost-time:",time.time()-start)
+            # start = time.time()
+
+    print('nums:', nums, 'right_nums:', result_dict['eval_acc'])
+    result_dict['eval_loss'] /= len(val_loader)
+    result_dict['eval_acc'] /= nums
+    result_dict['norm_edit_dis'] = 1 - result_dict['norm_edit_dis'] / nums
+    logger.info(f"eval_loss:{result_dict['eval_loss']}")
+    logger.info(f"eval_acc:{result_dict['eval_acc']}")
+    logger.info(f"norm_edit_dis:{result_dict['norm_edit_dis']}")
+
+    for s in show_str[:10]:
+        logger.info(s)
+    net.train()
+    return result_dict
+
+
+def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device,
+          cfg, global_state, logger):
+    """
+    训练函数
+
+    :param net: 网络
+    :param optimizer: 优化器
+    :param scheduler: 学习率更新器
+    :param loss_func: loss函数
+    :param train_loader: 训练数据集 dataloader
+    :param eval_loader: 验证数据集 dataloader
+    :param to_use_device: device
+    :param cfg: 当前训练所使用的配置
+    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
+    :param logger: logger 对象
+    :return: None
+    """
+
+    from torchocr.metrics import RecMetric
+    from torchocr.utils import CTCLabelConverter
+    converter = CTCLabelConverter(cfg.dataset.alphabet)
+    train_options = cfg.train_options
+    metric = RecMetric(converter)
+    # ===>
+    logger.info('Training...')
+    # ===> print loss信息的参数
+    all_step = len(train_loader)
+    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
+    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
+    if len(global_state) > 0:
+        best_model = global_state['best_model']
+        start_epoch = global_state['start_epoch']
+        global_step = global_state['global_step']
+    else:
+        best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
+        start_epoch = 0
+        global_step = 0
+    # 开始训练
+    try:
+        for epoch in range(start_epoch, train_options['epochs']):  # traverse each epoch
+            net.train()  # train mode
+            start = time.time()
+            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
+                current_lr = optimizer.param_groups[0]['lr']
+                cur_batch_size = batch_data['img'].shape[0]
+                targets, targets_lengths = converter.encode(batch_data['label'])
+                batch_data['targets'] = targets
+                batch_data['targets_lengths'] = targets_lengths
+                # 清零梯度及反向传播
+                optimizer.zero_grad()
+                output = net.forward(batch_data['img'].to(to_use_device))
+                loss_dict = loss_func(output, batch_data)
+                loss_dict['loss'].backward()
+                torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
+                optimizer.step()
+                # statistic loss for print
+                acc_dict = metric(output[1], batch_data['label'])
+                acc = acc_dict['n_correct'] / cur_batch_size
+                norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
+                if (i + 1) % train_options['print_interval'] == 0:
+                    interval_batch_time = time.time() - start
+                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
+                                f"[{i + 1}/{all_step}] - "
+                                f"lr:{current_lr} - "
+                                f"loss:{loss_dict['loss'].item():.4f} - "
+                                f"acc:{acc:.4f} - "
+                                f"norm_edit_dis:{norm_edit_dis:.4f} - "
+                                f"time:{interval_batch_time:.4f}")
+                    start = time.time()
+                if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
+                    global_state['start_epoch'] = epoch
+                    global_state['best_model'] = best_model
+                    global_state['global_step'] = global_step
+                    net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
+                    save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                    if train_options['ckpt_save_type'] == 'HighestAcc':
+                        # val
+                        eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
+                        if eval_dict['eval_acc'] > best_model['eval_acc']:
+                            best_model.update(eval_dict)
+                            best_model['best_model_epoch'] = epoch
+                            best_model['models'] = net_save_path
+
+                            global_state['start_epoch'] = epoch
+                            global_state['best_model'] = best_model
+                            global_state['global_step'] = global_step
+                            net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
+                            save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                    elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options[
+                        'ckpt_save_epoch'] == 0:
+                        shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
+                global_step += 1
+            scheduler.step()
+    except KeyboardInterrupt:
+        import os
+        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg,
+                        global_state=global_state)
+    except:
+        error_msg = traceback.format_exc()
+        logger.error(error_msg)
+    finally:
+        for k, v in best_model.items():
+            logger.info(f'{k}: {v}')
+
+
+def main():
+    # ===> 获取配置文件参数
+    cfg = parse_args()
+    os.makedirs(cfg.train_options['checkpoint_save_dir'], exist_ok=True)
+    logger = get_logger('torchocr', log_file=os.path.join(cfg.train_options['checkpoint_save_dir'], 'train.log'))
+
+    # ===> 训练信息的打印
+    train_options = cfg.train_options
+    logger.info(cfg)
+    # ===>
+    to_use_device = torch.device(
+        train_options['device'] if torch.cuda.is_available() and ('cuda' in train_options['device']) else 'cpu')
+    set_random_seed(cfg['SEED'], 'cuda' in train_options['device'], deterministic=True)
+
+    # ===> build network
+    net = build_model(cfg['model'])
+
+    # ===> 模型初始化及模型部署到对应的设备
+    if not cfg['model']['backbone']['pretrained']:  # 使用 pretrained
+        net.apply(weight_init)
+    # if torch.cuda.device_count() > 1:
+    net = nn.DataParallel(net)
+    net = net.to(to_use_device)
+    net.train()
+
+    # ===> get fine tune layers
+    params_to_train = get_fine_tune_params(net, train_options['fine_tune_stage'])
+    # ===> solver and lr scheduler
+    optimizer = build_optimizer(params_to_train, cfg['optimizer'])
+    scheduler = build_scheduler(optimizer, cfg['lr_scheduler'])
+
+    # ===> whether to resume from checkpoint
+    resume_from = train_options['resume_from']
+    if resume_from:
+        net, _resumed_optimizer, global_state = load_checkpoint(net, resume_from, to_use_device, optimizer,
+                                                                third_name=train_options['third_party_name'])
+        if _resumed_optimizer:
+            optimizer = _resumed_optimizer
+        logger.info(f'net resume from {resume_from}')
+    else:
+        global_state = {}
+        logger.info(f'net resume from scratch.')
+
+    # ===> loss function
+    loss_func = build_loss(cfg['loss'])
+    loss_func = loss_func.to(to_use_device)
+
+    # ===> data loader
+    # cfg.dataset.train.dataset.alphabet = cfg.dataset.alphabet
+    # train_loader = build_dataloader(cfg.dataset.train)
+    cfg.dataset.eval.dataset.alphabet = cfg.dataset.alphabet
+    eval_loader = build_dataloader(cfg.dataset.eval)
+
+    from torchocr.metrics import RecMetric
+    from torchocr.utils import CTCLabelConverter
+    _converter = CTCLabelConverter(cfg.dataset.alphabet)
+    _metric = RecMetric(_converter)
+    _converter = CTCLabelConverter(cfg.dataset.alphabet)
+    # start = time.time()
+    eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, _converter, _metric)
+    print(eval_dict)
+    # print("cost_time:",time.time()-start)
+
+
+    # ===> train
+    # train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger)
+
+
+if __name__ == '__main__':
+    main()

+ 364 - 0
tools/doc_test_resnet.py

@@ -0,0 +1,364 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/19 21:44
+# @Author  : xiangjing
+
+import os
+import sys
+import pathlib
+
+# 将 torchocr路径加到python路径里
+__dir__ = pathlib.Path(os.path.abspath(__file__))
+sys.path.append(str(__dir__))
+sys.path.append(str(__dir__.parent.parent))
+import random
+import time
+import shutil
+import traceback
+from importlib import import_module
+
+import numpy as np
+import torch
+from tqdm import tqdm
+from torch import nn
+from torch import optim
+from torchocr.networks import build_model, build_loss
+from torchocr.datasets import build_dataloader
+from torchocr.utils import get_logger, weight_init, load_checkpoint, save_checkpoint
+
+
+def parse_args():
+    import argparse
+    parser = argparse.ArgumentParser(description='train')
+    parser.add_argument('--config', type=str, default='/data2/znj/PytorchOCR/config/cfg_rec_crnn_doc_test_resnet.py',
+                        help='train config file path')
+    args = parser.parse_args()
+    # 解析.py文件
+    config_path = os.path.abspath(os.path.expanduser(args.config))
+    assert os.path.isfile(config_path)
+    if config_path.endswith('.py'):
+        module_name = os.path.basename(config_path)[:-3]
+        config_dir = os.path.dirname(config_path)
+        sys.path.insert(0, config_dir)
+        mod = import_module(module_name)
+        sys.path.pop(0)
+        return mod.config
+        # cfg_dict = {
+        #     name: value
+        #     for name, value in mod.__dict__.items()
+        #     if not name.startswith('__')
+        # }
+        # return cfg_dict
+    else:
+        raise IOError('Only py type are supported now!')
+
+
+def set_random_seed(seed, use_cuda=True, deterministic=False):
+    """Set random seed.
+
+    Args:
+        seed (int): Seed to be used.
+        use_cuda: whether depend on cuda
+        deterministic (bool): Whether to set the deterministic option for
+            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+            to True and `torch.backends.cudnn.benchmark` to False.
+            Default: False.
+    """
+    random.seed(seed)
+    np.random.seed(seed)
+    if use_cuda:
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        if deterministic:
+            torch.backends.cudnn.deterministic = True
+            torch.backends.cudnn.benchmark = False
+
+
+def build_optimizer(params, config):
+    """
+    优化器
+    Returns:
+    """
+
+    opt_type = config.pop('type')
+    opt = getattr(optim, opt_type)(params, **config)
+    return opt
+
+
+def build_scheduler(optimizer, config):
+    """
+    """
+    scheduler = None
+    sch_type = config.pop('type')
+    if sch_type == 'LambdaLR':
+        burn_in, steps = config['burn_in'], config['steps']
+
+        # Learning rate setup
+        def burnin_schedule(i):
+            if i < burn_in:
+                factor = pow(i / burn_in, 4)
+            elif i < steps[0]:
+                factor = 1.0
+            elif i < steps[1]:
+                factor = 0.1
+            else:
+                factor = 0.01
+            return factor
+
+        scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)
+    elif sch_type == 'StepLR':
+        # 等间隔调整学习率, 调整倍数为gamma倍,调整间隔为step_size,间隔单位是step,step通常是指epoch。
+        step_size, gamma = config['step_size'], config['gamma']
+        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
+    elif sch_type == 'ReduceLROnPlateau':
+        # 当某指标不再变化(下降或升高),调整学习率,这是非常实用的学习率调整策略。例如,当验证集的loss不再下降时,进行学习率调整;或者监测验证集的accuracy,当accuracy不再上升时,则调整学习率。
+        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1,
+                                                               patience=3, verbose=True, threshold=1e-4)
+    return scheduler
+
+
+def get_fine_tune_params(net, finetune_stage):
+    """
+    获取需要优化的参数
+    Args:
+        net:
+    Returns: 需要优化的参数
+    """
+    to_return_parameters = []
+    for stage in finetune_stage:
+        attr = getattr(net.module, stage, None)
+        for element in attr.parameters():
+            to_return_parameters.append(element)
+    return to_return_parameters
+
+
+def evaluate(net, val_loader, loss_func, to_use_device, logger, converter, metric):
+    """
+    在验证集上评估模型
+
+    :param net: 网络
+    :param val_loader: 验证集 dataloader
+    :param loss_func: 损失函数
+    :param to_use_device: device
+    :param logger: logger类对象
+    :param converter: label转换器类对象
+    :param metric: 根据网络输出和 label 计算 acc 等指标的类对象
+    :return:  一个包含 eval_loss,eval_acc和 norm_edit_dis 的 dict,
+        例子: {
+                'eval_loss':0,
+                'eval_acc': 0.99,
+                'norm_edit_dis': 0.9999,
+                }
+    """
+    logger.info('start evaluate')
+    net.eval()
+    nums = 0
+    result_dict = {'eval_loss': 0., 'eval_acc': 0., 'norm_edit_dis': 0.}
+    show_str = []
+    with torch.no_grad():
+        # start =time.time()
+        for batch_data in tqdm(val_loader):
+            targets, targets_lengths = converter.encode(batch_data['label'])
+            batch_data['targets'] = targets
+            batch_data['targets_lengths'] = targets_lengths
+            output = net.forward(batch_data['img'].to(to_use_device))
+            loss = loss_func(output, batch_data)
+
+            nums += batch_data['img'].shape[0]
+            acc_dict = metric(output[1], batch_data['label'])
+            result_dict['eval_loss'] += loss['loss'].item()
+            result_dict['eval_acc'] += acc_dict['n_correct']
+            result_dict['norm_edit_dis'] += acc_dict['norm_edit_dis']
+            show_str.extend(acc_dict['show_str'])
+
+            # print("cost-time:",time.time()-start)
+            # start = time.time()
+
+    print('nums:', nums, 'right_nums:', result_dict['eval_acc'])
+    result_dict['eval_loss'] /= len(val_loader)
+    result_dict['eval_acc'] /= nums
+    result_dict['norm_edit_dis'] = 1 - result_dict['norm_edit_dis'] / nums
+    logger.info(f"eval_loss:{result_dict['eval_loss']}")
+    logger.info(f"eval_acc:{result_dict['eval_acc']}")
+    logger.info(f"norm_edit_dis:{result_dict['norm_edit_dis']}")
+
+    for s in show_str[:10]:
+        logger.info(s)
+    net.train()
+    return result_dict
+
+
+def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device,
+          cfg, global_state, logger):
+    """
+    训练函数
+
+    :param net: 网络
+    :param optimizer: 优化器
+    :param scheduler: 学习率更新器
+    :param loss_func: loss函数
+    :param train_loader: 训练数据集 dataloader
+    :param eval_loader: 验证数据集 dataloader
+    :param to_use_device: device
+    :param cfg: 当前训练所使用的配置
+    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
+    :param logger: logger 对象
+    :return: None
+    """
+
+    from torchocr.metrics import RecMetric
+    from torchocr.utils import CTCLabelConverter
+    converter = CTCLabelConverter(cfg.dataset.alphabet)
+    train_options = cfg.train_options
+    metric = RecMetric(converter)
+    # ===>
+    logger.info('Training...')
+    # ===> print loss信息的参数
+    all_step = len(train_loader)
+    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
+    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
+    if len(global_state) > 0:
+        best_model = global_state['best_model']
+        start_epoch = global_state['start_epoch']
+        global_step = global_state['global_step']
+    else:
+        best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
+        start_epoch = 0
+        global_step = 0
+    # 开始训练
+    # try:
+    for epoch in range(start_epoch, train_options['epochs']):  # traverse each epoch
+        net.train()  # train mode
+        start = time.time()
+        for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
+            current_lr = optimizer.param_groups[0]['lr']
+            cur_batch_size = batch_data['img'].shape[0]
+            targets, targets_lengths = converter.encode(batch_data['label'])
+            batch_data['targets'] = targets
+            batch_data['targets_lengths'] = targets_lengths
+            # 清零梯度及反向传播
+            optimizer.zero_grad()
+            output = net.forward(batch_data['img'].to(to_use_device))
+            loss_dict = loss_func(output, batch_data)
+            loss_dict['loss'].backward()
+            torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
+            optimizer.step()
+            # statistic loss for print
+            acc_dict = metric(output[1], batch_data['label'])
+            acc = acc_dict['n_correct'] / cur_batch_size
+            norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
+            if (i + 1) % train_options['print_interval'] == 0:
+                interval_batch_time = time.time() - start
+                logger.info(f"[{epoch}/{train_options['epochs']}] - "
+                            f"[{i + 1}/{all_step}] - "
+                            f"lr:{current_lr} - "
+                            f"loss:{loss_dict['loss'].item():.4f} - "
+                            f"acc:{acc:.4f} - "
+                            f"norm_edit_dis:{norm_edit_dis:.4f} - "
+                            f"time:{interval_batch_time:.4f}")
+                start = time.time()
+            if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
+                global_state['start_epoch'] = epoch
+                global_state['best_model'] = best_model
+                global_state['global_step'] = global_step
+                net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
+                save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                if train_options['ckpt_save_type'] == 'HighestAcc':
+                    # val
+                    eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
+                    if eval_dict['eval_acc'] > best_model['eval_acc']:
+                        best_model.update(eval_dict)
+                        best_model['best_model_epoch'] = epoch
+                        best_model['models'] = net_save_path
+
+                        global_state['start_epoch'] = epoch
+                        global_state['best_model'] = best_model
+                        global_state['global_step'] = global_step
+                        net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
+                        save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options[
+                    'ckpt_save_epoch'] == 0:
+                    shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
+            global_step += 1
+        scheduler.step()
+    # except KeyboardInterrupt:
+    #     import os
+    #     save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg,
+    #                     global_state=global_state)
+    # except:
+    #     error_msg = traceback.format_exc()
+    #     logger.error(error_msg)
+    # finally:
+    for k, v in best_model.items():
+        logger.info(f'{k}: {v}')
+
+
+def main():
+    # ===> 获取配置文件参数
+    cfg = parse_args()
+    os.makedirs(cfg.train_options['checkpoint_save_dir'], exist_ok=True)
+    logger = get_logger('torchocr', log_file=os.path.join(cfg.train_options['checkpoint_save_dir'], 'train.log'))
+
+    # ===> 训练信息的打印
+    train_options = cfg.train_options
+    logger.info(cfg)
+    # ===>
+    to_use_device = torch.device(
+        train_options['device'] if torch.cuda.is_available() and ('cuda' in train_options['device']) else 'cpu')
+    set_random_seed(cfg['SEED'], 'cuda' in train_options['device'], deterministic=True)
+
+    # ===> build network
+    net = build_model(cfg['model'])
+
+    # ===> 模型初始化及模型部署到对应的设备
+    if not cfg['model']['backbone']['pretrained']:  # 使用 pretrained
+        net.apply(weight_init)
+    # if torch.cuda.device_count() > 1:
+    net = nn.DataParallel(net)
+    net = net.to(to_use_device)
+    net.train()
+
+    # ===> get fine tune layers
+    params_to_train = get_fine_tune_params(net, train_options['fine_tune_stage'])
+    # ===> solver and lr scheduler
+    optimizer = build_optimizer(params_to_train, cfg['optimizer'])
+    scheduler = build_scheduler(optimizer, cfg['lr_scheduler'])
+
+    # ===> whether to resume from checkpoint
+    resume_from = train_options['resume_from']
+    if resume_from:
+        net, _resumed_optimizer, global_state = load_checkpoint(net, resume_from, to_use_device, optimizer,
+                                                                third_name=train_options['third_party_name'])
+        if _resumed_optimizer:
+            optimizer = _resumed_optimizer
+        logger.info(f'net resume from {resume_from}')
+    else:
+        global_state = {}
+        logger.info(f'net resume from scratch.')
+
+    # ===> loss function
+    loss_func = build_loss(cfg['loss'])
+    loss_func = loss_func.to(to_use_device)
+
+    # ===> data loader
+    # cfg.dataset.train.dataset.alphabet = cfg.dataset.alphabet
+    # train_loader = build_dataloader(cfg.dataset.train)
+    cfg.dataset.eval.dataset.alphabet = cfg.dataset.alphabet
+    eval_loader = build_dataloader(cfg.dataset.eval)
+
+    from torchocr.metrics import RecMetric
+    from torchocr.utils import CTCLabelConverter
+    _converter = CTCLabelConverter(cfg.dataset.alphabet)
+    _metric = RecMetric(_converter)
+    _converter = CTCLabelConverter(cfg.dataset.alphabet)
+    # start = time.time()
+    eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, _converter, _metric)
+    print(eval_dict)
+    # print("cost_time:",time.time()-start)
+
+
+    # ===> train
+    # train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger)
+
+
+if __name__ == '__main__':
+    main()

+ 393 - 0
tools/minimum_2stage_inference.py

@@ -0,0 +1,393 @@
+import math
+import operator
+import os
+from functools import reduce
+
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from tqdm import tqdm
+from torchvision import transforms
+import cv2
+import numpy as np
+import scipy
+from torchocr.networks.architectures.RecModel import *
+from torchocr.networks.architectures.DetModel import *
+from torchocr.utils import CTCLabelConverter
+
+default_font_for_annotate = ImageFont.truetype('./田氏颜体大字库2.0.ttf', 20)
+
+
+def get_data(_to_eval_directory, _to_eval_file, _transform):
+    """
+    将一个文件夹中所有的图片或特定图片转换为特定网络用tensor
+    :param _to_eval_directory:  图像所在文件夹
+    :param _to_eval_file:   需要进行评估的文件
+    :param _transform:  eval需要用到的transform
+    :return:    每张图片的tensor
+    """
+    available_extensions = {'.png', '.jpg', '.jpeg', '.bmp'}
+    # 找到有效图像
+    all_to_eval = []
+    if _to_eval_file is not None:
+        assert os.path.splitext(_to_eval_file)[1].lower() in available_extensions, f'{_to_eval_file} 格式不支持'
+        target_file_path = os.path.join(_to_eval_directory, _to_eval_file)
+        assert os.path.exists(target_file_path), f'{target_file_path} 文件不存在'
+        all_to_eval.append(target_file_path)
+    else:
+        assert os.path.exists(_to_eval_directory) and os.path.isdir(_to_eval_directory), f'{_to_eval_directory} 文件夹无效'
+        for m_file in os.listdir(_to_eval_directory):
+            if os.path.splitext(m_file)[1].lower() in available_extensions:
+                all_to_eval.append(os.path.join(_to_eval_directory, m_file))
+    for m_file in all_to_eval:
+        m_pil_img = Image.open(m_file)
+        yield m_file, m_pil_img, _transform().unsqueeze(0)
+
+
+def plot_detect_result_on_img(_img, _polygons):
+    """
+    在图中将检测和识别的结果画出来
+    :param _img:    对应图片
+    :param _polygons:  文本所在的多边形
+    :return:    将检测结果绘画到图中
+    """
+    to_return_img = _img.copy()
+    to_draw = ImageDraw.Draw(to_return_img)
+    for m_polygon_index, m_polygon in enumerate(_polygons):
+        to_draw.polygon(m_polygon, outline="red")
+        to_draw.text(m_polygon[0], f'{m_polygon_index}', font=default_font_for_annotate, fill='blue')
+    return to_return_img
+
+
+def mask2polygon(_mask):
+    """
+    将mask区域转换为多边形区域(凸包)
+    :param _mask:   当前检测之后的处理后的结果(h*w,每个元素位置为0和1)
+    :return:    当前mask所有的能够提取的多边形
+    """
+    assert len(_mask.shape) == 2
+    to_return_polygons = []
+    contours, _ = cv2.findContours((_mask * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+    for contour in contours:
+        epsilon = 0.005 * cv2.arcLength(contour, True)
+        approx = cv2.approxPolyDP(contour, epsilon, True)
+        points = approx.reshape((-1, 2))
+        if points.shape[0] < 4:
+            continue
+        to_return_polygons.append(points)
+    return to_return_polygons
+
+
+def clockwise_points(_point_coords):
+    """
+    以左上角为起点的顺时针排序
+    原理就是将笛卡尔坐标转换为极坐标,然后对极坐标的φ进行排序
+    :param _point_coords:    待排序的点[(x,y),]
+    :return:    排序完成的点
+    """
+    center_point = tuple(
+        map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), _point_coords), [len(_point_coords)] * 2))
+    return sorted(_point_coords, key=lambda coord: (135 - math.degrees(
+        math.atan2(*tuple(map(operator.sub, coord, center_point))[::-1]))) % 360)
+
+
+def assign_points_to_four_edges(_point_coords):
+    """
+    将所有的点抽象到四条虚拟边上面(所有都是基于凸包的基础上)
+    简单思想就是有了第一个点,然后找到这个点最远点,则得到第三个点
+    然后第一个点和第三个点之间的两个鞍点分别为第二个点和第四个点
+    :param _point_coords:   所有关键点
+    :return:    将四条边所关联的点组成四个list并返回
+    """
+    to_return_edges = []
+    max_distance_index = 0
+    max_distance = -1
+    points_np = np.array(_point_coords)
+
+    for i in range(1, len(_point_coords)):
+        distance = np.linalg.norm(points_np[0] - points_np[i])
+        if distance > max_distance:
+            max_distance_index = i
+            max_distance = distance
+    angles = []
+    for i in range(1, max_distance_index):
+        # 计算每个两边之间的夹角
+        m_angle_1 = math.atan2(points_np[i][1] - points_np[0][1],
+                               points_np[i][0] - points_np[0][0])
+        m_angle_2 = math.atan2(points_np[i][1] - points_np[max_distance_index][1],
+                               points_np[i][0] - points_np[max_distance_index][0])
+        angles.append(abs(m_angle_1 - m_angle_2) % 180)
+    second_point_index = np.argmin(angles)
+    angles.clear()
+    for i in range(max_distance_index, len(_point_coords) - 1):
+        m_angle_1 = math.atan2(points_np[i][1] - points_np[0][1],
+                               points_np[i][0] - points_np[0][0])
+        m_angle_2 = math.atan2(points_np[i][1] - points_np[max_distance_index][1],
+                               points_np[i][0] - points_np[max_distance_index][0])
+        angles.append(abs(m_angle_1 - m_angle_2) % 180)
+    fourth_point_index = np.argmin(angles)
+    to_return_edges.append(_point_coords[:second_point_index])
+    to_return_edges.append(_point_coords[second_point_index:max_distance_index])
+    to_return_edges.append(_point_coords[max_distance_index:fourth_point_index])
+    to_return_edges.append(_point_coords[fourth_point_index:])
+    return to_return_edges
+
+
+def polygon_rectify_to_rectangle(_img, _polygon):
+    """
+    将扇形或环形的多边形区域转换为矩形区域,方便识别模型
+    :param _img:    当前图像
+    :param _polygon:    多边形区域
+    :return:    矩形区域
+    """
+    pass
+
+
+def extract_rectangle_with_correct_aspect_ratio(_img, _four_corner_points):
+    """
+    从图中抠出特定区域的四边形,并按照正确的长宽比进行变换
+    :param _img:    整张图
+    :param _four_corner_points:     矩形的四个点,必须是从左上角开始的顺时针
+    :return:    矫正后的进行透视变换的区域
+    """
+    rows, cols = _img.shape[:2]
+    u0 = cols / 2.0
+    v0 = rows / 2.0
+    p = [_four_corner_points[0], _four_corner_points[1], _four_corner_points[3], _four_corner_points[2]]
+    # widths and heights of the projected image
+    w1 = scipy.spatial.distance.euclidean(p[0], p[1])
+    w2 = scipy.spatial.distance.euclidean(p[2], p[3])
+
+    h1 = scipy.spatial.distance.euclidean(p[0], p[2])
+    h2 = scipy.spatial.distance.euclidean(p[1], p[3])
+
+    w = max(w1, w2)
+    h = max(h1, h2)
+
+    # visible aspect ratio
+    ar_vis = float(w) / float(h)
+
+    # make numpy arrays and append 1 for linear algebra
+    m1 = np.array((p[0][0], p[0][1], 1)).astype('float32')
+    m2 = np.array((p[1][0], p[1][1], 1)).astype('float32')
+    m3 = np.array((p[2][0], p[2][1], 1)).astype('float32')
+    m4 = np.array((p[3][0], p[3][1], 1)).astype('float32')
+
+    # calculate the focal distance
+    k2 = np.dot(np.cross(m1, m4), m3) / np.dot(np.cross(m2, m4), m3)
+    k3 = np.dot(np.cross(m1, m4), m2) / np.dot(np.cross(m3, m4), m2)
+
+    n2 = k2 * m2 - m1
+    n3 = k3 * m3 - m1
+
+    n21 = n2[0]
+    n22 = n2[1]
+    n23 = n2[2]
+
+    n31 = n3[0]
+    n32 = n3[1]
+    n33 = n3[2]
+
+    f = math.sqrt(np.abs((1.0 / (n23 * n33)) * ((n21 * n31 - (n21 * n33 + n23 * n31) * u0 + n23 * n33 * u0 * u0) + (
+            n22 * n32 - (n22 * n33 + n23 * n32) * v0 + n23 * n33 * v0 * v0))))
+
+    A = np.array([[f, 0, u0], [0, f, v0], [0, 0, 1]]).astype('float32')
+
+    At = np.transpose(A)
+    Ati = np.linalg.inv(At)
+    Ai = np.linalg.inv(A)
+
+    # calculate the real aspect ratio
+    ar_real = math.sqrt(np.dot(np.dot(np.dot(n2, Ati), Ai), n2) / np.dot(np.dot(np.dot(n3, Ati), Ai), n3))
+
+    if ar_real < ar_vis:
+        W = int(w)
+        H = int(W / ar_real)
+    else:
+        H = int(h)
+        W = int(ar_real * H)
+
+    pts1 = np.array(p).astype('float32')
+    pts2 = np.float32([[0, 0], [W, 0], [0, H], [W, H]])
+
+    # project the image with the new w/h
+    M = cv2.getPerspectiveTransform(pts1, pts2)
+
+    dst = cv2.warpPerspective(_img, M, (W, H))
+    return dst
+
+
+def polygon_to_rectangle_basic(_img, _polygon, _aspect_ratio_correct=True):
+    """
+    将多边形的最小面积四边形进行透视变换得到矩形
+    适用于绝大部分场景
+    :param  _img:   当前图像
+    :param _polygon:    多边形区域
+    :param _aspect_ratio_correct:   是否进行长宽比矫正,默认矫正
+    :return:    对应的矩形区域(ndarray)
+    """
+    sorted_points = clockwise_points(_polygon)
+    fours_edges = assign_points_to_four_edges(sorted_points)
+    arc_lengths = [cv2.arcLength(m_edge, False) for m_edge in fours_edges]
+    first_edge = (arc_lengths[0] + arc_lengths[2]) // 2
+    second_edge = (arc_lengths[1] + arc_lengths[3]) // 2
+    img_np = np.array(_img)
+    if not _aspect_ratio_correct:
+        M = cv2.getPerspectiveTransform([m_edge[0] for m_edge in fours_edges],
+                                        [[0, 0], [first_edge, 0], [first_edge, second_edge], [0, second_edge]])
+        warped_roi = cv2.warpPerspective(img_np, M, (first_edge, second_edge))
+    else:
+        warped_roi = extract_rectangle_with_correct_aspect_ratio(img_np, [m_edge[0] for m_edge in fours_edges])
+    if first_edge < second_edge:
+        return cv2.rotate(warped_roi, cv2.ROTATE_90_COUNTERCLOCKWISE)
+    else:
+        return warped_roi
+
+
+def detect_result_post_process(_detect_result, _detector_model_type):
+    """
+    对检测的结果根据检测算法类型进行后处理
+    :param _detect_result:  检测结果
+    :param _detector_model_type:    检测模型的类型
+    :return:    后处理后的结果
+    """
+    to_return_polygons = []
+    if _detector_model_type in {'pse', 'pan'}:
+        # 对label map进行多边形抽取以及聚合
+        pass
+    elif _detector_model_type in {'db'}:
+        # 对shrink map进行多边形抽取以及聚合
+        pass
+    elif _detector_model_type in {'centernet', 'fcos', 'east', 'advancedeast'}:
+        # 对于回归出来的直接是矩形框的,主要是做nms以及聚合
+        pass
+    elif _detector_model_type in {'ctpn', 'yolo_ctpn'}:
+        # 对ctpn类型的回归框进行连接得到多边形区域,并进行多边形聚合
+        pass
+    return to_return_polygons
+
+
+def extract_tensor_for_recognize(_img, _polygons, _transform):
+    """
+    将检测得到的区域提取出来,处理后转换为tensor用于识别模型
+    :param _img:    原始图像
+    :param _polygons:   所有检测到的文本的多边形区域
+    :param _transform:  需要对图像区域进行变换的部分
+    :return:    由于文本区域长度不同,所以每次只yield一条数据
+    """
+    for m_polygon in _polygons:
+        # 未来需要对多边形进行判断,不同类型的多边形使用不同的方案进行tensor的转换
+        yield _transform(polygon_to_rectangle_basic(_img, m_polygon)).unsqueeze(0)
+
+
+def recognition_result_post_process(_recognition_result, _recognizer_model_type, _str_label_converter):
+    """
+    对识别的结果进行后处理,包括ctc decoder以及correct以及词典的矫正
+    :param _recognition_result: 识别的结果
+    :param _recognizer_model_type:  识别网路的类型
+    :param _str_label_converter: 所有识别出来对应的字符标签
+    :return:    进行后处理后的识别结果
+    """
+    if 'crnn' in _recognizer_model_type:
+        return _str_label_converter.decode(_recognition_result)
+    else:
+        pass
+    pass
+
+
+def related_polygon_assembly(_polygons, _recognition_result):
+    """
+    将关联的多边形根据识别文本结果进行融合
+    :param _polygons:   每个多边形
+    :param _recognition_result: 每个多边形检测的结果
+    :return:    融合后的多边形与对应的文本识别结果
+    """
+    return _polygons, _recognition_result
+
+
+if __name__ == '__main__':
+    # 配置参数
+    device_name = 'cpu'
+    eval_dataset_directory = './data'
+    eval_file = None
+    target_size = (1024, 1024)
+    eval_stds = [0.229, 0.224, 0.225]
+    eval_means = [0.485, 0.456, 0.406]
+
+
+    def _resize_img_for_detect(_img, _longer_edge_length, _shorter_edge_base_length):
+        """
+        将图像按照最长边进行等比缩小且最短边能够被特定基数整除
+        :param _img:    需要进行resize的图像
+        :param _longer_edge_length:     最长边长度
+        :param _shorter_edge_base_length:   短边的基数
+        :return:    resize后的图像
+        """
+        h, w = _img.size
+        if h > w:
+            new_h = _longer_edge_length
+            new_w = (new_h / h * w) // _shorter_edge_base_length * _shorter_edge_base_length
+        else:
+            new_w = _longer_edge_length
+            new_h = (new_w / w * h) // _shorter_edge_base_length * _shorter_edge_base_length
+        return _img.resize((new_w, new_h))
+
+
+    def _resize_img_for_recognize(_img, _height=32):
+        h, w = _img.size
+        ratio = h / _height
+        return _img.resize((w // ratio, _height))
+
+
+    eval_detect_transformer = transforms.Compose([
+        lambda x: _resize_img_for_detect(x, 2240, 32),
+        transforms.Normalize(std=eval_stds, mean=eval_means),
+        transforms.ToTensor(),
+    ])
+    eval_recognize_transformer = transforms.Compose([
+        transforms.ToPILImage(),
+        lambda x: _resize_img_for_recognize(x, 32),
+        transforms.Normalize(std=[1, 1, 1], mean=[0.5, 0.5, 0.5]),
+        transforms.ToTensor(),
+    ])
+    detector_model_type = ''
+    detector_config = AttrDict()
+    recognizer_model_type = ''
+    recognizer_config = AttrDict()
+    detector_pretrained_model_file = ''
+    recognizer_pretrained_model_file = ''
+    annotate_on_image = True
+    need_rectify_on_single_character = True
+    labels = ''.join([f'{i}' for i in range(10)] + [chr(97 + i) for i in range(26)])
+    # 模型推断
+    label_converter = CTCLabelConverter(labels)
+    device = torch.device(device_name)
+    detector = DetModel(detector_config).to(device)
+    detector.load_state_dict(torch.load(detector_pretrained_model_file, map_location='cpu'))
+    recognizer = RecModel(recognizer_config).to(device)
+    recognizer.load_state_dict(torch.load(recognizer_pretrained_model_file, map_location='cpu'))
+    detector.eval()
+    recognizer.eval()
+    with torch.no_grad():
+        for m_path, m_pil_img, m_eval_tensor in tqdm(
+                get_data(eval_dataset_directory, eval_file, eval_detect_transformer)
+        ):
+            m_eval_tensor = m_eval_tensor.to(device)
+            # 获得检测需要的相关信息
+            m_detect_result = detector(m_eval_tensor)
+            # 根据网络类型,处理检测的相关信息,最后转换为一堆多边形
+            m_polygons = detect_result_post_process(m_detect_result, detector_model_type)
+            m_recognized_results = []
+            # 提取所有的文本区域
+            for m_region_tensor in extract_tensor_for_recognize(m_pil_img, m_polygons, eval_recognize_transformer):
+                refined_recognized_result = \
+                    recognition_result_post_process(recognizer(m_region_tensor), recognizer_model_type, label_converter)
+                m_recognized_results.append(refined_recognized_result)
+            m_final_polygons, m_final_text = related_polygon_assembly(m_polygons, m_recognized_results)
+            if annotate_on_image:
+                annotated_img = plot_detect_result_on_img(m_pil_img, m_final_polygons)
+                m_base_name, m_ext = os.path.splitext(m_path)
+                annotated_img.save(f'{m_base_name}_result{m_ext}')
+                with open(f'{m_base_name}_result.txt', mode='w', encoding='utf-8') as to_write:
+                    to_write.write('index,text\n')
+                    for m_index, m_text in enumerate(m_final_text):
+                        to_write.write(f'{m_index},{m_text}\n')

+ 113 - 0
tools/ocr_infer.py

@@ -0,0 +1,113 @@
+from det_infer import DetInfer
+from rec_infer import RecInfer
+import argparse
+from line_profiler import LineProfiler
+from memory_profiler import profile
+from torchocr.utils.vis import draw_ocr_box_txt
+import numpy as np
+
+def get_rotate_crop_image(img, points):
+    '''
+    img_height, img_width = img.shape[0:2]
+    left = int(np.min(points[:, 0]))
+    right = int(np.max(points[:, 0]))
+    top = int(np.min(points[:, 1]))
+    bottom = int(np.max(points[:, 1]))
+    img_crop = img[top:bottom, left:right, :].copy()
+    points[:, 0] = points[:, 0] - left
+    points[:, 1] = points[:, 1] - top
+    '''
+    points = points.astype(np.float32)
+    img_crop_width = int(
+        max(
+            np.linalg.norm(points[0] - points[1]),
+            np.linalg.norm(points[2] - points[3])))
+    img_crop_height = int(
+        max(
+            np.linalg.norm(points[0] - points[3]),
+            np.linalg.norm(points[1] - points[2])))
+    pts_std = np.float32([[0, 0], [img_crop_width, 0],
+                          [img_crop_width, img_crop_height],
+                          [0, img_crop_height]])
+    M = cv2.getPerspectiveTransform(points, pts_std)
+    dst_img = cv2.warpPerspective(
+        img,
+        M, (img_crop_width, img_crop_height),
+        borderMode=cv2.BORDER_REPLICATE,
+        flags=cv2.INTER_CUBIC)
+    dst_img_height, dst_img_width = dst_img.shape[0:2]
+    if dst_img_height * 1.0 / dst_img_width >= 1.5:
+        dst_img = np.rot90(dst_img)
+    return dst_img
+
+
+class OCRInfer(object):
+    def __init__(self, det_path, rec_path, rec_batch_size=16, time_profile=False, mem_profile=False ,**kwargs):
+        super().__init__()
+        self.det_model = DetInfer(det_path)
+        self.rec_model = RecInfer(rec_path, rec_batch_size)
+        assert not(time_profile and mem_profile),"can not profile memory and time at the same time"
+        self.line_profiler = None
+        if time_profile:
+            self.line_profiler = LineProfiler()
+            self.predict = self.predict_time_profile
+        if mem_profile:
+            self.predict = self.predict_mem_profile
+
+    def do_predict(self, img):
+        box_list, score_list = self.det_model.predict(img)
+        if len(box_list) == 0:
+            return [], [], img
+        draw_box_list = [tuple(map(tuple, box)) for box in box_list]
+        imgs =[get_rotate_crop_image(img, box) for box in box_list]
+        texts = self.rec_model.predict(imgs)
+        texts = [txt[0][0] for txt in texts]
+        debug_img = draw_ocr_box_txt(img, draw_box_list, texts)
+        return box_list, score_list, debug_img
+
+    def predict(self, img):
+        return self.do_predict(img)
+
+    def predict_mem_profile(self, img):
+        wapper = profile(self.do_predict)
+        return wapper(img)
+
+    def predict_time_profile(self, img):
+        # run multi time
+        for i in range(8):
+            print("*********** {} profile time *************".format(i))
+            lp = LineProfiler()
+            lp_wrapper = lp(self.do_predict)
+            ret = lp_wrapper(img)
+            lp.print_stats()
+        return ret
+
+
+def init_args():
+    import argparse
+    parser = argparse.ArgumentParser(description='OCR infer')
+    parser.add_argument('--det_path', required=True, type=str, help='det model path')
+    parser.add_argument('--rec_path', required=True, type=str, help='rec model path')
+    parser.add_argument('--img_path', required=True, type=str, help='img path for predict')
+    parser.add_argument('--rec_batch_size', type=int, help='rec batch_size', default=16)
+    parser.add_argument('-time_profile', action='store_true', help='enable time profile mode')
+    parser.add_argument('-mem_profile', action='store_true', help='enable memory profile mode')
+    args = parser.parse_args()
+    return vars(args)
+
+
+if __name__ == '__main__':
+    import cv2
+    args = init_args()
+    img = cv2.imread(args['img_path'])
+    model = OCRInfer(**args)
+    txts, boxes, debug_img = model.predict(img)
+    h,w,_, = debug_img.shape
+    raido = 1
+    if w > 1200:
+        raido = 600.0/w
+    debug_img = cv2.resize(debug_img, (int(w*raido), int(h*raido)))
+    if not(args['mem_profile'] or args['time_profile']):
+        cv2.imshow("debug", debug_img)
+        cv2.waitKey()
+

+ 349 - 0
tools/rec_fineturn.py

@@ -0,0 +1,349 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/19 21:44
+# @Author  : xiangjing
+
+import os
+import sys
+import pathlib
+
+# 将 torchocr路径加到python路径里
+__dir__ = pathlib.Path(os.path.abspath(__file__))
+sys.path.append(str(__dir__))
+sys.path.append(str(__dir__.parent.parent))
+import random
+import time
+import shutil
+import traceback
+from importlib import import_module
+
+import numpy as np
+import torch
+from tqdm import tqdm
+from torch import nn
+from torch import optim
+from torchocr.networks import build_model, build_loss
+from torchocr.datasets import build_dataloader
+from torchocr.utils import get_logger, weight_init, load_checkpoint, save_checkpoint
+
+
+def parse_args():
+    import argparse
+    parser = argparse.ArgumentParser(description='train')
+    parser.add_argument('--config', type=str, default='/data2/znj/PytorchOCR/config/cfg_rec_crnn_doc_fineturn.py',
+                        help='train config file path')
+    args = parser.parse_args()
+    # 解析.py文件
+    config_path = os.path.abspath(os.path.expanduser(args.config))
+    assert os.path.isfile(config_path)
+    if config_path.endswith('.py'):
+        module_name = os.path.basename(config_path)[:-3]
+        config_dir = os.path.dirname(config_path)
+        sys.path.insert(0, config_dir)
+        mod = import_module(module_name)
+        sys.path.pop(0)
+        return mod.config
+        # cfg_dict = {
+        #     name: value
+        #     for name, value in mod.__dict__.items()
+        #     if not name.startswith('__')
+        # }
+        # return cfg_dict
+    else:
+        raise IOError('Only py type are supported now!')
+
+
+def set_random_seed(seed, use_cuda=True, deterministic=False):
+    """Set random seed.
+
+    Args:
+        seed (int): Seed to be used.
+        use_cuda: whether depend on cuda
+        deterministic (bool): Whether to set the deterministic option for
+            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+            to True and `torch.backends.cudnn.benchmark` to False.
+            Default: False.
+    """
+    random.seed(seed)
+    np.random.seed(seed)
+    if use_cuda:
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        if deterministic:
+            torch.backends.cudnn.deterministic = True
+            torch.backends.cudnn.benchmark = False
+
+
+def build_optimizer(params, config):
+    """
+    优化器
+    Returns:
+    """
+
+    opt_type = config.pop('type')
+    opt = getattr(optim, opt_type)(params, **config)
+    return opt
+
+
+def build_scheduler(optimizer, config):
+    """
+    """
+    scheduler = None
+    sch_type = config.pop('type')
+    if sch_type == 'LambdaLR':
+        burn_in, steps = config['burn_in'], config['steps']
+
+        # Learning rate setup
+        def burnin_schedule(i):
+            if i < burn_in:
+                factor = pow(i / burn_in, 4)
+            elif i < steps[0]:
+                factor = 1.0
+            elif i < steps[1]:
+                factor = 0.1
+            else:
+                factor = 0.01
+            return factor
+
+        scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)
+    elif sch_type == 'StepLR':
+        # 等间隔调整学习率, 调整倍数为gamma倍,调整间隔为step_size,间隔单位是step,step通常是指epoch。
+        step_size, gamma = config['step_size'], config['gamma']
+        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
+    elif sch_type == 'ReduceLROnPlateau':
+        # 当某指标不再变化(下降或升高),调整学习率,这是非常实用的学习率调整策略。例如,当验证集的loss不再下降时,进行学习率调整;或者监测验证集的accuracy,当accuracy不再上升时,则调整学习率。
+        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1,
+                                                               patience=3, verbose=True, threshold=1e-4)
+    return scheduler
+
+
+def get_fine_tune_params(net, finetune_stage):
+    """
+    获取需要优化的参数
+    Args:
+        net:
+    Returns: 需要优化的参数
+    """
+    to_return_parameters = []
+    for stage in finetune_stage:
+        attr = getattr(net.module, stage, None)
+        for element in attr.parameters():
+            to_return_parameters.append(element)
+    return to_return_parameters
+
+
+def evaluate(net, val_loader, loss_func, to_use_device, logger, converter, metric):
+    """
+    在验证集上评估模型
+
+    :param net: 网络
+    :param val_loader: 验证集 dataloader
+    :param loss_func: 损失函数
+    :param to_use_device: device
+    :param logger: logger类对象
+    :param converter: label转换器类对象
+    :param metric: 根据网络输出和 label 计算 acc 等指标的类对象
+    :return:  一个包含 eval_loss,eval_acc和 norm_edit_dis 的 dict,
+        例子: {
+                'eval_loss':0,
+                'eval_acc': 0.99,
+                'norm_edit_dis': 0.9999,
+                }
+    """
+    logger.info('start evaluate')
+    net.eval()
+    nums = 0
+    result_dict = {'eval_loss': 0., 'eval_acc': 0., 'norm_edit_dis': 0.}
+    show_str = []
+    with torch.no_grad():
+        for batch_data in tqdm(val_loader):
+            targets, targets_lengths = converter.encode(batch_data['label'])
+            batch_data['targets'] = targets
+            batch_data['targets_lengths'] = targets_lengths
+            output = net.forward(batch_data['img'].to(to_use_device))
+            loss = loss_func(output, batch_data)
+
+            nums += batch_data['img'].shape[0]
+            acc_dict = metric(output[1], batch_data['label'])
+            result_dict['eval_loss'] += loss['loss'].item()
+            result_dict['eval_acc'] += acc_dict['n_correct']
+            result_dict['norm_edit_dis'] += acc_dict['norm_edit_dis']
+            show_str.extend(acc_dict['show_str'])
+
+    print('nums:', nums, 'right_nums:', result_dict['eval_acc'])
+    result_dict['eval_loss'] /= len(val_loader)
+    result_dict['eval_acc'] /= nums
+    result_dict['norm_edit_dis'] = 1 - result_dict['norm_edit_dis'] / nums
+    logger.info(f"eval_loss:{result_dict['eval_loss']}")
+    logger.info(f"eval_acc:{result_dict['eval_acc']}")
+    logger.info(f"norm_edit_dis:{result_dict['norm_edit_dis']}")
+
+    for s in show_str[:10]:
+        logger.info(s)
+    net.train()
+    return result_dict
+
+
+def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device,
+          cfg, global_state, logger):
+    """
+    训练函数
+
+    :param net: 网络
+    :param optimizer: 优化器
+    :param scheduler: 学习率更新器
+    :param loss_func: loss函数
+    :param train_loader: 训练数据集 dataloader
+    :param eval_loader: 验证数据集 dataloader
+    :param to_use_device: device
+    :param cfg: 当前训练所使用的配置
+    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
+    :param logger: logger 对象
+    :return: None
+    """
+
+    from torchocr.metrics import RecMetric
+    from torchocr.utils import CTCLabelConverter
+    converter = CTCLabelConverter(cfg.dataset.alphabet)
+    train_options = cfg.train_options
+    metric = RecMetric(converter)
+    # ===>
+    logger.info('Training...')
+    # ===> print loss信息的参数
+    all_step = len(train_loader)
+    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
+    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
+    if len(global_state) > 0:
+        best_model = global_state['best_model']
+        start_epoch = global_state['start_epoch']
+        global_step = global_state['global_step']
+    else:
+        best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
+        start_epoch = 0
+        global_step = 0
+    # 开始训练
+    try:
+        for epoch in range(start_epoch, train_options['epochs']):  # traverse each epoch
+            net.train()  # train mode
+            start = time.time()
+            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
+                current_lr = optimizer.param_groups[0]['lr']
+                cur_batch_size = batch_data['img'].shape[0]
+                targets, targets_lengths = converter.encode(batch_data['label'])
+                batch_data['targets'] = targets
+                batch_data['targets_lengths'] = targets_lengths
+                # 清零梯度及反向传播
+                optimizer.zero_grad()
+                output = net.forward(batch_data['img'].to(to_use_device))
+                loss_dict = loss_func(output, batch_data)
+                loss_dict['loss'].backward()
+                torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
+                optimizer.step()
+                # statistic loss for print
+                acc_dict = metric(output[1], batch_data['label'])
+                acc = acc_dict['n_correct'] / cur_batch_size
+                norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
+                if (i + 1) % train_options['print_interval'] == 0:
+                    interval_batch_time = time.time() - start
+                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
+                                f"[{i + 1}/{all_step}] - "
+                                f"lr:{current_lr} - "
+                                f"loss:{loss_dict['loss'].item():.4f} - "
+                                f"acc:{acc:.4f} - "
+                                f"norm_edit_dis:{norm_edit_dis:.4f} - "
+                                f"time:{interval_batch_time:.4f}")
+                    start = time.time()
+                if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
+                    global_state['start_epoch'] = epoch
+                    global_state['best_model'] = best_model
+                    global_state['global_step'] = global_step
+                    net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
+                    save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                    if train_options['ckpt_save_type'] == 'HighestAcc':
+                        # val
+                        eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
+                        if eval_dict['eval_acc'] > best_model['eval_acc']:
+                            best_model.update(eval_dict)
+                            best_model['best_model_epoch'] = epoch
+                            best_model['models'] = net_save_path
+
+                            global_state['start_epoch'] = epoch
+                            global_state['best_model'] = best_model
+                            global_state['global_step'] = global_step
+                            net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
+                            save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                    elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options[
+                        'ckpt_save_epoch'] == 0:
+                        shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
+                global_step += 1
+            scheduler.step()
+    except KeyboardInterrupt:
+        import os
+        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg,
+                        global_state=global_state)
+    except:
+        error_msg = traceback.format_exc()
+        logger.error(error_msg)
+    finally:
+        for k, v in best_model.items():
+            logger.info(f'{k}: {v}')
+
+
+def main():
+    # ===> 获取配置文件参数
+    cfg = parse_args()
+    os.makedirs(cfg.train_options['checkpoint_save_dir'], exist_ok=True)
+    logger = get_logger('torchocr', log_file=os.path.join(cfg.train_options['checkpoint_save_dir'], 'train.log'))
+
+    # ===> 训练信息的打印
+    train_options = cfg.train_options
+    logger.info(cfg)
+    # ===>
+    to_use_device = torch.device(
+        train_options['device'] if torch.cuda.is_available() and ('cuda' in train_options['device']) else 'cpu')
+    set_random_seed(cfg['SEED'], 'cuda' in train_options['device'], deterministic=True)
+
+    # ===> build network
+    net = build_model(cfg['model'])
+
+    # ===> 模型初始化及模型部署到对应的设备
+    if not cfg['model']['backbone']['pretrained']:  # 使用 pretrained
+        net.apply(weight_init)
+    # if torch.cuda.device_count() > 1:
+    net = nn.DataParallel(net)
+    net = net.to(to_use_device)
+    net.train()
+
+    # ===> get fine tune layers
+    params_to_train = get_fine_tune_params(net, train_options['fine_tune_stage'])
+    # ===> solver and lr scheduler
+    optimizer = build_optimizer(params_to_train, cfg['optimizer'])
+    scheduler = build_scheduler(optimizer, cfg['lr_scheduler'])
+
+    # ===> whether to resume from checkpoint
+    resume_from = train_options['resume_from']
+    if resume_from:
+        net, _resumed_optimizer, global_state = load_checkpoint(net, resume_from, to_use_device, optimizer,
+                                                                third_name=train_options['third_party_name'])
+        if _resumed_optimizer:
+            optimizer = _resumed_optimizer
+        logger.info(f'net resume from {resume_from}')
+    else:
+        global_state = {}
+        logger.info(f'net resume from scratch.')
+
+    # ===> loss function
+    loss_func = build_loss(cfg['loss'])
+    loss_func = loss_func.to(to_use_device)
+
+    # ===> data loader
+    cfg.dataset.train.dataset.alphabet = cfg.dataset.alphabet
+    train_loader = build_dataloader(cfg.dataset.train)
+    cfg.dataset.eval.dataset.alphabet = cfg.dataset.alphabet
+    eval_loader = build_dataloader(cfg.dataset.eval)
+
+    # ===> train
+    train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger)
+
+
+if __name__ == '__main__':
+    main()

+ 92 - 0
tools/rec_infer.py

@@ -0,0 +1,92 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/6/16 10:57
+# @Author  : zhoujun
+import os
+import sys
+import pathlib
+
+# 将 torchocr路径加到python陆经里
+__dir__ = pathlib.Path(os.path.abspath(__file__))
+
+import numpy as np
+
+sys.path.append(str(__dir__))
+sys.path.append(str(__dir__.parent.parent))
+
+import torch
+from torch import nn
+from torchocr.networks import build_model
+from torchocr.datasets.RecDataSet import RecDataProcess
+from torchocr.utils import CTCLabelConverter
+
+
+class RecInfer:
+    def __init__(self, model_path, batch_size=16):
+        ckpt = torch.load(model_path, map_location='cpu')
+        cfg = ckpt['cfg']
+        self.model = build_model(cfg['model'])
+        state_dict = {}
+        for k, v in ckpt['state_dict'].items():
+            state_dict[k.replace('module.', '')] = v
+        self.model.load_state_dict(state_dict)
+
+        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+        self.model.to(self.device)
+        self.model.eval()
+
+        self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
+        # self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])
+        self.converter = CTCLabelConverter("C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\char_std_7782.txt")
+        self.batch_size = batch_size
+
+    def predict(self, imgs):
+        # 预处理根据训练来
+        if not isinstance(imgs,list):
+            imgs = [imgs]
+        imgs = [self.process.normalize_img(self.process.resize_with_specific_height(img)) for img in imgs]
+        widths = np.array([img.shape[1] for img in imgs])
+        idxs = np.argsort(widths)
+        txts = []
+        for idx in range(0, len(imgs), self.batch_size):
+            batch_idxs = idxs[idx:min(len(imgs), idx+self.batch_size)]
+            batch_imgs = [self.process.width_pad_img(imgs[idx], imgs[batch_idxs[-1]].shape[1]) for idx in batch_idxs]
+            batch_imgs = np.stack(batch_imgs)
+            tensor = torch.from_numpy(batch_imgs.transpose([0,3, 1, 2])).float()
+            tensor = tensor.to(self.device)
+            with torch.no_grad():
+                out = self.model(tensor)
+                # print(out)
+                # out[1] 为最后输出
+                out = out
+                out = out.softmax(dim=2)
+            out = out.cpu().numpy()
+            txts.extend([self.converter.decode(np.expand_dims(txt, 0)) for txt in out])
+        #按输入图像的顺序排序
+        idxs = np.argsort(idxs)
+        out_txts = [txts[idx] for idx in idxs]
+        return out_txts
+
+
+def init_args():
+    import argparse
+    parser = argparse.ArgumentParser(description='PytorchOCR infer')
+    # parser.add_argument('--model_path', required=True, type=str, help='rec model path')
+    parser.add_argument('--model_path', required=False, type=str,
+        default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\best2.pth", help='rec model path')
+    # parser.add_argument('--img_path', required=True, type=str, help='img path for predict')
+    parser.add_argument('--img_path', required=False, type=str,
+        default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\test_image/Snipaste_2023-09-08_10-45-17.jpg", help='img path for predict')
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == '__main__':
+    import cv2
+
+    args = init_args()
+    img = cv2.imread(args.img_path)
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+    model = RecInfer(args.model_path)
+    out = model.predict(img)
+    print(out)

+ 141 - 0
tools/rec_infer_att_test.py

@@ -0,0 +1,141 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/6/16 10:57
+# @Author  : zhoujun
+import os
+import sys
+import pathlib
+
+# 将 torchocr路径加到python陆经里
+__dir__ = pathlib.Path(os.path.abspath(__file__))
+
+import numpy as np
+
+sys.path.append(str(__dir__))
+sys.path.append(str(__dir__.parent.parent))
+
+import torch
+from torch import nn
+from torchocr.networks import build_model
+from torchocr.datasets.RecDataSet import RecDataProcess
+from torchocr.utils import CTCLabelConverter
+
+
+class RecInfer:
+    def __init__(self, model_path, batch_size=16):
+        ckpt = torch.load(model_path, map_location='cpu')
+        cfg = ckpt['cfg']
+        self.model = build_model(cfg['model'])
+        state_dict = {}
+        for k, v in ckpt['state_dict'].items():
+            state_dict[k.replace('module.', '')] = v
+        self.model.load_state_dict(state_dict)
+
+        self.device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
+        self.model.to(self.device)
+        self.model.eval()
+
+        self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
+        self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])
+        # self.converter = CTCLabelConverter("C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\char_std_7782.txt")
+        self.batch_size = batch_size
+
+    def predict(self, imgs):
+        # 预处理根据训练来
+        if not isinstance(imgs,list):
+            imgs = [imgs]
+        imgs = [self.process.normalize_img(self.process.resize_with_specific_height(img)) for img in imgs]
+        widths = np.array([img.shape[1] for img in imgs])
+        idxs = np.argsort(widths)
+        txts = []
+        for idx in range(0, len(imgs), self.batch_size):
+            batch_idxs = idxs[idx:min(len(imgs), idx+self.batch_size)]
+            batch_imgs = [self.process.width_pad_img(imgs[idx], imgs[batch_idxs[-1]].shape[1]) for idx in batch_idxs]
+            batch_imgs = np.stack(batch_imgs)
+            tensor = torch.from_numpy(batch_imgs.transpose([0,3, 1, 2])).float()
+            tensor = tensor.to(self.device)
+            with torch.no_grad():
+                out = self.model(tensor)
+                # print(out)
+                # out[1] 为最后输出
+                out = out[1]
+                out = out.softmax(dim=2)
+            out = out.cpu().numpy()
+            txts.extend([self.converter.decode(np.expand_dims(txt, 0)) for txt in out])
+        #按输入图像的顺序排序
+        idxs = np.argsort(idxs)
+        out_txts = [txts[idx] for idx in idxs]
+        return out_txts
+
+
+def init_args():
+    import argparse
+    parser = argparse.ArgumentParser(description='PytorchOCR infer')
+    # parser.add_argument('--model_path', required=True, type=str, help='rec model path')
+    parser.add_argument('--model_path', required=False, type=str,
+        default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\best.pth", help='rec model path')
+    # parser.add_argument('--img_path', required=True, type=str, help='img path for predict')
+    parser.add_argument('--img_path', required=False, type=str,
+        default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\test_image/Snipaste_2023-08-21_17-07-58.jpg", help='img path for predict')
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == '__main__':
+    import cv2
+    import re
+    from unicodedata import normalize
+    # args = init_args()
+    # img = cv2.imread(args.img_path)
+    model_path = '/data2/znj/PytorchOCR/tools/output/CRNN/checkpoint_resnet3/best.pth'
+    model = RecInfer(model_path)
+    cnt = 0
+    right_cnt = 0
+    error_cnt = 0
+    with open("/data2/znj/ocr_data/image2.txt",mode='r') as f:
+        for line in f.readlines():
+            line = line.strip()
+            # line_split = line.split(" ")
+            # line_split = re.split(" ",line,maxsplit=2)
+            iamge_path,line_split2 = re.split(" ",line,maxsplit=1)
+            text, box = re.split(" \[\[",line_split2,maxsplit=1)
+            box = '[['+box
+            # if len(line_split)==3 :
+            try:
+                if True :
+                    # iamge_path,text,box = line_split
+                    img = cv2.imread(iamge_path)
+                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+                    bbox = eval(box)
+                    x1 = int(min([i[0] for i in bbox]))
+                    x2 = int(max([i[0] for i in bbox]))
+                    y1 = int(min([i[1] for i in bbox]))
+                    y2 = int(max([i[1] for i in bbox]))
+                    img = img[y1:y2, x1:x2]
+
+                    out = model.predict(img)
+                    out = out[0][0][0]
+                    cnt += 1
+                    # if out==text:
+                    if normalize('NFKD', out)==normalize('NFKD', text):
+                        right_cnt += 1
+                    else:
+                        if not re.search("\s",out) and len(out)>0:
+                            print(iamge_path+" "+text+" "+box+" rec_res:"+out)
+                        # print(out+" -> "+ text)
+            except:
+                error_cnt += 1
+                pass
+            # if cnt>= 500000:
+            #     break
+            if cnt-right_cnt>= 400000:
+                break
+
+    print('count_num:',cnt)
+    print('right_num:',right_cnt)
+    print('right_%:',right_cnt/cnt)
+    print('process_error_cnt:',error_cnt)
+
+
+
+    pass

+ 347 - 0
tools/rec_train.py

@@ -0,0 +1,347 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/19 21:44
+# @Author  : xiangjing
+
+import os
+import sys
+import pathlib
+
+# 将 torchocr路径加到python路径里
+__dir__ = pathlib.Path(os.path.abspath(__file__))
+sys.path.append(str(__dir__))
+sys.path.append(str(__dir__.parent.parent))
+import random
+import time
+import shutil
+import traceback
+from importlib import import_module
+
+import numpy as np
+import torch
+from tqdm import tqdm
+from torch import nn
+from torch import optim
+from torchocr.networks import build_model, build_loss
+from torchocr.datasets import build_dataloader
+from torchocr.utils import get_logger, weight_init, load_checkpoint, save_checkpoint
+
+
+def parse_args():
+    import argparse
+    parser = argparse.ArgumentParser(description='train')
+    parser.add_argument('--config', type=str, default='/data2/znj/PytorchOCR/config/cfg_rec_crnn.py', help='train config file path')
+    args = parser.parse_args()
+    # 解析.py文件
+    config_path = os.path.abspath(os.path.expanduser(args.config))
+    assert os.path.isfile(config_path)
+    if config_path.endswith('.py'):
+        module_name = os.path.basename(config_path)[:-3]
+        config_dir = os.path.dirname(config_path)
+        sys.path.insert(0, config_dir)
+        mod = import_module(module_name)
+        sys.path.pop(0)
+        return mod.config
+        # cfg_dict = {
+        #     name: value
+        #     for name, value in mod.__dict__.items()
+        #     if not name.startswith('__')
+        # }
+        # return cfg_dict
+    else:
+        raise IOError('Only py type are supported now!')
+
+
+def set_random_seed(seed, use_cuda=True, deterministic=False):
+    """Set random seed.
+
+    Args:
+        seed (int): Seed to be used.
+        use_cuda: whether depend on cuda
+        deterministic (bool): Whether to set the deterministic option for
+            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+            to True and `torch.backends.cudnn.benchmark` to False.
+            Default: False.
+    """
+    random.seed(seed)
+    np.random.seed(seed)
+    if use_cuda:
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        if deterministic:
+            torch.backends.cudnn.deterministic = True
+            torch.backends.cudnn.benchmark = False
+
+
+def build_optimizer(params, config):
+    """
+    优化器
+    Returns:
+    """
+    
+    opt_type = config.pop('type')
+    opt = getattr(optim, opt_type)(params, **config)
+    return opt
+
+
+def build_scheduler(optimizer, config):
+    """
+    """
+    scheduler = None
+    sch_type = config.pop('type')
+    if sch_type == 'LambdaLR':
+        burn_in, steps = config['burn_in'], config['steps']
+
+        # Learning rate setup
+        def burnin_schedule(i):
+            if i < burn_in:
+                factor = pow(i / burn_in, 4)
+            elif i < steps[0]:
+                factor = 1.0
+            elif i < steps[1]:
+                factor = 0.1
+            else:
+                factor = 0.01
+            return factor
+
+        scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)
+    elif sch_type == 'StepLR':
+        # 等间隔调整学习率, 调整倍数为gamma倍,调整间隔为step_size,间隔单位是step,step通常是指epoch。
+        step_size, gamma = config['step_size'], config['gamma']
+        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
+    elif sch_type == 'ReduceLROnPlateau':
+        # 当某指标不再变化(下降或升高),调整学习率,这是非常实用的学习率调整策略。例如,当验证集的loss不再下降时,进行学习率调整;或者监测验证集的accuracy,当accuracy不再上升时,则调整学习率。
+        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1,
+                                                               patience=3, verbose=True, threshold=1e-4)
+    return scheduler
+
+
+def get_fine_tune_params(net, finetune_stage):
+    """
+    获取需要优化的参数
+    Args:
+        net:
+    Returns: 需要优化的参数
+    """
+    to_return_parameters = []
+    for stage in finetune_stage:
+        attr = getattr(net.module, stage, None)
+        for element in attr.parameters():
+            to_return_parameters.append(element)
+    return to_return_parameters
+
+
+def evaluate(net, val_loader, loss_func, to_use_device, logger, converter, metric):
+    """
+    在验证集上评估模型
+
+    :param net: 网络
+    :param val_loader: 验证集 dataloader
+    :param loss_func: 损失函数
+    :param to_use_device: device
+    :param logger: logger类对象
+    :param converter: label转换器类对象
+    :param metric: 根据网络输出和 label 计算 acc 等指标的类对象
+    :return:  一个包含 eval_loss,eval_acc和 norm_edit_dis 的 dict,
+        例子: {
+                'eval_loss':0,
+                'eval_acc': 0.99,
+                'norm_edit_dis': 0.9999,
+                }
+    """
+    logger.info('start evaluate')
+    net.eval()
+    nums = 0
+    result_dict = {'eval_loss': 0., 'eval_acc': 0., 'norm_edit_dis': 0.}
+    show_str = []
+    with torch.no_grad():
+        for batch_data in tqdm(val_loader):
+            targets, targets_lengths = converter.encode(batch_data['label'])
+            batch_data['targets'] = targets
+            batch_data['targets_lengths'] = targets_lengths
+            output = net.forward(batch_data['img'].to(to_use_device))
+            loss = loss_func(output, batch_data)
+
+            nums += batch_data['img'].shape[0]
+            acc_dict = metric(output[1], batch_data['label'])
+            result_dict['eval_loss'] += loss['loss'].item()
+            result_dict['eval_acc'] += acc_dict['n_correct']
+            result_dict['norm_edit_dis'] += acc_dict['norm_edit_dis']
+            show_str.extend(acc_dict['show_str'])
+
+    print('nums:',nums,'right_nums:',result_dict['eval_acc'])
+    result_dict['eval_loss'] /= len(val_loader)
+    result_dict['eval_acc'] /= nums
+    result_dict['norm_edit_dis'] = 1 - result_dict['norm_edit_dis'] / nums
+    logger.info(f"eval_loss:{result_dict['eval_loss']}")
+    logger.info(f"eval_acc:{result_dict['eval_acc']}")
+    logger.info(f"norm_edit_dis:{result_dict['norm_edit_dis']}")
+
+    for s in show_str[:10]:
+        logger.info(s)
+    net.train()
+    return result_dict
+
+
+def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device,
+          cfg, global_state, logger):
+    """
+    训练函数
+
+    :param net: 网络
+    :param optimizer: 优化器
+    :param scheduler: 学习率更新器
+    :param loss_func: loss函数
+    :param train_loader: 训练数据集 dataloader
+    :param eval_loader: 验证数据集 dataloader
+    :param to_use_device: device
+    :param cfg: 当前训练所使用的配置
+    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
+    :param logger: logger 对象
+    :return: None
+    """
+
+    from torchocr.metrics import RecMetric
+    from torchocr.utils import CTCLabelConverter
+    converter = CTCLabelConverter(cfg.dataset.alphabet)
+    train_options = cfg.train_options
+    metric = RecMetric(converter)
+    # ===>
+    logger.info('Training...')
+    # ===> print loss信息的参数
+    all_step = len(train_loader)
+    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
+    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
+    if len(global_state) > 0:
+        best_model = global_state['best_model']
+        start_epoch = global_state['start_epoch']
+        global_step = global_state['global_step']
+    else:
+        best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
+        start_epoch = 0
+        global_step = 0
+    # 开始训练
+    try:
+        for epoch in range(start_epoch, train_options['epochs']):  # traverse each epoch
+            net.train()  # train mode
+            start = time.time()
+            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
+                current_lr = optimizer.param_groups[0]['lr']
+                cur_batch_size = batch_data['img'].shape[0]
+                targets, targets_lengths = converter.encode(batch_data['label'])
+                batch_data['targets'] = targets
+                batch_data['targets_lengths'] = targets_lengths
+                # 清零梯度及反向传播
+                optimizer.zero_grad()
+                output = net.forward(batch_data['img'].to(to_use_device))
+                loss_dict = loss_func(output, batch_data)
+                loss_dict['loss'].backward()
+                torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
+                optimizer.step()
+                # statistic loss for print
+                acc_dict = metric(output[1], batch_data['label'])
+                acc = acc_dict['n_correct'] / cur_batch_size
+                norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
+                if (i + 1) % train_options['print_interval'] == 0:
+                    interval_batch_time = time.time() - start
+                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
+                                f"[{i + 1}/{all_step}] - "
+                                f"lr:{current_lr} - "
+                                f"loss:{loss_dict['loss'].item():.4f} - "
+                                f"acc:{acc:.4f} - "
+                                f"norm_edit_dis:{norm_edit_dis:.4f} - "
+                                f"time:{interval_batch_time:.4f}")
+                    start = time.time()
+                if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
+                    global_state['start_epoch'] = epoch
+                    global_state['best_model'] = best_model
+                    global_state['global_step'] = global_step
+                    net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
+                    save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                    if train_options['ckpt_save_type'] == 'HighestAcc':
+                        # val
+                        eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
+                        if eval_dict['eval_acc'] > best_model['eval_acc']:
+                            best_model.update(eval_dict)
+                            best_model['best_model_epoch'] = epoch
+                            best_model['models'] = net_save_path
+
+                            global_state['start_epoch'] = epoch
+                            global_state['best_model'] = best_model
+                            global_state['global_step'] = global_step
+                            net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
+                            save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                    elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0:
+                        shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
+                global_step += 1
+            scheduler.step()
+    except KeyboardInterrupt:
+        import os
+        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg, global_state=global_state)
+    except:
+        error_msg = traceback.format_exc()
+        logger.error(error_msg)
+    finally:
+        for k, v in best_model.items():
+            logger.info(f'{k}: {v}')
+
+
+def main():
+    # ===> 获取配置文件参数
+    cfg = parse_args()
+    os.makedirs(cfg.train_options['checkpoint_save_dir'], exist_ok=True)
+    logger = get_logger('torchocr', log_file=os.path.join(cfg.train_options['checkpoint_save_dir'], 'train.log'))
+
+    # ===> 训练信息的打印
+    train_options = cfg.train_options
+    logger.info(cfg)
+    # ===>
+    to_use_device = torch.device(
+        train_options['device'] if torch.cuda.is_available() and ('cuda' in train_options['device']) else 'cpu')
+    set_random_seed(cfg['SEED'], 'cuda' in train_options['device'], deterministic=True)
+
+    # ===> build network
+    net = build_model(cfg['model'])
+
+    # ===> 模型初始化及模型部署到对应的设备
+    if not cfg['model']['backbone']['pretrained']:  # 使用 pretrained
+        net.apply(weight_init)
+    # if torch.cuda.device_count() > 1:
+    net = nn.DataParallel(net)
+    net = net.to(to_use_device)
+    net.train()
+
+    # ===> get fine tune layers
+    params_to_train = get_fine_tune_params(net, train_options['fine_tune_stage'])
+    # ===> solver and lr scheduler
+    optimizer = build_optimizer(params_to_train, cfg['optimizer'])
+    scheduler = build_scheduler(optimizer, cfg['lr_scheduler'])
+
+    # ===> whether to resume from checkpoint
+    resume_from = train_options['resume_from']
+    if resume_from:
+        net, _resumed_optimizer,global_state = load_checkpoint(net, resume_from, to_use_device, optimizer,
+                                                                 third_name=train_options['third_party_name'])
+        if _resumed_optimizer:
+            optimizer = _resumed_optimizer
+        logger.info(f'net resume from {resume_from}')
+    else:
+        global_state = {}
+        logger.info(f'net resume from scratch.')
+
+    # ===> loss function
+    loss_func = build_loss(cfg['loss'])
+    loss_func = loss_func.to(to_use_device)
+
+
+    # ===> data loader
+    cfg.dataset.train.dataset.alphabet = cfg.dataset.alphabet
+    train_loader = build_dataloader(cfg.dataset.train)
+    cfg.dataset.eval.dataset.alphabet = cfg.dataset.alphabet
+    eval_loader = build_dataloader(cfg.dataset.eval)
+
+    # ===> train
+    train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger)
+
+
+if __name__ == '__main__':
+    main()

+ 365 - 0
tools/test_one.py

@@ -0,0 +1,365 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/19 21:44
+# @Author  : xiangjing
+
+import os
+import sys
+import pathlib
+
+# 将 torchocr路径加到python路径里
+__dir__ = pathlib.Path(os.path.abspath(__file__))
+sys.path.append(str(__dir__))
+sys.path.append(str(__dir__.parent.parent))
+import random
+import time
+import shutil
+import traceback
+from importlib import import_module
+
+import numpy as np
+import torch
+from tqdm import tqdm
+from torch import nn
+from torch import optim
+from torchocr.networks import build_model, build_loss
+from torchocr.datasets import build_dataloader
+from torchocr.utils import get_logger, weight_init, load_checkpoint, save_checkpoint
+
+
+def parse_args():
+    import argparse
+    parser = argparse.ArgumentParser(description='train')
+    parser.add_argument('--config', type=str, default='C:/Users/Administrator/Desktop/OCR_pytorch/PytorchOCR1/config/cfg_rec_crnn_test1.py',
+                        help='train config file path')
+    args = parser.parse_args()
+    # 解析.py文件
+    config_path = os.path.abspath(os.path.expanduser(args.config))
+    assert os.path.isfile(config_path)
+    if config_path.endswith('.py'):
+        module_name = os.path.basename(config_path)[:-3]
+        config_dir = os.path.dirname(config_path)
+        sys.path.insert(0, config_dir)
+        mod = import_module(module_name)
+        sys.path.pop(0)
+        return mod.config
+        # cfg_dict = {
+        #     name: value
+        #     for name, value in mod.__dict__.items()
+        #     if not name.startswith('__')
+        # }
+        # return cfg_dict
+    else:
+        raise IOError('Only py type are supported now!')
+
+
+def set_random_seed(seed, use_cuda=True, deterministic=False):
+    """Set random seed.
+
+    Args:
+        seed (int): Seed to be used.
+        use_cuda: whether depend on cuda
+        deterministic (bool): Whether to set the deterministic option for
+            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+            to True and `torch.backends.cudnn.benchmark` to False.
+            Default: False.
+    """
+    random.seed(seed)
+    np.random.seed(seed)
+    if use_cuda:
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        if deterministic:
+            torch.backends.cudnn.deterministic = True
+            torch.backends.cudnn.benchmark = False
+
+
+def build_optimizer(params, config):
+    """
+    优化器
+    Returns:
+    """
+
+    opt_type = config.pop('type')
+    opt = getattr(optim, opt_type)(params, **config)
+    return opt
+
+
+def build_scheduler(optimizer, config):
+    """
+    """
+    scheduler = None
+    sch_type = config.pop('type')
+    if sch_type == 'LambdaLR':
+        burn_in, steps = config['burn_in'], config['steps']
+
+        # Learning rate setup
+        def burnin_schedule(i):
+            if i < burn_in:
+                factor = pow(i / burn_in, 4)
+            elif i < steps[0]:
+                factor = 1.0
+            elif i < steps[1]:
+                factor = 0.1
+            else:
+                factor = 0.01
+            return factor
+
+        scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)
+    elif sch_type == 'StepLR':
+        # 等间隔调整学习率, 调整倍数为gamma倍,调整间隔为step_size,间隔单位是step,step通常是指epoch。
+        step_size, gamma = config['step_size'], config['gamma']
+        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
+    elif sch_type == 'ReduceLROnPlateau':
+        # 当某指标不再变化(下降或升高),调整学习率,这是非常实用的学习率调整策略。例如,当验证集的loss不再下降时,进行学习率调整;或者监测验证集的accuracy,当accuracy不再上升时,则调整学习率。
+        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1,
+                                                               patience=3, verbose=True, threshold=1e-4)
+    return scheduler
+
+
+def get_fine_tune_params(net, finetune_stage):
+    """
+    获取需要优化的参数
+    Args:
+        net:
+    Returns: 需要优化的参数
+    """
+    to_return_parameters = []
+    for stage in finetune_stage:
+        attr = getattr(net.module, stage, None)
+        for element in attr.parameters():
+            to_return_parameters.append(element)
+    return to_return_parameters
+
+
+def evaluate(net, val_loader, loss_func, to_use_device, logger, converter, metric):
+    """
+    在验证集上评估模型
+
+    :param net: 网络
+    :param val_loader: 验证集 dataloader
+    :param loss_func: 损失函数
+    :param to_use_device: device
+    :param logger: logger类对象
+    :param converter: label转换器类对象
+    :param metric: 根据网络输出和 label 计算 acc 等指标的类对象
+    :return:  一个包含 eval_loss,eval_acc和 norm_edit_dis 的 dict,
+        例子: {
+                'eval_loss':0,
+                'eval_acc': 0.99,
+                'norm_edit_dis': 0.9999,
+                }
+    """
+    logger.info('start evaluate')
+    net.eval()
+    nums = 0
+    result_dict = {'eval_loss': 0., 'eval_acc': 0., 'norm_edit_dis': 0.}
+    show_str = []
+    with torch.no_grad():
+        # start =time.time()
+        for batch_data in tqdm(val_loader):
+            targets, targets_lengths = converter.encode(batch_data['label'])
+            batch_data['targets'] = targets
+            batch_data['targets_lengths'] = targets_lengths
+            output = net.forward(batch_data['img'].to(to_use_device))
+            loss = loss_func(output, batch_data)
+
+            nums += batch_data['img'].shape[0]
+            acc_dict = metric(output[1], batch_data['label'])
+            result_dict['eval_loss'] += loss['loss'].item()
+            result_dict['eval_acc'] += acc_dict['n_correct']
+            result_dict['norm_edit_dis'] += acc_dict['norm_edit_dis']
+            show_str.extend(acc_dict['show_str'])
+
+            # print("cost-time:",time.time()-start)
+            # start = time.time()
+
+    print('nums:', nums, 'right_nums:', result_dict['eval_acc'])
+    result_dict['eval_loss'] /= len(val_loader)
+    result_dict['eval_acc'] /= nums
+    result_dict['norm_edit_dis'] = 1 - result_dict['norm_edit_dis'] / nums
+    logger.info(f"eval_loss:{result_dict['eval_loss']}")
+    logger.info(f"eval_acc:{result_dict['eval_acc']}")
+    logger.info(f"norm_edit_dis:{result_dict['norm_edit_dis']}")
+
+    for s in show_str[:10]:
+        logger.info(s)
+    net.train()
+    return result_dict
+
+
+def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device,
+          cfg, global_state, logger):
+    """
+    训练函数
+
+    :param net: 网络
+    :param optimizer: 优化器
+    :param scheduler: 学习率更新器
+    :param loss_func: loss函数
+    :param train_loader: 训练数据集 dataloader
+    :param eval_loader: 验证数据集 dataloader
+    :param to_use_device: device
+    :param cfg: 当前训练所使用的配置
+    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
+    :param logger: logger 对象
+    :return: None
+    """
+
+    from torchocr.metrics import RecMetric
+    from torchocr.utils import CTCLabelConverter
+    converter = CTCLabelConverter(cfg.dataset.alphabet)
+    train_options = cfg.train_options
+    metric = RecMetric(converter)
+    # ===>
+    logger.info('Training...')
+    # ===> print loss信息的参数
+    all_step = len(train_loader)
+    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
+    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
+    if len(global_state) > 0:
+        best_model = global_state['best_model']
+        start_epoch = global_state['start_epoch']
+        global_step = global_state['global_step']
+    else:
+        best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
+        start_epoch = 0
+        global_step = 0
+    # 开始训练
+    try:
+        for epoch in range(start_epoch, train_options['epochs']):  # traverse each epoch
+            net.train()  # train mode
+            start = time.time()
+            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
+                current_lr = optimizer.param_groups[0]['lr']
+                cur_batch_size = batch_data['img'].shape[0]
+                targets, targets_lengths = converter.encode(batch_data['label'])
+                batch_data['targets'] = targets
+                batch_data['targets_lengths'] = targets_lengths
+                # 清零梯度及反向传播
+                optimizer.zero_grad()
+                output = net.forward(batch_data['img'].to(to_use_device))
+                loss_dict = loss_func(output, batch_data)
+                loss_dict['loss'].backward()
+                torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
+                optimizer.step()
+                # statistic loss for print
+                acc_dict = metric(output[1], batch_data['label'])
+                acc = acc_dict['n_correct'] / cur_batch_size
+                norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
+                if (i + 1) % train_options['print_interval'] == 0:
+                    interval_batch_time = time.time() - start
+                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
+                                f"[{i + 1}/{all_step}] - "
+                                f"lr:{current_lr} - "
+                                f"loss:{loss_dict['loss'].item():.4f} - "
+                                f"acc:{acc:.4f} - "
+                                f"norm_edit_dis:{norm_edit_dis:.4f} - "
+                                f"time:{interval_batch_time:.4f}")
+                    start = time.time()
+                if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
+                    global_state['start_epoch'] = epoch
+                    global_state['best_model'] = best_model
+                    global_state['global_step'] = global_step
+                    net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
+                    save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                    if train_options['ckpt_save_type'] == 'HighestAcc':
+                        # val
+                        eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
+                        if eval_dict['eval_acc'] > best_model['eval_acc']:
+                            best_model.update(eval_dict)
+                            best_model['best_model_epoch'] = epoch
+                            best_model['models'] = net_save_path
+
+                            global_state['start_epoch'] = epoch
+                            global_state['best_model'] = best_model
+                            global_state['global_step'] = global_step
+                            net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
+                            save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
+                    elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options[
+                        'ckpt_save_epoch'] == 0:
+                        shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
+                global_step += 1
+            scheduler.step()
+    except KeyboardInterrupt:
+        import os
+        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg,
+                        global_state=global_state)
+    except:
+        error_msg = traceback.format_exc()
+        logger.error(error_msg)
+    finally:
+        for k, v in best_model.items():
+            logger.info(f'{k}: {v}')
+
+
+def main():
+    # ===> 获取配置文件参数
+    cfg = parse_args()
+    os.makedirs(cfg.train_options['checkpoint_save_dir'], exist_ok=True)
+    logger = get_logger('torchocr', log_file=os.path.join(cfg.train_options['checkpoint_save_dir'], 'train.log'))
+
+    # ===> 训练信息的打印
+    train_options = cfg.train_options
+    logger.info(cfg)
+    # ===>
+    to_use_device = torch.device(
+        train_options['device'] if torch.cuda.is_available() and ('cuda' in train_options['device']) else 'cpu')
+    set_random_seed(cfg['SEED'], 'cuda' in train_options['device'], deterministic=True)
+
+    # ===> build network
+    net = build_model(cfg['model'])
+
+    # ===> 模型初始化及模型部署到对应的设备
+    if not cfg['model']['backbone']['pretrained']:  # 使用 pretrained
+        net.apply(weight_init)
+    # if torch.cuda.device_count() > 1:
+    net = nn.DataParallel(net)
+    net = net.to(to_use_device)
+    net.train()
+
+    # ===> get fine tune layers
+    params_to_train = get_fine_tune_params(net, train_options['fine_tune_stage'])
+    # ===> solver and lr scheduler
+    optimizer = build_optimizer(params_to_train, cfg['optimizer'])
+    scheduler = build_scheduler(optimizer, cfg['lr_scheduler'])
+
+    # ===> whether to resume from checkpoint
+    resume_from = train_options['resume_from']
+    if resume_from:
+        net, _resumed_optimizer, global_state = load_checkpoint(net, resume_from, to_use_device, optimizer,
+                                                                third_name=train_options['third_party_name'])
+        if _resumed_optimizer:
+            optimizer = _resumed_optimizer
+        logger.info(f'net resume from {resume_from}')
+    else:
+        global_state = {}
+        logger.info(f'net resume from scratch.')
+
+    # ===> loss function
+    loss_func = build_loss(cfg['loss'])
+    loss_func = loss_func.to(to_use_device)
+
+    # ===> data loader
+    # cfg.dataset.train.dataset.alphabet = cfg.dataset.alphabet
+    # train_loader = build_dataloader(cfg.dataset.train)
+    # cfg.dataset.eval.dataset.alphabet = cfg.dataset.alphabet
+    cfg.dataset.eval.dataset.alphabet = "C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\char_std_7782.txt"
+    eval_loader = build_dataloader(cfg.dataset.eval)
+
+    from torchocr.metrics import RecMetric
+    from torchocr.utils import CTCLabelConverter
+    _converter = CTCLabelConverter(cfg.dataset.alphabet)
+    _metric = RecMetric(_converter)
+    _converter = CTCLabelConverter(cfg.dataset.alphabet)
+    # start = time.time()
+    eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, _converter, _metric)
+    print(eval_dict)
+    # print("cost_time:",time.time()-start)
+
+
+    # ===> train
+    # train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device, cfg, global_state, logger)
+
+
+if __name__ == '__main__':
+    main()

+ 3 - 0
torchocr/__init__.py

@@ -0,0 +1,3 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/15 17:44
+# @Author  : zhoujun

+ 33 - 0
torchocr/datasets/DetCollateFN.py

@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/6/22 14:16
+# @Author  : zhoujun
+import PIL
+import numpy as np
+import torch
+from torchvision import transforms
+
+__all__ = ['DetCollectFN']
+
+
+class DetCollectFN:
+    def __init__(self, *args, **kwargs):
+        pass
+
+    def __call__(self, batch):
+        data_dict = {}
+        to_tensor_keys = []
+        for sample in batch:
+            for k, v in sample.items():
+                if k not in data_dict:
+                    data_dict[k] = []
+                if isinstance(v, (np.ndarray, torch.Tensor, PIL.Image.Image)):
+                    if k not in to_tensor_keys:
+                        to_tensor_keys.append(k)
+                    if isinstance(v, np.ndarray):
+                        v = torch.tensor(v)
+                    if isinstance(v, PIL.Image.Image):
+                        v = transforms.ToTensor()(v)
+                data_dict[k].append(v)
+        for k in to_tensor_keys:
+            data_dict[k] = torch.stack(data_dict[k], 0)
+        return data_dict

+ 155 - 0
torchocr/datasets/DetDataSet.py

@@ -0,0 +1,155 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/6/22 10:53
+# @Author  : zhoujun
+import os
+import cv2
+import json
+import copy
+import numpy as np
+from tqdm import tqdm
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchocr.datasets.det_modules import *
+
+
+def load_json(file_path: str):
+    with open(file_path, 'r', encoding='utf8') as f:
+        content = json.load(f)
+    return content
+
+
+class JsonDataset(Dataset):
+    """
+    from https://github.com/WenmuZhou/OCR_DataSet/blob/master/dataset/det.py
+    """
+
+    def __init__(self, config):
+        assert config.img_mode in ['RGB', 'BRG', 'GRAY']
+        self.ignore_tags = config.ignore_tags
+        # 加载字符级标注
+        self.load_char_annotation = False
+
+        self.data_list = self.load_data(config.file)
+        item_keys = ['img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags']
+        for item in item_keys:
+            assert item in self.data_list[0], 'data_list from load_data must contains {}'.format(item_keys)
+        self.img_mode = config.img_mode
+        self.filter_keys = config.filter_keys
+        self._init_pre_processes(config.pre_processes)
+        self.transform = transforms.Compose([
+            transforms.ToTensor(),
+            transforms.Normalize(mean=config.mean, std=config.std)
+        ])
+
+    def _init_pre_processes(self, pre_processes):
+        self.aug = []
+        if pre_processes is not None:
+            for aug in pre_processes:
+                if 'args' not in aug:
+                    args = {}
+                else:
+                    args = aug['args']
+                if isinstance(args, dict):
+                    cls = eval(aug['type'])(**args)
+                else:
+                    cls = eval(aug['type'])(args)
+                self.aug.append(cls)
+
+    def load_data(self, path: str) -> list:
+        """
+        从json文件中读取出 文本行的坐标和gt,字符的坐标和gt
+        :params path: 存储数据的文件夹或者文件
+        return a dict ,包含了,'img_path','img_name','text_polys','texts','ignore_tags'
+        """
+        data_list = []
+        content = load_json(path)
+        for gt in tqdm(content['data_list'], desc='read file {}'.format(path)):
+            try:
+                img_path = os.path.join(content['data_root'], gt['img_name'])
+                polygons = []
+                texts = []
+                illegibility_list = []
+                language_list = []
+                for annotation in gt['annotations']:
+                    if len(annotation['polygon']) == 0 or len(annotation['text']) == 0:
+                        continue
+                    polygons.append(annotation['polygon'])
+                    texts.append(annotation['text'])
+                    illegibility_list.append(annotation['illegibility'])
+                    language_list.append(annotation['language'])
+                    if self.load_char_annotation:
+                        for char_annotation in annotation['chars']:
+                            if len(char_annotation['polygon']) == 0 or len(char_annotation['char']) == 0:
+                                continue
+                            polygons.append(char_annotation['polygon'])
+                            texts.append(char_annotation['char'])
+                            illegibility_list.append(char_annotation['illegibility'])
+                            language_list.append(char_annotation['language'])
+                data_list.append({'img_path': img_path, 'img_name': gt['img_name'], 'text_polys': polygons,
+                                  'texts': texts, 'ignore_tags': illegibility_list})
+            except:
+                print(f'error gt:{img_path}')
+        return data_list
+
+    def apply_pre_processes(self, data):
+        for aug in self.aug:
+            data = aug(data)
+        return data
+
+    def __getitem__(self, index):
+        data = copy.deepcopy(self.data_list[index])
+        im = cv2.imread(data['img_path'], 1 if self.img_mode != 'GRAY' else 0)
+        if self.img_mode == 'RGB':
+            try:
+                im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+            except:
+                print(data['img_path'])
+        data['img'] = im
+        data['shape'] = [im.shape[0], im.shape[1]]
+        data = self.apply_pre_processes(data)
+
+        if self.transform:
+            data['img'] = self.transform(data['img'])
+        data['text_polys'] = data['text_polys']
+        if len(self.filter_keys):
+            data_dict = {}
+            for k, v in data.items():
+                if k not in self.filter_keys:
+                    data_dict[k] = v
+            return data_dict
+        else:
+            return data
+
+    def __len__(self):
+        return len(self.data_list)
+
+
+if __name__ == '__main__':
+    import torch
+    from torch.utils.data import DataLoader
+    # from config.cfg_det_db import config
+    from local.cfg.cfg_det_db_latin import config
+    from torchocr.utils import show_img, draw_bbox
+
+    from matplotlib import pyplot as plt
+
+    dataset = JsonDataset(config.dataset.train.dataset)
+    train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0)
+    for i, data in enumerate(tqdm(train_loader)):
+        # print(data['img_path'])
+        img = data['img'][0].numpy().transpose(1, 2, 0)
+        shrink_label = data['shrink_map'].numpy().transpose(1, 2, 0)
+        threshold_label = data['threshold_map'].numpy().transpose(1, 2, 0)
+        show_img(img, title='img')
+        show_img(shrink_label, title='shrink_label')
+        show_img(threshold_label, title='threshold_label')
+        plt.show()
+        # print(threshold_label.shape, threshold_label.shape, img.shape)
+        # show_img(img[0].numpy().transpose(1, 2, 0), title='img')
+        # show_img((shrink_label[0].to(torch.float)).numpy(), title='shrink_label')
+        # show_img((threshold_label[0].to(torch.float)).numpy(), title='threshold_label')
+        # img = draw_bbox(img[0].numpy().transpose(1, 2, 0), np.array(data['text_polys']))
+        # show_img(img, title='draw_bbox')
+        # plt.show()
+
+        pass

+ 160 - 0
torchocr/datasets/DetDataSetFce.py

@@ -0,0 +1,160 @@
+import os
+import cv2
+import json
+import copy
+import numpy as np
+from tqdm import tqdm
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchocr.datasets.det_modules import *
+
+
+def load_json(file_path: str):
+    with open(file_path, 'r', encoding='utf8') as f:
+        content = json.load(f)
+    return content
+
+
+class FCEDataset(Dataset):
+    def __init__(self, config):
+        assert config.img_mode in ['RGB', 'BRG', 'GRAY']
+        self.ignore_tags = config.ignore_tags
+        # 加载字符级标注
+        self.load_char_annotation = False
+
+        self.data_list = self.load_data(config.file)
+        item_keys = ['img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags']
+        for item in item_keys:
+            assert item in self.data_list[0], 'data_list from load_data must contains {}'.format(item_keys)
+        self.img_mode = config.img_mode
+        self.filter_keys = config.filter_keys
+        self._init_pre_processes(config.pre_processes)
+        self.transform = transforms.Compose([
+            transforms.ToTensor(),
+            transforms.Normalize(mean=config.mean, std=config.std)
+        ])
+
+    def _init_pre_processes(self, pre_processes):
+        self.aug = []
+        if pre_processes is not None:
+            for aug in pre_processes:
+                if 'args' not in aug:
+                    args = {}
+                else:
+                    args = aug['args']
+                if isinstance(args, dict):
+                    cls = eval(aug['type'])(**args)
+                else:
+                    cls = eval(aug['type'])(args)
+                self.aug.append(cls)
+
+    def load_data(self, path: str) -> list:
+        """
+        从json文件中读取出 文本行的坐标和gt,字符的坐标和gt
+        :params path: 存储数据的文件夹或者文件
+        return a dict ,包含了,'img_path','img_name','text_polys','texts','ignore_tags'
+        """
+        data_list = []
+        content = load_json(path)
+        for gt in tqdm(content['data_list'], desc='read file {}'.format(path)):
+            try:
+                img_path = os.path.join(content['data_root'], gt['img_name'])
+                polygons = []
+                texts = []
+                illegibility_list = []
+                language_list = []
+                max_poly_len = 0
+                if len( gt['annotations'])==0:
+                    print(img_path)
+                    continue
+                for annotation in gt['annotations']:
+                    if len(annotation['polygon']) == 0 or len(annotation['text']) == 0:
+                        continue
+                    max_poly_len = max(max_poly_len, len(annotation['polygon']))
+                    polygons.append(annotation['polygon'])
+                    texts.append(annotation['text'])
+                    illegibility_list.append(annotation['illegibility'])
+                    language_list.append(annotation['language'])
+                    if self.load_char_annotation:
+                        for char_annotation in annotation['chars']:
+                            if len(char_annotation['polygon']) == 0 or len(char_annotation['char']) == 0:
+                                continue
+                            polygons.append(char_annotation['polygon'])
+                            texts.append(char_annotation['char'])
+                            illegibility_list.append(char_annotation['illegibility'])
+                            language_list.append(char_annotation['language'])
+                ex_polygons = []
+                for pl in polygons:
+                    ex_pl = pl + [pl[-1]] * (max_poly_len - len(pl))
+                    ex_polygons.append(ex_pl)
+
+                data_list.append(
+                    {'img_path': img_path, 'img_name': gt['img_name'], 'text_polys': np.array(ex_polygons, dtype=np.float32),
+                     'texts': texts, 'ignore_tags': illegibility_list})
+            except:
+                print(f'error gt:{img_path}')
+        return data_list
+
+    def apply_pre_processes(self, data):
+        for aug in self.aug:
+            data = aug(data)
+        return data
+
+    def __getitem__(self, index):
+        # try:
+        data = copy.deepcopy(self.data_list[index])
+        im = cv2.imread(data['img_path'], 1 if self.img_mode != 'GRAY' else 0)
+        if self.img_mode == 'RGB':
+            try:
+                im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+            except:
+                print(data['img_path'])
+        data['img'] = im
+        data['shape'] = [im.shape[0], im.shape[1]]
+        data = self.apply_pre_processes(data)
+
+        if self.transform:
+            data['img'] = self.transform(data['img'])
+        data['text_polys'] = data['text_polys']
+        if len(self.filter_keys):
+            data_dict = {}
+            for k, v in data.items():
+                if k not in self.filter_keys:
+                    data_dict[k] = v
+            return data_dict
+        else:
+            return data
+
+    def __len__(self):
+        return len(self.data_list)
+
+
+if __name__ == '__main__':
+    import torch
+    from torch.utils.data import DataLoader
+    # from config.cfg_det_db import config
+    from local.cfg.cfg_det_fce import config
+    from torchocr.utils import show_img, draw_bbox
+
+    from matplotlib import pyplot as plt
+
+    dataset = JsonDataset(config.dataset.train.dataset)
+    train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0)
+    for i, data in enumerate(tqdm(train_loader)):
+        # print(data['img_path'])
+        # img = data['img'][0].numpy().transpose(1, 2, 0)
+        # shrink_label = data['shrink_map'].numpy().transpose(1, 2, 0)
+        # threshold_label = data['threshold_map'].numpy().transpose(1, 2, 0)
+        # show_img(img, title='img')
+        # show_img(shrink_label, title='shrink_label')
+        # show_img(threshold_label, title='threshold_label')
+        # plt.show()
+        # print(threshold_label.shape, threshold_label.shape, img.shape)
+        # show_img(img[0].numpy().transpose(1, 2, 0), title='img')
+        # show_img((shrink_label[0].to(torch.float)).numpy(), title='shrink_label')
+        # show_img((threshold_label[0].to(torch.float)).numpy(), title='threshold_label')
+        # img = draw_bbox(img[0].numpy().transpose(1, 2, 0), np.array(data['text_polys']))
+        # show_img(img, title='draw_bbox')
+        # plt.show()
+
+        pass

+ 757 - 0
torchocr/datasets/DetDateSetPse.py

@@ -0,0 +1,757 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2018/6/11 15:54
+# @Author  : zhoujun
+
+import os
+import math
+import random
+import numbers
+import pathlib
+import pyclipper
+from torch.utils import data
+import glob
+import numpy as np
+import cv2
+from skimage.util import random_noise
+import json
+from tqdm import tqdm
+from torchvision import transforms
+
+
+# from utils.utils import draw_bbox
+
+# 图像均为cv2读取
+class DataAugment():
+    def __init__(self):
+        pass
+
+    def add_noise(self, im: np.ndarray):
+        """
+        对图片加噪声
+        :param img: 图像array
+        :return: 加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255
+        """
+        return (random_noise(im, mode='gaussian', clip=True) * 255).astype(im.dtype)
+
+    def random_scale(self, im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray or list) -> tuple:
+        """
+        从scales中随机选择一个尺度,对图片和文本框进行缩放
+        :param im: 原图
+        :param text_polys: 文本框
+        :param scales: 尺度
+        :return: 经过缩放的图片和文本
+        """
+        tmp_text_polys = text_polys.copy()
+        rd_scale = float(np.random.choice(scales))
+        im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
+        tmp_text_polys *= rd_scale
+        return im, tmp_text_polys
+
+    def random_rotate_img_bbox(self, img, text_polys, degrees: numbers.Number or list or tuple or np.ndarray,
+                               same_size=False):
+        """
+        从给定的角度中选择一个角度,对图片和文本框进行旋转
+        :param img: 图片
+        :param text_polys: 文本框
+        :param degrees: 角度,可以是一个数值或者list
+        :param same_size: 是否保持和原图一样大
+        :return: 旋转后的图片和角度
+        """
+        if isinstance(degrees, numbers.Number):
+            if degrees < 0:
+                raise ValueError("If degrees is a single number, it must be positive.")
+            degrees = (-degrees, degrees)
+        elif isinstance(degrees, list) or isinstance(degrees, tuple) or isinstance(degrees, np.ndarray):
+            if len(degrees) != 2:
+                raise ValueError("If degrees is a sequence, it must be of len 2.")
+            degrees = degrees
+        else:
+            raise Exception('degrees must in Number or list or tuple or np.ndarray')
+        # ---------------------- 旋转图像 ----------------------
+        w = img.shape[1]
+        h = img.shape[0]
+        angle = np.random.uniform(degrees[0], degrees[1])
+
+        if same_size:
+            nw = w
+            nh = h
+        else:
+            # 角度变弧度
+            rangle = np.deg2rad(angle)
+            # 计算旋转之后图像的w, h
+            nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
+            nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
+        # 构造仿射矩阵
+        rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1)
+        # 计算原图中心点到新图中心点的偏移量
+        rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
+        # 更新仿射矩阵
+        rot_mat[0, 2] += rot_move[0]
+        rot_mat[1, 2] += rot_move[1]
+        # 仿射变换
+        rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
+
+        # ---------------------- 矫正bbox坐标 ----------------------
+        # rot_mat是最终的旋转矩阵
+        # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下
+        rot_text_polys = list()
+        for bbox in text_polys:
+            point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
+            point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
+            point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
+            point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
+            rot_text_polys.append([point1, point2, point3, point4])
+        return rot_img, np.array(rot_text_polys, dtype=np.float32)
+
+    def random_crop_img_bboxes(self, im: np.ndarray, text_polys: np.ndarray, max_tries=50) -> tuple:
+        """
+        从图片中裁剪出 cropsize大小的图片和对应区域的文本框
+        :param im: 图片
+        :param text_polys: 文本框
+        :param max_tries: 最大尝试次数
+        :return: 裁剪后的图片和文本框
+        """
+        h, w, _ = im.shape
+        pad_h = h // 10
+        pad_w = w // 10
+        h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
+        w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
+        for poly in text_polys:
+            poly = np.round(poly, decimals=0).astype(np.int32)  # 四舍五入取整
+            minx = np.min(poly[:, 0])
+            maxx = np.max(poly[:, 0])
+            w_array[minx + pad_w:maxx + pad_w] = 1  # 将文本区域的在w_array上设为1,表示x轴方向上这部分位置有文本
+            miny = np.min(poly[:, 1])
+            maxy = np.max(poly[:, 1])
+            h_array[miny + pad_h:maxy + pad_h] = 1  # 将文本区域的在h_array上设为1,表示y轴方向上这部分位置有文本
+        # 在两个轴上 拿出背景位置去进行随机的位置选择,避免选择的区域穿过文本
+        h_axis = np.where(h_array == 0)[0]
+        w_axis = np.where(w_array == 0)[0]
+        if len(h_axis) == 0 or len(w_axis) == 0:
+            # 整张图全是文本的情况下,直接返回
+            return im, text_polys
+        for i in range(max_tries):
+            xx = np.random.choice(w_axis, size=2)
+            # 对选择区域进行边界控制
+            xmin = np.min(xx) - pad_w
+            xmax = np.max(xx) - pad_w
+            xmin = np.clip(xmin, 0, w - 1)
+            xmax = np.clip(xmax, 0, w - 1)
+            yy = np.random.choice(h_axis, size=2)
+            ymin = np.min(yy) - pad_h
+            ymax = np.max(yy) - pad_h
+            ymin = np.clip(ymin, 0, h - 1)
+            ymax = np.clip(ymax, 0, h - 1)
+            if xmax - xmin < 0.1 * w or ymax - ymin < 0.1 * h:
+                # 选择的区域过小
+                # area too small
+                continue
+            if text_polys.shape[0] != 0:  # 这个判断不知道干啥的
+                poly_axis_in_area = (text_polys[:, :, 0] >= xmin) & (text_polys[:, :, 0] <= xmax) \
+                                    & (text_polys[:, :, 1] >= ymin) & (text_polys[:, :, 1] <= ymax)
+                selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
+            else:
+                selected_polys = []
+            if len(selected_polys) == 0:
+                # 区域内没有文本
+                continue
+            im = im[ymin:ymax + 1, xmin:xmax + 1, :]
+            polys = text_polys[selected_polys]
+            # 坐标调整到裁剪图片上
+            polys[:, :, 0] -= xmin
+            polys[:, :, 1] -= ymin
+            return im, polys
+        return im, text_polys
+
+    def random_crop_image_pse(self, im: np.ndarray, text_polys: np.ndarray, input_size) -> tuple:
+        """
+        从图片中裁剪出 cropsize大小的图片和对应区域的文本框
+        :param im: 图片
+        :param text_polys: 文本框
+        :param input_size: 输出图像大小
+        :return: 裁剪后的图片和文本框
+        """
+        h, w, _ = im.shape
+        short_edge = min(h, w)
+        if short_edge < input_size:
+            # 保证短边 >= inputsize
+            scale = input_size / short_edge
+            im = cv2.resize(im, dsize=None, fx=scale, fy=scale)
+            text_polys *= scale
+            h, w, _ = im.shape
+        # 计算随机范围
+        w_range = w - input_size
+        h_range = h - input_size
+        for _ in range(50):
+            xmin = random.randint(0, w_range)
+            ymin = random.randint(0, h_range)
+            xmax = xmin + input_size
+            ymax = ymin + input_size
+            if text_polys.shape[0] != 0:
+                selected_polys = []
+                for poly in text_polys:
+                    if poly[:, 0].max() < xmin or poly[:, 0].min() > xmax or \
+                            poly[:, 1].max() < ymin or poly[:, 1].min() > ymax:
+                        continue
+                    # area_p = cv2.contourArea(poly)
+                    poly[:, 0] -= xmin
+                    poly[:, 1] -= ymin
+                    poly[:, 0] = np.clip(poly[:, 0], 0, input_size)
+                    poly[:, 1] = np.clip(poly[:, 1], 0, input_size)
+                    # rect = cv2.minAreaRect(poly)
+                    # area_n = cv2.contourArea(poly)
+                    # h1, w1 = rect[1]
+                    # if w1 < 10 or h1 < 10 or area_n / area_p < 0.5:
+                    #     continue
+                    selected_polys.append(poly)
+            else:
+                selected_polys = []
+            # if len(selected_polys) == 0:
+            # 区域内没有文本
+            # continue
+            im = im[ymin:ymax, xmin:xmax, :]
+            polys = np.array(selected_polys)
+            return im, polys
+        return im, text_polys
+
+    def random_crop_author(self, imgs, img_size):
+        h, w = imgs[0].shape[0:2]
+        th, tw = img_size
+        if w == tw and h == th:
+            return imgs
+
+        # label中存在文本实例,并且按照概率进行裁剪
+        if np.max(imgs[1][:, :, -1]) > 0 and random.random() > 3.0 / 8.0:
+            # 文本实例的top left点
+            tl = np.min(np.where(imgs[1][:, :, -1] > 0), axis=1) - img_size
+            tl[tl < 0] = 0
+            # 文本实例的 bottom right 点
+            br = np.max(np.where(imgs[1][:, :, -1] > 0), axis=1) - img_size
+            br[br < 0] = 0
+            # 保证选到右下角点是,有足够的距离进行crop
+            br[0] = min(br[0], h - th)
+            br[1] = min(br[1], w - tw)
+            for _ in range(50000):
+                i = random.randint(tl[0], br[0])
+                j = random.randint(tl[1], br[1])
+                # 保证最小的图有文本
+                if imgs[1][:, :, 0][i:i + th, j:j + tw].sum() <= 0:
+                    continue
+                else:
+                    break
+        else:
+            i = random.randint(0, h - th)
+            j = random.randint(0, w - tw)
+
+        # return i, j, th, tw
+        for idx in range(len(imgs)):
+            if len(imgs[idx].shape) == 3:
+                imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
+            else:
+                imgs[idx] = imgs[idx][i:i + th, j:j + tw]
+        return imgs
+
+    def resize(self, im: np.ndarray, text_polys: np.ndarray,
+               input_size: numbers.Number or list or tuple or np.ndarray, keep_ratio: bool = False) -> tuple:
+        """
+        对图片和文本框进行resize
+        :param im: 图片
+        :param text_polys: 文本框
+        :param input_size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
+        :param keep_ratio: 是否保持长宽比
+        :return: resize后的图片和文本框
+        """
+        if isinstance(input_size, numbers.Number):
+            if input_size < 0:
+                raise ValueError("If input_size is a single number, it must be positive.")
+            input_size = (input_size, input_size)
+        elif isinstance(input_size, list) or isinstance(input_size, tuple) or isinstance(input_size, np.ndarray):
+            if len(input_size) != 2:
+                raise ValueError("If input_size is a sequence, it must be of len 2.")
+            input_size = (input_size[0], input_size[1])
+        else:
+            raise Exception('input_size must in Number or list or tuple or np.ndarray')
+        if keep_ratio:
+            # 将图片短边pad到和长边一样
+            h, w, c = im.shape
+            max_h = max(h, input_size[0])
+            max_w = max(w, input_size[1])
+            im_padded = np.zeros((max_h, max_w, c), dtype=np.uint8)
+            im_padded[:h, :w] = im.copy()
+            im = im_padded
+        text_polys = text_polys.astype(np.float32)
+        h, w, _ = im.shape
+        im = cv2.resize(im, input_size)
+        w_scale = input_size[0] / float(w)
+        h_scale = input_size[1] / float(h)
+        text_polys[:, :, 0] *= w_scale
+        text_polys[:, :, 1] *= h_scale
+        return im, text_polys
+
+    def horizontal_flip(self, im: np.ndarray, text_polys: np.ndarray) -> tuple:
+        """
+        对图片和文本框进行水平翻转
+        :param im: 图片
+        :param text_polys: 文本框
+        :return: 水平翻转之后的图片和文本框
+        """
+        flip_text_polys = text_polys.copy()
+        flip_im = cv2.flip(im, 1)
+        h, w, _ = flip_im.shape
+        flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0]
+        return flip_im, flip_text_polys
+
+    def vertical_flip(self, im: np.ndarray, text_polys: np.ndarray) -> tuple:
+        """
+         对图片和文本框进行竖直翻转
+        :param im: 图片
+        :param text_polys: 文本框
+        :return: 竖直翻转之后的图片和文本框
+        """
+        flip_text_polys = text_polys.copy()
+        flip_im = cv2.flip(im, 0)
+        h, w, _ = flip_im.shape
+        flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1]
+        return flip_im, flip_text_polys
+
+    def test(self, im: np.ndarray, text_polys: np.ndarray):
+        print('随机尺度缩放')
+        t_im, t_text_polys = self.random_scale(im, text_polys, [0.5, 1, 2, 3])
+        print(t_im.shape, t_text_polys.dtype)
+        show_pic(t_im, t_text_polys, 'random_scale')
+
+        print('随机旋转')
+        t_im, t_text_polys = self.random_rotate_img_bbox(im, text_polys, 10)
+        print(t_im.shape, t_text_polys.dtype)
+        show_pic(t_im, t_text_polys, 'random_rotate_img_bbox')
+
+        print('随机裁剪')
+        t_im, t_text_polys = self.random_crop_img_bboxes(im, text_polys)
+        print(t_im.shape, t_text_polys.dtype)
+        show_pic(t_im, t_text_polys, 'random_crop_img_bboxes')
+
+        print('水平翻转')
+        t_im, t_text_polys = self.horizontal_flip(im, text_polys)
+        print(t_im.shape, t_text_polys.dtype)
+        show_pic(t_im, t_text_polys, 'horizontal_flip')
+
+        print('竖直翻转')
+        t_im, t_text_polys = self.vertical_flip(im, text_polys)
+        print(t_im.shape, t_text_polys.dtype)
+        show_pic(t_im, t_text_polys, 'vertical_flip')
+        show_pic(im, text_polys, 'vertical_flip_ori')
+
+        print('加噪声')
+        t_im = self.add_noise(im)
+        print(t_im.shape)
+        show_pic(t_im, text_polys, 'add_noise')
+        show_pic(im, text_polys, 'add_noise_ori')
+
+
+data_aug = DataAugment()
+
+
+def load_json(file_path: str):
+    with open(file_path, 'r', encoding='utf8') as f:
+        content = json.load(f)
+    return content
+
+
+def check_and_validate_polys(polys, xxx_todo_changeme):
+    '''
+    check so that the text poly is in the same direction,
+    and also filter some invalid polygons
+    :param polys:
+    :param tags:
+    :return:
+    '''
+    (h, w) = xxx_todo_changeme
+    if polys.shape[0] == 0:
+        return polys
+    polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)  # x coord not max w-1, and not min 0
+    polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)  # y coord not max h-1, and not min 0
+
+    validated_polys = []
+    for poly in polys:
+        p_area = cv2.contourArea(poly)
+        if abs(p_area) < 1:
+            continue
+        validated_polys.append(poly)
+    return np.array(validated_polys)
+
+
+def generate_rbox(im_size, text_polys, text_tags, training_mask, i, n, m):
+    """
+    生成mask图,白色部分是文本,黑色是北京
+    :param im_size: 图像的h,w
+    :param text_polys: 框的坐标
+    :param text_tags: 标注文本框是否参与训练
+    :return: 生成的mask图
+    """
+    h, w = im_size
+    score_map = np.zeros((h, w), dtype=np.uint8)
+    for poly, tag in zip(text_polys, text_tags):
+        poly = poly.astype(np.int)
+        r_i = 1 - (1 - m) * (n - i) / (n - 1)
+        d_i = cv2.contourArea(poly) * (1 - r_i * r_i) / cv2.arcLength(poly, True)
+        pco = pyclipper.PyclipperOffset()
+        # pco.AddPath(pyclipper.scale_to_clipper(poly), pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+        # shrinked_poly = np.floor(np.array(pyclipper.scale_from_clipper(pco.Execute(-d_i)))).astype(np.int)
+        pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+        shrinked_poly = np.array(pco.Execute(-d_i))
+        cv2.fillPoly(score_map, shrinked_poly, 1)
+        # 制作mask
+        # rect = cv2.minAreaRect(shrinked_poly)
+        # poly_h, poly_w = rect[1]
+
+        # if min(poly_h, poly_w) < 10:
+        #     cv2.fillPoly(training_mask, shrinked_poly, 0)
+        if tag:
+            cv2.fillPoly(training_mask, shrinked_poly, 0)
+        # 闭运算填充内部小框
+        # kernel = np.ones((3, 3), np.uint8)
+        # score_map = cv2.morphologyEx(score_map, cv2.MORPH_CLOSE, kernel)
+    return score_map, training_mask
+
+
+def augmentation(im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray, degrees: int, input_size: int) -> tuple:
+    # the images are rescaled with ratio {0.5, 1.0, 2.0, 3.0} randomly
+    im, text_polys = data_aug.random_scale(im, text_polys, scales)
+    # the images are horizontally fliped and rotated in range [−10◦, 10◦] randomly
+    if random.random() < 0.5:
+        im, text_polys = data_aug.horizontal_flip(im, text_polys)
+    if random.random() < 0.5:
+        im, text_polys = data_aug.random_rotate_img_bbox(im, text_polys, degrees)
+    # 640 × 640 random samples are cropped from the transformed images
+    # im, text_polys = data_aug.random_crop_img_bboxes(im, text_polys)
+
+    # im, text_polys = data_aug.resize(im, text_polys, input_size, keep_ratio=False)
+    # im, text_polys = data_aug.random_crop_image_pse(im, text_polys, input_size)
+
+    return im, text_polys
+class EastRandomCropData():
+    def __init__(self, size=(640, 640), max_tries=50, min_crop_side_ratio=0.1, require_original_image=False, keep_ratio=True):
+        self.size = size
+        self.max_tries = max_tries
+        self.min_crop_side_ratio = min_crop_side_ratio
+        self.require_original_image = require_original_image
+        self.keep_ratio = keep_ratio
+
+    def __call__(self, data: dict) -> dict:
+        """
+        从scales中随机选择一个尺度,对图片和文本框进行缩放
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        im = data['img']
+        training_mask = data['training_mask']
+        score_maps = data['score_maps'].transpose((1,2,0))
+        text_polys = data['text_polys']
+        ignore_tags = data['ignore_tags']
+        texts = data['texts']
+        all_care_polys = [text_polys[i] for i, tag in enumerate(ignore_tags) if not tag]
+        # 计算crop区域
+        crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys)
+        # crop 图片 保持比例填充
+        scale_w = self.size[0] / crop_w
+        scale_h = self.size[1] / crop_h
+        scale = min(scale_w, scale_h)
+        h = int(crop_h * scale)
+        w = int(crop_w * scale)
+        try:
+            if self.keep_ratio:
+                if len(im.shape) == 3:
+                    padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), im.dtype)
+                else:
+                    padimg = np.zeros((self.size[1], self.size[0]), im.dtype)
+                padimg[:h, :w] = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
+                img = padimg
+
+                padimg2 = np.zeros((self.size[1], self.size[0]), im.dtype)
+                padimg2[:h, :w] = cv2.resize(training_mask[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
+                data['training_mask'] = padimg2
+
+                padimg2 = np.zeros((self.size[1], self.size[0],6), im.dtype)
+                padimg2[:h, :w] = cv2.resize(score_maps[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
+                data['score_maps'] = padimg2.transpose((2,0,1))
+            else:
+                img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], tuple(self.size))
+        except Exception:
+            import traceback
+            traceback.print_exc()
+        # crop 文本框
+        text_polys_crop = []
+        ignore_tags_crop = []
+        texts_crop = []
+        try:
+            for poly, text, tag in zip(text_polys, texts, ignore_tags):
+                poly = ((np.array(poly) - (crop_x, crop_y)) * scale).astype('float32')
+                if not self.is_poly_outside_rect(poly, 0, 0, w, h):
+                    text_polys_crop.append(poly)
+                    ignore_tags_crop.append(tag)
+                    texts_crop.append(text)
+            data['img'] = img
+            data['text_polys'] = text_polys_crop
+            data['ignore_tags'] = ignore_tags_crop
+            data['texts'] = texts_crop
+        except:
+            a = 1
+        return data
+
+    def is_poly_in_rect(self, poly, x, y, w, h):
+        poly = np.array(poly)
+        if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
+            return False
+        if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
+            return False
+        return True
+
+    def is_poly_outside_rect(self, poly, x, y, w, h):
+        poly = np.array(poly)
+        if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
+            return True
+        if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
+            return True
+        return False
+
+    def split_regions(self, axis):
+        regions = []
+        min_axis = 0
+        for i in range(1, axis.shape[0]):
+            if axis[i] != axis[i - 1] + 1:
+                region = axis[min_axis:i]
+                min_axis = i
+                regions.append(region)
+        return regions
+
+    def random_select(self, axis, max_size):
+        xx = np.random.choice(axis, size=2)
+        xmin = np.min(xx)
+        xmax = np.max(xx)
+        xmin = np.clip(xmin, 0, max_size - 1)
+        xmax = np.clip(xmax, 0, max_size - 1)
+        return xmin, xmax
+
+    def region_wise_random_select(self, regions, max_size):
+        selected_index = list(np.random.choice(len(regions), 2))
+        selected_values = []
+        for index in selected_index:
+            axis = regions[index]
+            xx = int(np.random.choice(axis, size=1))
+            selected_values.append(xx)
+        xmin = min(selected_values)
+        xmax = max(selected_values)
+        return xmin, xmax
+
+    def crop_area(self, im, text_polys):
+        h, w = im.shape[:2]
+        h_array = np.zeros(h, dtype=np.int32)
+        w_array = np.zeros(w, dtype=np.int32)
+        for points in text_polys:
+            points = np.round(points, decimals=0).astype(np.int32)
+            minx = np.min(points[:, 0])
+            maxx = np.max(points[:, 0])
+            w_array[minx:maxx] = 1
+            miny = np.min(points[:, 1])
+            maxy = np.max(points[:, 1])
+            h_array[miny:maxy] = 1
+        # ensure the cropped area not across a text
+        h_axis = np.where(h_array == 0)[0]
+        w_axis = np.where(w_array == 0)[0]
+
+        if len(h_axis) == 0 or len(w_axis) == 0:
+            return 0, 0, w, h
+
+        h_regions = self.split_regions(h_axis)
+        w_regions = self.split_regions(w_axis)
+
+        for i in range(self.max_tries):
+            if len(w_regions) > 1:
+                xmin, xmax = self.region_wise_random_select(w_regions, w)
+            else:
+                xmin, xmax = self.random_select(w_axis, w)
+            if len(h_regions) > 1:
+                ymin, ymax = self.region_wise_random_select(h_regions, h)
+            else:
+                ymin, ymax = self.random_select(h_axis, h)
+
+            if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h:
+                # area too small
+                continue
+            num_poly_in_rect = 0
+            for poly in text_polys:
+                if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin):
+                    num_poly_in_rect += 1
+                    break
+
+            if num_poly_in_rect > 0:
+                return xmin, ymin, xmax - xmin, ymax - ymin
+
+        return 0, 0, w, h
+
+erc=EastRandomCropData()
+def image_label(data, n: int, m: float, input_size: int,
+                defrees: int = 10,
+                scales: np.ndarray = np.array([0.5, 1, 2.0, 3.0])) -> tuple:
+    '''
+    get image's corresponding matrix and ground truth
+    return
+    images [512, 512, 3]
+    score  [128, 128, 1]
+    geo    [128, 128, 5]
+    mask   [128, 128, 1]
+    '''
+
+
+    im = cv2.imread(data['img_path'])
+    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+    h, w, _ = im.shape
+    # 检查越界
+    data['text_polys'] = check_and_validate_polys(data['text_polys'], (h, w))
+    data['img'], data['text_polys'], = augmentation(im, data['text_polys'], scales, defrees, input_size)
+
+    h, w, _ = data['img'].shape
+    short_edge = min(h, w)
+    if isinstance(input_size, dict):
+        print(input_size)
+        pass
+    if short_edge < input_size:
+        # 保证短边 >= inputsize
+        scale = input_size / short_edge
+        data['img'] = cv2.resize(data['img'], dsize=None, fx=scale, fy=scale)
+        data['text_polys'] *= scale
+    h, w, _ = data['img'].shape
+    training_mask = np.ones((h, w), dtype=np.uint8)
+    score_maps = []
+    for i in range(1, n + 1):
+        # s1->sn,由小到大
+        score_map, training_mask = generate_rbox((h, w), data['text_polys'], data['ignore_tags'], training_mask, i, n, m)
+        score_maps.append(score_map)
+    score_maps = np.array(score_maps, dtype=np.float32)
+    data['training_mask']=training_mask
+    data['score_maps']=score_maps
+    data=erc(data)
+    return data
+
+
+    # imgs = data_aug.random_crop_author([im, score_maps.transpose((1, 2, 0)), training_mask], (input_size, input_size))
+    # return imgs[0], imgs[1].transpose((2, 0, 1)), imgs[2], text_polys, text_tags  # im,score_maps,training_mask#
+
+import torch
+class MyDataset(data.Dataset):
+    def __init__(self, config):
+        self.load_char_annotation = False
+        self.data_list = self.load_data(config.file)
+        self.data_shape = config.data_shape
+        self.filter_keys = config.filter_keys
+        self.transform = transforms.Compose([
+            transforms.ToTensor(),
+            transforms.Normalize(mean=config.mean, std=config.std)
+        ])
+        self.n = config.n
+        self.m = config.m
+
+    def __getitem__(self, index):
+        # print(self.image_list[index])
+        data = self.data_list[index]
+        img_path, text_polys, text_tags = self.data_list[index]['img_path'], self.data_list[index]['text_polys'], self.data_list[index]['ignore_tags']
+        data = image_label(data, input_size=self.data_shape,n=self.n,m=self.m)
+
+        im = cv2.imread(img_path)
+        if self.transform:
+            img = self.transform(data['img'])
+        shape = (data['img'].shape[0], data['img'].shape[1])
+
+        data['img'] = img
+        data['shape'] = shape
+        # data['score_maps'] = score_maps
+        # data['training_mask'] = training_mask
+        # data['text_polys'] =torch.Tensor(list(text_polys))
+        # data['ignore_tags'] = [text_tags]
+        # data['shape'] = shape
+        # data['texts'] = [data['texts']]
+
+        if len(self.filter_keys):
+            data_dict = {}
+            for k, v in data.items():
+                if k not in self.filter_keys:
+                    data_dict[k] = v
+            return data_dict
+        else:
+            # return {'img': img, 'score_maps': score_maps, 'training_mask': training_mask, 'shape': shape, 'text_polys': list(text_polys), 'ignore_tags': text_tags}
+            return {}
+
+    def load_data(self, path: str) -> list:
+        data_list = []
+        content = load_json(path)
+        for gt in tqdm(content['data_list'], desc='read file {}'.format(path)):
+            img_path = os.path.join(content['data_root'], gt['img_name'])
+            polygons = []
+            texts = []
+            illegibility_list = []
+            language_list = []
+            for annotation in gt['annotations']:
+                if len(annotation['polygon']) == 0 or len(annotation['text']) == 0:
+                    continue
+                polygons.append(annotation['polygon'])
+                texts.append(annotation['text'])
+                illegibility_list.append(annotation['illegibility'])
+                language_list.append(annotation['language'])
+                if self.load_char_annotation:
+                    for char_annotation in annotation['chars']:
+                        if len(char_annotation['polygon']) == 0 or len(char_annotation['char']) == 0:
+                            continue
+                        polygons.append(char_annotation['polygon'])
+                        texts.append(char_annotation['char'])
+                        illegibility_list.append(char_annotation['illegibility'])
+                        language_list.append(char_annotation['language'])
+            data_list.append({'img_path': img_path, 'img_name': gt['img_name'], 'text_polys': np.array(polygons, dtype=np.float32),
+                              'texts': texts, 'ignore_tags': illegibility_list})
+        return data_list
+
+    def __len__(self):
+        return len(self.data_list)
+
+    def save_label(self, img_path, label):
+        save_path = img_path.replace('img', 'save')
+        if not os.path.exists(os.path.split(save_path)[0]):
+            os.makedirs(os.path.split(save_path)[0])
+        img = draw_bbox(img_path, label)
+        cv2.imwrite(save_path, img)
+        return img
+
+
+def show_img(imgs: np.ndarray, color=False):
+    if (len(imgs.shape) == 3 and color) or (len(imgs.shape) == 2 and not color):
+        imgs = np.expand_dims(imgs, axis=0)
+    for img in imgs:
+        plt.figure()
+        plt.imshow(img, cmap=None if color else 'gray')
+
+
+if __name__ == '__main__':
+    import torch
+    import config
+    from config.cfg_det_pse import config
+    from tqdm import tqdm
+    from torch.utils.data import DataLoader
+    import matplotlib.pyplot as plt
+    from torchvision import transforms
+
+    train_data = MyDataset(config.dataset.train.dataset)
+    train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, num_workers=0)
+
+    pbar = tqdm(total=len(train_loader))
+    for i, batch_data in enumerate(train_loader):
+        img, label, mask = batch_data['img'], batch_data['score_maps'], batch_data['training_mask']
+        print(label.shape)
+        print(img.shape)
+        print(label[0][-1].sum())
+        print(mask[0].shape)
+        pbar.update(1)
+        show_img((img[0] * mask[0].to(torch.float)).numpy().transpose(1, 2, 0), color=True)
+        show_img(label[0])
+        show_img(mask[0])
+        plt.show()
+
+    pbar.close()

+ 86 - 0
torchocr/datasets/RecCollateFn.py

@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/6/16 17:06
+# @Author  : zhoujun
+import torch
+import numpy as np
+import cv2
+from torchvision import transforms
+
+class Resize:
+    def __init__(self, img_h, img_w, pad=True, **kwargs):
+        self.img_h = img_h
+        self.img_w = img_w
+        self.pad = pad
+
+    def __call__(self, img: np.ndarray):
+        """
+        对图片进行处理,先按照高度进行resize,resize之后如果宽度不足指定宽度,就补黑色像素,否则就强行缩放到指定宽度
+        :param img_path: 图片地址
+        :return: 处理为指定宽高的图片
+        """
+        img_h = self.img_h
+        img_w = self.img_w
+        h, w = img.shape[:2]
+        ratio_h = self.img_h / h
+        new_w = int(w * ratio_h)
+        if new_w < img_w and self.pad:
+            img = cv2.resize(img, (new_w, img_h))
+            if len(img.shape) == 2:
+                img = np.expand_dims(img, 2)
+            step = np.zeros((img_h, img_w - new_w, img.shape[-1]), dtype=img.dtype)
+            img = np.column_stack((img, step))
+        else:
+            img = cv2.resize(img, (img_w, img_h))
+            if len(img.shape) == 2:
+                img = np.expand_dims(img, 2)
+        if img.shape[-1] == 1:
+            img = img[:, :, 0]
+        return img
+
+
+class RecCollateFn:
+    def __init__(self, *args, **kwargs):
+        self.process = kwargs['dataset'].process
+        self.t = transforms.ToTensor()
+
+    def __call__(self, batch):
+        resize_images = []
+
+        all_same_height_images = [self.process.resize_with_specific_height(_['img']) for _ in batch]
+        max_img_w = max({m_img.shape[1] for m_img in all_same_height_images})
+        # make sure max_img_w is integral multiple of 8
+        max_img_w = int(np.ceil(max_img_w / 8) * 8)
+        labels = []
+        for i in range(len(batch)):
+            _label = batch[i]['label']
+            labels.append(_label)
+            img = self.process.normalize_img(all_same_height_images[i])
+            img = self.process.width_pad_img(img, max_img_w)
+            
+            img = img.transpose([2, 0, 1])
+            resize_images.append(torch.tensor(img, dtype=torch.float))
+        resize_images = torch.stack(resize_images)
+        return {'img': resize_images, 'label': labels}
+
+class RecCollateFnWithResize:
+    """
+    将图片resize到固定宽度的RecCollateFn
+    """
+    def __init__(self, *args, **kwargs):
+        from torchvision import transforms
+        self.img_h = kwargs.get('img_h', 32)
+        self.img_w = kwargs.get('img_w', 320)
+        self.pad = kwargs.get('pad', True)
+        self.t = transforms.ToTensor()
+
+    def __call__(self, batch):
+        resize_images = []
+        resize_image_class = Resize(self.img_h, self.img_w, self.pad)
+        labels = []
+        for data in batch:
+            labels.append(data['label'])
+            resize_image = resize_image_class(data['img'])
+            resize_image = self.t(resize_image)
+            resize_images.append(resize_image)
+        resize_images = torch.cat([t.unsqueeze(0) for t in resize_images], 0)
+        return {'img':resize_images,'label':labels}

+ 461 - 0
torchocr/datasets/RecDataSet.py

@@ -0,0 +1,461 @@
+# -*-coding:utf-8-*-
+"""
+@Author: Jeffery Sheng (Zhenfei Sheng)
+@Time:   2020/5/21 19:44
+@File:   RecDataSet.py
+"""
+import six
+import cv2
+import torch
+import numpy as np
+from PIL import Image
+from torch.utils.data import Dataset, DataLoader
+from torchocr.utils.CreateRecAug import cv2pil, pil2cv, RandomBrightness, RandomContrast, \
+    RandomLine, RandomSharpness, Compress, Rotate, \
+    Blur, MotionBlur, Salt, AdjustResolution
+import re
+
+
+class RecTextLineDataset(Dataset):
+    def __init__(self, config):
+        """
+        文本行 DataSet, 用于处理标注格式为 `img_path\tlabel` 的标注格式
+
+        :param config: 相关配置,一般为 config['dataset']['train']['dataset] or config['dataset']['eval']['dataset]
+                其主要应包含如下字段: file: 标注文件路径
+                                    input_h: 图片的目标高
+                                    mean: 归一化均值
+                                    std: 归一化方差
+                                    augmentation: 使用使用数据增强
+        :return None
+        """
+        self.augmentation = config.augmentation
+        self.process = RecDataProcess(config)
+        with open(config.alphabet, 'r', encoding='utf-8') as file:
+            alphabet = ''.join([s.strip('\n') for s in file.readlines()])
+        # alphabet += ' '
+        alphabet = alphabet.replace("blank"," ") #add
+        self.str2idx = {c: i for i, c in enumerate(alphabet)}
+        self.labels = []
+        # if "test.txt" in config.file:
+        with open(config.file, 'r', encoding='utf-8') as f_reader:
+            for m_line in f_reader.readlines():
+                m_line=m_line.strip()
+                params = m_line.split(' ')
+                # print(params)
+                if len(params) >= 2:
+                    m_image_name = params[0]
+                    m_image_name = '/data2/znj/CRNN_Chinese_Characters_Rec/data/data/python_znj/Lets_OCR/recognizer/crnn/data/images/'+m_image_name
+                    m_gt_text = params[1:]
+                    # print(m_gt_text)
+                    m_gt_text = "".join([alphabet[int(idx)] if int(idx)<len(alphabet) else '#none#' for idx in m_gt_text])
+                    # if True in [c not in self.str2idx for c in m_gt_text]:
+                    #     continue
+                    if "#none#" in m_gt_text:
+                        continue
+                    self.labels.append((m_image_name, m_gt_text))
+
+        # with open(config.file, 'r', encoding='utf-8') as f_reader:
+        #     for m_line in f_reader.readlines():
+        #         is_skip = False
+        #         m_line=m_line.strip()
+        #         params = m_line.split(' ')
+        #         if len(params) >= 2:
+        #             m_image_name = params[0]
+        #             m_image_name = '/data2/znj/text_renderer/output2/default/' + m_image_name + '.jpg'
+        #             m_gt_text = params[1]
+        #             for w in m_gt_text:
+        #                 if w not in alphabet:
+        #                     is_skip = True
+        #                     break
+        #             if is_skip:
+        #                 continue
+        #             self.labels.append((m_image_name, m_gt_text))
+
+        # add
+        # if "train.txt" in config.file:
+        #     with open("/data2/znj/PytorchOCR/data/train.txt", 'r', encoding='utf-8') as f_reader:
+        #         for m_line in f_reader.readlines():
+        #             m_line=m_line.strip()
+        #             params = m_line.split('\t')
+        #             # print(params)
+        #             if len(params) >= 3:
+        #                 m_image_name = params[2]
+        #                 m_image_name = '/data2/znj/PytorchOCR/data/image/'+m_image_name
+        #                 m_gt_text = params[3]
+        #                 # print(m_gt_text)
+        #                 # m_gt_text = "".join([alphabet[int(idx)] if int(idx)<len(alphabet) else '#none#' for idx in m_gt_text])
+        #                 # if True in [c not in self.str2idx for c in m_gt_text]:
+        #                 #     continue
+        #                 _m_gt_text = "".join([str(self.str2idx.get(i,'#none#')) for i in m_gt_text])
+        #                 if "#none#" in _m_gt_text:
+        #                     continue
+        #                 self.labels.append((m_image_name, m_gt_text))
+        # add
+        # if "train.txt" in config.file:
+        #     with open("/data2/znj/text_renderer/output3/default/tmp_labels.txt", 'r', encoding='utf-8') as f_reader:
+        #         for m_line in f_reader.readlines():
+        #             is_skip = False
+        #             m_line=m_line.strip()
+        #             params = m_line.split(' ')
+        #             # print(params)
+        #             if len(params) >= 2:
+        #                 m_image_name = params[0]
+        #                 m_image_name = '/data2/znj/text_renderer/output3/default/' + m_image_name + '.jpg'
+        #                 m_gt_text = params[1]
+        #                 for w in m_gt_text:
+        #                     if w not in alphabet:
+        #                         is_skip = True
+        #                         break
+        #                 if is_skip:
+        #                     continue
+        #                 self.labels.append((m_image_name, m_gt_text))
+
+        # test
+        with open(config.file, 'r', encoding='utf-8') as f_reader:
+            for m_line in f_reader.readlines():
+                m_line=m_line.strip()
+                params = m_line.split(' ')
+                # print(params)
+                if len(params) >= 2:
+                    m_image_name = params[0]
+
+                    m_gt_text = params[1:]
+                    # print(m_gt_text)
+                    m_gt_text = "".join([alphabet[int(idx)] if int(idx)<len(alphabet) else '#none#' for idx in m_gt_text])
+                    # if True in [c not in self.str2idx for c in m_gt_text]:
+                    #     continue
+                    if "#none#" in m_gt_text:
+                        continue
+                    # self.labels.append((m_image_name, m_gt_text))
+                    self.labels.append((m_image_name, 'a123'))
+        print(self.labels)
+        # paddle 识别数据
+        if "log4.log" in config.file:
+            with open(config.file, 'r', encoding='utf-8') as f_reader:
+                for m_line in f_reader.readlines():
+                    m_line=m_line.strip()
+                    iamge_path, line_split2 = re.split(" ", m_line, maxsplit=1)
+                    text, box = re.split(" \[\[", line_split2, maxsplit=1)
+                    box = '[[' + box
+                    box,_ = re.split(" rec_res:", box, maxsplit=1)
+                    self.labels.append((iamge_path, text,box))
+        print(self.labels)
+
+    def _find_max_length(self):
+        return max({len(_[1]) for _ in self.labels})
+
+    def __len__(self):
+        return len(self.labels)
+
+    def __getitem__(self, index):
+        # get img_path and trans
+        # img_path, trans = self.labels[index]
+        label = self.labels[index]
+        if len(label)==2:
+            img_path, trans = label
+            # read img
+            img = cv2.imread(img_path)
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        elif len(label)==3:
+            img_path, trans,box = label
+            # read img
+            img = cv2.imread(img_path)
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+            bbox = eval(box)
+            x1 = int(min([i[0] for i in bbox]))
+            x2 = int(max([i[0] for i in bbox]))
+            y1 = int(min([i[1] for i in bbox]))
+            y2 = int(max([i[1] for i in bbox]))
+            img = img[y1:y2, x1:x2]
+        # do aug
+        if len(label)==2:
+            if self.augmentation:
+                img = pil2cv(self.process.aug_img(cv2pil(img)))
+        return {'img': img, 'label': trans}
+
+class RecTextLineDataset2(Dataset):
+    def __init__(self, config):
+        """
+        文本行 DataSet, 用于处理标注格式为 `img_path\tlabel` 的标注格式
+
+        :param config: 相关配置,一般为 config['dataset']['train']['dataset] or config['dataset']['eval']['dataset]
+                其主要应包含如下字段: file: 标注文件路径
+                                    input_h: 图片的目标高
+                                    mean: 归一化均值
+                                    std: 归一化方差
+                                    augmentation: 使用使用数据增强
+        :return None
+        """
+        self.augmentation = config.augmentation
+        self.process = RecDataProcess(config)
+
+        self.labels = []
+
+        with open("", 'r', encoding='utf-8') as f_reader:
+            for m_line in f_reader.readlines():
+                m_line=m_line.strip()
+                params = m_line.split(' ')
+                if len(params) >= 2:
+                    m_image_name = params[0]
+
+                    m_gt_text = params[1:]
+                    m_gt_text = "".join([alphabet[int(idx)] if int(idx)<len(alphabet) else '#none#' for idx in m_gt_text])
+                    if "#none#" in m_gt_text:
+                        continue
+                    self.labels.append((m_image_name, 'a123'))
+        print(self.labels)
+
+    def _find_max_length(self):
+        return max({len(_[1]) for _ in self.labels})
+
+    def __len__(self):
+        return len(self.labels)
+
+    def __getitem__(self, index):
+        # get img_path and trans
+        img_path, trans = self.labels[index]
+        # read img
+        img = cv2.imread(img_path)
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        # do aug
+        if self.augmentation:
+            img = pil2cv(self.process.aug_img(cv2pil(img)))
+        return {'img': img, 'label': trans}
+
+class RecLmdbDataset(Dataset):
+    def __init__(self, config):
+        """
+        Lmdb DataSet, 用于处理转换为 lmdb 文件后的数据集
+
+        :param config: 相关配置,一般为 config['dataset']['train']['dataset] or config['dataset']['eval']['dataset]
+                其主要应包含如下字段: file: 标注文件路径
+                                    input_h: 图片的目标高
+                                    mean: 归一化均值
+                                    std: 归一化方差
+                                    augmentation: 使用使用数据增强
+        :return None
+        """
+        import lmdb, sys
+        self.env = lmdb.open(config.file, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
+        if not self.env:
+            print('cannot create lmdb from %s' % (config.file))
+            sys.exit(0)
+
+        self.augmentation = config.augmentation
+        self.process = RecDataProcess(config)
+        self.filtered_index_list = []
+        self.labels = []
+        self.str2idx = {c: i for i, c in enumerate(config.alphabet)}
+        with self.env.begin(write=False) as txn:
+            nSamples = int(txn.get('num-samples'.encode()))
+            self.nSamples = nSamples
+            for index in range(self.nSamples):
+                index += 1  # lmdb starts with 1
+                label_key = 'label-%09d'.encode() % index
+                label = txn.get(label_key).decode('utf-8')
+                # todo 添加 过滤最长
+                # if len(label) > config.max_len:
+                #     # print(f'The length of the label is longer than max_length: length
+                #     # {len(label)}, {label} in dataset {self.root}')
+                #     continue
+                if True in [c not in self.str2idx for c in label]:
+                    continue
+                # By default, images containing characters which are not in opt.character are filtered.
+                # You can add [UNK] token to `opt.character` in utils.py instead of this filtering.
+                self.labels.append(label)
+                self.filtered_index_list.append(index)
+
+    def _find_max_length(self):
+        return max({len(_) for _ in self.labels})
+
+    def __getitem__(self, index):
+        index = self.filtered_index_list[index]
+        with self.env.begin(write=False) as txn:
+            label_key = 'label-%09d'.encode() % index
+            label = txn.get(label_key).decode('utf-8')
+            img_key = 'image-%09d'.encode() % index
+            imgbuf = txn.get(img_key)
+
+            buf = six.BytesIO()
+            buf.write(imgbuf)
+            buf.seek(0)
+            img = Image.open(buf).convert('RGB')  # for color image
+            # We only train and evaluate on alphanumerics (or pre-defined character set in rec_train.py)
+            img = np.array(img)
+            if self.augmentation:
+                img = pil2cv(self.process.aug_img(cv2pil(img)))
+        return {'img': img, 'label': label}
+
+    def __len__(self):
+        return len(self.filtered_index_list)
+
+
+class RecDataLoader:
+    def __init__(self, dataset, batch_size, shuffle, num_workers, **kwargs):
+        """
+        自定义 DataLoader, 主要实现数据集的按长度划分,将长度相近的放在一个 batch
+
+        :param dataset: 继承自 torch.utils.data.DataSet的类对象
+        :param batch_size: 一个 batch 的图片数量
+        :param shuffle: 是否打乱数据集
+        :param num_workers: 后台进程数
+        :param kwargs: **
+        """
+        self.dataset = dataset
+        self.process = dataset.process
+        self.len_thresh = self.dataset._find_max_length() // 2
+        self.batch_size = batch_size
+        self.shuffle = shuffle
+        self.num_workers = num_workers
+        self.iteration = 0
+        self.dataiter = None
+        self.queue_1 = list()
+        self.queue_2 = list()
+
+    def __len__(self):
+        return len(self.dataset) // self.batch_size if len(self.dataset) % self.batch_size == 0 \
+            else len(self.dataset) // self.batch_size + 1
+
+    def __iter__(self):
+        return self
+
+    def pack(self, batch_data):
+        batch = {'img': [], 'label': []}
+        # img tensor current shape: B,H,W,C
+        all_same_height_images = [self.process.resize_with_specific_height(_['img'][0].numpy()) for _ in batch_data]
+        max_img_w = max({m_img.shape[1] for m_img in all_same_height_images})
+        # make sure max_img_w is integral multiple of 8
+        max_img_w = int(np.ceil(max_img_w / 8) * 8)
+        for i in range(len(batch_data)):
+            _label = batch_data[i]['label'][0]
+            img = self.process.normalize_img(self.process.width_pad_img(all_same_height_images[i], max_img_w))
+            img = img.transpose([2, 0, 1])
+            batch['img'].append(torch.tensor(img, dtype=torch.float))
+            batch['label'].append(_label)
+        batch['img'] = torch.stack(batch['img'])
+        return batch
+
+    def build(self):
+        self.dataiter = DataLoader(self.dataset, batch_size=1, shuffle=self.shuffle,
+                                   num_workers=self.num_workers).__iter__()
+
+    def __next__(self):
+        if self.dataiter == None:
+            self.build()
+        if self.iteration == len(self.dataset) and len(self.queue_2):
+            batch_data = self.queue_2
+            self.queue_2 = list()
+            return self.pack(batch_data)
+        if not len(self.queue_2) and not len(self.queue_1) and self.iteration == len(self.dataset):
+            self.iteration = 0
+            self.dataiter = None
+            raise StopIteration
+        # start iteration
+        try:
+            while True:
+                # get data from origin dataloader
+                temp = self.dataiter.__next__()
+                self.iteration += 1
+                # to different queue
+                if len(temp['label'][0]) <= self.len_thresh:
+                    self.queue_1.append(temp)
+                else:
+                    self.queue_2.append(temp)
+
+                # to store batch data
+                batch_data = None
+                # queue_1 full, push to batch_data
+                if len(self.queue_1) == self.batch_size:
+                    batch_data = self.queue_1
+                    self.queue_1 = list()
+                # or queue_2 full, push to batch_data
+                elif len(self.queue_2) == self.batch_size:
+                    batch_data = self.queue_2
+                    self.queue_2 = list()
+
+                # start to process batch
+                if batch_data is not None:
+                    return self.pack(batch_data)
+        # deal with last batch
+        except StopIteration:
+            if self.queue_1 == []:
+                raise StopIteration
+            batch_data = self.queue_1
+            self.queue_1 = list()
+            return self.pack(batch_data)
+
+
+class RecDataProcess:
+    def __init__(self, config):
+        """
+        文本是被数据增广类
+
+        :param config: 配置,主要用到的字段有 input_h, mean, std
+        """
+        self.config = config
+        self.random_contrast = RandomContrast(probability=0.3)
+        self.random_brightness = RandomBrightness(probability=0.3)
+        self.random_sharpness = RandomSharpness(probability=0.3)
+        self.compress = Compress(probability=0.3)
+        self.rotate = Rotate(probability=0.5)
+        self.blur = Blur(probability=0.3)
+        self.motion_blur = MotionBlur(probability=0.3)
+        self.salt = Salt(probability=0.3)
+        self.adjust_resolution = AdjustResolution(probability=0.3)
+        self.random_line = RandomLine(probability=0.3)
+        self.random_contrast.setparam()
+        self.random_brightness.setparam()
+        self.random_sharpness.setparam()
+        self.compress.setparam()
+        self.rotate.setparam()
+        self.blur.setparam()
+        self.motion_blur.setparam()
+        self.salt.setparam()
+        self.adjust_resolution.setparam()
+
+    def aug_img(self, img):
+        img = self.random_contrast.process(img)
+        img = self.random_brightness.process(img)
+        img = self.random_sharpness.process(img)
+        img = self.random_line.process(img)
+
+        if img.size[1] >= 32:
+            img = self.compress.process(img)
+            img = self.adjust_resolution.process(img)
+            img = self.motion_blur.process(img)
+            img = self.blur.process(img)
+        img = self.rotate.process(img)
+        img = self.salt.process(img)
+        return img
+
+    def resize_with_specific_height(self, _img):
+        """
+        将图像resize到指定高度
+        :param _img:    待resize的图像
+        :return:    resize完成的图像
+        """
+        resize_ratio = self.config.input_h / _img.shape[0]
+        return cv2.resize(_img, (0, 0), fx=resize_ratio, fy=resize_ratio, interpolation=cv2.INTER_LINEAR)
+
+    def normalize_img(self, _img):
+        """
+        根据配置的均值和标准差进行归一化
+        :param _img:    待归一化的图像
+        :return:    归一化后的图像
+        """
+        return (_img.astype(np.float32) / 255 - self.config.mean) / self.config.std
+
+    def width_pad_img(self, _img, _target_width, _pad_value=0):
+        """
+        将图像进行高度不变,宽度的调整的pad
+        :param _img:    待pad的图像
+        :param _target_width:   目标宽度
+        :param _pad_value:  pad的值
+        :return:    pad完成后的图像
+        """
+        _height, _width, _channels = _img.shape
+        to_return_img = np.ones([_height, _target_width, _channels], dtype=_img.dtype) * _pad_value
+        to_return_img[:_height, :_width, :] = _img
+        return to_return_img

+ 99 - 0
torchocr/datasets/__init__.py

@@ -0,0 +1,99 @@
+import copy
+from addict import Dict
+from torch.utils.data import DataLoader
+
+from .RecDataSet import RecDataLoader, RecTextLineDataset, RecLmdbDataset
+from .DetDataSet import JsonDataset
+from .RecCollateFn import RecCollateFn
+from .DetCollateFN import DetCollectFN
+from .DetDateSetPse import MyDataset
+from .DetDataSetFce import FCEDataset
+
+__all__ = ['build_dataloader']
+
+support_dataset = ['RecTextLineDataset', 'RecLmdbDataset', 'DetTextLineDataset','JsonDataset','MyDataset','FCEDataset']
+support_loader = ['RecDataLoader', 'DataLoader']
+
+
+def build_dataset(config):
+    """
+    根据配置构造dataset
+
+    :param config: 数据集相关的配置,一般为 config['dataset']['train']['dataset] or config['dataset']['eval']['dataset]
+    :return: 根据配置构造好的 DataSet 类对象
+    """
+    dataset_type = config.pop('type')
+    assert dataset_type in support_dataset, f'{dataset_type} is not developed yet!, only {support_dataset} are support now'
+    dataset_class = eval(dataset_type)(config)
+    return dataset_class
+
+
+def build_loader(dataset, config):
+    """
+    根据配置构造 dataloader, 包含两个步骤,1. 构造 collate_fn, 2. 构造 dataloader
+
+    :param dataset: 继承自 torch.utils.data.DataSet的类对象
+    :param config: loader 相关的配置,一般为 config['dataset']['train']['loader] or config['dataset']['eval']['loader]
+    :return: 根据配置构造好的 DataSet 类对象
+    """
+    dataloader_type = config.pop('type')
+    assert dataloader_type in support_loader, f'{dataloader_type} is not developed yet!, only {support_loader} are support now'
+
+    # build collate_fn
+    if 'collate_fn' in config:
+        config['collate_fn']['dataset'] = dataset
+        collate_fn = build_collate_fn(config.pop('collate_fn'))
+    else:
+        collate_fn = None
+    dataloader_class = eval(dataloader_type)(dataset=dataset, collate_fn=collate_fn, **config ,pin_memory=True)
+    return dataloader_class
+
+def build_loader_add(dataset, config):
+    """
+    根据配置构造 dataloader, 包含两个步骤,1. 构造 collate_fn, 2. 构造 dataloader
+
+    :param dataset: 继承自 torch.utils.data.DataSet的类对象
+    :param config: loader 相关的配置,一般为 config['dataset']['train']['loader] or config['dataset']['eval']['loader]
+    :return: 根据配置构造好的 DataSet 类对象
+    """
+    dataloader_type = config.pop('type')
+    assert dataloader_type in support_loader, f'{dataloader_type} is not developed yet!, only {support_loader} are support now'
+
+    # build collate_fn
+    if 'collate_fn' in config:
+        config['collate_fn']['dataset'] = dataset
+        collate_fn = build_collate_fn(config.pop('collate_fn'))
+    else:
+        collate_fn = None
+    dataloader_class = eval(dataloader_type)(dataset=dataset, collate_fn=collate_fn, **config ,pin_memory=True)
+    return dataloader_class
+
+
+def build_collate_fn(config):
+    """
+    根据配置构造 collate_fn
+
+    :param config: collate_fn 相关的配置
+    :return: 根据配置构造好的 collate_fn 类对象
+    """
+    collate_fn_type = config.pop('type')
+    if len(collate_fn_type) == 0:
+        return None
+    collate_fn_class = eval(collate_fn_type)(**config)
+    return collate_fn_class
+
+
+def build_dataloader(config):
+    """
+    根据配置构造 dataloader, 包含两个步骤,1. 构造 dataset, 2. 构造 dataloader
+    :param config: 数据集相关的配置,一般为 config['dataset']['train'] or config['dataset']['eval']
+    :return: 根据配置构造好的 DataLoader 类对象
+    """
+    # build dataset
+    copy_config = copy.deepcopy(config)
+    copy_config = Dict(copy_config)
+    dataset = build_dataset(copy_config.dataset)
+
+    # build loader
+    loader = build_loader(dataset, copy_config.loader)
+    return loader

+ 91 - 0
torchocr/datasets/alphabets/dict.txt

@@ -0,0 +1,91 @@
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+\
+]
+^
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+|
+}
+~

+ 3827 - 0
torchocr/datasets/alphabets/dict_baidu.txt

@@ -0,0 +1,3827 @@
+ⅱ
+①
+②
+③
+④
+〇
+の
+サ
+シ
+ジ
+マ
+㸃
+一
+丁
+七
+万
+丈
+三
+上
+下
+不
+与
+丑
+专
+且
+世
+丘
+丙
+业
+丛
+东
+丝
+丞
+丢
+两
+严
+丧
+丨
+个
+丫
+中
+丰
+串
+临
+丶
+丸
+丹
+为
+主
+丼
+丽
+举
+乃
+久
+么
+义
+之
+乌
+乍
+乎
+乐
+乒
+乓
+乔
+乖
+乘
+乙
+九
+乞
+也
+习
+乡
+书
+买
+乱
+乳
+乾
+了
+予
+争
+事
+二
+于
+亏
+云
+互
+五
+井
+亘
+亚
+些
+亢
+交
+亦
+产
+亨
+亩
+享
+京
+亭
+亮
+亲
+亳
+人
+亿
+什
+仁
+仅
+仆
+仇
+今
+介
+从
+仑
+仓
+仔
+仕
+他
+仗
+付
+仙
+仞
+仟
+代
+令
+以
+仨
+仪
+们
+仰
+仲
+件
+价
+仺
+任
+份
+仿
+企
+伊
+伍
+伏
+伐
+休
+众
+优
+伙
+会
+伞
+伟
+传
+伢
+伤
+伦
+伪
+伯
+估
+伴
+伶
+伸
+似
+伽
+但
+位
+低
+住
+佐
+佑
+体
+何
+佗
+佘
+佛
+作
+你
+佣
+佧
+佩
+佬
+佰
+佳
+佶
+使
+侃
+侈
+例
+侍
+供
+依
+侠
+侣
+侦
+侧
+侨
+侬
+侯
+侵
+便
+促
+俄
+俊
+俏
+俗
+保
+俞
+信
+俩
+俪
+俬
+俭
+修
+俱
+俵
+俺
+倌
+倍
+倒
+候
+倚
+借
+倡
+倩
+倪
+债
+值
+倾
+假
+偏
+做
+停
+健
+偶
+偷
+偿
+傅
+傣
+储
+催
+傲
+傻
+像
+僧
+儒
+儿
+允
+元
+兄
+充
+兆
+先
+光
+克
+免
+兑
+兔
+党
+兜
+入
+全
+八
+公
+六
+兮
+兰
+共
+关
+兴
+兵
+其
+具
+典
+兹
+养
+兼
+兽
+兿
+冀
+内
+冈
+冉
+册
+再
+冒
+写
+冚
+军
+农
+冠
+冥
+冬
+冮
+冯
+冰
+冲
+决
+况
+冶
+冷
+冻
+净
+准
+凉
+凌
+减
+凝
+几
+凡
+凤
+処
+凭
+凯
+凰
+凳
+凸
+凹
+出
+击
+函
+凿
+刀
+刁
+刃
+分
+切
+刊
+刑
+划
+列
+刘
+则
+刚
+创
+初
+删
+判
+刨
+利
+别
+刮
+到
+制
+刷
+券
+刹
+刺
+刻
+剁
+剂
+剃
+削
+前
+剐
+剑
+剔
+剖
+剥
+剧
+剩
+剪
+副
+割
+劈
+力
+劝
+办
+功
+加
+务
+动
+助
+努
+劫
+励
+劲
+劳
+劵
+势
+勃
+勅
+勇
+勋
+勐
+勑
+勒
+勘
+募
+勤
+勺
+勾
+勿
+匀
+包
+匆
+匍
+化
+北
+匙
+匝
+匠
+匡
+匪
+匹
+区
+医
+匾
+十
+千
+升
+午
+卉
+半
+华
+协
+卑
+卒
+卓
+单
+卖
+南
+博
+卜
+占
+卡
+卢
+卤
+卦
+卧
+卫
+卯
+印
+危
+即
+却
+卵
+卷
+卸
+卿
+厂
+厅
+历
+厉
+压
+厕
+厘
+厚
+厝
+原
+厢
+厦
+厨
+去
+县
+叁
+参
+又
+叉
+及
+友
+双
+反
+发
+叔
+取
+受
+变
+叙
+叠
+口
+古
+句
+另
+只
+叫
+召
+叭
+叮
+可
+台
+史
+右
+叶
+号
+司
+叹
+叽
+吁
+吃
+各
+合
+吉
+吊
+同
+名
+后
+吐
+向
+吓
+吕
+吖
+吗
+君
+吞
+吟
+否
+吧
+吨
+含
+听
+启
+吴
+吵
+吸
+吹
+吻
+吾
+呀
+呆
+呈
+告
+呐
+呕
+呗
+员
+呛
+呢
+呦
+周
+呱
+味
+呵
+呷
+呼
+命
+咀
+咋
+和
+咏
+咔
+咕
+咖
+咚
+咤
+咨
+咪
+咱
+咳
+咸
+咻
+咽
+哀
+品
+哆
+哇
+哈
+响
+哎
+哒
+哓
+哔
+哗
+哚
+哟
+哥
+哦
+哨
+哪
+哮
+哲
+哺
+唇
+唐
+唔
+唛
+唢
+唤
+售
+唯
+唱
+唻
+啃
+啄
+商
+啊
+啡
+啤
+啥
+啦
+啵
+啸
+啼
+啾
+喀
+善
+喆
+喇
+喉
+喊
+喔
+喘
+喜
+喝
+喨
+喱
+喵
+喷
+喻
+喽
+嗅
+嗒
+嗝
+嗡
+嗨
+嗲
+嗽
+嘀
+嘉
+嘎
+嘛
+嘟
+嘢
+嘬
+嘴
+嘻
+嘿
+噜
+噢
+器
+噪
+嚏
+嚼
+囊
+囍
+囗
+四
+回
+因
+囡
+团
+囧
+园
+困
+囱
+围
+固
+国
+图
+囿
+圃
+圆
+圈
+圗
+圜
+土
+圣
+圧
+在
+圩
+圪
+地
+圳
+场
+圾
+址
+坂
+均
+坊
+坋
+坎
+坏
+坐
+坑
+块
+坚
+坛
+坝
+坞
+坟
+坠
+坡
+坤
+坦
+坨
+坪
+坯
+垂
+垃
+型
+垒
+垓
+垛
+垢
+垣
+垦
+垧
+垫
+埃
+埋
+城
+埔
+埗
+域
+埠
+培
+基
+堂
+堃
+堆
+堑
+堡
+堤
+堰
+堵
+塑
+塔
+塘
+塞
+填
+塾
+境
+墅
+墓
+墙
+增
+墟
+墨
+墩
+壁
+壕
+壤
+士
+壮
+声
+壳
+壶
+壹
+处
+备
+夌
+复
+夏
+夕
+外
+夙
+多
+夜
+够
+大
+天
+太
+夫
+央
+夯
+失
+头
+夷
+夹
+夺
+奂
+奇
+奈
+奉
+奋
+奎
+奏
+契
+奓
+奔
+奕
+奖
+套
+奠
+奢
+奥
+女
+奴
+奶
+她
+好
+如
+妃
+妆
+妇
+妈
+妊
+妍
+妖
+妙
+妞
+妤
+妥
+妮
+妯
+妹
+妻
+姆
+始
+姐
+姑
+姓
+委
+姗
+姚
+姜
+姣
+姥
+姨
+姬
+姻
+姿
+威
+娃
+娄
+娅
+娇
+娌
+娓
+娘
+娜
+娟
+娠
+娣
+娱
+娲
+娽
+婆
+婉
+婕
+婚
+婴
+婵
+婷
+媄
+媒
+媚
+媛
+媳
+嫁
+嫂
+嫚
+嫩
+嬉
+子
+孔
+孕
+孖
+字
+存
+孙
+孚
+孜
+孝
+孟
+季
+孤
+学
+孩
+孬
+孵
+宁
+它
+宅
+宇
+守
+安
+宋
+完
+宏
+宓
+宗
+官
+宙
+定
+宛
+宜
+宝
+实
+宠
+审
+客
+宣
+室
+宦
+宪
+宫
+宰
+害
+宴
+宵
+家
+宸
+容
+宽
+宾
+宿
+寂
+寄
+寅
+密
+寇
+富
+寒
+寓
+寝
+察
+寨
+寰
+寳
+寸
+对
+寺
+寻
+导
+寿
+封
+専
+射
+将
+尊
+小
+少
+尓
+尔
+尕
+尖
+尘
+尚
+尝
+尤
+尧
+尬
+就
+尹
+尺
+尼
+尽
+尾
+尿
+局
+层
+居
+屈
+届
+屋
+屌
+屏
+屑
+展
+属
+屠
+履
+屯
+山
+屹
+屿
+岁
+岂
+岐
+岑
+岔
+岗
+岙
+岚
+岛
+岢
+岩
+岭
+岱
+岳
+岸
+岽
+峙
+峡
+峨
+峪
+峯
+峰
+峻
+崂
+崇
+崋
+崎
+崔
+崖
+崛
+崧
+崴
+崽
+嵊
+嵋
+嵌
+嵘
+嵩
+嶪
+巅
+巍
+川
+州
+巡
+巢
+工
+左
+巧
+巨
+巩
+巫
+差
+己
+已
+巴
+巷
+巾
+币
+市
+布
+帅
+帆
+师
+希
+帐
+帕
+帘
+帛
+帜
+帝
+带
+席
+帮
+常
+帼
+帽
+幂
+幅
+幕
+幛
+幢
+干
+平
+年
+并
+幸
+幻
+幼
+幽
+广
+庄
+庆
+床
+序
+庐
+库
+应
+底
+店
+庙
+庚
+府
+庞
+废
+度
+座
+庭
+庵
+康
+廉
+廊
+廓
+廖
+延
+廷
+建
+廿
+开
+异
+弃
+弄
+弈
+式
+弓
+引
+弗
+弘
+弛
+弟
+张
+弥
+弦
+弧
+弯
+弱
+弹
+强
+归
+当
+录
+彝
+形
+彤
+彦
+彩
+彪
+彬
+彭
+影
+役
+彻
+彼
+往
+征
+径
+待
+很
+徉
+律
+徐
+徒
+徕
+得
+徜
+御
+循
+微
+徳
+德
+徽
+心
+必
+忆
+志
+忘
+忙
+忠
+忧
+快
+忱
+念
+忻
+怀
+态
+怎
+怕
+思
+怡
+急
+性
+怪
+总
+恂
+恋
+恍
+恒
+恕
+恢
+恤
+恩
+恭
+息
+恰
+恶
+恺
+悉
+悍
+悟
+悠
+患
+悦
+您
+悬
+悸
+情
+惊
+惑
+惕
+惚
+惜
+惟
+惠
+惦
+惩
+惯
+想
+愁
+愈
+愉
+意
+愚
+感
+愧
+愿
+慈
+慌
+慎
+慕
+慢
+慧
+慰
+慷
+憨
+憩
+憬
+憾
+懂
+懋
+懒
+懮
+懿
+戈
+戊
+戎
+戏
+成
+我
+戒
+或
+战
+戚
+截
+戴
+户
+房
+所
+扁
+扇
+手
+才
+扎
+扑
+扒
+打
+扔
+托
+扙
+扛
+扞
+扣
+扦
+执
+扩
+扪
+扫
+扬
+扭
+扮
+扯
+扰
+扳
+扶
+批
+找
+承
+技
+抄
+把
+抒
+抓
+投
+抖
+抗
+折
+抚
+抛
+抢
+护
+报
+披
+抱
+抵
+抹
+押
+抽
+抿
+担
+拆
+拇
+拉
+拌
+拍
+拎
+拐
+拒
+拓
+拔
+拖
+招
+拜
+拟
+拢
+拥
+拦
+拨
+择
+括
+拱
+拳
+拴
+拷
+拼
+拽
+拾
+拿
+持
+挂
+指
+按
+挑
+挖
+挚
+挞
+挡
+挣
+挤
+挥
+挪
+振
+挺
+挽
+捆
+捉
+捌
+捍
+捏
+捐
+捕
+捞
+损
+捡
+换
+捧
+据
+捶
+捷
+掂
+授
+掉
+掌
+掏
+排
+掘
+探
+接
+控
+推
+掩
+措
+掺
+揉
+描
+提
+插
+握
+揪
+揭
+援
+揸
+揽
+搂
+搅
+搏
+搓
+搜
+搞
+搬
+搭
+携
+摄
+摆
+摇
+摊
+摔
+摘
+摩
+摸
+撑
+撒
+撕
+撞
+撤
+撬
+播
+撮
+撸
+撼
+擀
+擂
+擅
+操
+擎
+擦
+攀
+支
+收
+攸
+改
+攻
+放
+政
+故
+效
+敌
+敏
+救
+敖
+教
+敞
+敟
+敢
+散
+敦
+敬
+数
+敲
+整
+敷
+文
+斋
+斌
+斐
+斑
+斓
+斗
+料
+斛
+斜
+斤
+断
+斯
+新
+方
+施
+旁
+旅
+旋
+旌
+族
+旗
+无
+既
+日
+旦
+旧
+旨
+早
+旬
+旭
+旱
+时
+旺
+昀
+昂
+昆
+昇
+昊
+昌
+明
+昏
+易
+昔
+昕
+星
+映
+春
+昭
+是
+昱
+昵
+昶
+昼
+昽
+显
+晃
+晋
+晏
+晒
+晓
+晕
+晖
+晗
+晚
+晞
+晟
+晨
+普
+景
+晰
+晴
+晶
+智
+晾
+暂
+暄
+暑
+暖
+暗
+暨
+暴
+曙
+曜
+曦
+曰
+曲
+更
+曹
+曼
+曾
+替
+最
+月
+有
+朋
+服
+朔
+朕
+朗
+望
+朝
+期
+木
+未
+末
+本
+札
+术
+朱
+朴
+朵
+机
+朽
+杀
+杂
+权
+杆
+杉
+李
+杏
+材
+村
+杖
+杜
+杞
+束
+杠
+条
+来
+杨
+杭
+杯
+杰
+杳
+杷
+松
+板
+极
+构
+枇
+枉
+枋
+析
+枕
+林
+果
+枝
+枞
+枢
+枣
+枪
+枫
+枭
+枯
+枱
+架
+枸
+柏
+某
+柑
+柒
+染
+柔
+柘
+柚
+柛
+柜
+柠
+查
+柬
+柯
+柱
+柳
+柴
+柿
+栅
+标
+栈
+栋
+栏
+树
+栓
+栖
+栗
+校
+株
+样
+核
+根
+格
+栽
+栾
+桁
+桂
+桃
+框
+案
+桉
+桌
+桐
+桑
+桔
+桖
+桢
+档
+桥
+桦
+桩
+桶
+梁
+梅
+梏
+梓
+梗
+梦
+梧
+梨
+梭
+梯
+械
+梳
+梵
+检
+棉
+棋
+棍
+棒
+棕
+棚
+棛
+棠
+森
+棵
+棺
+椅
+植
+椎
+椒
+椰
+椹
+椿
+楂
+楚
+楠
+楷
+楼
+楽
+概
+榄
+榆
+榈
+榔
+榕
+榜
+榞
+榨
+榭
+榴
+榻
+槎
+槐
+様
+槟
+槽
+槿
+樊
+樟
+模
+横
+樱
+樵
+樽
+橄
+橘
+橙
+橡
+橦
+橱
+檀
+檬
+欠
+次
+欢
+欣
+欧
+欲
+款
+歆
+歇
+歌
+止
+正
+此
+步
+武
+歧
+歪
+死
+殊
+残
+殖
+殡
+段
+殿
+毁
+毂
+毅
+母
+每
+毒
+毓
+比
+毕
+毛
+毡
+毫
+毯
+毽
+氏
+民
+气
+氙
+氛
+氟
+氢
+氧
+氨
+氩
+氪
+氮
+氯
+氰
+水
+永
+氽
+汀
+汁
+求
+汆
+汇
+汉
+汊
+汐
+汕
+汗
+汛
+汝
+江
+池
+污
+汤
+汪
+汴
+汶
+汽
+汾
+沁
+沂
+沃
+沅
+沈
+沉
+沌
+沐
+沓
+沔
+沙
+沛
+沟
+没
+沣
+沥
+沧
+沪
+沫
+沱
+河
+沸
+油
+治
+沼
+沽
+沾
+沿
+泄
+泉
+泊
+泌
+泓
+法
+泗
+泛
+泡
+波
+泥
+注
+泩
+泪
+泰
+泳
+泵
+泷
+泸
+泺
+泻
+泼
+泽
+泾
+洁
+洋
+洒
+洗
+洛
+洞
+津
+洪
+洱
+洲
+洺
+活
+洼
+洽
+派
+流
+浅
+浆
+浇
+测
+济
+浏
+浑
+浒
+浓
+浔
+浙
+浚
+浜
+浠
+浣
+浦
+浩
+浪
+浮
+浴
+海
+浸
+涂
+消
+涉
+涌
+涕
+涛
+涞
+涡
+涤
+润
+涧
+涩
+涫
+涮
+涯
+液
+涵
+淀
+淄
+淅
+淇
+淋
+淑
+淘
+淞
+淡
+淮
+深
+淳
+混
+添
+淼
+清
+渊
+渍
+渐
+渔
+渗
+渝
+渠
+渡
+渣
+温
+渭
+港
+渴
+游
+湃
+湄
+湓
+湖
+湘
+湛
+湟
+湾
+湿
+溃
+溉
+源
+溜
+溢
+溧
+溪
+溯
+溶
+溸
+滁
+滇
+滋
+滏
+滑
+滕
+滘
+滙
+滚
+滟
+满
+滢
+滤
+滨
+滩
+滴
+漂
+漆
+漏
+演
+漕
+漠
+漫
+漯
+漳
+漾
+潇
+潍
+潘
+潜
+潞
+潢
+潭
+潮
+潼
+澄
+澈
+澎
+澜
+澡
+澧
+澳
+激
+濑
+濠
+瀑
+瀚
+瀛
+灌
+灏
+灡
+火
+灭
+灯
+灰
+灵
+灶
+灸
+灼
+灾
+灿
+炀
+炉
+炊
+炎
+炒
+炔
+炕
+炖
+炙
+炜
+炝
+炤
+炫
+炬
+炭
+炮
+炳
+炸
+点
+炼
+炽
+烂
+烈
+烊
+烘
+烙
+烛
+烜
+烟
+烤
+烦
+烧
+烨
+烩
+烫
+热
+烯
+烹
+烽
+焊
+焕
+焖
+焗
+焙
+焦
+焰
+焱
+然
+煊
+煋
+煌
+煎
+煜
+煤
+煦
+照
+煨
+煮
+煲
+煸
+熄
+熊
+熏
+熔
+熙
+熟
+熠
+熨
+熬
+熹
+燃
+燎
+燕
+燘
+燚
+燥
+爆
+爪
+爬
+爱
+爵
+父
+爷
+爸
+爽
+片
+版
+牌
+牙
+牛
+牟
+牡
+牢
+牧
+物
+牵
+特
+犀
+犇
+犬
+犯
+状
+犹
+狂
+狄
+狐
+狗
+狠
+独
+狭
+狮
+狱
+狸
+狼
+猎
+猗
+猛
+猜
+猪
+猫
+猬
+献
+猴
+獭
+玄
+率
+玉
+王
+玖
+玛
+玥
+玩
+玫
+玮
+环
+现
+玲
+玺
+玻
+珀
+珂
+珈
+珊
+珍
+珏
+珑
+珙
+珞
+珠
+班
+珺
+球
+琅
+理
+琉
+琛
+琢
+琥
+琦
+琨
+琪
+琳
+琴
+琵
+琶
+琼
+瑁
+瑙
+瑚
+瑛
+瑜
+瑞
+瑟
+瑭
+瑰
+瑶
+瑾
+璃
+璇
+璐
+璜
+璞
+璟
+璧
+璨
+瓜
+瓢
+瓣
+瓦
+瓮
+瓶
+瓷
+甄
+甏
+甘
+甚
+甜
+生
+用
+甩
+甫
+甬
+田
+由
+甲
+申
+电
+男
+甸
+町
+画
+畅
+畈
+界
+畏
+畔
+留
+畜
+略
+番
+畸
+疆
+疏
+疑
+疗
+疙
+疡
+疣
+疤
+疥
+疫
+疮
+疯
+疱
+疲
+疹
+疼
+疾
+病
+症
+痒
+痔
+痕
+痘
+痛
+痣
+痤
+痧
+痨
+痰
+痹
+痿
+瘁
+瘊
+瘘
+瘙
+瘢
+瘤
+瘦
+瘩
+瘫
+瘾
+癌
+癜
+癣
+癫
+癸
+登
+白
+百
+皂
+的
+皆
+皇
+皈
+皋
+皓
+皖
+皙
+皮
+皱
+皲
+皿
+盅
+盆
+盈
+益
+盐
+监
+盒
+盔
+盖
+盗
+盘
+盛
+盟
+目
+盱
+盲
+直
+相
+盼
+盾
+省
+眉
+看
+眙
+真
+眠
+眩
+眸
+眼
+着
+睛
+睡
+睢
+督
+睦
+睫
+睿
+瞄
+瞬
+瞰
+瞳
+瞻
+瞿
+矗
+知
+矩
+矫
+短
+矮
+石
+矶
+矽
+矾
+矿
+砀
+码
+砂
+砋
+砌
+砍
+研
+砖
+砚
+砥
+砭
+砰
+破
+砸
+砺
+砼
+础
+硂
+硅
+硒
+硕
+硚
+硝
+硫
+硬
+确
+碁
+碍
+碎
+碑
+碗
+碚
+碟
+碣
+碧
+碰
+碱
+碳
+碾
+磁
+磅
+磊
+磐
+磨
+磷
+磺
+礴
+示
+礼
+社
+祁
+祈
+祎
+祖
+祙
+祛
+祝
+神
+祠
+祥
+票
+祭
+祸
+祺
+禁
+禄
+禅
+福
+禧
+禹
+禺
+离
+禽
+禾
+秀
+私
+秃
+秆
+秉
+秋
+种
+科
+秒
+秘
+租
+秤
+秦
+秩
+秭
+积
+称
+秸
+移
+稀
+程
+稍
+税
+稚
+稞
+稠
+稳
+稷
+稻
+稼
+稽
+穂
+穆
+穗
+穴
+究
+穷
+空
+穿
+突
+窃
+窄
+窑
+窖
+窗
+窝
+窦
+窨
+立
+竖
+站
+竞
+竟
+章
+竣
+童
+竭
+端
+竹
+竺
+竿
+笃
+笈
+笋
+笑
+笔
+笙
+笛
+笠
+符
+笨
+第
+笺
+笼
+等
+筋
+筑
+筒
+答
+策
+筛
+筝
+筠
+筱
+筷
+筹
+签
+简
+箍
+箔
+箕
+算
+管
+箫
+箭
+箱
+篆
+篇
+篓
+篮
+篱
+篷
+簧
+簸
+籁
+籍
+籣
+米
+类
+籽
+粉
+粑
+粒
+粗
+粘
+粟
+粢
+粤
+粥
+粪
+粮
+粱
+粹
+粼
+粽
+精
+糁
+糊
+糕
+糖
+糜
+糟
+糠
+糯
+系
+素
+索
+紧
+紫
+累
+絮
+繁
+纂
+纠
+红
+纤
+约
+级
+纪
+纫
+纬
+纯
+纱
+纲
+纳
+纵
+纶
+纷
+纸
+纹
+纺
+纽
+线
+练
+组
+绅
+细
+织
+终
+绍
+绎
+经
+绑
+绒
+结
+绕
+绘
+给
+络
+绝
+绞
+统
+绢
+绣
+绥
+继
+绨
+绩
+绪
+续
+绮
+绰
+绳
+维
+绵
+绶
+绸
+综
+绽
+绿
+缅
+缆
+缇
+缎
+缐
+缓
+缔
+编
+缘
+缙
+缝
+缠
+缤
+缦
+缨
+缩
+缪
+缴
+缸
+缺
+罐
+网
+罕
+罗
+罚
+罡
+罩
+罪
+置
+署
+羊
+羌
+美
+羔
+羚
+羡
+群
+羲
+羹
+羽
+羿
+翁
+翅
+翊
+翌
+翎
+翔
+翘
+翟
+翠
+翡
+翥
+翰
+翱
+翻
+翼
+耀
+老
+考
+耄
+者
+而
+耍
+耐
+耒
+耕
+耗
+耘
+耙
+耦
+耳
+耵
+耶
+耸
+耿
+聂
+聊
+聋
+职
+联
+聘
+聚
+聪
+聿
+肃
+肆
+肇
+肉
+肋
+肌
+肖
+肘
+肚
+肛
+肝
+肠
+股
+肢
+肤
+肥
+肩
+肪
+肯
+育
+肴
+肺
+肽
+肾
+肿
+胀
+胁
+胃
+胆
+背
+胎
+胖
+胗
+胚
+胜
+胞
+胡
+胪
+胭
+胳
+胶
+胸
+胺
+能
+脂
+脆
+脉
+脊
+脏
+脐
+脑
+脖
+脚
+脯
+脱
+脸
+脾
+腊
+腋
+腌
+腐
+腔
+腕
+腩
+腰
+腱
+腹
+腺
+腻
+腾
+腿
+膊
+膏
+膜
+膝
+膨
+膳
+臀
+臂
+臊
+臣
+自
+臭
+至
+致
+臻
+臼
+舅
+舌
+舍
+舒
+舜
+舞
+舟
+航
+舫
+般
+舰
+舱
+舵
+舶
+船
+艇
+良
+艰
+色
+艳
+艺
+艾
+节
+芃
+芊
+芋
+芎
+芒
+芙
+芜
+芝
+芥
+芦
+芪
+芬
+芭
+芮
+芯
+花
+芳
+芷
+芸
+芹
+芽
+苍
+苏
+苑
+苓
+苔
+苗
+苝
+苞
+苟
+若
+苦
+苪
+英
+苹
+茁
+茂
+范
+茄
+茅
+茉
+茗
+茜
+茨
+茭
+茯
+茱
+茴
+茵
+茶
+茸
+茹
+荃
+荆
+草
+荐
+荒
+荔
+荘
+荞
+荟
+荠
+荡
+荣
+荤
+荥
+荧
+荨
+荪
+荫
+药
+荷
+荻
+莅
+莆
+莉
+莎
+莓
+莘
+莜
+莞
+莫
+莱
+莲
+获
+莹
+莺
+莽
+菀
+菁
+菇
+菊
+菌
+菏
+菓
+菘
+菜
+菠
+菡
+菩
+菱
+菲
+萃
+萄
+萌
+萍
+萘
+萝
+营
+萧
+萨
+萱
+落
+葆
+著
+葛
+葡
+董
+葫
+葬
+葱
+葳
+葵
+蒂
+蒄
+蒋
+蒙
+蒜
+蒡
+蒲
+蒸
+蓄
+蓉
+蓓
+蓝
+蓟
+蓬
+蔓
+蔗
+蔚
+蔡
+蔬
+蔷
+蔺
+蔻
+蔽
+蕃
+蕉
+蕊
+蕙
+蕲
+蕳
+蕴
+蕾
+薄
+薇
+薏
+薛
+薪
+薯
+薰
+藏
+藓
+藕
+藜
+藤
+藻
+蘑
+蘘
+蘸
+虎
+虏
+虐
+虑
+虔
+虚
+虞
+虫
+虱
+虹
+虽
+虾
+蚀
+蚁
+蚂
+蚊
+蚌
+蚕
+蚝
+蚨
+蛇
+蛊
+蛋
+蛙
+蛛
+蛟
+蛮
+蛳
+蜀
+蜂
+蜓
+蜕
+蜗
+蜘
+蜜
+蜡
+蜻
+蝇
+蝉
+蝎
+蝴
+蝶
+螃
+融
+螺
+蟠
+蟹
+蟾
+蠡
+蠢
+血
+行
+衔
+衖
+街
+衡
+衢
+衣
+补
+表
+衫
+衬
+衰
+袁
+袄
+袋
+袍
+袖
+袜
+被
+袭
+裁
+裂
+装
+裔
+裕
+裘
+裙
+裤
+裱
+裳
+裴
+裹
+褀
+褂
+褊
+褐
+褔
+褥
+襄
+西
+要
+覃
+见
+观
+规
+觅
+视
+览
+觉
+角
+解
+触
+言
+詹
+誉
+警
+计
+订
+认
+讨
+让
+训
+议
+讯
+记
+讲
+讴
+许
+论
+讼
+设
+访
+证
+评
+识
+诈
+诉
+诊
+译
+试
+诗
+诚
+诛
+话
+诞
+询
+该
+详
+语
+误
+诱
+说
+诵
+请
+诸
+诺
+读
+课
+谁
+调
+谅
+谈
+谊
+谋
+谌
+谐
+谕
+谚
+谛
+谜
+谢
+谣
+谦
+谨
+谭
+谱
+谷
+豆
+豉
+豊
+豌
+豚
+象
+豪
+豫
+豹
+貂
+貌
+贝
+贞
+负
+贡
+财
+责
+贤
+账
+货
+质
+贩
+贪
+贫
+购
+贯
+贰
+贴
+贵
+贷
+贸
+费
+贺
+贼
+贾
+赁
+资
+赉
+赊
+赋
+赌
+赎
+赏
+赐
+赔
+赖
+赚
+赛
+赞
+赠
+赢
+赣
+赤
+赫
+走
+赵
+赶
+起
+趁
+超
+越
+趋
+趟
+趣
+足
+趴
+趵
+趾
+跃
+跆
+跌
+跑
+跖
+距
+跟
+跤
+跨
+路
+跳
+践
+跷
+踏
+踢
+踩
+踪
+踺
+蹄
+蹈
+蹦
+蹭
+躁
+身
+躺
+车
+轧
+轨
+轩
+转
+轮
+软
+轰
+轴
+轶
+轻
+载
+轿
+较
+辅
+辆
+辈
+辉
+辊
+辋
+辐
+辑
+输
+辛
+辜
+辞
+辟
+辣
+辨
+辩
+辫
+辰
+边
+辽
+达
+迁
+迅
+过
+迈
+迎
+运
+近
+返
+还
+这
+进
+远
+违
+连
+迟
+迦
+迩
+迪
+迫
+迮
+迷
+迹
+迺
+追
+退
+送
+适
+逃
+逅
+逆
+选
+逊
+逍
+透
+逐
+递
+途
+逗
+通
+逛
+速
+造
+逢
+逸
+逹
+逻
+逾
+遁
+遂
+遇
+遍
+遐
+道
+遗
+遣
+遥
+遮
+遵
+避
+邀
+邂
+邑
+邓
+邝
+邢
+那
+邦
+邨
+邪
+邮
+邯
+邱
+邳
+邵
+邸
+邹
+邺
+邻
+郁
+郅
+郊
+郎
+郏
+郑
+郝
+郡
+郦
+部
+郫
+郭
+郸
+都
+鄂
+鄢
+鄯
+鄱
+酉
+酌
+配
+酒
+酝
+酥
+酩
+酪
+酬
+酯
+酰
+酱
+酵
+酷
+酸
+酿
+醇
+醉
+醋
+醍
+醒
+醛
+醪
+醴
+采
+釉
+释
+里
+重
+野
+量
+金
+釜
+鉴
+錧
+鍊
+鑫
+针
+钉
+钊
+钎
+钒
+钓
+钙
+钛
+钜
+钝
+钞
+钟
+钠
+钢
+钣
+钥
+钦
+钧
+钩
+钮
+钯
+钰
+钱
+钳
+钵
+钻
+钾
+钿
+铁
+铂
+铃
+铅
+铆
+铍
+铎
+铕
+铖
+铛
+铜
+铝
+铞
+铠
+铣
+铧
+铪
+铬
+铭
+铮
+铰
+铱
+铲
+铵
+银
+铸
+铺
+链
+销
+锁
+锂
+锄
+锅
+锈
+锋
+锌
+锐
+错
+锚
+锟
+锡
+锣
+锤
+锦
+锭
+键
+锯
+锰
+锴
+锻
+镀
+镁
+镂
+镇
+镍
+镐
+镒
+镖
+镗
+镜
+镦
+镭
+镯
+镶
+长
+门
+闪
+闫
+闭
+问
+闯
+闰
+闲
+间
+闵
+闷
+闸
+闹
+闺
+闻
+闽
+闿
+阀
+阁
+阅
+阎
+阔
+阖
+阙
+阜
+队
+阡
+阮
+防
+阳
+阴
+阵
+阶
+阻
+阿
+陀
+陂
+附
+际
+陆
+陇
+陈
+陋
+陌
+降
+限
+陕
+陟
+陡
+院
+除
+险
+陪
+陵
+陶
+陷
+隅
+隆
+隋
+隍
+随
+隐
+隔
+障
+隧
+隽
+难
+雀
+雁
+雄
+雅
+集
+雍
+雎
+雏
+雕
+雨
+雪
+雯
+零
+雷
+雾
+需
+霄
+霆
+震
+霉
+霍
+霏
+霓
+霖
+霜
+霞
+露
+霸
+青
+靓
+靖
+静
+靛
+非
+靠
+靡
+面
+革
+靳
+靴
+鞋
+鞍
+鞘
+鞭
+韦
+韧
+韩
+韬
+韭
+音
+韵
+韶
+頔
+页
+顶
+项
+顺
+须
+顽
+顾
+顿
+颁
+颂
+预
+颅
+领
+颈
+颊
+颌
+颍
+颐
+频
+颓
+颖
+颗
+题
+颜
+额
+颠
+颢
+风
+飘
+飙
+飚
+飞
+食
+飨
+餐
+饦
+饨
+饪
+饭
+饮
+饯
+饰
+饱
+饲
+饵
+饶
+饸
+饹
+饺
+饼
+饿
+馀
+馄
+馅
+馆
+馈
+馋
+馍
+馏
+馐
+馒
+馕
+首
+香
+馥
+馨
+馫
+马
+驭
+驰
+驱
+驳
+驴
+驶
+驹
+驻
+驼
+驾
+驿
+骁
+骄
+骆
+骋
+验
+骏
+骐
+骑
+骗
+骚
+骜
+骝
+骞
+骨
+髓
+高
+髪
+鬼
+魁
+魂
+魄
+魅
+魏
+魔
+鱬
+鱼
+鱿
+鲁
+鲈
+鲍
+鲜
+鲢
+鲤
+鲨
+鲩
+鲫
+鲮
+鲱
+鲲
+鲶
+鲸
+鳄
+鳅
+鳌
+鳗
+鳝
+鳞
+鳯
+鶏
+鸟
+鸡
+鸣
+鸥
+鸦
+鸭
+鸯
+鸳
+鸽
+鸿
+鹃
+鹅
+鹉
+鹊
+鹌
+鹏
+鹑
+鹞
+鹤
+鹦
+鹭
+鹰
+鹿
+麋
+麒
+麟
+麦
+麸
+麺
+麻
+麾
+黄
+黍
+黎
+黑
+黔
+默
+黛
+鼎
+鼓
+鼠
+鼻
+鼾
+齐
+齿
+龄
+龅
+龙
+龚
+龛
+龟
+龢
+凉
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z

+ 10 - 0
torchocr/datasets/alphabets/digit.txt

@@ -0,0 +1,10 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9

+ 92 - 0
torchocr/datasets/alphabets/enAlphaNumPunc90.txt

@@ -0,0 +1,92 @@
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+1
+2
+3
+4
+5
+6
+7
+8
+9
+0
+!
+;
+:
+?
+#
+"
+'
+&
+)
+(
++
+*
+-
+,
+/
+.
+|
+$
+%
+'
+[
+]
+@
+`
+ 
+=
+\

+ 6623 - 0
torchocr/datasets/alphabets/ppocr_keys_v1.txt

@@ -0,0 +1,6623 @@
+'
+疗
+绚
+诚
+娇
+溜
+题
+贿
+者
+廖
+更
+纳
+加
+奉
+公
+一
+就
+汴
+计
+与
+路
+房
+原
+妇
+2
+0
+8
+-
+7
+其
+>
+:
+]
+,
+,
+骑
+刈
+全
+消
+昏
+傈
+安
+久
+钟
+嗅
+不
+影
+处
+驽
+蜿
+资
+关
+椤
+地
+瘸
+专
+问
+忖
+票
+嫉
+炎
+韵
+要
+月
+田
+节
+陂
+鄙
+捌
+备
+拳
+伺
+眼
+网
+盎
+大
+傍
+心
+东
+愉
+汇
+蹿
+科
+每
+业
+里
+航
+晏
+字
+平
+录
+先
+1
+3
+彤
+鲶
+产
+稍
+督
+腴
+有
+象
+岳
+注
+绍
+在
+泺
+文
+定
+核
+名
+水
+过
+理
+让
+偷
+率
+等
+这
+发
+”
+为
+含
+肥
+酉
+相
+鄱
+七
+编
+猥
+锛
+日
+镀
+蒂
+掰
+倒
+辆
+栾
+栗
+综
+涩
+州
+雌
+滑
+馀
+了
+机
+块
+司
+宰
+甙
+兴
+矽
+抚
+保
+用
+沧
+秩
+如
+收
+息
+滥
+页
+疑
+埠
+!
+!
+姥
+异
+橹
+钇
+向
+下
+跄
+的
+椴
+沫
+国
+绥
+獠
+报
+开
+民
+蜇
+何
+分
+凇
+长
+讥
+藏
+掏
+施
+羽
+中
+讲
+派
+嘟
+人
+提
+浼
+间
+世
+而
+古
+多
+倪
+唇
+饯
+控
+庚
+首
+赛
+蜓
+味
+断
+制
+觉
+技
+替
+艰
+溢
+潮
+夕
+钺
+外
+摘
+枋
+动
+双
+单
+啮
+户
+枇
+确
+锦
+曜
+杜
+或
+能
+效
+霜
+盒
+然
+侗
+电
+晁
+放
+步
+鹃
+新
+杖
+蜂
+吒
+濂
+瞬
+评
+总
+隍
+对
+独
+合
+也
+是
+府
+青
+天
+诲
+墙
+组
+滴
+级
+邀
+帘
+示
+已
+时
+骸
+仄
+泅
+和
+遨
+店
+雇
+疫
+持
+巍
+踮
+境
+只
+亨
+目
+鉴
+崤
+闲
+体
+泄
+杂
+作
+般
+轰
+化
+解
+迂
+诿
+蛭
+璀
+腾
+告
+版
+服
+省
+师
+小
+规
+程
+线
+海
+办
+引
+二
+桧
+牌
+砺
+洄
+裴
+修
+图
+痫
+胡
+许
+犊
+事
+郛
+基
+柴
+呼
+食
+研
+奶
+律
+蛋
+因
+葆
+察
+戏
+褒
+戒
+再
+李
+骁
+工
+貂
+油
+鹅
+章
+啄
+休
+场
+给
+睡
+纷
+豆
+器
+捎
+说
+敏
+学
+会
+浒
+设
+诊
+格
+廓
+查
+来
+霓
+室
+溆
+¢
+诡
+寥
+焕
+舜
+柒
+狐
+回
+戟
+砾
+厄
+实
+翩
+尿
+五
+入
+径
+惭
+喹
+股
+宇
+篝
+|
+;
+美
+期
+云
+九
+祺
+扮
+靠
+锝
+槌
+系
+企
+酰
+阊
+暂
+蚕
+忻
+豁
+本
+羹
+执
+条
+钦
+H
+獒
+限
+进
+季
+楦
+于
+芘
+玖
+铋
+茯
+未
+答
+粘
+括
+样
+精
+欠
+矢
+甥
+帷
+嵩
+扣
+令
+仔
+风
+皈
+行
+支
+部
+蓉
+刮
+站
+蜡
+救
+钊
+汗
+松
+嫌
+成
+可
+.
+鹤
+院
+从
+交
+政
+怕
+活
+调
+球
+局
+验
+髌
+第
+韫
+谗
+串
+到
+圆
+年
+米
+/
+*
+友
+忿
+检
+区
+看
+自
+敢
+刃
+个
+兹
+弄
+流
+留
+同
+没
+齿
+星
+聆
+轼
+湖
+什
+三
+建
+蛔
+儿
+椋
+汕
+震
+颧
+鲤
+跟
+力
+情
+璺
+铨
+陪
+务
+指
+族
+训
+滦
+鄣
+濮
+扒
+商
+箱
+十
+召
+慷
+辗
+所
+莞
+管
+护
+臭
+横
+硒
+嗓
+接
+侦
+六
+露
+党
+馋
+驾
+剖
+高
+侬
+妪
+幂
+猗
+绺
+骐
+央
+酐
+孝
+筝
+课
+徇
+缰
+门
+男
+西
+项
+句
+谙
+瞒
+秃
+篇
+教
+碲
+罚
+声
+呐
+景
+前
+富
+嘴
+鳌
+稀
+免
+朋
+啬
+睐
+去
+赈
+鱼
+住
+肩
+愕
+速
+旁
+波
+厅
+健
+茼
+厥
+鲟
+谅
+投
+攸
+炔
+数
+方
+击
+呋
+谈
+绩
+别
+愫
+僚
+躬
+鹧
+胪
+炳
+招
+喇
+膨
+泵
+蹦
+毛
+结
+5
+4
+谱
+识
+陕
+粽
+婚
+拟
+构
+且
+搜
+任
+潘
+比
+郢
+妨
+醪
+陀
+桔
+碘
+扎
+选
+哈
+骷
+楷
+亿
+明
+缆
+脯
+监
+睫
+逻
+婵
+共
+赴
+淝
+凡
+惦
+及
+达
+揖
+谩
+澹
+减
+焰
+蛹
+番
+祁
+柏
+员
+禄
+怡
+峤
+龙
+白
+叽
+生
+闯
+起
+细
+装
+谕
+竟
+聚
+钙
+上
+导
+渊
+按
+艾
+辘
+挡
+耒
+盹
+饪
+臀
+记
+邮
+蕙
+受
+各
+医
+搂
+普
+滇
+朗
+茸
+带
+翻
+酚
+(
+光
+堤
+墟
+蔷
+万
+幻
+〓
+瑙
+辈
+昧
+盏
+亘
+蛀
+吉
+铰
+请
+子
+假
+闻
+税
+井
+诩
+哨
+嫂
+好
+面
+琐
+校
+馊
+鬣
+缂
+营
+访
+炖
+占
+农
+缀
+否
+经
+钚
+棵
+趟
+张
+亟
+吏
+茶
+谨
+捻
+论
+迸
+堂
+玉
+信
+吧
+瞠
+乡
+姬
+寺
+咬
+溏
+苄
+皿
+意
+赉
+宝
+尔
+钰
+艺
+特
+唳
+踉
+都
+荣
+倚
+登
+荐
+丧
+奇
+涵
+批
+炭
+近
+符
+傩
+感
+道
+着
+菊
+虹
+仲
+众
+懈
+濯
+颞
+眺
+南
+释
+北
+缝
+标
+既
+茗
+整
+撼
+迤
+贲
+挎
+耱
+拒
+某
+妍
+卫
+哇
+英
+矶
+藩
+治
+他
+元
+领
+膜
+遮
+穗
+蛾
+飞
+荒
+棺
+劫
+么
+市
+火
+温
+拈
+棚
+洼
+转
+果
+奕
+卸
+迪
+伸
+泳
+斗
+邡
+侄
+涨
+屯
+萋
+胭
+氡
+崮
+枞
+惧
+冒
+彩
+斜
+手
+豚
+随
+旭
+淑
+妞
+形
+菌
+吲
+沱
+争
+驯
+歹
+挟
+兆
+柱
+传
+至
+包
+内
+响
+临
+红
+功
+弩
+衡
+寂
+禁
+老
+棍
+耆
+渍
+织
+害
+氵
+渑
+布
+载
+靥
+嗬
+虽
+苹
+咨
+娄
+库
+雉
+榜
+帜
+嘲
+套
+瑚
+亲
+簸
+欧
+边
+6
+腿
+旮
+抛
+吹
+瞳
+得
+镓
+梗
+厨
+继
+漾
+愣
+憨
+士
+策
+窑
+抑
+躯
+襟
+脏
+参
+贸
+言
+干
+绸
+鳄
+穷
+藜
+音
+折
+详
+)
+举
+悍
+甸
+癌
+黎
+谴
+死
+罩
+迁
+寒
+驷
+袖
+媒
+蒋
+掘
+模
+纠
+恣
+观
+祖
+蛆
+碍
+位
+稿
+主
+澧
+跌
+筏
+京
+锏
+帝
+贴
+证
+糠
+才
+黄
+鲸
+略
+炯
+饱
+四
+出
+园
+犀
+牧
+容
+汉
+杆
+浈
+汰
+瑷
+造
+虫
+瘩
+怪
+驴
+济
+应
+花
+沣
+谔
+夙
+旅
+价
+矿
+以
+考
+s
+u
+呦
+晒
+巡
+茅
+准
+肟
+瓴
+詹
+仟
+褂
+译
+桌
+混
+宁
+怦
+郑
+抿
+些
+余
+鄂
+饴
+攒
+珑
+群
+阖
+岔
+琨
+藓
+预
+环
+洮
+岌
+宀
+杲
+瀵
+最
+常
+囡
+周
+踊
+女
+鼓
+袭
+喉
+简
+范
+薯
+遐
+疏
+粱
+黜
+禧
+法
+箔
+斤
+遥
+汝
+奥
+直
+贞
+撑
+置
+绱
+集
+她
+馅
+逗
+钧
+橱
+魉
+[
+恙
+躁
+唤
+9
+旺
+膘
+待
+脾
+惫
+购
+吗
+依
+盲
+度
+瘿
+蠖
+俾
+之
+镗
+拇
+鲵
+厝
+簧
+续
+款
+展
+啃
+表
+剔
+品
+钻
+腭
+损
+清
+锶
+统
+涌
+寸
+滨
+贪
+链
+吠
+冈
+伎
+迥
+咏
+吁
+览
+防
+迅
+失
+汾
+阔
+逵
+绀
+蔑
+列
+川
+凭
+努
+熨
+揪
+利
+俱
+绉
+抢
+鸨
+我
+即
+责
+膦
+易
+毓
+鹊
+刹
+玷
+岿
+空
+嘞
+绊
+排
+术
+估
+锷
+违
+们
+苟
+铜
+播
+肘
+件
+烫
+审
+鲂
+广
+像
+铌
+惰
+铟
+巳
+胍
+鲍
+康
+憧
+色
+恢
+想
+拷
+尤
+疳
+知
+S
+Y
+F
+D
+A
+峄
+裕
+帮
+握
+搔
+氐
+氘
+难
+墒
+沮
+雨
+叁
+缥
+悴
+藐
+湫
+娟
+苑
+稠
+颛
+簇
+后
+阕
+闭
+蕤
+缚
+怎
+佞
+码
+嘤
+蔡
+痊
+舱
+螯
+帕
+赫
+昵
+升
+烬
+岫
+、
+疵
+蜻
+髁
+蕨
+隶
+烛
+械
+丑
+盂
+梁
+强
+鲛
+由
+拘
+揉
+劭
+龟
+撤
+钩
+呕
+孛
+费
+妻
+漂
+求
+阑
+崖
+秤
+甘
+通
+深
+补
+赃
+坎
+床
+啪
+承
+吼
+量
+暇
+钼
+烨
+阂
+擎
+脱
+逮
+称
+P
+神
+属
+矗
+华
+届
+狍
+葑
+汹
+育
+患
+窒
+蛰
+佼
+静
+槎
+运
+鳗
+庆
+逝
+曼
+疱
+克
+代
+官
+此
+麸
+耧
+蚌
+晟
+例
+础
+榛
+副
+测
+唰
+缢
+迹
+灬
+霁
+身
+岁
+赭
+扛
+又
+菡
+乜
+雾
+板
+读
+陷
+徉
+贯
+郁
+虑
+变
+钓
+菜
+圾
+现
+琢
+式
+乐
+维
+渔
+浜
+左
+吾
+脑
+钡
+警
+T
+啵
+拴
+偌
+漱
+湿
+硕
+止
+骼
+魄
+积
+燥
+联
+踢
+玛
+则
+窿
+见
+振
+畿
+送
+班
+钽
+您
+赵
+刨
+印
+讨
+踝
+籍
+谡
+舌
+崧
+汽
+蔽
+沪
+酥
+绒
+怖
+财
+帖
+肱
+私
+莎
+勋
+羔
+霸
+励
+哼
+帐
+将
+帅
+渠
+纪
+婴
+娩
+岭
+厘
+滕
+吻
+伤
+坝
+冠
+戊
+隆
+瘁
+介
+涧
+物
+黍
+并
+姗
+奢
+蹑
+掣
+垸
+锴
+命
+箍
+捉
+病
+辖
+琰
+眭
+迩
+艘
+绌
+繁
+寅
+若
+毋
+思
+诉
+类
+诈
+燮
+轲
+酮
+狂
+重
+反
+职
+筱
+县
+委
+磕
+绣
+奖
+晋
+濉
+志
+徽
+肠
+呈
+獐
+坻
+口
+片
+碰
+几
+村
+柿
+劳
+料
+获
+亩
+惕
+晕
+厌
+号
+罢
+池
+正
+鏖
+煨
+家
+棕
+复
+尝
+懋
+蜥
+锅
+岛
+扰
+队
+坠
+瘾
+钬
+@
+卧
+疣
+镇
+譬
+冰
+彷
+频
+黯
+据
+垄
+采
+八
+缪
+瘫
+型
+熹
+砰
+楠
+襁
+箐
+但
+嘶
+绳
+啤
+拍
+盥
+穆
+傲
+洗
+盯
+塘
+怔
+筛
+丿
+台
+恒
+喂
+葛
+永
+¥
+烟
+酒
+桦
+书
+砂
+蚝
+缉
+态
+瀚
+袄
+圳
+轻
+蛛
+超
+榧
+遛
+姒
+奘
+铮
+右
+荽
+望
+偻
+卡
+丶
+氰
+附
+做
+革
+索
+戚
+坨
+桷
+唁
+垅
+榻
+岐
+偎
+坛
+莨
+山
+殊
+微
+骇
+陈
+爨
+推
+嗝
+驹
+澡
+藁
+呤
+卤
+嘻
+糅
+逛
+侵
+郓
+酌
+德
+摇
+※
+鬃
+被
+慨
+殡
+羸
+昌
+泡
+戛
+鞋
+河
+宪
+沿
+玲
+鲨
+翅
+哽
+源
+铅
+语
+照
+邯
+址
+荃
+佬
+顺
+鸳
+町
+霭
+睾
+瓢
+夸
+椁
+晓
+酿
+痈
+咔
+侏
+券
+噎
+湍
+签
+嚷
+离
+午
+尚
+社
+锤
+背
+孟
+使
+浪
+缦
+潍
+鞅
+军
+姹
+驶
+笑
+鳟
+鲁
+》
+孽
+钜
+绿
+洱
+礴
+焯
+椰
+颖
+囔
+乌
+孔
+巴
+互
+性
+椽
+哞
+聘
+昨
+早
+暮
+胶
+炀
+隧
+低
+彗
+昝
+铁
+呓
+氽
+藉
+喔
+癖
+瑗
+姨
+权
+胱
+韦
+堑
+蜜
+酋
+楝
+砝
+毁
+靓
+歙
+锲
+究
+屋
+喳
+骨
+辨
+碑
+武
+鸠
+宫
+辜
+烊
+适
+坡
+殃
+培
+佩
+供
+走
+蜈
+迟
+翼
+况
+姣
+凛
+浔
+吃
+飘
+债
+犟
+金
+促
+苛
+崇
+坂
+莳
+畔
+绂
+兵
+蠕
+斋
+根
+砍
+亢
+欢
+恬
+崔
+剁
+餐
+榫
+快
+扶
+‖
+濒
+缠
+鳜
+当
+彭
+驭
+浦
+篮
+昀
+锆
+秸
+钳
+弋
+娣
+瞑
+夷
+龛
+苫
+拱
+致
+%
+嵊
+障
+隐
+弑
+初
+娓
+抉
+汩
+累
+蓖
+"
+唬
+助
+苓
+昙
+押
+毙
+破
+城
+郧
+逢
+嚏
+獭
+瞻
+溱
+婿
+赊
+跨
+恼
+璧
+萃
+姻
+貉
+灵
+炉
+密
+氛
+陶
+砸
+谬
+衔
+点
+琛
+沛
+枳
+层
+岱
+诺
+脍
+榈
+埂
+征
+冷
+裁
+打
+蹴
+素
+瘘
+逞
+蛐
+聊
+激
+腱
+萘
+踵
+飒
+蓟
+吆
+取
+咙
+簋
+涓
+矩
+曝
+挺
+揣
+座
+你
+史
+舵
+焱
+尘
+苏
+笈
+脚
+溉
+榨
+诵
+樊
+邓
+焊
+义
+庶
+儋
+蟋
+蒲
+赦
+呷
+杞
+诠
+豪
+还
+试
+颓
+茉
+太
+除
+紫
+逃
+痴
+草
+充
+鳕
+珉
+祗
+墨
+渭
+烩
+蘸
+慕
+璇
+镶
+穴
+嵘
+恶
+骂
+险
+绋
+幕
+碉
+肺
+戳
+刘
+潞
+秣
+纾
+潜
+銮
+洛
+须
+罘
+销
+瘪
+汞
+兮
+屉
+r
+林
+厕
+质
+探
+划
+狸
+殚
+善
+煊
+烹
+〒
+锈
+逯
+宸
+辍
+泱
+柚
+袍
+远
+蹋
+嶙
+绝
+峥
+娥
+缍
+雀
+徵
+认
+镱
+谷
+=
+贩
+勉
+撩
+鄯
+斐
+洋
+非
+祚
+泾
+诒
+饿
+撬
+威
+晷
+搭
+芍
+锥
+笺
+蓦
+候
+琊
+档
+礁
+沼
+卵
+荠
+忑
+朝
+凹
+瑞
+头
+仪
+弧
+孵
+畏
+铆
+突
+衲
+车
+浩
+气
+茂
+悖
+厢
+枕
+酝
+戴
+湾
+邹
+飚
+攘
+锂
+写
+宵
+翁
+岷
+无
+喜
+丈
+挑
+嗟
+绛
+殉
+议
+槽
+具
+醇
+淞
+笃
+郴
+阅
+饼
+底
+壕
+砚
+弈
+询
+缕
+庹
+翟
+零
+筷
+暨
+舟
+闺
+甯
+撞
+麂
+茌
+蔼
+很
+珲
+捕
+棠
+角
+阉
+媛
+娲
+诽
+剿
+尉
+爵
+睬
+韩
+诰
+匣
+危
+糍
+镯
+立
+浏
+阳
+少
+盆
+舔
+擘
+匪
+申
+尬
+铣
+旯
+抖
+赘
+瓯
+居
+哮
+游
+锭
+茏
+歌
+坏
+甚
+秒
+舞
+沙
+仗
+劲
+潺
+阿
+燧
+郭
+嗖
+霏
+忠
+材
+奂
+耐
+跺
+砀
+输
+岖
+媳
+氟
+极
+摆
+灿
+今
+扔
+腻
+枝
+奎
+药
+熄
+吨
+话
+q
+额
+慑
+嘌
+协
+喀
+壳
+埭
+视
+著
+於
+愧
+陲
+翌
+峁
+颅
+佛
+腹
+聋
+侯
+咎
+叟
+秀
+颇
+存
+较
+罪
+哄
+岗
+扫
+栏
+钾
+羌
+己
+璨
+枭
+霉
+煌
+涸
+衿
+键
+镝
+益
+岢
+奏
+连
+夯
+睿
+冥
+均
+糖
+狞
+蹊
+稻
+爸
+刿
+胥
+煜
+丽
+肿
+璃
+掸
+跚
+灾
+垂
+樾
+濑
+乎
+莲
+窄
+犹
+撮
+战
+馄
+软
+络
+显
+鸢
+胸
+宾
+妲
+恕
+埔
+蝌
+份
+遇
+巧
+瞟
+粒
+恰
+剥
+桡
+博
+讯
+凯
+堇
+阶
+滤
+卖
+斌
+骚
+彬
+兑
+磺
+樱
+舷
+两
+娱
+福
+仃
+差
+找
+桁
+净
+把
+阴
+污
+戬
+雷
+碓
+蕲
+楚
+罡
+焖
+抽
+妫
+咒
+仑
+闱
+尽
+邑
+菁
+爱
+贷
+沥
+鞑
+牡
+嗉
+崴
+骤
+塌
+嗦
+订
+拮
+滓
+捡
+锻
+次
+坪
+杩
+臃
+箬
+融
+珂
+鹗
+宗
+枚
+降
+鸬
+妯
+阄
+堰
+盐
+毅
+必
+杨
+崃
+俺
+甬
+状
+莘
+货
+耸
+菱
+腼
+铸
+唏
+痤
+孚
+澳
+懒
+溅
+翘
+疙
+杷
+淼
+缙
+骰
+喊
+悉
+砻
+坷
+艇
+赁
+界
+谤
+纣
+宴
+晃
+茹
+归
+饭
+梢
+铡
+街
+抄
+肼
+鬟
+苯
+颂
+撷
+戈
+炒
+咆
+茭
+瘙
+负
+仰
+客
+琉
+铢
+封
+卑
+珥
+椿
+镧
+窨
+鬲
+寿
+御
+袤
+铃
+萎
+砖
+餮
+脒
+裳
+肪
+孕
+嫣
+馗
+嵇
+恳
+氯
+江
+石
+褶
+冢
+祸
+阻
+狈
+羞
+银
+靳
+透
+咳
+叼
+敷
+芷
+啥
+它
+瓤
+兰
+痘
+懊
+逑
+肌
+往
+捺
+坊
+甩
+呻
+〃
+沦
+忘
+膻
+祟
+菅
+剧
+崆
+智
+坯
+臧
+霍
+墅
+攻
+眯
+倘
+拢
+骠
+铐
+庭
+岙
+瓠
+′
+缺
+泥
+迢
+捶
+?
+?
+郏
+喙
+掷
+沌
+纯
+秘
+种
+听
+绘
+固
+螨
+团
+香
+盗
+妒
+埚
+蓝
+拖
+旱
+荞
+铀
+血
+遏
+汲
+辰
+叩
+拽
+幅
+硬
+惶
+桀
+漠
+措
+泼
+唑
+齐
+肾
+念
+酱
+虚
+屁
+耶
+旗
+砦
+闵
+婉
+馆
+拭
+绅
+韧
+忏
+窝
+醋
+葺
+顾
+辞
+倜
+堆
+辋
+逆
+玟
+贱
+疾
+董
+惘
+倌
+锕
+淘
+嘀
+莽
+俭
+笏
+绑
+鲷
+杈
+择
+蟀
+粥
+嗯
+驰
+逾
+案
+谪
+褓
+胫
+哩
+昕
+颚
+鲢
+绠
+躺
+鹄
+崂
+儒
+俨
+丝
+尕
+泌
+啊
+萸
+彰
+幺
+吟
+骄
+苣
+弦
+脊
+瑰
+〈
+诛
+镁
+析
+闪
+剪
+侧
+哟
+框
+螃
+守
+嬗
+燕
+狭
+铈
+缮
+概
+迳
+痧
+鲲
+俯
+售
+笼
+痣
+扉
+挖
+满
+咋
+援
+邱
+扇
+歪
+便
+玑
+绦
+峡
+蛇
+叨
+〖
+泽
+胃
+斓
+喋
+怂
+坟
+猪
+该
+蚬
+炕
+弥
+赞
+棣
+晔
+娠
+挲
+狡
+创
+疖
+铕
+镭
+稷
+挫
+弭
+啾
+翔
+粉
+履
+苘
+哦
+楼
+秕
+铂
+土
+锣
+瘟
+挣
+栉
+习
+享
+桢
+袅
+磨
+桂
+谦
+延
+坚
+蔚
+噗
+署
+谟
+猬
+钎
+恐
+嬉
+雒
+倦
+衅
+亏
+璩
+睹
+刻
+殿
+王
+算
+雕
+麻
+丘
+柯
+骆
+丸
+塍
+谚
+添
+鲈
+垓
+桎
+蚯
+芥
+予
+飕
+镦
+谌
+窗
+醚
+菀
+亮
+搪
+莺
+蒿
+羁
+足
+J
+真
+轶
+悬
+衷
+靛
+翊
+掩
+哒
+炅
+掐
+冼
+妮
+l
+谐
+稚
+荆
+擒
+犯
+陵
+虏
+浓
+崽
+刍
+陌
+傻
+孜
+千
+靖
+演
+矜
+钕
+煽
+杰
+酗
+渗
+伞
+栋
+俗
+泫
+戍
+罕
+沾
+疽
+灏
+煦
+芬
+磴
+叱
+阱
+榉
+湃
+蜀
+叉
+醒
+彪
+租
+郡
+篷
+屎
+良
+垢
+隗
+弱
+陨
+峪
+砷
+掴
+颁
+胎
+雯
+绵
+贬
+沐
+撵
+隘
+篙
+暖
+曹
+陡
+栓
+填
+臼
+彦
+瓶
+琪
+潼
+哪
+鸡
+摩
+啦
+俟
+锋
+域
+耻
+蔫
+疯
+纹
+撇
+毒
+绶
+痛
+酯
+忍
+爪
+赳
+歆
+嘹
+辕
+烈
+册
+朴
+钱
+吮
+毯
+癜
+娃
+谀
+邵
+厮
+炽
+璞
+邃
+丐
+追
+词
+瓒
+忆
+轧
+芫
+谯
+喷
+弟
+半
+冕
+裙
+掖
+墉
+绮
+寝
+苔
+势
+顷
+褥
+切
+衮
+君
+佳
+嫒
+蚩
+霞
+佚
+洙
+逊
+镖
+暹
+唛
+&
+殒
+顶
+碗
+獗
+轭
+铺
+蛊
+废
+恹
+汨
+崩
+珍
+那
+杵
+曲
+纺
+夏
+薰
+傀
+闳
+淬
+姘
+舀
+拧
+卷
+楂
+恍
+讪
+厩
+寮
+篪
+赓
+乘
+灭
+盅
+鞣
+沟
+慎
+挂
+饺
+鼾
+杳
+树
+缨
+丛
+絮
+娌
+臻
+嗳
+篡
+侩
+述
+衰
+矛
+圈
+蚜
+匕
+筹
+匿
+濞
+晨
+叶
+骋
+郝
+挚
+蚴
+滞
+增
+侍
+描
+瓣
+吖
+嫦
+蟒
+匾
+圣
+赌
+毡
+癞
+恺
+百
+曳
+需
+篓
+肮
+庖
+帏
+卿
+驿
+遗
+蹬
+鬓
+骡
+歉
+芎
+胳
+屐
+禽
+烦
+晌
+寄
+媾
+狄
+翡
+苒
+船
+廉
+终
+痞
+殇
+々
+畦
+饶
+改
+拆
+悻
+萄
+£
+瓿
+乃
+訾
+桅
+匮
+溧
+拥
+纱
+铍
+骗
+蕃
+龋
+缬
+父
+佐
+疚
+栎
+醍
+掳
+蓄
+x
+惆
+颜
+鲆
+榆
+〔
+猎
+敌
+暴
+谥
+鲫
+贾
+罗
+玻
+缄
+扦
+芪
+癣
+落
+徒
+臾
+恿
+猩
+托
+邴
+肄
+牵
+春
+陛
+耀
+刊
+拓
+蓓
+邳
+堕
+寇
+枉
+淌
+啡
+湄
+兽
+酷
+萼
+碚
+濠
+萤
+夹
+旬
+戮
+梭
+琥
+椭
+昔
+勺
+蜊
+绐
+晚
+孺
+僵
+宣
+摄
+冽
+旨
+萌
+忙
+蚤
+眉
+噼
+蟑
+付
+契
+瓜
+悼
+颡
+壁
+曾
+窕
+颢
+澎
+仿
+俑
+浑
+嵌
+浣
+乍
+碌
+褪
+乱
+蔟
+隙
+玩
+剐
+葫
+箫
+纲
+围
+伐
+决
+伙
+漩
+瑟
+刑
+肓
+镳
+缓
+蹭
+氨
+皓
+典
+畲
+坍
+铑
+檐
+塑
+洞
+倬
+储
+胴
+淳
+戾
+吐
+灼
+惺
+妙
+毕
+珐
+缈
+虱
+盖
+羰
+鸿
+磅
+谓
+髅
+娴
+苴
+唷
+蚣
+霹
+抨
+贤
+唠
+犬
+誓
+逍
+庠
+逼
+麓
+籼
+釉
+呜
+碧
+秧
+氩
+摔
+霄
+穸
+纨
+辟
+妈
+映
+完
+牛
+缴
+嗷
+炊
+恩
+荔
+茆
+掉
+紊
+慌
+莓
+羟
+阙
+萁
+磐
+另
+蕹
+辱
+鳐
+湮
+吡
+吩
+唐
+睦
+垠
+舒
+圜
+冗
+瞿
+溺
+芾
+囱
+匠
+僳
+汐
+菩
+饬
+漓
+黑
+霰
+浸
+濡
+窥
+毂
+蒡
+兢
+驻
+鹉
+芮
+诙
+迫
+雳
+厂
+忐
+臆
+猴
+鸣
+蚪
+栈
+箕
+羡
+渐
+莆
+捍
+眈
+哓
+趴
+蹼
+埕
+嚣
+骛
+宏
+淄
+斑
+噜
+严
+瑛
+垃
+椎
+诱
+压
+庾
+绞
+焘
+廿
+抡
+迄
+棘
+夫
+纬
+锹
+眨
+瞌
+侠
+脐
+竞
+瀑
+孳
+骧
+遁
+姜
+颦
+荪
+滚
+萦
+伪
+逸
+粳
+爬
+锁
+矣
+役
+趣
+洒
+颔
+诏
+逐
+奸
+甭
+惠
+攀
+蹄
+泛
+尼
+拼
+阮
+鹰
+亚
+颈
+惑
+勒
+〉
+际
+肛
+爷
+刚
+钨
+丰
+养
+冶
+鲽
+辉
+蔻
+画
+覆
+皴
+妊
+麦
+返
+醉
+皂
+擀
+〗
+酶
+凑
+粹
+悟
+诀
+硖
+港
+卜
+z
+杀
+涕
+舍
+铠
+抵
+弛
+段
+敝
+镐
+奠
+拂
+轴
+跛
+袱
+e
+t
+沉
+菇
+俎
+薪
+峦
+秭
+蟹
+历
+盟
+菠
+寡
+液
+肢
+喻
+染
+裱
+悱
+抱
+氙
+赤
+捅
+猛
+跑
+氮
+谣
+仁
+尺
+辊
+窍
+烙
+衍
+架
+擦
+倏
+璐
+瑁
+币
+楞
+胖
+夔
+趸
+邛
+惴
+饕
+虔
+蝎
+哉
+贝
+宽
+辫
+炮
+扩
+饲
+籽
+魏
+菟
+锰
+伍
+猝
+末
+琳
+哚
+蛎
+邂
+呀
+姿
+鄞
+却
+歧
+仙
+恸
+椐
+森
+牒
+寤
+袒
+婆
+虢
+雅
+钉
+朵
+贼
+欲
+苞
+寰
+故
+龚
+坭
+嘘
+咫
+礼
+硷
+兀
+睢
+汶
+’
+铲
+烧
+绕
+诃
+浃
+钿
+哺
+柜
+讼
+颊
+璁
+腔
+洽
+咐
+脲
+簌
+筠
+镣
+玮
+鞠
+谁
+兼
+姆
+挥
+梯
+蝴
+谘
+漕
+刷
+躏
+宦
+弼
+b
+垌
+劈
+麟
+莉
+揭
+笙
+渎
+仕
+嗤
+仓
+配
+怏
+抬
+错
+泯
+镊
+孰
+猿
+邪
+仍
+秋
+鼬
+壹
+歇
+吵
+炼
+<
+尧
+射
+柬
+廷
+胧
+霾
+凳
+隋
+肚
+浮
+梦
+祥
+株
+堵
+退
+L
+鹫
+跎
+凶
+毽
+荟
+炫
+栩
+玳
+甜
+沂
+鹿
+顽
+伯
+爹
+赔
+蛴
+徐
+匡
+欣
+狰
+缸
+雹
+蟆
+疤
+默
+沤
+啜
+痂
+衣
+禅
+w
+i
+h
+辽
+葳
+黝
+钗
+停
+沽
+棒
+馨
+颌
+肉
+吴
+硫
+悯
+劾
+娈
+马
+啧
+吊
+悌
+镑
+峭
+帆
+瀣
+涉
+咸
+疸
+滋
+泣
+翦
+拙
+癸
+钥
+蜒
++
+尾
+庄
+凝
+泉
+婢
+渴
+谊
+乞
+陆
+锉
+糊
+鸦
+淮
+I
+B
+N
+晦
+弗
+乔
+庥
+葡
+尻
+席
+橡
+傣
+渣
+拿
+惩
+麋
+斛
+缃
+矮
+蛏
+岘
+鸽
+姐
+膏
+催
+奔
+镒
+喱
+蠡
+摧
+钯
+胤
+柠
+拐
+璋
+鸥
+卢
+荡
+倾
+^
+_
+珀
+逄
+萧
+塾
+掇
+贮
+笆
+聂
+圃
+冲
+嵬
+M
+滔
+笕
+值
+炙
+偶
+蜱
+搐
+梆
+汪
+蔬
+腑
+鸯
+蹇
+敞
+绯
+仨
+祯
+谆
+梧
+糗
+鑫
+啸
+豺
+囹
+猾
+巢
+柄
+瀛
+筑
+踌
+沭
+暗
+苁
+鱿
+蹉
+脂
+蘖
+牢
+热
+木
+吸
+溃
+宠
+序
+泞
+偿
+拜
+檩
+厚
+朐
+毗
+螳
+吞
+媚
+朽
+担
+蝗
+橘
+畴
+祈
+糟
+盱
+隼
+郜
+惜
+珠
+裨
+铵
+焙
+琚
+唯
+咚
+噪
+骊
+丫
+滢
+勤
+棉
+呸
+咣
+淀
+隔
+蕾
+窈
+饨
+挨
+煅
+短
+匙
+粕
+镜
+赣
+撕
+墩
+酬
+馁
+豌
+颐
+抗
+酣
+氓
+佑
+搁
+哭
+递
+耷
+涡
+桃
+贻
+碣
+截
+瘦
+昭
+镌
+蔓
+氚
+甲
+猕
+蕴
+蓬
+散
+拾
+纛
+狼
+猷
+铎
+埋
+旖
+矾
+讳
+囊
+糜
+迈
+粟
+蚂
+紧
+鲳
+瘢
+栽
+稼
+羊
+锄
+斟
+睁
+桥
+瓮
+蹙
+祉
+醺
+鼻
+昱
+剃
+跳
+篱
+跷
+蒜
+翎
+宅
+晖
+嗑
+壑
+峻
+癫
+屏
+狠
+陋
+袜
+途
+憎
+祀
+莹
+滟
+佶
+溥
+臣
+约
+盛
+峰
+磁
+慵
+婪
+拦
+莅
+朕
+鹦
+粲
+裤
+哎
+疡
+嫖
+琵
+窟
+堪
+谛
+嘉
+儡
+鳝
+斩
+郾
+驸
+酊
+妄
+胜
+贺
+徙
+傅
+噌
+钢
+栅
+庇
+恋
+匝
+巯
+邈
+尸
+锚
+粗
+佟
+蛟
+薹
+纵
+蚊
+郅
+绢
+锐
+苗
+俞
+篆
+淆
+膀
+鲜
+煎
+诶
+秽
+寻
+涮
+刺
+怀
+噶
+巨
+褰
+魅
+灶
+灌
+桉
+藕
+谜
+舸
+薄
+搀
+恽
+借
+牯
+痉
+渥
+愿
+亓
+耘
+杠
+柩
+锔
+蚶
+钣
+珈
+喘
+蹒
+幽
+赐
+稗
+晤
+莱
+泔
+扯
+肯
+菪
+裆
+腩
+豉
+疆
+骜
+腐
+倭
+珏
+唔
+粮
+亡
+润
+慰
+伽
+橄
+玄
+誉
+醐
+胆
+龊
+粼
+塬
+陇
+彼
+削
+嗣
+绾
+芽
+妗
+垭
+瘴
+爽
+薏
+寨
+龈
+泠
+弹
+赢
+漪
+猫
+嘧
+涂
+恤
+圭
+茧
+烽
+屑
+痕
+巾
+赖
+荸
+凰
+腮
+畈
+亵
+蹲
+偃
+苇
+澜
+艮
+换
+骺
+烘
+苕
+梓
+颉
+肇
+哗
+悄
+氤
+涠
+葬
+屠
+鹭
+植
+竺
+佯
+诣
+鲇
+瘀
+鲅
+邦
+移
+滁
+冯
+耕
+癔
+戌
+茬
+沁
+巩
+悠
+湘
+洪
+痹
+锟
+循
+谋
+腕
+鳃
+钠
+捞
+焉
+迎
+碱
+伫
+急
+榷
+奈
+邝
+卯
+辄
+皲
+卟
+醛
+畹
+忧
+稳
+雄
+昼
+缩
+阈
+睑
+扌
+耗
+曦
+涅
+捏
+瞧
+邕
+淖
+漉
+铝
+耦
+禹
+湛
+喽
+莼
+琅
+诸
+苎
+纂
+硅
+始
+嗨
+傥
+燃
+臂
+赅
+嘈
+呆
+贵
+屹
+壮
+肋
+亍
+蚀
+卅
+豹
+腆
+邬
+迭
+浊
+}
+童
+螂
+捐
+圩
+勐
+触
+寞
+汊
+壤
+荫
+膺
+渌
+芳
+懿
+遴
+螈
+泰
+蓼
+蛤
+茜
+舅
+枫
+朔
+膝
+眙
+避
+梅
+判
+鹜
+璜
+牍
+缅
+垫
+藻
+黔
+侥
+惚
+懂
+踩
+腰
+腈
+札
+丞
+唾
+慈
+顿
+摹
+荻
+琬
+~
+斧
+沈
+滂
+胁
+胀
+幄
+莜
+Z
+匀
+鄄
+掌
+绰
+茎
+焚
+赋
+萱
+谑
+汁
+铒
+瞎
+夺
+蜗
+野
+娆
+冀
+弯
+篁
+懵
+灞
+隽
+芡
+脘
+俐
+辩
+芯
+掺
+喏
+膈
+蝈
+觐
+悚
+踹
+蔗
+熠
+鼠
+呵
+抓
+橼
+峨
+畜
+缔
+禾
+崭
+弃
+熊
+摒
+凸
+拗
+穹
+蒙
+抒
+祛
+劝
+闫
+扳
+阵
+醌
+踪
+喵
+侣
+搬
+仅
+荧
+赎
+蝾
+琦
+买
+婧
+瞄
+寓
+皎
+冻
+赝
+箩
+莫
+瞰
+郊
+笫
+姝
+筒
+枪
+遣
+煸
+袋
+舆
+痱
+涛
+母
+〇
+启
+践
+耙
+绲
+盘
+遂
+昊
+搞
+槿
+诬
+纰
+泓
+惨
+檬
+亻
+越
+C
+o
+憩
+熵
+祷
+钒
+暧
+塔
+阗
+胰
+咄
+娶
+魔
+琶
+钞
+邻
+扬
+杉
+殴
+咽
+弓
+〆
+髻
+】
+吭
+揽
+霆
+拄
+殖
+脆
+彻
+岩
+芝
+勃
+辣
+剌
+钝
+嘎
+甄
+佘
+皖
+伦
+授
+徕
+憔
+挪
+皇
+庞
+稔
+芜
+踏
+溴
+兖
+卒
+擢
+饥
+鳞
+煲
+‰
+账
+颗
+叻
+斯
+捧
+鳍
+琮
+讹
+蛙
+纽
+谭
+酸
+兔
+莒
+睇
+伟
+觑
+羲
+嗜
+宜
+褐
+旎
+辛
+卦
+诘
+筋
+鎏
+溪
+挛
+熔
+阜
+晰
+鳅
+丢
+奚
+灸
+呱
+献
+陉
+黛
+鸪
+甾
+萨
+疮
+拯
+洲
+疹
+辑
+叙
+恻
+谒
+允
+柔
+烂
+氏
+逅
+漆
+拎
+惋
+扈
+湟
+纭
+啕
+掬
+擞
+哥
+忽
+涤
+鸵
+靡
+郗
+瓷
+扁
+廊
+怨
+雏
+钮
+敦
+E
+懦
+憋
+汀
+拚
+啉
+腌
+岸
+f
+痼
+瞅
+尊
+咀
+眩
+飙
+忌
+仝
+迦
+熬
+毫
+胯
+篑
+茄
+腺
+凄
+舛
+碴
+锵
+诧
+羯
+後
+漏
+汤
+宓
+仞
+蚁
+壶
+谰
+皑
+铄
+棰
+罔
+辅
+晶
+苦
+牟
+闽
+\
+烃
+饮
+聿
+丙
+蛳
+朱
+煤
+涔
+鳖
+犁
+罐
+荼
+砒
+淦
+妤
+黏
+戎
+孑
+婕
+瑾
+戢
+钵
+枣
+捋
+砥
+衩
+狙
+桠
+稣
+阎
+肃
+梏
+诫
+孪
+昶
+婊
+衫
+嗔
+侃
+塞
+蜃
+樵
+峒
+貌
+屿
+欺
+缫
+阐
+栖
+诟
+珞
+荭
+吝
+萍
+嗽
+恂
+啻
+蜴
+磬
+峋
+俸
+豫
+谎
+徊
+镍
+韬
+魇
+晴
+U
+囟
+猜
+蛮
+坐
+囿
+伴
+亭
+肝
+佗
+蝠
+妃
+胞
+滩
+榴
+氖
+垩
+苋
+砣
+扪
+馏
+姓
+轩
+厉
+夥
+侈
+禀
+垒
+岑
+赏
+钛
+辐
+痔
+披
+纸
+碳
+“
+坞
+蠓
+挤
+荥
+沅
+悔
+铧
+帼
+蒌
+蝇
+a
+p
+y
+n
+g
+哀
+浆
+瑶
+凿
+桶
+馈
+皮
+奴
+苜
+佤
+伶
+晗
+铱
+炬
+优
+弊
+氢
+恃
+甫
+攥
+端
+锌
+灰
+稹
+炝
+曙
+邋
+亥
+眶
+碾
+拉
+萝
+绔
+捷
+浍
+腋
+姑
+菖
+凌
+涞
+麽
+锢
+桨
+潢
+绎
+镰
+殆
+锑
+渝
+铬
+困
+绽
+觎
+匈
+糙
+暑
+裹
+鸟
+盔
+肽
+迷
+綦
+『
+亳
+佝
+俘
+钴
+觇
+骥
+仆
+疝
+跪
+婶
+郯
+瀹
+唉
+脖
+踞
+针
+晾
+忒
+扼
+瞩
+叛
+椒
+疟
+嗡
+邗
+肆
+跆
+玫
+忡
+捣
+咧
+唆
+艄
+蘑
+潦
+笛
+阚
+沸
+泻
+掊
+菽
+贫
+斥
+髂
+孢
+镂
+赂
+麝
+鸾
+屡
+衬
+苷
+恪
+叠
+希
+粤
+爻
+喝
+茫
+惬
+郸
+绻
+庸
+撅
+碟
+宄
+妹
+膛
+叮
+饵
+崛
+嗲
+椅
+冤
+搅
+咕
+敛
+尹
+垦
+闷
+蝉
+霎
+勰
+败
+蓑
+泸
+肤
+鹌
+幌
+焦
+浠
+鞍
+刁
+舰
+乙
+竿
+裔
+。
+茵
+函
+伊
+兄
+丨
+娜
+匍
+謇
+莪
+宥
+似
+蝽
+翳
+酪
+翠
+粑
+薇
+祢
+骏
+赠
+叫
+Q
+噤
+噻
+竖
+芗
+莠
+潭
+俊
+羿
+耜
+O
+郫
+趁
+嗪
+囚
+蹶
+芒
+洁
+笋
+鹑
+敲
+硝
+啶
+堡
+渲
+揩
+』
+携
+宿
+遒
+颍
+扭
+棱
+割
+萜
+蔸
+葵
+琴
+捂
+饰
+衙
+耿
+掠
+募
+岂
+窖
+涟
+蔺
+瘤
+柞
+瞪
+怜
+匹
+距
+楔
+炜
+哆
+秦
+缎
+幼
+茁
+绪
+痨
+恨
+楸
+娅
+瓦
+桩
+雪
+嬴
+伏
+榔
+妥
+铿
+拌
+眠
+雍
+缇
+‘
+卓
+搓
+哌
+觞
+噩
+屈
+哧
+髓
+咦
+巅
+娑
+侑
+淫
+膳
+祝
+勾
+姊
+莴
+胄
+疃
+薛
+蜷
+胛
+巷
+芙
+芋
+熙
+闰
+勿
+窃
+狱
+剩
+钏
+幢
+陟
+铛
+慧
+靴
+耍
+k
+浙
+浇
+飨
+惟
+绗
+祜
+澈
+啼
+咪
+磷
+摞
+诅
+郦
+抹
+跃
+壬
+吕
+肖
+琏
+颤
+尴
+剡
+抠
+凋
+赚
+泊
+津
+宕
+殷
+倔
+氲
+漫
+邺
+涎
+怠
+$
+垮
+荬
+遵
+俏
+叹
+噢
+饽
+蜘
+孙
+筵
+疼
+鞭
+羧
+牦
+箭
+潴
+c
+眸
+祭
+髯
+啖
+坳
+愁
+芩
+驮
+倡
+巽
+穰
+沃
+胚
+怒
+凤
+槛
+剂
+趵
+嫁
+v
+邢
+灯
+鄢
+桐
+睽
+檗
+锯
+槟
+婷
+嵋
+圻
+诗
+蕈
+颠
+遭
+痢
+芸
+怯
+馥
+竭
+锗
+徜
+恭
+遍
+籁
+剑
+嘱
+苡
+龄
+僧
+桑
+潸
+弘
+澶
+楹
+悲
+讫
+愤
+腥
+悸
+谍
+椹
+呢
+桓
+葭
+攫
+阀
+翰
+躲
+敖
+柑
+郎
+笨
+橇
+呃
+魁
+燎
+脓
+葩
+磋
+垛
+玺
+狮
+沓
+砜
+蕊
+锺
+罹
+蕉
+翱
+虐
+闾
+巫
+旦
+茱
+嬷
+枯
+鹏
+贡
+芹
+汛
+矫
+绁
+拣
+禺
+佃
+讣
+舫
+惯
+乳
+趋
+疲
+挽
+岚
+虾
+衾
+蠹
+蹂
+飓
+氦
+铖
+孩
+稞
+瑜
+壅
+掀
+勘
+妓
+畅
+髋
+W
+庐
+牲
+蓿
+榕
+练
+垣
+唱
+邸
+菲
+昆
+婺
+穿
+绡
+麒
+蚱
+掂
+愚
+泷
+涪
+漳
+妩
+娉
+榄
+讷
+觅
+旧
+藤
+煮
+呛
+柳
+腓
+叭
+庵
+烷
+阡
+罂
+蜕
+擂
+猖
+咿
+媲
+脉
+【
+沏
+貅
+黠
+熏
+哲
+烁
+坦
+酵
+兜
+潇
+撒
+剽
+珩
+圹
+乾
+摸
+樟
+帽
+嗒
+襄
+魂
+轿
+憬
+锡
+〕
+喃
+皆
+咖
+隅
+脸
+残
+泮
+袂
+鹂
+珊
+囤
+捆
+咤
+误
+徨
+闹
+淙
+芊
+淋
+怆
+囗
+拨
+梳
+渤
+R
+G
+绨
+蚓
+婀
+幡
+狩
+麾
+谢
+唢
+裸
+旌
+伉
+纶
+裂
+驳
+砼
+咛
+澄
+樨
+蹈
+宙
+澍
+倍
+貔
+操
+勇
+蟠
+摈
+砧
+虬
+够
+缁
+悦
+藿
+撸
+艹
+摁
+淹
+豇
+虎
+榭
+吱
+d
+喧
+荀
+踱
+侮
+奋
+偕
+饷
+犍
+惮
+坑
+璎
+徘
+宛
+妆
+袈
+倩
+窦
+昂
+荏
+乖
+K
+怅
+撰
+鳙
+牙
+袁
+酞
+X
+痿
+琼
+闸
+雁
+趾
+荚
+虻
+涝
+《
+杏
+韭
+偈
+烤
+绫
+鞘
+卉
+症
+遢
+蓥
+诋
+杭
+荨
+匆
+竣
+簪
+辙
+敕
+虞
+丹
+缭
+咩
+黟
+m
+淤
+瑕
+咂
+铉
+硼
+茨
+嶂
+痒
+畸
+敬
+涿
+粪
+窘
+熟
+叔
+嫔
+盾
+忱
+裘
+憾
+梵
+赡
+珙
+咯
+娘
+庙
+溯
+胺
+葱
+痪
+摊
+荷
+卞
+乒
+髦
+寐
+铭
+坩
+胗
+枷
+爆
+溟
+嚼
+羚
+砬
+轨
+惊
+挠
+罄
+竽
+菏
+氧
+浅
+楣
+盼
+枢
+炸
+阆
+杯
+谏
+噬
+淇
+渺
+俪
+秆
+墓
+泪
+跻
+砌
+痰
+垡
+渡
+耽
+釜
+讶
+鳎
+煞
+呗
+韶
+舶
+绷
+鹳
+缜
+旷
+铊
+皱
+龌
+檀
+霖
+奄
+槐
+艳
+蝶
+旋
+哝
+赶
+骞
+蚧
+腊
+盈
+丁
+`
+蜚
+矸
+蝙
+睨
+嚓
+僻
+鬼
+醴
+夜
+彝
+磊
+笔
+拔
+栀
+糕
+厦
+邰
+纫
+逭
+纤
+眦
+膊
+馍
+躇
+烯
+蘼
+冬
+诤
+暄
+骶
+哑
+瘠
+」
+臊
+丕
+愈
+咱
+螺
+擅
+跋
+搏
+硪
+谄
+笠
+淡
+嘿
+骅
+谧
+鼎
+皋
+姚
+歼
+蠢
+驼
+耳
+胬
+挝
+涯
+狗
+蒽
+孓
+犷
+凉
+芦
+箴
+铤
+孤
+嘛
+坤
+V
+茴
+朦
+挞
+尖
+橙
+诞
+搴
+碇
+洵
+浚
+帚
+蜍
+漯
+柘
+嚎
+讽
+芭
+荤
+咻
+祠
+秉
+跖
+埃
+吓
+糯
+眷
+馒
+惹
+娼
+鲑
+嫩
+讴
+轮
+瞥
+靶
+褚
+乏
+缤
+宋
+帧
+删
+驱
+碎
+扑
+俩
+俄
+偏
+涣
+竹
+噱
+皙
+佰
+渚
+唧
+斡
+#
+镉
+刀
+崎
+筐
+佣
+夭
+贰
+肴
+峙
+哔
+艿
+匐
+牺
+镛
+缘
+仡
+嫡
+劣
+枸
+堀
+梨
+簿
+鸭
+蒸
+亦
+稽
+浴
+{
+衢
+束
+槲
+j
+阁
+揍
+疥
+棋
+潋
+聪
+窜
+乓
+睛
+插
+冉
+阪
+苍
+搽
+「
+蟾
+螟
+幸
+仇
+樽
+撂
+慢
+跤
+幔
+俚
+淅
+覃
+觊
+溶
+妖
+帛
+侨
+曰
+妾
+泗
+:
+瀘
+風
+(
+)
+∶
+紅
+紗
+瑭
+雲
+頭
+鶏
+財
+許
+•
+樂
+焗
+麗
+—
+;
+滙
+東
+榮
+繪
+興
+…
+門
+業
+楊
+國
+顧
+盤
+寳
+龍
+鳳
+島
+誌
+緣
+結
+銭
+萬
+勝
+祎
+璟
+優
+歡
+臨
+時
+購
+=
+★
+藍
+昇
+鐵
+觀
+勅
+農
+聲
+畫
+兿
+術
+發
+劉
+記
+專
+耑
+園
+書
+壴
+種
+●
+褀
+號
+銀
+匯
+敟
+锘
+葉
+橪
+廣
+進
+蒄
+鑽
+阝
+祙
+貢
+鍋
+豊
+夬
+喆
+團
+閣
+開
+燁
+賓
+館
+酡
+沔
+順
++
+硚
+劵
+饸
+陽
+車
+湓
+復
+萊
+氣
+軒
+華
+堃
+迮
+纟
+戶
+馬
+學
+裡
+電
+嶽
+獨
+マ
+シ
+サ
+ジ
+燘
+袪
+環
+❤
+臺
+灣
+専
+賣
+孖
+聖
+攝
+線
+▪
+傢
+俬
+夢
+達
+莊
+喬
+貝
+薩
+劍
+羅
+壓
+棛
+饦
+尃
+璈
+囍
+醫
+G
+I
+A
+#
+N
+鷄
+髙
+嬰
+啓
+約
+隹
+潔
+賴
+藝
+~
+寶
+籣
+麺
+ 
+嶺
+√
+義
+網
+峩
+長
+∧
+魚
+機
+構
+②
+鳯
+偉
+L
+B
+㙟
+畵
+鴿
+'
+詩
+溝
+嚞
+屌
+藔
+佧
+玥
+蘭
+織
+1
+3
+9
+0
+7
+點
+砭
+鴨
+鋪
+銘
+廳
+弍
+‧
+創
+湯
+坶
+℃
+卩
+骝
+&
+烜
+荘
+當
+潤
+扞
+係
+懷
+碶
+钅
+蚨
+讠
+☆
+叢
+爲
+埗
+涫
+塗
+→
+楽
+現
+鯨
+愛
+瑪
+鈺
+忄
+悶
+藥
+飾
+樓
+視
+孬
+ㆍ
+燚
+苪
+師
+①
+丼
+锽
+│
+韓
+標
+兒
+閏
+匋
+張
+漢
+髪
+會
+閑
+檔
+習
+裝
+の
+峯
+菘
+輝
+雞
+釣
+億
+浐
+K
+O
+R
+8
+H
+E
+P
+T
+W
+D
+S
+C
+M
+F
+姌
+饹
+晞
+廰
+嵯
+鷹
+負
+飲
+絲
+冚
+楗
+澤
+綫
+區
+❋
+←
+質
+靑
+揚
+③
+滬
+統
+産
+協
+﹑
+乸
+畐
+經
+運
+際
+洺
+岽
+為
+粵
+諾
+崋
+豐
+碁
+V
+2
+6
+齋
+誠
+訂
+勑
+雙
+陳
+無
+泩
+媄
+夌
+刂
+i
+c
+t
+o
+r
+a
+嘢
+耄
+燴
+暃
+壽
+媽
+靈
+抻
+體
+唻
+冮
+甹
+鎮
+錦
+蜛
+蠄
+尓
+駕
+戀
+飬
+逹
+倫
+貴
+極
+寬
+磚
+嶪
+郎
+職
+|
+間
+n
+d
+剎
+伈
+課
+飛
+橋
+瘊
+№
+譜
+骓
+圗
+滘
+縣
+粿
+咅
+養
+濤
+彳
+%
+Ⅱ
+啰
+㴪
+見
+矞
+薬
+糁
+邨
+鲮
+顔
+罱
+選
+話
+贏
+氪
+俵
+競
+瑩
+繡
+枱
+綉
+獅
+爾
+™
+麵
+戋
+淩
+徳
+個
+劇
+場
+務
+簡
+寵
+h
+實
+膠
+轱
+圖
+築
+嘣
+樹
+㸃
+營
+耵
+孫
+饃
+鄺
+飯
+麯
+遠
+輸
+坫
+孃
+乚
+閃
+鏢
+㎡
+題
+廠
+關
+↑
+爺
+將
+軍
+連
+篦
+覌
+參
+箸
+-
+窠
+棽
+寕
+夀
+爰
+歐
+呙
+閥
+頡
+熱
+雎
+垟
+裟
+凬
+勁
+帑
+馕
+夆
+疌
+枼
+馮
+貨
+蒤
+樸
+彧
+旸
+靜
+龢
+暢
+㐱
+鳥
+珺
+鏡
+灡
+爭
+堷
+廚
+騰
+診
+┅
+蘇
+褔
+凱
+頂
+豕
+亞
+帥
+嘬
+⊥
+仺
+桖
+複
+饣
+絡
+穂
+顏
+棟
+納
+▏
+濟
+親
+設
+計
+攵
+埌
+烺
+頤
+燦
+蓮
+撻
+節
+講
+濱
+濃
+娽
+洳
+朿
+燈
+鈴
+護
+膚
+铔
+過
+補
+Z
+U
+5
+4
+坋
+闿
+䖝
+餘
+缐
+铞
+貿
+铪
+桼
+趙
+鍊
+[
+㐂
+垚
+菓
+揸
+捲
+鐘
+滏
+𣇉
+爍
+輪
+燜
+鴻
+鮮
+動
+鹞
+鷗
+丄
+慶
+鉌
+翥
+飮
+腸
+⇋
+漁
+覺
+來
+熘
+昴
+翏
+鲱
+圧
+鄉
+萭
+頔
+爐
+嫚
+貭
+類
+聯
+幛
+輕
+訓
+鑒
+夋
+锨
+芃
+珣
+䝉
+扙
+嵐
+銷
+處
+ㄱ
+語
+誘
+苝
+歸
+儀
+燒
+楿
+內
+粢
+葒
+奧
+麥
+礻
+滿
+蠔
+穵
+瞭
+態
+鱬
+榞
+硂
+鄭
+黃
+煙
+祐
+奓
+逺
+*
+瑄
+獲
+聞
+薦
+讀
+這
+樣
+決
+問
+啟
+們
+執
+説
+轉
+單
+隨
+唘
+帶
+倉
+庫
+還
+贈
+尙
+皺
+■
+餅
+產
+○
+∈
+報
+狀
+楓
+賠
+琯
+嗮
+禮
+`
+傳
+>
+≤
+嗞
+≥
+換
+咭
+∣
+↓
+曬
+応
+寫
+″
+終
+様
+純
+費
+療
+聨
+凍
+壐
+郵
+黒
+∫
+製
+塊
+調
+軽
+確
+撃
+級
+馴
+Ⅲ
+涇
+繹
+數
+碼
+證
+狒
+処
+劑
+<
+晧
+賀
+衆
+]
+櫥
+兩
+陰
+絶
+對
+鯉
+憶
+◎
+p
+e
+Y
+蕒
+煖
+頓
+測
+試
+鼽
+僑
+碩
+妝
+帯
+≈
+鐡
+舖
+權
+喫
+倆
+該
+悅
+俫
+.
+f
+s
+b
+m
+k
+g
+u
+j
+貼
+淨
+濕
+針
+適
+備
+l
+/
+給
+謢
+強
+觸
+衛
+與
+⊙
+$
+緯
+變
+⑴
+⑵
+⑶
+㎏
+殺
+∩
+幚
+─
+價
+▲
+離
+飄
+烏
+関
+閟
+﹝
+﹞
+邏
+輯
+鍵
+驗
+訣
+導
+歷
+屆
+層
+▼
+儱
+錄
+熳
+艦
+吋
+錶
+辧
+飼
+顯
+④
+禦
+販
+気
+対
+枰
+閩
+紀
+幹
+瞓
+貊
+淚
+△
+眞
+墊
+獻
+褲
+縫
+緑
+亜
+鉅
+餠
+{
+}
+◆
+蘆
+薈
+█
+◇
+溫
+彈
+晳
+粧
+犸
+穩
+訊
+崬
+凖
+熥
+舊
+條
+紋
+圍
+Ⅳ
+筆
+尷
+難
+雜
+錯
+綁
+識
+頰
+鎖
+艶
+□
+殁
+殼
+⑧
+├
+▕
+鵬
+糝
+綱
+▎
+盜
+饅
+醬
+籤
+蓋
+釀
+鹽
+據
+辦
+◥
+彐
+┌
+婦
+獸
+鲩
+伱
+蒟
+蒻
+齊
+袆
+腦
+寧
+凈
+妳
+煥
+詢
+偽
+謹
+啫
+鯽
+騷
+鱸
+損
+傷
+鎻
+髮
+買
+冏
+儥
+両
+﹢
+∞
+載
+喰
+z
+羙
+悵
+燙
+曉
+員
+組
+徹
+艷
+痠
+鋼
+鼙
+縮
+細
+嚒
+爯
+≠
+維
+"
+鱻
+壇
+厍
+帰
+浥
+犇
+薡
+軎
+應
+醜
+刪
+緻
+鶴
+賜
+噁
+軌
+尨
+镔
+鷺
+槗
+彌
+葚
+濛
+請
+溇
+緹
+賢
+訪
+獴
+瑅
+資
+縤
+陣
+蕟
+栢
+韻
+祼
+恁
+伢
+謝
+劃
+涑
+總
+衖
+踺
+砋
+凉
+籃
+駿
+苼
+瘋
+昽
+紡
+驊
+腎
+﹗
+響
+杋
+剛
+嚴
+禪
+歓
+槍
+傘
+檸
+檫
+炣
+勢
+鏜
+鎢
+銑
+尐
+減
+奪
+惡
+僮
+婭
+臘
+殻
+鉄
+∑
+蛲
+焼
+緖
+續
+紹
+懮

+ 732 - 0
torchocr/datasets/det_modules/FCE_aug.py

@@ -0,0 +1,732 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/transforms.py
+"""
+import numpy as np
+import torchvision.transforms
+from PIL import Image, ImageDraw
+import cv2
+from shapely.geometry import Polygon
+import math
+from torchocr.utils.poly_nms import poly_intersection
+from torchvision.transforms import ColorJitter as Jitter
+
+class Pad(object):
+    def __init__(self, size=None, size_div=32, **kwargs):
+        if size is not None and not isinstance(size, (int, list, tuple)):
+            raise TypeError("Type of target_size is invalid. Now is {}".format(
+                type(size)))
+        if isinstance(size, int):
+            size = [size, size]
+        self.size = size
+        self.size_div = size_div
+
+    def __call__(self, data):
+
+        img = data['img']
+        img_h, img_w = img.shape[0], img.shape[1]
+        if self.size:
+            resize_h2, resize_w2 = self.size
+            assert (
+                img_h < resize_h2 and img_w < resize_w2
+            ), '(h, w) of target size should be greater than (img_h, img_w)'
+        else:
+            resize_h2 = max(
+                int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
+                self.size_div)
+            resize_w2 = max(
+                int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
+                self.size_div)
+        img = cv2.copyMakeBorder(
+            img,
+            0,
+            resize_h2 - img_h,
+            0,
+            resize_w2 - img_w,
+            cv2.BORDER_CONSTANT,
+            value=0)
+        data['img'] = img
+        return data
+
+class ColorJitter(object):
+    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, **kwargs):
+        self.aug = Jitter(brightness, contrast, saturation, hue)
+
+    def __call__(self, data):
+        image = data['img']
+        pil_img=Image.fromarray(image).convert('RGB')
+        image = np.asarray(self.aug(pil_img))
+        data['img'] = image
+        return data
+
+
+class RandomScaling:
+    def __init__(self, size=800, scale=(3. / 4, 5. / 2), **kwargs):
+        """Random scale the image while keeping aspect.
+
+        Args:
+            size (int) : Base size before scaling.
+            scale (tuple(float)) : The range of scaling.
+        """
+        assert isinstance(size, int)
+        assert isinstance(scale, float) or isinstance(scale, tuple)
+        self.size = size
+        self.scale = scale if isinstance(scale, tuple) \
+            else (1 - scale, 1 + scale)
+
+    def __call__(self, data):
+        image = data['img']
+        text_polys = data['text_polys']
+        h, w, _ = image.shape
+
+        aspect_ratio = np.random.uniform(min(self.scale), max(self.scale))
+        scales = self.size * 1.0 / max(h, w) * aspect_ratio
+        scales = np.array([scales, scales])
+        out_size = (int(h * scales[1]), int(w * scales[0]))
+        image = cv2.resize(image, out_size[::-1])
+        try:
+            data['img'] = image
+            text_polys[:, :, 0::2] = text_polys[:, :, 0::2] * scales[1]
+            text_polys[:, :, 1::2] = text_polys[:, :, 1::2] * scales[0]
+            data['text_polys'] = text_polys
+        except:
+            print('1')
+
+        return data
+
+
+class RandomCropFlip:
+    def __init__(self,
+                 pad_ratio=0.1,
+                 crop_ratio=0.5,
+                 iter_num=1,
+                 min_area_ratio=0.2,
+                 **kwargs):
+        """Random crop and flip a patch of the image.
+
+        Args:
+            crop_ratio (float): The ratio of cropping.
+            iter_num (int): Number of operations.
+            min_area_ratio (float): Minimal area ratio between cropped patch
+                and original image.
+        """
+        assert isinstance(crop_ratio, float)
+        assert isinstance(iter_num, int)
+        assert isinstance(min_area_ratio, float)
+
+        self.pad_ratio = pad_ratio
+        self.epsilon = 1e-2
+        self.crop_ratio = crop_ratio
+        self.iter_num = iter_num
+        self.min_area_ratio = min_area_ratio
+
+    def __call__(self, results):
+        for i in range(self.iter_num):
+            results = self.random_crop_flip(results)
+
+        return results
+
+    def random_crop_flip(self, results):
+        image = results['img']
+        polygons = results['text_polys']
+        ignore_tags = results['ignore_tags']
+        if len(polygons) == 0:
+            return results
+
+        if np.random.random() >= self.crop_ratio:
+            return results
+
+        h, w, _ = image.shape
+        area = h * w
+        pad_h = int(h * self.pad_ratio)
+        pad_w = int(w * self.pad_ratio)
+        h_axis, w_axis = self.generate_crop_target(image, polygons, pad_h,
+                                                   pad_w)
+        if len(h_axis) == 0 or len(w_axis) == 0:
+            return results
+
+        attempt = 0
+        while attempt < 50:
+            attempt += 1
+            polys_keep = []
+            polys_new = []
+            ignore_tags_keep = []
+            ignore_tags_new = []
+            xx = np.random.choice(w_axis, size=2)
+            xmin = np.min(xx) - pad_w
+            xmax = np.max(xx) - pad_w
+            xmin = np.clip(xmin, 0, w - 1)
+            xmax = np.clip(xmax, 0, w - 1)
+            yy = np.random.choice(h_axis, size=2)
+            ymin = np.min(yy) - pad_h
+            ymax = np.max(yy) - pad_h
+            ymin = np.clip(ymin, 0, h - 1)
+            ymax = np.clip(ymax, 0, h - 1)
+            if (xmax - xmin) * (ymax - ymin) < area * self.min_area_ratio:
+                # area too small
+                continue
+
+            pts = np.stack([[xmin, xmax, xmax, xmin],
+                            [ymin, ymin, ymax, ymax]]).T.astype(np.int32)
+            pp = Polygon(pts)
+            fail_flag = False
+            for polygon, ignore_tag in zip(polygons, ignore_tags):
+                ppi = Polygon(polygon.reshape(-1, 2))
+                ppiou, _ = poly_intersection(ppi, pp, buffer=0)
+                if np.abs(ppiou - float(ppi.area)) > self.epsilon and \
+                        np.abs(ppiou) > self.epsilon:
+                    fail_flag = True
+                    break
+                elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
+                    polys_new.append(polygon)
+                    ignore_tags_new.append(ignore_tag)
+                else:
+                    polys_keep.append(polygon)
+                    ignore_tags_keep.append(ignore_tag)
+
+            if fail_flag:
+                continue
+            else:
+                break
+
+        cropped = image[ymin:ymax, xmin:xmax, :]
+        select_type = np.random.randint(3)
+        if select_type == 0:
+            img = np.ascontiguousarray(cropped[:, ::-1])
+        elif select_type == 1:
+            img = np.ascontiguousarray(cropped[::-1, :])
+        else:
+            img = np.ascontiguousarray(cropped[::-1, ::-1])
+        image[ymin:ymax, xmin:xmax, :] = img
+        results['img'] = image
+
+        if len(polys_new) != 0:
+            height, width, _ = cropped.shape
+            if select_type == 0:
+                for idx, polygon in enumerate(polys_new):
+                    poly = polygon.reshape(-1, 2)
+                    poly[:, 0] = width - poly[:, 0] + 2 * xmin
+                    polys_new[idx] = poly
+            elif select_type == 1:
+                for idx, polygon in enumerate(polys_new):
+                    poly = polygon.reshape(-1, 2)
+                    poly[:, 1] = height - poly[:, 1] + 2 * ymin
+                    polys_new[idx] = poly
+            else:
+                for idx, polygon in enumerate(polys_new):
+                    poly = polygon.reshape(-1, 2)
+                    poly[:, 0] = width - poly[:, 0] + 2 * xmin
+                    poly[:, 1] = height - poly[:, 1] + 2 * ymin
+                    polys_new[idx] = poly
+            polygons = polys_keep + polys_new
+            ignore_tags = ignore_tags_keep + ignore_tags_new
+            results['text_polys'] = np.array(polygons)
+            results['ignore_tags'] = ignore_tags
+
+        return results
+
+    def generate_crop_target(self, image, all_polys, pad_h, pad_w):
+        """Generate crop target and make sure not to crop the polygon
+        instances.
+
+        Args:
+            image (ndarray): The image waited to be crop.
+            all_polys (list[list[ndarray]]): All polygons including ground
+                truth polygons and ground truth ignored polygons.
+            pad_h (int): Padding length of height.
+            pad_w (int): Padding length of width.
+        Returns:
+            h_axis (ndarray): Vertical cropping range.
+            w_axis (ndarray): Horizontal cropping range.
+        """
+        h, w, _ = image.shape
+        h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
+        w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
+
+        text_polys = []
+        for polygon in all_polys:
+            rect = cv2.minAreaRect(polygon.astype(np.int32).reshape(-1, 2))
+            box = cv2.boxPoints(rect)
+            box = np.int0(box)
+            text_polys.append([box[0], box[1], box[2], box[3]])
+
+        polys = np.array(text_polys, dtype=np.int32)
+        for poly in polys:
+            poly = np.round(poly, decimals=0).astype(np.int32)
+            minx = np.min(poly[:, 0])
+            maxx = np.max(poly[:, 0])
+            w_array[minx + pad_w:maxx + pad_w] = 1
+            miny = np.min(poly[:, 1])
+            maxy = np.max(poly[:, 1])
+            h_array[miny + pad_h:maxy + pad_h] = 1
+
+        h_axis = np.where(h_array == 0)[0]
+        w_axis = np.where(w_array == 0)[0]
+        return h_axis, w_axis
+
+
+class RandomCropPolyInstances:
+    """Randomly crop images and make sure to contain at least one intact
+    instance."""
+
+    def __init__(self, crop_ratio=5.0 / 8.0, min_side_ratio=0.4, **kwargs):
+        super().__init__()
+        self.crop_ratio = crop_ratio
+        self.min_side_ratio = min_side_ratio
+
+    def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
+
+        assert isinstance(min_len, int)
+        assert len(valid_array) > min_len
+
+        start_array = valid_array.copy()
+        max_start = min(len(start_array) - min_len, max_start)
+        start_array[max_start:] = 0
+        start_array[0] = 1
+        diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
+        region_starts = np.where(diff_array < 0)[0]
+        region_ends = np.where(diff_array > 0)[0]
+        region_ind = np.random.randint(0, len(region_starts))
+        start = np.random.randint(region_starts[region_ind],
+                                  region_ends[region_ind])
+
+        end_array = valid_array.copy()
+        min_end = max(start + min_len, min_end)
+        end_array[:min_end] = 0
+        end_array[-1] = 1
+        diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
+        region_starts = np.where(diff_array < 0)[0]
+        region_ends = np.where(diff_array > 0)[0]
+        region_ind = np.random.randint(0, len(region_starts))
+        end = np.random.randint(region_starts[region_ind],
+                                region_ends[region_ind])
+        return start, end
+
+    def sample_crop_box(self, img_size, results):
+        """Generate crop box and make sure not to crop the polygon instances.
+
+        Args:
+            img_size (tuple(int)): The image size (h, w).
+            results (dict): The results dict.
+        """
+
+        assert isinstance(img_size, tuple)
+        h, w = img_size[:2]
+
+        key_masks = results['text_polys']
+
+        x_valid_array = np.ones(w, dtype=np.int32)
+        y_valid_array = np.ones(h, dtype=np.int32)
+
+        selected_mask = key_masks[np.random.randint(0, len(key_masks))]
+        selected_mask = selected_mask.reshape((-1, 2)).astype(np.int32)
+        max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
+        min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
+        max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
+        min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
+
+        for mask in key_masks:
+            mask = mask.reshape((-1, 2)).astype(np.int32)
+            clip_x = np.clip(mask[:, 0], 0, w - 1)
+            clip_y = np.clip(mask[:, 1], 0, h - 1)
+            min_x, max_x = np.min(clip_x), np.max(clip_x)
+            min_y, max_y = np.min(clip_y), np.max(clip_y)
+
+            x_valid_array[min_x - 2:max_x + 3] = 0
+            y_valid_array[min_y - 2:max_y + 3] = 0
+
+        min_w = int(w * self.min_side_ratio)
+        min_h = int(h * self.min_side_ratio)
+
+        x1, x2 = self.sample_valid_start_end(x_valid_array, min_w, max_x_start,
+                                             min_x_end)
+        y1, y2 = self.sample_valid_start_end(y_valid_array, min_h, max_y_start,
+                                             min_y_end)
+
+        return np.array([x1, y1, x2, y2])
+
+    def crop_img(self, img, bbox):
+        assert img.ndim == 3
+        h, w, _ = img.shape
+        assert 0 <= bbox[1] < bbox[3] <= h
+        assert 0 <= bbox[0] < bbox[2] <= w
+        return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
+
+    def __call__(self, results):
+        image = results['img']
+        polygons = results['text_polys']
+        ignore_tags = results['ignore_tags']
+        if len(polygons) < 1:
+            return results
+
+        if np.random.random_sample() < self.crop_ratio:
+
+            crop_box = self.sample_crop_box(image.shape, results)
+            img = self.crop_img(image, crop_box)
+            results['img'] = img
+            # crop and filter masks
+            x1, y1, x2, y2 = crop_box
+            w = max(x2 - x1, 1)
+            h = max(y2 - y1, 1)
+            polygons[:, :, 0::2] = polygons[:, :, 0::2] - x1
+            polygons[:, :, 1::2] = polygons[:, :, 1::2] - y1
+
+            valid_masks_list = []
+            valid_tags_list = []
+            for ind, polygon in enumerate(polygons):
+                if (polygon[:, ::2] > -4).all() and (
+                        polygon[:, ::2] < w + 4).all() and (
+                        polygon[:, 1::2] > -4).all() and (
+                        polygon[:, 1::2] < h + 4).all():
+                    polygon[:, ::2] = np.clip(polygon[:, ::2], 0, w)
+                    polygon[:, 1::2] = np.clip(polygon[:, 1::2], 0, h)
+                    valid_masks_list.append(polygon)
+                    valid_tags_list.append(ignore_tags[ind])
+
+            results['text_polys'] = np.array(valid_masks_list)
+            results['ignore_tags'] = valid_tags_list
+
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        return repr_str
+
+
+class RandomRotatePolyInstances:
+    def __init__(self,
+                 rotate_ratio=0.5,
+                 max_angle=10,
+                 pad_with_fixed_color=False,
+                 pad_value=(0, 0, 0),
+                 **kwargs):
+        """Randomly rotate images and polygon masks.
+
+        Args:
+            rotate_ratio (float): The ratio of samples to operate rotation.
+            max_angle (int): The maximum rotation angle.
+            pad_with_fixed_color (bool): The flag for whether to pad rotated
+               image with fixed value. If set to False, the rotated image will
+               be padded onto cropped image.
+            pad_value (tuple(int)): The color value for padding rotated image.
+        """
+        self.rotate_ratio = rotate_ratio
+        self.max_angle = max_angle
+        self.pad_with_fixed_color = pad_with_fixed_color
+        self.pad_value = pad_value
+
+    def rotate(self, center, points, theta, center_shift=(0, 0)):
+        # rotate points.
+        (center_x, center_y) = center
+        center_y = -center_y
+        x, y = points[:, ::2], points[:, 1::2]
+        y = -y
+
+        theta = theta / 180 * math.pi
+        cos = math.cos(theta)
+        sin = math.sin(theta)
+
+        x = (x - center_x)
+        y = (y - center_y)
+
+        _x = center_x + x * cos - y * sin + center_shift[0]
+        _y = -(center_y + x * sin + y * cos) + center_shift[1]
+
+        points[:, ::2], points[:, 1::2] = _x, _y
+        return points
+
+    def cal_canvas_size(self, ori_size, degree):
+        assert isinstance(ori_size, tuple)
+        angle = degree * math.pi / 180.0
+        h, w = ori_size[:2]
+
+        cos = math.cos(angle)
+        sin = math.sin(angle)
+        canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
+        canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
+
+        canvas_size = (canvas_h, canvas_w)
+        return canvas_size
+
+    def sample_angle(self, max_angle):
+        angle = np.random.random_sample() * 2 * max_angle - max_angle
+        return angle
+
+    def rotate_img(self, img, angle, canvas_size):
+        h, w = img.shape[:2]
+        rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
+        rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
+        rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
+
+        if self.pad_with_fixed_color:
+            target_img = cv2.warpAffine(
+                img,
+                rotation_matrix, (canvas_size[1], canvas_size[0]),
+                flags=cv2.INTER_NEAREST,
+                borderValue=self.pad_value)
+        else:
+            mask = np.zeros_like(img)
+            (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
+                              np.random.randint(0, w * 7 // 8))
+            img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
+            img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0]))
+
+            mask = cv2.warpAffine(
+                mask,
+                rotation_matrix, (canvas_size[1], canvas_size[0]),
+                borderValue=[1, 1, 1])
+            target_img = cv2.warpAffine(
+                img,
+                rotation_matrix, (canvas_size[1], canvas_size[0]),
+                borderValue=[0, 0, 0])
+            target_img = target_img + img_cut * mask
+
+        return target_img
+
+    def __call__(self, results):
+        if np.random.random_sample() < self.rotate_ratio:
+            image = results['img']
+            polygons = results['text_polys']
+            h, w = image.shape[:2]
+
+            angle = self.sample_angle(self.max_angle)
+            canvas_size = self.cal_canvas_size((h, w), angle)
+            center_shift = (int((canvas_size[1] - w) / 2), int(
+                (canvas_size[0] - h) / 2))
+            image = self.rotate_img(image, angle, canvas_size)
+            results['img'] = image
+            # rotate polygons
+            rotated_masks = []
+            for mask in polygons:
+                rotated_mask = self.rotate((w / 2, h / 2), mask, angle,
+                                           center_shift)
+                rotated_masks.append(rotated_mask)
+            results['text_polys'] = np.array(rotated_masks)
+
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        return repr_str
+
+
+class SquareResizePad:
+    def __init__(self,
+                 target_size,
+                 pad_ratio=0.6,
+                 pad_with_fixed_color=False,
+                 pad_value=(0, 0, 0),
+                 **kwargs):
+        """Resize or pad images to be square shape.
+
+        Args:
+            target_size (int): The target size of square shaped image.
+            pad_with_fixed_color (bool): The flag for whether to pad rotated
+               image with fixed value. If set to False, the rescales image will
+               be padded onto cropped image.
+            pad_value (tuple(int)): The color value for padding rotated image.
+        """
+        assert isinstance(target_size, int)
+        assert isinstance(pad_ratio, float)
+        assert isinstance(pad_with_fixed_color, bool)
+        assert isinstance(pad_value, tuple)
+
+        self.target_size = target_size
+        self.pad_ratio = pad_ratio
+        self.pad_with_fixed_color = pad_with_fixed_color
+        self.pad_value = pad_value
+
+    def resize_img(self, img, keep_ratio=True):
+        h, w, _ = img.shape
+        if keep_ratio:
+            t_h = self.target_size if h >= w else int(h * self.target_size / w)
+            t_w = self.target_size if h <= w else int(w * self.target_size / h)
+        else:
+            t_h = t_w = self.target_size
+        img = cv2.resize(img, (t_w, t_h))
+        return img, (t_h, t_w)
+
+    def square_pad(self, img):
+        h, w = img.shape[:2]
+        if h == w:
+            return img, (0, 0)
+        pad_size = max(h, w)
+        if self.pad_with_fixed_color:
+            expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8)
+            expand_img[:] = self.pad_value
+        else:
+            (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
+                              np.random.randint(0, w * 7 // 8))
+            img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
+            expand_img = cv2.resize(img_cut, (pad_size, pad_size))
+        if h > w:
+            y0, x0 = 0, (h - w) // 2
+        else:
+            y0, x0 = (w - h) // 2, 0
+        expand_img[y0:y0 + h, x0:x0 + w] = img
+        offset = (x0, y0)
+
+        return expand_img, offset
+
+    def square_pad_mask(self, points, offset):
+        x0, y0 = offset
+        pad_points = points.copy()
+        pad_points[::2] = pad_points[::2] + x0
+        pad_points[1::2] = pad_points[1::2] + y0
+        return pad_points
+
+    def __call__(self, results):
+        image = results['img']
+        polygons = results['text_polys']
+        h, w = image.shape[:2]
+
+        if np.random.random_sample() < self.pad_ratio:
+            image, out_size = self.resize_img(image, keep_ratio=True)
+            image, offset = self.square_pad(image)
+        else:
+            image, out_size = self.resize_img(image, keep_ratio=False)
+            offset = (0, 0)
+        results['img'] = image
+        try:
+            polygons[:, :, 0::2] = polygons[:, :, 0::2] * out_size[
+                1] / w + offset[0]
+            polygons[:, :, 1::2] = polygons[:, :, 1::2] * out_size[
+                0] / h + offset[1]
+        except:
+            pass
+        results['text_polys'] = polygons
+
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        return repr_str
+
+
+
+class DetResizeForTest(object):
+    def __init__(self, **kwargs):
+        super(DetResizeForTest, self).__init__()
+        self.resize_type = 0
+        if 'image_shape' in kwargs:
+            self.image_shape = kwargs['image_shape']
+            self.resize_type = 1
+        elif 'limit_side_len' in kwargs:
+            self.limit_side_len = kwargs['limit_side_len']
+            self.limit_type = kwargs.get('limit_type', 'min')
+        elif 'resize_long' in kwargs:
+            self.resize_type = 2
+            self.resize_long = kwargs.get('resize_long', 960)
+        else:
+            self.limit_side_len = 736
+            self.limit_type = 'min'
+
+    def __call__(self, data):
+        img = data['img']
+        src_h, src_w, _ = img.shape
+
+        if self.resize_type == 0:
+            # img, shape = self.resize_image_type0(img)
+            img, [ratio_h, ratio_w] = self.resize_image_type0(img)
+        elif self.resize_type == 2:
+            img, [ratio_h, ratio_w] = self.resize_image_type2(img)
+        else:
+            # img, shape = self.resize_image_type1(img)
+            img, [ratio_h, ratio_w] = self.resize_image_type1(img)
+        data['img'] = img
+        data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
+        return data
+
+    def resize_image_type1(self, img):
+        resize_h, resize_w = self.image_shape
+        ori_h, ori_w = img.shape[:2]  # (h, w, c)
+        ratio_h = float(resize_h) / ori_h
+        ratio_w = float(resize_w) / ori_w
+        img = cv2.resize(img, (int(resize_w), int(resize_h)))
+        # return img, np.array([ori_h, ori_w])
+        return img, [ratio_h, ratio_w]
+
+    def resize_image_type0(self, img):
+        """
+        resize image to a size multiple of 32 which is required by the network
+        args:
+            img(array): array with shape [h, w, c]
+        return(tuple):
+            img, (ratio_h, ratio_w)
+        """
+        limit_side_len = self.limit_side_len
+        h, w, c = img.shape
+
+        # limit the max side
+        if self.limit_type == 'max':
+            if max(h, w) > limit_side_len:
+                if h > w:
+                    ratio = float(limit_side_len) / h
+                else:
+                    ratio = float(limit_side_len) / w
+            else:
+                ratio = 1.
+        elif self.limit_type == 'min':
+            if min(h, w) < limit_side_len:
+                if h < w:
+                    ratio = float(limit_side_len) / h
+                else:
+                    ratio = float(limit_side_len) / w
+            else:
+                ratio = 1.
+        elif self.limit_type == 'resize_long':
+            ratio = float(limit_side_len) / max(h, w)
+        else:
+            raise Exception('not support limit type, image ')
+        resize_h = int(h * ratio)
+        resize_w = int(w * ratio)
+
+        resize_h = max(int(round(resize_h / 32) * 32), 32)
+        resize_w = max(int(round(resize_w / 32) * 32), 32)
+
+        try:
+            if int(resize_w) <= 0 or int(resize_h) <= 0:
+                return None, (None, None)
+            img = cv2.resize(img, (int(resize_w), int(resize_h)))
+        except:
+            print(img.shape, resize_w, resize_h)
+            sys.exit(0)
+        ratio_h = resize_h / float(h)
+        ratio_w = resize_w / float(w)
+        return img, [ratio_h, ratio_w]
+
+    def resize_image_type2(self, img):
+        h, w, _ = img.shape
+
+        resize_w = w
+        resize_h = h
+
+        if resize_h > resize_w:
+            ratio = float(self.resize_long) / resize_h
+        else:
+            ratio = float(self.resize_long) / resize_w
+
+        resize_h = int(resize_h * ratio)
+        resize_w = int(resize_w * ratio)
+
+        max_stride = 128
+        resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+        resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+        img = cv2.resize(img, (int(resize_w), int(resize_h)))
+        ratio_h = resize_h / float(h)
+        ratio_w = resize_w / float(w)
+
+        return img, [ratio_h, ratio_w]

+ 658 - 0
torchocr/datasets/det_modules/FCE_target.py

@@ -0,0 +1,658 @@
+import cv2
+import numpy as np
+from numpy.fft import fft
+from numpy.linalg import norm
+import sys
+
+
+def vector_slope(vec):
+    assert len(vec) == 2
+    return abs(vec[1] / (vec[0] + 1e-8))
+
+
+class FCENetTargets:
+    """Generate the ground truth targets of FCENet: Fourier Contour Embedding
+    for Arbitrary-Shaped Text Detection.
+
+    [https://arxiv.org/abs/2104.10442]
+
+    Args:
+        fourier_degree (int): The maximum Fourier transform degree k.
+        resample_step (float): The step size for resampling the text center
+            line (TCL). It's better not to exceed half of the minimum width.
+        center_region_shrink_ratio (float): The shrink ratio of text center
+            region.
+        level_size_divisors (tuple(int)): The downsample ratio on each level.
+        level_proportion_range (tuple(tuple(int))): The range of text sizes
+            assigned to each level.
+    """
+
+    def __init__(self,
+                 fourier_degree=5,
+                 resample_step=4.0,
+                 center_region_shrink_ratio=0.3,
+                 level_size_divisors=(8, 16, 32),
+                 level_proportion_range=((0, 0.25), (0.2, 0.65), (0.55, 1.0)),
+                 orientation_thr=2.0,
+                 **kwargs):
+
+        super().__init__()
+        assert isinstance(level_size_divisors, tuple)
+        assert isinstance(level_proportion_range, tuple)
+        assert len(level_size_divisors) == len(level_proportion_range)
+        self.fourier_degree = fourier_degree
+        self.resample_step = resample_step
+        self.center_region_shrink_ratio = center_region_shrink_ratio
+        self.level_size_divisors = level_size_divisors
+        self.level_proportion_range = level_proportion_range
+
+        self.orientation_thr = orientation_thr
+
+    def vector_angle(self, vec1, vec2):
+        if vec1.ndim > 1:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
+        if vec2.ndim > 1:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
+        return np.arccos(
+            np.clip(
+                np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
+
+    def resample_line(self, line, n):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [
+            norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
+        ]
+        total_length = sum(length_list)
+        length_cumsum = np.cumsum([0.0] + length_list)
+        delta_length = total_length / (float(n) + 1e-8)
+
+        current_edge_ind = 0
+        resampled_line = [line[0]]
+        for i in range(1, n):
+            current_line_len = i * delta_length
+
+            while current_line_len >= length_cumsum[current_edge_ind + 1]:
+                current_edge_ind += 1
+            current_edge_end_shift = current_line_len - length_cumsum[
+                current_edge_ind]
+            end_shift_ratio = current_edge_end_shift / length_list[
+                current_edge_ind]
+            current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
+                                                      - line[current_edge_ind]
+                                                      ) * end_shift_ratio
+            resampled_line.append(current_point)
+
+        resampled_line.append(line[-1])
+        resampled_line = np.array(resampled_line)
+
+        return resampled_line
+
+    def reorder_poly_edge(self, points):
+        """Get the respective points composing head edge, tail edge, top
+        sideline and bottom sideline.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+
+        Returns:
+            head_edge (ndarray): The two points composing the head edge of text
+                polygon.
+            tail_edge (ndarray): The two points composing the tail edge of text
+                polygon.
+            top_sideline (ndarray): The points composing top curved sideline of
+                text polygon.
+            bot_sideline (ndarray): The points composing bottom curved sideline
+                of text polygon.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+
+        head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
+        head_edge, tail_edge = points[head_inds], points[tail_inds]
+
+        pad_points = np.vstack([points, points])
+        if tail_inds[1] < 1:
+            tail_inds[1] = len(points)
+        sideline1 = pad_points[head_inds[1]:tail_inds[1]]
+        sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
+        sideline_mean_shift = np.mean(
+            sideline1, axis=0) - np.mean(
+            sideline2, axis=0)
+
+        if sideline_mean_shift[1] > 0:
+            top_sideline, bot_sideline = sideline2, sideline1
+        else:
+            top_sideline, bot_sideline = sideline1, sideline2
+
+        return head_edge, tail_edge, top_sideline, bot_sideline
+
+    def find_head_tail(self, points, orientation_thr):
+        """Find the head edge and tail edge of a text polygon.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+            orientation_thr (float): The threshold for distinguishing between
+                head edge and tail edge among the horizontal and vertical edges
+                of a quadrangle.
+
+        Returns:
+            head_inds (list): The indexes of two points composing head edge.
+            tail_inds (list): The indexes of two points composing tail edge.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+        assert isinstance(orientation_thr, float)
+
+        if len(points) > 4:
+            pad_points = np.vstack([points, points[0]])
+            edge_vec = pad_points[1:] - pad_points[:-1]
+
+            theta_sum = []
+            adjacent_vec_theta = []
+            for i, edge_vec1 in enumerate(edge_vec):
+                adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
+                adjacent_edge_vec = edge_vec[adjacent_ind]
+                temp_theta_sum = np.sum(
+                    self.vector_angle(edge_vec1, adjacent_edge_vec))
+                temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
+                                                        adjacent_edge_vec[1])
+                theta_sum.append(temp_theta_sum)
+                adjacent_vec_theta.append(temp_adjacent_theta)
+            theta_sum_score = np.array(theta_sum) / np.pi
+            adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
+            poly_center = np.mean(points, axis=0)
+            edge_dist = np.maximum(
+                norm(
+                    pad_points[1:] - poly_center, axis=-1),
+                norm(
+                    pad_points[:-1] - poly_center, axis=-1))
+            dist_score = edge_dist / np.max(edge_dist)
+            position_score = np.zeros(len(edge_vec))
+            score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
+            score += 0.35 * dist_score
+            if len(points) % 2 == 0:
+                position_score[(len(score) // 2 - 1)] += 1
+                position_score[-1] += 1
+            score += 0.1 * position_score
+            pad_score = np.concatenate([score, score])
+            score_matrix = np.zeros((len(score), len(score) - 3))
+            x = np.arange(len(score) - 3) / float(len(score) - 4)
+            gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
+                (x - 0.5) / 0.5, 2.) / 2)
+            gaussian = gaussian / np.max(gaussian)
+            for i in range(len(score)):
+                score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
+                    score) - 1)] * gaussian * 0.3
+
+            head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
+                                                          score_matrix.shape)
+            tail_start = (head_start + tail_increment + 2) % len(points)
+            head_end = (head_start + 1) % len(points)
+            tail_end = (tail_start + 1) % len(points)
+
+            if head_end > tail_end:
+                head_start, tail_start = tail_start, head_start
+                head_end, tail_end = tail_end, head_end
+            head_inds = [head_start, head_end]
+            tail_inds = [tail_start, tail_end]
+        else:
+            if vector_slope(points[1] - points[0]) + vector_slope(
+                    points[3] - points[2]) < vector_slope(points[
+                                                              2] - points[1]) + vector_slope(points[0] - points[3]):
+                horizontal_edge_inds = [[0, 1], [2, 3]]
+                vertical_edge_inds = [[3, 0], [1, 2]]
+            else:
+                horizontal_edge_inds = [[3, 0], [1, 2]]
+                vertical_edge_inds = [[0, 1], [2, 3]]
+
+            vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
+                vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
+                0]] - points[vertical_edge_inds[1][1]])
+            horizontal_len_sum = norm(points[horizontal_edge_inds[0][
+                0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
+                                                                     horizontal_edge_inds[1][0]] - points[
+                                                                     horizontal_edge_inds[1]
+                                                                     [1]])
+
+            if vertical_len_sum > horizontal_len_sum * orientation_thr:
+                head_inds = horizontal_edge_inds[0]
+                tail_inds = horizontal_edge_inds[1]
+            else:
+                head_inds = vertical_edge_inds[0]
+                tail_inds = vertical_edge_inds[1]
+
+        return head_inds, tail_inds
+
+    def resample_sidelines(self, sideline1, sideline2, resample_step):
+        """Resample two sidelines to be of the same points number according to
+        step size.
+
+        Args:
+            sideline1 (ndarray): The points composing a sideline of a text
+                polygon.
+            sideline2 (ndarray): The points composing another sideline of a
+                text polygon.
+            resample_step (float): The resampled step size.
+
+        Returns:
+            resampled_line1 (ndarray): The resampled line 1.
+            resampled_line2 (ndarray): The resampled line 2.
+        """
+
+        assert sideline1.ndim == sideline2.ndim == 2
+        assert sideline1.shape[1] == sideline2.shape[1] == 2
+        assert sideline1.shape[0] >= 2
+        assert sideline2.shape[0] >= 2
+        assert isinstance(resample_step, float)
+
+        length1 = sum([
+            norm(sideline1[i + 1] - sideline1[i])
+            for i in range(len(sideline1) - 1)
+        ])
+        length2 = sum([
+            norm(sideline2[i + 1] - sideline2[i])
+            for i in range(len(sideline2) - 1)
+        ])
+
+        total_length = (length1 + length2) / 2
+        resample_point_num = max(int(float(total_length) / resample_step), 1)
+
+        resampled_line1 = self.resample_line(sideline1, resample_point_num)
+        resampled_line2 = self.resample_line(sideline2, resample_point_num)
+
+        return resampled_line1, resampled_line2
+
+    def generate_center_region_mask(self, img_size, text_polys):
+        """Generate text center region mask.
+
+        Args:
+            img_size (tuple): The image size of (height, width).
+            text_polys (list[list[ndarray]]): The list of text polygons.
+
+        Returns:
+            center_region_mask (ndarray): The text center region mask.
+        """
+
+        assert isinstance(img_size, tuple)
+        # assert check_argument.is_2dlist(text_polys)
+
+        h, w = img_size
+
+        center_region_mask = np.zeros((h, w), np.uint8)
+
+        center_region_boxes = []
+        for poly in text_polys:
+            # assert len(poly) == 1
+            polygon_points = np.unique(poly.reshape(-1, 2), axis=0)
+            if polygon_points.shape[0] < 4:
+                continue
+            _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
+            resampled_top_line, resampled_bot_line = self.resample_sidelines(
+                top_line, bot_line, self.resample_step)
+            resampled_bot_line = resampled_bot_line[::-1]
+            center_line = (resampled_top_line + resampled_bot_line) / 2
+
+            line_head_shrink_len = norm(resampled_top_line[0] -
+                                        resampled_bot_line[0]) / 4.0
+            line_tail_shrink_len = norm(resampled_top_line[-1] -
+                                        resampled_bot_line[-1]) / 4.0
+            head_shrink_num = int(line_head_shrink_len // self.resample_step)
+            tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
+            if len(center_line) > head_shrink_num + tail_shrink_num + 2:
+                center_line = center_line[head_shrink_num:len(center_line) -
+                                                          tail_shrink_num]
+                resampled_top_line = resampled_top_line[head_shrink_num:len(
+                    resampled_top_line) - tail_shrink_num]
+                resampled_bot_line = resampled_bot_line[head_shrink_num:len(
+                    resampled_bot_line) - tail_shrink_num]
+
+            for i in range(0, len(center_line) - 1):
+                tl = center_line[i] + (resampled_top_line[i] - center_line[i]
+                                       ) * self.center_region_shrink_ratio
+                tr = center_line[i + 1] + (resampled_top_line[i + 1] -
+                                           center_line[i + 1]
+                                           ) * self.center_region_shrink_ratio
+                br = center_line[i + 1] + (resampled_bot_line[i + 1] -
+                                           center_line[i + 1]
+                                           ) * self.center_region_shrink_ratio
+                bl = center_line[i] + (resampled_bot_line[i] - center_line[i]
+                                       ) * self.center_region_shrink_ratio
+                current_center_box = np.vstack([tl, tr, br,
+                                                bl]).astype(np.int32)
+                center_region_boxes.append(current_center_box)
+
+        cv2.fillPoly(center_region_mask, center_region_boxes, 1)
+        return center_region_mask
+
+    def resample_polygon(self, polygon, n=400):
+        """Resample one polygon with n points on its boundary.
+
+        Args:
+            polygon (list[float]): The input polygon.
+            n (int): The number of resampled points.
+        Returns:
+            resampled_polygon (list[float]): The resampled polygon.
+        """
+        length = []
+
+        for i in range(len(polygon)):
+            p1 = polygon[i]
+            if i == len(polygon) - 1:
+                p2 = polygon[0]
+            else:
+                p2 = polygon[i + 1]
+            length.append(((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5)
+
+        total_length = sum(length)
+        n_on_each_line = (np.array(length) / (total_length + 1e-8)) * n
+        n_on_each_line = n_on_each_line.astype(np.int32)
+        new_polygon = []
+
+        for i in range(len(polygon)):
+            num = n_on_each_line[i]
+            p1 = polygon[i]
+            if i == len(polygon) - 1:
+                p2 = polygon[0]
+            else:
+                p2 = polygon[i + 1]
+
+            if num == 0:
+                continue
+
+            dxdy = (p2 - p1) / num
+            for j in range(num):
+                point = p1 + dxdy * j
+                new_polygon.append(point)
+
+        return np.array(new_polygon)
+
+    def normalize_polygon(self, polygon):
+        """Normalize one polygon so that its start point is at right most.
+
+        Args:
+            polygon (list[float]): The origin polygon.
+        Returns:
+            new_polygon (lost[float]): The polygon with start point at right.
+        """
+        try:
+            temp_polygon = polygon - polygon.mean(axis=0)
+            x = np.abs(temp_polygon[:, 0])
+            y = temp_polygon[:, 1]
+            index_x = np.argsort(x)
+            index_y = np.argmin(y[index_x[:8]])
+            index = index_x[index_y]
+            new_polygon = np.concatenate([polygon[index:], polygon[:index]])
+        except:
+            print(polygon.shape)
+        return new_polygon
+
+    def poly2fourier(self, polygon, fourier_degree):
+        """Perform Fourier transformation to generate Fourier coefficients ck
+        from polygon.
+
+        Args:
+            polygon (ndarray): An input polygon.
+            fourier_degree (int): The maximum Fourier degree K.
+        Returns:
+            c (ndarray(complex)): Fourier coefficients.
+        """
+        points = polygon[:, 0] + polygon[:, 1] * 1j
+        c_fft = fft(points) / len(points)
+        c = np.hstack((c_fft[-fourier_degree:], c_fft[:fourier_degree + 1]))
+        return c
+
+    def clockwise(self, c, fourier_degree):
+        """Make sure the polygon reconstructed from Fourier coefficients c in
+        the clockwise direction.
+
+        Args:
+            polygon (list[float]): The origin polygon.
+        Returns:
+            new_polygon (lost[float]): The polygon in clockwise point order.
+        """
+        if np.abs(c[fourier_degree + 1]) > np.abs(c[fourier_degree - 1]):
+            return c
+        elif np.abs(c[fourier_degree + 1]) < np.abs(c[fourier_degree - 1]):
+            return c[::-1]
+        else:
+            if np.abs(c[fourier_degree + 2]) > np.abs(c[fourier_degree - 2]):
+                return c
+            else:
+                return c[::-1]
+
+    def cal_fourier_signature(self, polygon, fourier_degree):
+        """Calculate Fourier signature from input polygon.
+
+        Args:
+              polygon (ndarray): The input polygon.
+              fourier_degree (int): The maximum Fourier degree K.
+        Returns:
+              fourier_signature (ndarray): An array shaped (2k+1, 2) containing
+                  real part and image part of 2k+1 Fourier coefficients.
+        """
+        resampled_polygon = self.resample_polygon(polygon)
+        if len(resampled_polygon) == 0:
+            print('111')
+            return None
+
+        resampled_polygon = self.normalize_polygon(resampled_polygon)
+
+        fourier_coeff = self.poly2fourier(resampled_polygon, fourier_degree)
+        fourier_coeff = self.clockwise(fourier_coeff, fourier_degree)
+
+        real_part = np.real(fourier_coeff).reshape((-1, 1))
+        image_part = np.imag(fourier_coeff).reshape((-1, 1))
+        fourier_signature = np.hstack([real_part, image_part])
+
+        return fourier_signature
+
+    def generate_fourier_maps(self, img_size, text_polys):
+        """Generate Fourier coefficient maps.
+
+        Args:
+            img_size (tuple): The image size of (height, width).
+            text_polys (list[list[ndarray]]): The list of text polygons.
+
+        Returns:
+            fourier_real_map (ndarray): The Fourier coefficient real part maps.
+            fourier_image_map (ndarray): The Fourier coefficient image part
+                maps.
+        """
+
+        assert isinstance(img_size, tuple)
+
+        h, w = img_size
+        k = self.fourier_degree
+        real_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
+        imag_map = np.zeros((k * 2 + 1, h, w), dtype=np.float32)
+
+        for poly in text_polys:
+            mask = np.zeros((h, w), dtype=np.uint8)
+            polygon = np.array(poly).reshape((1, -1, 2))
+            if polygon.shape[0] == 0:
+                print('xx')
+                continue
+            cv2.fillPoly(mask, polygon.astype(np.int32), 1)
+            fourier_coeff = self.cal_fourier_signature(polygon[0], k)
+            if fourier_coeff is None:
+                continue
+            for i in range(-k, k + 1):
+                if i != 0:
+                    real_map[i + k, :, :] = mask * fourier_coeff[i + k, 0] + (
+                            1 - mask) * real_map[i + k, :, :]
+                    imag_map[i + k, :, :] = mask * fourier_coeff[i + k, 1] + (
+                            1 - mask) * imag_map[i + k, :, :]
+                else:
+                    yx = np.argwhere(mask > 0.5)
+                    k_ind = np.ones((len(yx)), dtype=np.int64) * k
+                    y, x = yx[:, 0], yx[:, 1]
+                    real_map[k_ind, y, x] = fourier_coeff[k, 0] - x
+                    imag_map[k_ind, y, x] = fourier_coeff[k, 1] - y
+
+        return real_map, imag_map
+
+    def generate_text_region_mask(self, img_size, text_polys):
+        """Generate text center region mask and geometry attribute maps.
+
+        Args:
+            img_size (tuple): The image size (height, width).
+            text_polys (list[list[ndarray]]): The list of text polygons.
+
+        Returns:
+            text_region_mask (ndarray): The text region mask.
+        """
+
+        assert isinstance(img_size, tuple)
+
+        h, w = img_size
+        text_region_mask = np.zeros((h, w), dtype=np.uint8)
+
+        for poly in text_polys:
+            polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
+            cv2.fillPoly(text_region_mask, polygon, 1)
+
+        return text_region_mask
+
+    def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
+        """Generate effective mask by setting the ineffective regions to 0 and
+        effective regions to 1.
+
+        Args:
+            mask_size (tuple): The mask size.
+            polygons_ignore (list[[ndarray]]: The list of ignored text
+                polygons.
+
+        Returns:
+            mask (ndarray): The effective mask of (height, width).
+        """
+
+        mask = np.ones(mask_size, dtype=np.uint8)
+
+        for poly in polygons_ignore:
+            instance = poly.reshape(-1, 2).astype(np.int32).reshape(1, -1, 2)
+            cv2.fillPoly(mask, instance, 0)
+
+        return mask
+
+    def generate_level_targets(self, img_size, text_polys, ignore_polys):
+        """Generate ground truth target on each level.
+
+        Args:
+            img_size (list[int]): Shape of input image.
+            text_polys (list[list[ndarray]]): A list of ground truth polygons.
+            ignore_polys (list[list[ndarray]]): A list of ignored polygons.
+        Returns:
+            level_maps (list(ndarray)): A list of ground target on each level.
+        """
+        h, w = img_size
+        lv_size_divs = self.level_size_divisors
+        lv_proportion_range = self.level_proportion_range
+        lv_text_polys = [[] for i in range(len(lv_size_divs))]
+        lv_ignore_polys = [[] for i in range(len(lv_size_divs))]
+        level_maps = []
+        for poly in text_polys:
+            polygon = np.array(poly, dtype=np.int).reshape((1, -1, 2))
+            _, _, box_w, box_h = cv2.boundingRect(polygon)
+            proportion = max(box_h, box_w) / (h + 1e-8)
+
+            for ind, proportion_range in enumerate(lv_proportion_range):
+                if proportion_range[0] < proportion < proportion_range[1]:
+                    lv_text_polys[ind].append(poly / lv_size_divs[ind])
+
+        for ignore_poly in ignore_polys:
+            polygon = np.array(ignore_poly, dtype=np.int).reshape((1, -1, 2))
+            _, _, box_w, box_h = cv2.boundingRect(polygon)
+            proportion = max(box_h, box_w) / (h + 1e-8)
+
+            for ind, proportion_range in enumerate(lv_proportion_range):
+                if proportion_range[0] < proportion < proportion_range[1]:
+                    lv_ignore_polys[ind].append(ignore_poly / lv_size_divs[ind])
+
+        for ind, size_divisor in enumerate(lv_size_divs):
+            current_level_maps = []
+            level_img_size = (h // size_divisor, w // size_divisor)
+
+            text_region = self.generate_text_region_mask(
+                level_img_size, lv_text_polys[ind])[None]
+            current_level_maps.append(text_region)
+
+            center_region = self.generate_center_region_mask(
+                level_img_size, lv_text_polys[ind])[None]
+            current_level_maps.append(center_region)
+
+            effective_mask = self.generate_effective_mask(
+                level_img_size, lv_ignore_polys[ind])[None]
+            current_level_maps.append(effective_mask)
+
+            fourier_real_map, fourier_image_maps = self.generate_fourier_maps(
+                level_img_size, lv_text_polys[ind])
+            current_level_maps.append(fourier_real_map)
+            current_level_maps.append(fourier_image_maps)
+
+            level_maps.append(np.concatenate(current_level_maps))
+
+        return level_maps
+
+    def generate_targets(self, results):
+        """Generate the ground truth targets for FCENet.
+
+        Args:
+            results (dict): The input result dictionary.
+
+        Returns:
+            results (dict): The output result dictionary.
+        """
+
+        assert isinstance(results, dict)
+        image = results['img']
+        polygons = results['text_polys']
+        ignore_tags = results['ignore_tags']
+        h, w, _ = image.shape
+
+        polygon_masks = []
+        polygon_masks_ignore = []
+        for tag, polygon in zip(ignore_tags, polygons):
+            if tag is True:
+                polygon_masks_ignore.append(polygon)
+            else:
+                polygon_masks.append(polygon)
+
+        level_maps = self.generate_level_targets((h, w), polygon_masks,
+                                                 polygon_masks_ignore)
+
+        mapping = {
+            'p3_maps': level_maps[0],
+            'p4_maps': level_maps[1],
+            'p5_maps': level_maps[2]
+        }
+        for key, value in mapping.items():
+            results[key] = value
+
+        return results
+
+    def __call__(self, results):
+        results = self.generate_targets(results)
+        return results

+ 14 - 0
torchocr/datasets/det_modules/__init__.py

@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2019/12/4 10:53
+# @Author  : zhoujun
+
+"""
+此模块包含了检测算法的图片预处理组件,如随机裁剪,随机缩放,随机旋转,label制作等
+"""
+from .iaa_augment import IaaAugment
+from .augment import *
+from .random_crop_data import EastRandomCropData,PSERandomCrop
+from .make_border_map import MakeBorderMap
+from .make_shrink_map import MakeShrinkMap
+from .FCE_aug import *
+from .FCE_target import *

+ 385 - 0
torchocr/datasets/det_modules/augment.py

@@ -0,0 +1,385 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2019/8/23 21:52
+# @Author  : zhoujun
+
+import math
+import numbers
+import random
+
+import cv2
+import numpy as np
+from skimage.util import random_noise
+
+__all__ = ['RandomNoise', 'RandomResize', 'RandomScale', 'ResizeShortSize', 'RandomRotateImgBox', 'HorizontalFlip',
+           'VerticallFlip', 'ResizeFixedSize', 'ResizeLongSize']
+
+
+class RandomNoise:
+    def __init__(self, random_rate):
+        self.random_rate = random_rate
+
+    def __call__(self, data: dict):
+        """
+        对图片加噪声
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        if random.random() > self.random_rate:
+            return data
+        data['img'] = (random_noise(data['img'], mode='gaussian', clip=True) * 255).astype(data['img'].dtype)
+        return data
+
+
+class RandomScale:
+    def __init__(self, scales, random_rate):
+        """
+        :param scales: 尺度
+        :param ramdon_rate: 随机系数
+        :return:
+        """
+        self.random_rate = random_rate
+        self.scales = scales
+
+    def __call__(self, data: dict) -> dict:
+        """
+        从scales中随机选择一个尺度,对图片和文本框进行缩放
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        if random.random() > self.random_rate:
+            return data
+        im = data['img']
+        text_polys = data['text_polys']
+
+        tmp_text_polys = text_polys.copy()
+        rd_scale = float(np.random.choice(self.scales))
+        im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
+        tmp_text_polys *= rd_scale
+
+        data['img'] = im
+        data['text_polys'] = tmp_text_polys
+        return data
+
+
+class RandomRotateImgBox:
+    def __init__(self, degrees, random_rate, same_size=False):
+        """
+        :param degrees: 角度,可以是一个数值或者list
+        :param ramdon_rate: 随机系数
+        :param same_size: 是否保持和原图一样大
+        :return:
+        """
+        if isinstance(degrees, numbers.Number):
+            if degrees < 0:
+                raise ValueError("If degrees is a single number, it must be positive.")
+            degrees = (-degrees, degrees)
+        elif isinstance(degrees, list) or isinstance(degrees, tuple) or isinstance(degrees, np.ndarray):
+            if len(degrees) != 2:
+                raise ValueError("If degrees is a sequence, it must be of len 2.")
+            degrees = degrees
+        else:
+            raise Exception('degrees must in Number or list or tuple or np.ndarray')
+        self.degrees = degrees
+        self.same_size = same_size
+        self.random_rate = random_rate
+
+    def __call__(self, data: dict) -> dict:
+        """
+        从scales中随机选择一个尺度,对图片和文本框进行缩放
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        if random.random() > self.random_rate:
+            return data
+        im = data['img']
+        text_polys = data['text_polys']
+
+        # ---------------------- 旋转图像 ----------------------
+        w = im.shape[1]
+        h = im.shape[0]
+        angle = np.random.uniform(self.degrees[0], self.degrees[1])
+
+        if self.same_size:
+            nw = w
+            nh = h
+        else:
+            # 角度变弧度
+            rangle = np.deg2rad(angle)
+            # 计算旋转之后图像的w, h
+            nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
+            nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
+        # 构造仿射矩阵
+        rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1)
+        # 计算原图中心点到新图中心点的偏移量
+        rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
+        # 更新仿射矩阵
+        rot_mat[0, 2] += rot_move[0]
+        rot_mat[1, 2] += rot_move[1]
+        # 仿射变换
+        rot_img = cv2.warpAffine(im, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
+
+        # ---------------------- 矫正bbox坐标 ----------------------
+        # rot_mat是最终的旋转矩阵
+        # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下
+        rot_text_polys = list()
+        for bbox in text_polys:
+            point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
+            point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
+            point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
+            point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
+            rot_text_polys.append([point1, point2, point3, point4])
+        data['img'] = rot_img
+        data['text_polys'] = np.array(rot_text_polys)
+        return data
+
+
+class RandomResize:
+    def __init__(self, size, random_rate, keep_ratio=False):
+        """
+        :param input_size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
+        :param ramdon_rate: 随机系数
+        :param keep_ratio: 是否保持长宽比
+        :return:
+        """
+        if isinstance(size, numbers.Number):
+            if size < 0:
+                raise ValueError("If input_size is a single number, it must be positive.")
+            size = (size, size)
+        elif isinstance(size, list) or isinstance(size, tuple) or isinstance(size, np.ndarray):
+            if len(size) != 2:
+                raise ValueError("If input_size is a sequence, it must be of len 2.")
+            size = (size[0], size[1])
+        else:
+            raise Exception('input_size must in Number or list or tuple or np.ndarray')
+        self.size = size
+        self.keep_ratio = keep_ratio
+        self.random_rate = random_rate
+
+    def __call__(self, data: dict) -> dict:
+        """
+        从scales中随机选择一个尺度,对图片和文本框进行缩放
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        if random.random() > self.random_rate:
+            return data
+        im = data['img']
+        text_polys = data['text_polys']
+
+        if self.keep_ratio:
+            # 将图片短边pad到和长边一样
+            h, w, c = im.shape
+            max_h = max(h, self.size[0])
+            max_w = max(w, self.size[1])
+            im_padded = np.zeros((max_h, max_w, c), dtype=np.uint8)
+            im_padded[:h, :w] = im.copy()
+            im = im_padded
+        text_polys = text_polys.astype(np.float32)
+        h, w, _ = im.shape
+        im = cv2.resize(im, self.size)
+        w_scale = self.size[0] / float(w)
+        h_scale = self.size[1] / float(h)
+        text_polys[:, :, 0] *= w_scale
+        text_polys[:, :, 1] *= h_scale
+
+        data['img'] = im
+        data['text_polys'] = text_polys
+        return data
+
+
+def resize_image(img, short_size):
+    height, width, _ = img.shape
+    if height < width:
+        new_height = short_size
+        new_width = new_height / height * width
+    else:
+        new_width = short_size
+        new_height = new_width / width * height
+    new_height = int(round(new_height / 32) * 32)
+    new_width = int(round(new_width / 32) * 32)
+    resized_img = cv2.resize(img, (new_width, new_height))
+    return resized_img, (new_width / width, new_height / height)
+
+
+class ResizeShortSize:
+    def __init__(self, short_size, resize_text_polys=True):
+        """
+        :param size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
+        :return:
+        """
+        self.short_size = short_size
+        self.resize_text_polys = resize_text_polys
+
+    def __call__(self, data: dict) -> dict:
+        """
+        对图片和文本框进行缩放
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        im = data['img']
+        text_polys = data['text_polys']
+        h, w, _ = im.shape
+        if min(h, w) < self.short_size:
+            if h < w:
+                ratio = float(self.short_size) / h
+            else:
+                ratio = float(self.short_size) / w
+        else:
+            ratio = 1.
+        resize_h = int(h * ratio)
+        resize_w = int(w * ratio)
+        resize_h = max(int(round(resize_h / 32) * 32), 32)
+        resize_w = max(int(round(resize_w / 32) * 32), 32)
+        img = cv2.resize(im, (int(resize_w), int(resize_h)))
+        if self.resize_text_polys:
+            text_polys[:, 0] *= ratio
+            text_polys[:, 1] *= ratio
+        data['img'] = img
+        data['text_polys'] = text_polys
+        return data
+
+
+class HorizontalFlip:
+    def __init__(self, random_rate):
+        """
+
+        :param random_rate: 随机系数
+        """
+        self.random_rate = random_rate
+
+    def __call__(self, data: dict) -> dict:
+        """
+        从scales中随机选择一个尺度,对图片和文本框进行缩放
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        if random.random() > self.random_rate:
+            return data
+        im = data['img']
+        text_polys = data['text_polys']
+
+        flip_text_polys = text_polys.copy()
+        flip_im = cv2.flip(im, 1)
+        h, w, _ = flip_im.shape
+        flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0]
+
+        data['img'] = flip_im
+        data['text_polys'] = flip_text_polys
+        return data
+
+
+class VerticallFlip:
+    def __init__(self, random_rate):
+        """
+
+        :param random_rate: 随机系数
+        """
+        self.random_rate = random_rate
+
+    def __call__(self, data: dict) -> dict:
+        """
+        从scales中随机选择一个尺度,对图片和文本框进行缩放
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        if random.random() > self.random_rate:
+            return data
+        im = data['img']
+        text_polys = data['text_polys']
+
+        flip_text_polys = text_polys.copy()
+        flip_im = cv2.flip(im, 0)
+        h, w, _ = flip_im.shape
+        flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1]
+        data['img'] = flip_im
+        data['text_polys'] = flip_text_polys
+        return data
+
+
+class ResizeFixedSize:
+    def __init__(self, short_size, resize_text_polys=True):
+        """
+        :param size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
+        :return:
+        """
+        self.short_size = short_size
+        self.resize_text_polys = resize_text_polys
+
+    def __call__(self, data: dict) -> dict:
+        """
+        对图片和文本框进行缩放
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        im = data['img']
+        text_polys = data['text_polys']
+        h, w, _ = im.shape
+        if min(h, w) < self.short_size:
+            if h < w:
+                ratio = float(self.short_size) / h
+            else:
+                ratio = float(self.short_size) / w
+        else:
+            ratio = 1.
+        resize_h = int(h * ratio)
+        resize_w = int(w * ratio)
+        resize_h = max(int(round(resize_h / 32) * 32), 32)
+        resize_w = max(int(round(resize_w / 32) * 32), 32)
+
+        try:
+            if int(resize_w) <= 0 or int(resize_h) <= 0:
+                return None, (None, None)
+            img = cv2.resize(im, (int(resize_w), int(resize_h)))
+        except:
+            print(img.shape, resize_w, resize_h)
+            import sys
+            sys.exit(0)
+
+        ratio_h = resize_h / float(h)
+        ratio_w = resize_w / float(w)
+        if self.resize_text_polys:
+            text_polys[:, 0] *= ratio_h
+            text_polys[:, 1] *= ratio_w
+
+        data['img'] = img
+        data['text_polys'] = text_polys
+        return data
+
+
+class ResizeLongSize:
+    def __init__(self, long_size, resize_text_polys=True):  # short_size,
+        """
+        :param size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
+        :return:
+        """
+        # self.short_size = short_size
+        self.long_size = long_size
+        self.resize_text_polys = resize_text_polys
+
+    def __call__(self, data: dict) -> dict:
+        """
+        对图片和文本框进行缩放
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        im = data['img']
+        text_polys = data['text_polys']
+        h, w, _ = im.shape
+        if max(h, w) > self.long_size:
+            if h < w:
+                ratio = float(self.long_size) / w
+            else:
+                ratio = float(self.long_size) / h
+        else:
+            ratio = 1.
+        resize_h = int(h * ratio)
+        resize_w = int(w * ratio)
+        resize_h = max(int(round(resize_h / 32) * 32), 32)
+        resize_w = max(int(round(resize_w / 32) * 32), 32)
+        img = cv2.resize(im, (int(resize_w), int(resize_h)))
+        if self.resize_text_polys:
+            text_polys[:, 0] *= ratio
+            text_polys[:, 1] *= ratio
+        data['img'] = img
+        data['text_polys'] = text_polys
+        return data

+ 68 - 0
torchocr/datasets/det_modules/iaa_augment.py

@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2019/12/4 18:06
+# @Author  : zhoujun
+import numpy as np
+import imgaug
+import imgaug.augmenters as iaa
+
+
+class AugmenterBuilder(object):
+    def __init__(self):
+        pass
+
+    def build(self, args, root=True):
+        if args is None or len(args) == 0:
+            return None
+        elif isinstance(args, list):
+            if root:
+                sequence = [self.build(value, root=False) for value in args]
+                return iaa.Sequential(sequence)
+            else:
+                return getattr(iaa, args[0])(*[self.to_tuple_if_list(a) for a in args[1:]])
+        elif isinstance(args, dict):
+            cls = getattr(iaa, args['type'])
+            return cls(**{k: self.to_tuple_if_list(v) for k, v in args['args'].items()})
+        else:
+            raise RuntimeError('unknown augmenter arg: ' + str(args))
+
+    def to_tuple_if_list(self, obj):
+        if isinstance(obj, list):
+            return tuple(obj)
+        return obj
+
+
+class IaaAugment():
+    def __init__(self, augmenter_args=None):
+        if augmenter_args is None:
+            augmenter_args = [{'type': 'Fliplr', 'args': {'p': 0.5}},
+                              {'type': 'Affine', 'args': {'rotate': [-10, 10]}},
+                              {'type': 'Resize', 'args': {'size': [0.5, 3]}}]
+        self.augmenter = AugmenterBuilder().build(augmenter_args)
+
+    def __call__(self, data):
+        image = data['img']
+        shape = image.shape
+
+        if self.augmenter:
+            aug = self.augmenter.to_deterministic()
+            data['img'] = aug.augment_image(image)
+            data = self.may_augment_annotation(aug, data, shape)
+        return data
+
+    def may_augment_annotation(self, aug, data, shape):
+        if aug is None:
+            return data
+
+        line_polys = []
+        for poly in data['text_polys']:
+            new_poly = self.may_augment_poly(aug, shape, poly)
+            line_polys.append(np.array(new_poly))
+        data['text_polys'] = line_polys
+        return data
+
+    def may_augment_poly(self, aug, img_shape, poly):
+        keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
+        keypoints = aug.augment_keypoints(
+            [imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints
+        poly = [(p.x, p.y) for p in keypoints]
+        return poly

+ 122 - 0
torchocr/datasets/det_modules/make_border_map.py

@@ -0,0 +1,122 @@
+import cv2
+import numpy as np
+
+np.seterr(divide='ignore', invalid='ignore')
+import pyclipper
+from shapely.geometry import Polygon
+
+__all__ = ['MakeBorderMap']
+
+
+class MakeBorderMap():
+    def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7):
+        self.shrink_ratio = shrink_ratio
+        self.thresh_min = thresh_min
+        self.thresh_max = thresh_max
+
+    def __call__(self, data: dict) -> dict:
+        """
+        从scales中随机选择一个尺度,对图片和文本框进行缩放
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        im = data['img']
+        text_polys = data['text_polys']
+        ignore_tags = data['ignore_tags']
+
+        canvas = np.zeros(im.shape[:2], dtype=np.float32)
+        mask = np.zeros(im.shape[:2], dtype=np.float32)
+
+        for i in range(len(text_polys)):
+            if ignore_tags[i]:
+                continue
+            self.draw_border_map(text_polys[i], canvas, mask=mask)
+        canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
+
+        data['threshold_map'] = canvas
+        data['threshold_mask'] = mask
+        return data
+
+    def draw_border_map(self, polygon, canvas, mask):
+        polygon = np.array(polygon)
+        assert polygon.ndim == 2
+        assert polygon.shape[1] == 2
+
+        polygon_shape = Polygon(polygon)
+        if polygon_shape.area <= 0:
+            return
+        distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
+        subject = [tuple(l) for l in polygon]
+        padding = pyclipper.PyclipperOffset()
+        padding.AddPath(subject, pyclipper.JT_ROUND,
+                        pyclipper.ET_CLOSEDPOLYGON)
+        try:
+            padded_polygon = np.array(padding.Execute(distance)[0])
+        except:
+            return
+        cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
+
+        xmin = padded_polygon[:, 0].min()
+        xmax = padded_polygon[:, 0].max()
+        ymin = padded_polygon[:, 1].min()
+        ymax = padded_polygon[:, 1].max()
+        width = xmax - xmin + 1
+        height = ymax - ymin + 1
+
+        polygon[:, 0] = polygon[:, 0] - xmin
+        polygon[:, 1] = polygon[:, 1] - ymin
+
+        xs = np.broadcast_to(
+            np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
+        ys = np.broadcast_to(
+            np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))
+
+        distance_map = np.zeros(
+            (polygon.shape[0], height, width), dtype=np.float32)
+        for i in range(polygon.shape[0]):
+            j = (i + 1) % polygon.shape[0]
+            absolute_distance = self.distance(xs, ys, polygon[i], polygon[j])
+            distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
+        distance_map = distance_map.min(axis=0)
+
+        xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
+        xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
+        ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
+        ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
+        rever_distance = 1 - distance_map[
+                             ymin_valid - ymin:ymax_valid - ymax + height,
+                             xmin_valid - xmin:xmax_valid - xmax + width]
+        rever_distance[np.isnan(rever_distance)] = 0.99
+        canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
+            rever_distance,
+            canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
+
+    def distance(self, xs, ys, point_1, point_2):
+        '''
+        compute the distance from point to a line
+        ys: coordinates in the first axis
+        xs: coordinates in the second axis
+        point_1, point_2: (x, y), the end of the line
+        '''
+        height, width = xs.shape[:2]
+        square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
+        square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
+        square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])
+
+        cosin = (square_distance - square_distance_1 - square_distance_2) / (2 * np.sqrt(square_distance_1 * square_distance_2))
+        square_sin = 1 - np.square(cosin)
+        square_sin = np.nan_to_num(square_sin)
+
+        result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / square_distance)
+        result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin < 0]
+        # self.extend_line(point_1, point_2, result)
+        return result
+
+    def extend_line(self, point_1, point_2, result):
+        ex_point_1 = (int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))),
+                      int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio))))
+        cv2.line(result, tuple(ex_point_1), tuple(point_1), 4096.0, 1, lineType=cv2.LINE_AA, shift=0)
+        ex_point_2 = (int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))),
+                      int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio))))
+        cv2.line(result, tuple(ex_point_2), tuple(point_2), 4096.0, 1, lineType=cv2.LINE_AA, shift=0)
+        return ex_point_1, ex_point_2

+ 132 - 0
torchocr/datasets/det_modules/make_shrink_map.py

@@ -0,0 +1,132 @@
+import numpy as np
+import cv2
+
+__all__ = ['MakeShrinkMap']
+
+
+def shrink_polygon_py(polygon, shrink_ratio):
+    """
+    对框进行缩放,返回去的比例为1/shrink_ratio 即可
+    """
+    cx = polygon[:, 0].mean()
+    cy = polygon[:, 1].mean()
+    polygon[:, 0] = cx + (polygon[:, 0] - cx) * shrink_ratio
+    polygon[:, 1] = cy + (polygon[:, 1] - cy) * shrink_ratio
+    return polygon
+
+
+def shrink_polygon_pyclipper(polygon, shrink_ratio):
+    from shapely.geometry import Polygon
+    import pyclipper
+    polygon_shape = Polygon(polygon)
+    subject = [tuple(l) for l in polygon]
+    padding = pyclipper.PyclipperOffset()
+    padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+    shrinked = []
+    possible_ratios = np.arange(shrink_ratio, 1, shrink_ratio)
+    np.append(possible_ratios, 1)
+    for ratio in possible_ratios:
+        distance = polygon_shape.area * (
+                1 - np.power(ratio, 2)) / polygon_shape.length
+        shrinked = padding.Execute(-distance)
+        if len(shrinked) == 1:
+            break
+    return shrinked
+
+
+class MakeShrinkMap():
+    r'''
+    Making binary mask from detection data with ICDAR format.
+    Typically following the process of class `MakeICDARData`.
+    '''
+
+    def __init__(self, min_text_size=8, shrink_ratio=0.4, shrink_type='pyclipper'):
+        shrink_func_dict = {'py': shrink_polygon_py, 'pyclipper': shrink_polygon_pyclipper}
+        self.shrink_func = shrink_func_dict[shrink_type]
+        self.min_text_size = min_text_size
+        self.shrink_ratio = shrink_ratio
+
+    def __call__(self, data: dict) -> dict:
+        """
+        从scales中随机选择一个尺度,对图片和文本框进行缩放
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        image = data['img']
+        text_polys = data['text_polys']
+        ignore_tags = data['ignore_tags']
+
+        h, w = image.shape[:2]
+        text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
+        gt = np.zeros((h, w), dtype=np.float32)
+        mask = np.ones((h, w), dtype=np.float32)
+        for i in range(len(text_polys)):
+            polygon = text_polys[i]
+            height = max(polygon[:, 1]) - min(polygon[:, 1])
+            width = max(polygon[:, 0]) - min(polygon[:, 0])
+            if ignore_tags[i] or min(height, width) < self.min_text_size:
+                cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
+                ignore_tags[i] = True
+            else:
+                shrinked = self.shrink_func(polygon, self.shrink_ratio)
+                if shrinked == []:
+                    cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
+                    ignore_tags[i] = True
+                    continue
+                for each_shirnk in shrinked:
+                    shirnk = np.array(each_shirnk).reshape(-1, 2)
+                    cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
+
+        data['shrink_map'] = gt
+        data['shrink_mask'] = mask
+        data['ignore_tags'] = ignore_tags
+        return data
+
+    def validate_polygons(self, polygons, ignore_tags, h, w):
+        '''
+        polygons (numpy.array, required): of shape (num_instances, num_points, 2)
+        '''
+        if len(polygons) == 0:
+            return polygons, ignore_tags
+        assert len(polygons) == len(ignore_tags)
+        for polygon in polygons:
+            polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
+            polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
+
+        for i in range(len(polygons)):
+            area = self.polygon_area(polygons[i])
+            if abs(area) < 1:
+                ignore_tags[i] = True
+            if area > 0:
+                polygons[i] = polygons[i][::-1, :]
+        return polygons, ignore_tags
+
+    def polygon_area(self, polygon):
+        polygon = polygon.reshape(-1, 2)
+        edge = 0
+        for i in range(polygon.shape[0]):
+            next_index = (i + 1) % polygon.shape[0]
+            edge += (polygon[next_index, 0] - polygon[i, 0]) * (
+                    polygon[next_index, 1] + polygon[i, 1])
+
+        return edge / 2.
+
+
+if __name__ == '__main__':
+    from shapely.geometry import Polygon
+    import pyclipper
+
+    polygon = np.array([[0, 0], [100, 10], [100, 100], [10, 90]])
+    a = shrink_polygon_py(polygon, 0.4)
+    print(a)
+    print(shrink_polygon_py(a, 1 / 0.4))
+    b = shrink_polygon_pyclipper(polygon, 0.4)
+    print(b)
+    poly = Polygon(b)
+    distance = poly.area * 1.5 / poly.length
+    offset = pyclipper.PyclipperOffset()
+    offset.AddPath(b, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+    expanded = np.array(offset.Execute(distance))
+    bounding_box = cv2.minAreaRect(expanded)
+    points = cv2.boxPoints(bounding_box)
+    print(points)

+ 200 - 0
torchocr/datasets/det_modules/random_crop_data.py

@@ -0,0 +1,200 @@
+import random
+
+import cv2
+import numpy as np
+
+__all__ = ['EastRandomCropData', 'PSERandomCrop']
+
+
+# random crop algorithm similar to https://github.com/argman/EAST
+class EastRandomCropData():
+    def __init__(self, size=(640, 640), max_tries=50, min_crop_side_ratio=0.1, require_original_image=False, keep_ratio=True):
+        self.size = size
+        self.max_tries = max_tries
+        self.min_crop_side_ratio = min_crop_side_ratio
+        self.require_original_image = require_original_image
+        self.keep_ratio = keep_ratio
+
+    def __call__(self, data: dict) -> dict:
+        """
+        从scales中随机选择一个尺度,对图片和文本框进行缩放
+        :param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
+        :return:
+        """
+        im = data['img']
+        text_polys = data['text_polys']
+        ignore_tags = data['ignore_tags']
+        texts = data['texts']
+        all_care_polys = [text_polys[i] for i, tag in enumerate(ignore_tags) if not tag]
+        # 计算crop区域
+        crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys)
+        # crop 图片 保持比例填充
+        scale_w = self.size[0] / crop_w
+        scale_h = self.size[1] / crop_h
+        scale = min(scale_w, scale_h)
+        h = int(crop_h * scale)
+        w = int(crop_w * scale)
+        if self.keep_ratio:
+            if len(im.shape) == 3:
+                padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), im.dtype)
+            else:
+                padimg = np.zeros((self.size[1], self.size[0]), im.dtype)
+            padimg[:h, :w] = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
+            img = padimg
+        else:
+            img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], tuple(self.size))
+        # crop 文本框
+        text_polys_crop = []
+        ignore_tags_crop = []
+        texts_crop = []
+        try:
+            for poly, text, tag in zip(text_polys, texts, ignore_tags):
+                poly = ((np.array(poly) - (crop_x, crop_y)) * scale).astype('float32')
+                if not self.is_poly_outside_rect(poly, 0, 0, w, h):
+                    text_polys_crop.append(poly)
+                    ignore_tags_crop.append(tag)
+                    texts_crop.append(text)
+            data['img'] = img
+            data['text_polys'] = text_polys_crop
+            data['ignore_tags'] = ignore_tags_crop
+            data['texts'] = texts_crop
+        except:
+            a = 1
+        return data
+
+    def is_poly_in_rect(self, poly, x, y, w, h):
+        poly = np.array(poly)
+        if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
+            return False
+        if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
+            return False
+        return True
+
+    def is_poly_outside_rect(self, poly, x, y, w, h):
+        poly = np.array(poly)
+        if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
+            return True
+        if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
+            return True
+        return False
+
+    def split_regions(self, axis):
+        regions = []
+        min_axis = 0
+        for i in range(1, axis.shape[0]):
+            if axis[i] != axis[i - 1] + 1:
+                region = axis[min_axis:i]
+                min_axis = i
+                regions.append(region)
+        return regions
+
+    def random_select(self, axis, max_size):
+        xx = np.random.choice(axis, size=2)
+        xmin = np.min(xx)
+        xmax = np.max(xx)
+        xmin = np.clip(xmin, 0, max_size - 1)
+        xmax = np.clip(xmax, 0, max_size - 1)
+        return xmin, xmax
+
+    def region_wise_random_select(self, regions, max_size):
+        selected_index = list(np.random.choice(len(regions), 2))
+        selected_values = []
+        for index in selected_index:
+            axis = regions[index]
+            xx = int(np.random.choice(axis, size=1))
+            selected_values.append(xx)
+        xmin = min(selected_values)
+        xmax = max(selected_values)
+        return xmin, xmax
+
+    def crop_area(self, im, text_polys):
+        h, w = im.shape[:2]
+        h_array = np.zeros(h, dtype=np.int32)
+        w_array = np.zeros(w, dtype=np.int32)
+        for points in text_polys:
+            points = np.round(points, decimals=0).astype(np.int32)
+            minx = np.min(points[:, 0])
+            maxx = np.max(points[:, 0])
+            w_array[minx:maxx] = 1
+            miny = np.min(points[:, 1])
+            maxy = np.max(points[:, 1])
+            h_array[miny:maxy] = 1
+        # ensure the cropped area not across a text
+        h_axis = np.where(h_array == 0)[0]
+        w_axis = np.where(w_array == 0)[0]
+
+        if len(h_axis) == 0 or len(w_axis) == 0:
+            return 0, 0, w, h
+
+        h_regions = self.split_regions(h_axis)
+        w_regions = self.split_regions(w_axis)
+
+        for i in range(self.max_tries):
+            if len(w_regions) > 1:
+                xmin, xmax = self.region_wise_random_select(w_regions, w)
+            else:
+                xmin, xmax = self.random_select(w_axis, w)
+            if len(h_regions) > 1:
+                ymin, ymax = self.region_wise_random_select(h_regions, h)
+            else:
+                ymin, ymax = self.random_select(h_axis, h)
+
+            if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h:
+                # area too small
+                continue
+            num_poly_in_rect = 0
+            for poly in text_polys:
+                if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin):
+                    num_poly_in_rect += 1
+                    break
+
+            if num_poly_in_rect > 0:
+                return xmin, ymin, xmax - xmin, ymax - ymin
+
+        return 0, 0, w, h
+
+
+class PSERandomCrop():
+    def __init__(self, size):
+        self.size = size
+
+    def __call__(self, data):
+        imgs = data['imgs']
+
+        h, w = imgs[0].shape[0:2]
+        th, tw = self.size
+        if w == tw and h == th:
+            return imgs
+
+        # label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制
+        if np.max(imgs[2]) > 0 and random.random() > 3 / 8:
+            # 文本实例的左上角点
+            tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size
+            tl[tl < 0] = 0
+            # 文本实例的右下角点
+            br = np.max(np.where(imgs[2] > 0), axis=1) - self.size
+            br[br < 0] = 0
+            # 保证选到右下角点时,有足够的距离进行crop
+            br[0] = min(br[0], h - th)
+            br[1] = min(br[1], w - tw)
+
+            for _ in range(50000):
+                i = random.randint(tl[0], br[0])
+                j = random.randint(tl[1], br[1])
+                # 保证shrink_label_map有文本
+                if imgs[1][i:i + th, j:j + tw].sum() <= 0:
+                    continue
+                else:
+                    break
+        else:
+            i = random.randint(0, h - th)
+            j = random.randint(0, w - tw)
+
+        # return i, j, th, tw
+        for idx in range(len(imgs)):
+            if len(imgs[idx].shape) == 3:
+                imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
+            else:
+                imgs[idx] = imgs[idx][i:i + th, j:j + tw]
+        data['imgs'] = imgs
+        return data

+ 84 - 0
torchocr/datasets/icdar15/ICDAR15CropSave.py

@@ -0,0 +1,84 @@
+'''
+@Author: Jeffery Sheng (Zhenfei Sheng)
+@Time:   2020/5/21 18:34
+@File:   ICDAR15CropSave.py
+'''
+
+import os
+import cv2
+from glob import glob
+from tqdm import tqdm
+
+
+class icdar2015CropSave:
+    def __init__(self, img_dir :str, gt_dir :str, save_data_dir :str,
+                 train_val_split_ratio: float or None=0.1):
+        self.save_id = 1
+        self.img_dir = os.path.abspath(img_dir)
+        self.gt_dir = os.path.abspath(gt_dir)
+        if not os.path.exists(save_data_dir):
+            os.mkdir(save_data_dir)
+        self.save_data_dir = save_data_dir
+        self.train_val_split_ratio = train_val_split_ratio
+
+    def crop_save(self) -> None:
+        all_img_paths = glob(os.path.join(self.img_dir, '*.jpg'))
+        all_gt_paths = glob(os.path.join(self.gt_dir, '*.txt'))
+        # check length
+        assert len(all_img_paths) == len(all_gt_paths)
+        # create lists to store text-line
+        text_lines = list()
+        # start to crop and save
+        for img_path in tqdm(all_img_paths):
+            img = cv2.imread(img_path)
+            gt_path = os.path.join(self.gt_dir, 'gt_' + os.path.basename(img_path).replace('.jpg', '.txt'))
+            with open(gt_path, 'r', encoding='utf-8-sig') as file:
+                lines = file.readlines()
+            for line in lines:
+                line = line.strip().split(',')
+                # get points
+                x1, y1, x2, y2, x3, y3, x4, y4 = list(map(int, line[: 8]))
+                # get transcript
+                trans = line[8]
+                if trans in {'', '*', '###'}:
+                    continue
+                # check & make dir
+                save_img_dir = os.path.join(self.save_data_dir, 'images')
+                if not os.path.exists(save_img_dir):
+                    os.mkdir(save_img_dir)
+                # build save img path
+                save_img_path = os.path.join(save_img_dir, f'textbox_{self.save_id}.jpg')
+                # check if rectangle
+                if len({x1, y1, x2, y2, x3, y3, x4, y4}) == 4:
+                    # save rectangle
+                    cv2.imwrite(save_img_path, img[y1: y4, x1: x2])
+                # if polygon, save minimize circumscribed rectangle
+                else:
+                    x_min, x_max = min((x1, x2, x3, x4)), max((x1, x2, x3, x4))
+                    y_min, y_max = min((y1, y2, y3, y4)), max((y1, y2, y3, y4))
+                    cv2.imwrite(save_img_path, img[y_min: y_max, x_min: x_max])
+                # save to text-line
+                text_lines.append(f'textbox_{self.save_id}.jpg\t{trans}\n')
+                # save_id self increase
+                self.save_id += 1
+        if self.train_val_split_ratio:
+            train = text_lines[: int(round((1-self.train_val_split_ratio)*self.save_id))]
+            val = text_lines[int(round((1-self.train_val_split_ratio)*self.save_id)): ]
+            # save text-line file
+            with open(os.path.join(self.save_data_dir, 'train.txt'), 'w') as save_file:
+                save_file.writelines(train)
+            with open(os.path.join(self.save_data_dir, 'val.txt'), 'w') as save_file:
+                save_file.writelines(val)
+            print(f'{self.save_id-1} text-box images and 2 text-line file are saved.')
+        else:
+            # save text-line file
+            with open(os.path.join(self.save_data_dir, 'train.txt'), 'w') as save_file:
+                save_file.writelines(text_lines)
+            print(f'{self.save_id-1} text-box images and 1 text-line file are saved.')
+
+
+if __name__ == '__main__':
+    img_dir = '/data/disk7/private/szf/Datasets/ICDAR2015/train'
+    gt_dir = '/data/disk7/private/szf/Datasets/ICDAR2015/train_local_trans'
+    save_data_dir = '/data/disk7/private/szf/Datasets/ICDAR2015/data'
+    icdar2015CropSave(img_dir, gt_dir, save_data_dir).crop_save()

+ 5 - 0
torchocr/datasets/icdar15/__init__.py

@@ -0,0 +1,5 @@
+'''
+@Author: Jeffery Sheng (Zhenfei Sheng)
+@Time:   2020/5/26 17:42
+@File:   __init__.py.py
+'''

+ 49 - 0
torchocr/datasets/icdar15/convert_icdar2015_rec.py

@@ -0,0 +1,49 @@
+import os
+import cv2
+import numpy as np
+
+if __name__ == '__main__':
+    icdar2015_directory = '/data/OCR/ICDAR2015'
+    target_directory = '/data/OCR/ICDAR2015/converted_data'
+    target_image_directory = os.path.join(target_directory, 'image')
+    train_image_directory = os.path.join(icdar2015_directory, 'ch4_training_images')
+    train_gt_directory = os.path.join(icdar2015_directory, 'ch4_training_localization_transcription_gt')
+    test_image_directory = os.path.join(icdar2015_directory, 'ch4_test_images')
+    test_gt_directory = os.path.join(icdar2015_directory, 'Challenge4_Test_Task4_GT')
+    os.makedirs(target_directory, exist_ok=True)
+    os.makedirs(target_image_directory, exist_ok=True)
+    for m_name, m_image_directory, m_gt_directory in zip(['train', 'eval'],
+                                                         [train_image_directory, test_image_directory],
+                                                         [train_gt_directory, test_gt_directory]):
+        m_index = 0
+        with open(os.path.join(target_directory, m_name + '.txt'), mode='w', encoding='utf-8') as to_write:
+            for m_image_file in os.listdir(m_image_directory):
+                m_gt_file = os.path.join(m_gt_directory, 'gt_' + os.path.splitext(m_image_file)[0] + '.txt')
+                m_img = cv2.imread(os.path.join(m_image_directory, m_image_file))
+                with open(m_gt_file, mode='r', encoding='utf-8') as to_read:
+                    # 识别阶段只考虑每行中非###的字段
+                    for m_line in to_read:
+                        m_line = m_line.strip('\ufeff\n')
+                        if not m_line.endswith('###'):
+                            # 前八个为从左上角开始的四个点的坐标,这里是四个点的多边形,可能是矩形罢了,用逗号进行了间隔
+                            coordinates_and_transcript = m_line.split(',')
+                            # 保留字符串中唯一的一个空格,去除多个空格
+                            transcript = ' '.join(''.join(coordinates_and_transcript[8:]).split())
+                            if len(transcript) == 0:
+                                continue
+                            np_coordinates = np.array([int(_) for _ in coordinates_and_transcript[:8]]).reshape((-1, 2))
+                            min_x, min_y = np.min(np_coordinates, axis=0)
+                            max_x, max_y = np.max(np_coordinates, axis=0)
+                            m_width = max_x - min_x + 1
+                            m_height = max_y - min_y + 1
+                            m_target_roi = np.zeros((m_height, m_width, m_img.shape[2]), dtype=np.uint8)
+                            m_region = np.array([np_coordinates - [min_x, min_y]], dtype=np.int32)
+                            m_target_roi = cv2.fillPoly(m_target_roi,
+                                                        m_region,
+                                                        (255,) * m_img.shape[2])
+                            m_target_roi = cv2.bitwise_and(m_img[min_y:max_y + 1, min_x:max_x + 1, ...], m_target_roi)
+                            target_image_name = f'{m_name}_{m_index}.jpg'
+                            cv2.imwrite(os.path.join(target_image_directory, target_image_name), m_target_roi)
+                            m_index += 1
+                            to_write.write(f'{target_image_name}\t{transcript}\n')
+                    to_write.flush()

+ 104 - 0
torchocr/datasets/训练用数据集汇总.md

@@ -0,0 +1,104 @@
+# 训练用数据集汇总
+
+[TOC]
+
+## 随时会更新的百度网盘
+
+链接:https://pan.baidu.com/s/1Ed1xrviL3xsuXahVqnqycg  密码:ob01
+
+公开数据集汇总,随意下。
+
+## 常见数据
+
+### Chinese Text in the Wild(CTW)
+
+https://share.weiyun.com/50hF1Cc
+
+该数据集包含32285张图像,1018402个中文字符(来自于腾讯街景), 包含平面文本,凸起文本,城市文本,农村文本,低亮度文本,远处文本,部分遮挡文本。图像大小2048*2048,数据集大小为31GB。以(8:1:1)的比例将数据集分为训练集(25887张图像,812872个汉字),测试集(3269张图像,103519个汉字),验证集(3129张图像,103519个汉字)。
+
+### Reading Chinese Text in the Wild(RCTW-17)
+
+https://rctw.vlrlab.net/dataset/
+
+有12000张图片,包括用手机拍的街景、海报、菜单、室内场景以及手机截图等。
+
+### ICPR MWI 2018 挑战赛
+
+https://tianchi.aliyun.com/competition/entrance/231686/information
+
+大赛提供20000张图像作为数据集,其中50%作为训练集,50%作为测试集。主要由合成图像,产品描述,网络广告构成。该数据集数据量充分,中英文混合,涵盖数十种字体,字体大小不一,多种版式,背景复杂。文件大小为3.2GB。
+
+### SVHN
+
+http://ufldl.stanford.edu/housenumbers/
+
+训练集数据包括3W张照片,验证集数据包括1W张照片,每张照片包括颜色图像和对应的编码类别和具体位置;为了保证比赛的公平性,测试集A包括4W张照片,测试集B包括4W张照片。
+
+### 中文场景文字识别技术创新大赛
+
+https://aistudio.baidu.com/aistudio/competition/detail/8
+
+共29万张图片,其中21万张图片为训练集,8万张为测试集。所有图像经过一些预处理,将文字区域利用仿射变化,等比映射为一张高为48像素的图片。
+
+### Total-Text
+
+http://www.cs-chan.com/source/ICDAR2017/totaltext.zip
+
+该数据集共1555张图像,11459文本行,包含水平文本,倾斜文本,弯曲文本。文件大小441MB。大部分为英文文本,少量中文文本。训练集:1255张 测试集:300
+
+### Google FSNS(谷歌街景文本数据集)
+
+http://rrc.cvc.uab.es/?ch=6&com=downloads
+
+该数据集是从谷歌法国街景图片上获得的一百多万张街道名字标志,每一张包含同一街道标志牌的不同视角,图像大小为600*150,训练集1044868张,验证集16150张,测试集20404张。
+
+### **COCO-TEXT**
+
+https://vision.cornell.edu/se3/coco-text-2/
+
+该数据集,包括63686幅图像,173589个文本实例,包括手写版和打印版,清晰版和非清晰版。文件大小12.58GB,训练集:43686张,测试集:10000张,验证集:10000张
+
+### **Synthetic Data for Text Localisation**
+
+http://www.robots.ox.ac.uk/~vgg/data/scenetext/
+
+在复杂背景下人工合成的自然场景文本数据。包含858750张图像,共7266866个单词实例,28971487个字符,文件大小为41GB。该合成算法,不需要人工标注就可知道文字的label信息和位置信息,可得到大量自然场景文本标注数据。
+
+### **Synthetic Word Dataset**
+
+http://www.robots.ox.ac.uk/~vgg/data/text/
+
+合成文本识别数据集,包含9百万张图像,涵盖了9万个英语单词。文件大小为10GB
+
+### IIIT 5K-Words 2012
+
+http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html
+
+两千张训练,三千张验证,大小写不区分的crop好的图像
+
+### KAIST Scene_Text Database 2010
+
+http://www.iapr-tc11.org/mediawiki/index.php/KAIST_Scene_Text_Database
+
+3000张室内和室外场景文本,包括韩语、英语、数字。
+
+### ICDAR大礼包
+
+https://rrc.cvc.uab.es/
+
+## 其他语系数据
+
+### 阿语和英语混合的PPT中的文本
+
+https://gitlab.com/rex-yue-wu/ISI-PPT-Dataset
+
+有10692张图片,大约超过10W行文本。
+
+
+
+## 数据合成相关
+
+常用中文的词组:https://github.com/qingyujean/chinese_words_lib
+
+快递单数据:https://aistudio.baidu.com/aistudio/datasetdetail/16246
+

+ 110 - 0
torchocr/deprecated/FeaturePyramidNetwork.py

@@ -0,0 +1,110 @@
+# -*- coding:utf-8 -*-
+# @author :adolf
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+"""
+out_channels=96时,和现有的fpn相比,这个fpn精度差不多,但是模型尺寸会大500k
+"""
+class FeaturePyramidNetwork(nn.Module):
+
+    def __init__(self, in_channels, out_channels=256):
+        super(FeaturePyramidNetwork, self).__init__()
+        self.inner_blocks = nn.ModuleList()
+        self.layer_blocks = nn.ModuleList()
+        self.out_channels = out_channels
+        for in_channels in in_channels:
+            if in_channels == 0:
+                raise ValueError("in_channels=0 is currently not supported")
+            inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
+            layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1)
+            self.inner_blocks.append(inner_block_module)
+            self.layer_blocks.append(layer_block_module)
+
+        # initialize parameters now to avoid modifying the initialization of top_blocks
+        for m in self.children():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_uniform_(m.weight, a=1)
+                nn.init.constant_(m.bias, 0)
+
+    def get_result_from_inner_blocks(self, x, idx):
+        num_blocks = 0
+        for m in self.inner_blocks:
+            num_blocks += 1
+        if idx < 0:
+            idx += num_blocks
+        i = 0
+        out = x
+        for module in self.inner_blocks:
+            if i == idx:
+                out = module(x)
+            i += 1
+        return out
+
+    def get_result_from_layer_blocks(self, x, idx):
+        num_blocks = 0
+        for m in self.layer_blocks:
+            num_blocks += 1
+        if idx < 0:
+            idx += num_blocks
+        i = 0
+        out = x
+        for module in self.layer_blocks:
+            if i == idx:
+                out = module(x)
+            i += 1
+        return out
+
+    def forward(self, x):
+        # unpack OrderedDict into two lists for easier handling
+        # names = list(x.keys())
+        # x = list(x.values())
+
+        last_inner = self.get_result_from_inner_blocks(x[-1], -1)
+        results = []
+        results.append(self.get_result_from_layer_blocks(last_inner, -1))
+
+        for idx in range(len(x) - 2, -1, -1):
+            inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
+            feat_shape = inner_lateral.shape[-2:]
+            inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
+            last_inner = inner_lateral + inner_top_down
+            results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
+
+        # make it back an OrderedDict
+        # out = OrderedDict([(k, v) for k, v in zip(names, results)])
+        out = results[0]
+
+        return out
+
+
+class ExtraFPNBlock(nn.Module):
+    def forward(self, results, x, names):
+        pass
+
+
+class LastLevelMaxPool(ExtraFPNBlock):
+    def forward(self, x, y, names):
+        names.append("pool")
+        x.append(F.max_pool2d(x[-1], 1, 2, 0))
+        return x, names
+
+
+class LastLevelP6P7(ExtraFPNBlock):
+    def __init__(self, in_channels, out_channels):
+        super(LastLevelP6P7, self).__init__()
+        self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
+        self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
+        for module in [self.p6, self.p7]:
+            nn.init.kaiming_uniform_(module.weight, a=1)
+            nn.init.constant_(module.bias, 0)
+        self.use_P5 = in_channels == out_channels
+
+    def forward(self, p, c, names):
+        p5, c5 = p[-1], c[-1]
+        x = p5 if self.use_P5 else c5
+        p6 = self.p6(x)
+        p7 = self.p7(F.relu(p6))
+        p.extend([p6, p7])
+        names.extend(["p6", "p7"])
+        return p, names

+ 3 - 0
torchocr/deprecated/__init__.py

@@ -0,0 +1,3 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/7/3 9:12
+# @Author  : zhoujun

+ 82 - 0
torchocr/metrics/DetMetric.py

@@ -0,0 +1,82 @@
+import numpy as np
+
+from torchocr.metrics.iou_utils import DetectionIoUEvaluator
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+        return self
+
+
+class DetMetric():
+    def __init__(self, is_output_polygon=False):
+        self.is_output_polygon = is_output_polygon
+        self.evaluator = DetectionIoUEvaluator(is_output_polygon=is_output_polygon)
+
+    def __call__(self, batch, output, box_thresh=0.6):
+        '''
+        batch: (image, polygons, ignore_tags
+        batch: a dict produced by dataloaders.
+            image: tensor of shape (N, C, H, W).
+            polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
+            ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
+            shape: the original shape of images.
+            filename: the original filenames of images.
+        output: (polygons, ...)
+        '''
+        results = []
+        gt_polyons_batch = batch['text_polys']
+        ignore_tags_batch = batch['ignore_tags']
+        pred_polygons_batch = np.array(output[0])
+        pred_scores_batch = np.array(output[1])
+        for polygons, pred_polygons, pred_scores, ignore_tags in zip(gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch):
+            gt = [dict(points=np.int64(polygons[i]), ignore=ignore_tags[i]) for i in range(len(polygons))]
+            if self.is_output_polygon:
+                pred = [dict(points=pred_polygons[i]) for i in range(len(pred_polygons))]
+            else:
+                pred = []
+                # print(pred_polygons.shape)
+                for i in range(pred_polygons.shape[0]):
+                    if pred_scores[i] >= box_thresh:
+                        # print(pred_polygons[i,:,:].tolist())
+                        pred.append(dict(points=pred_polygons[i, :, :].astype(np.int)))
+                # pred = [dict(points=pred_polygons[i,:,:].tolist()) if pred_scores[i] >= box_thresh for i in range(pred_polygons.shape[0])]
+            results.append(self.evaluator.evaluate_image(gt, pred))
+        return results
+
+    def gather_measure(self, raw_metrics):
+        raw_metrics = [image_metrics
+                       for batch_metrics in raw_metrics
+                       for image_metrics in batch_metrics]
+
+        result = self.evaluator.combine_results(raw_metrics)
+
+        precision = AverageMeter()
+        recall = AverageMeter()
+        fmeasure = AverageMeter()
+
+        precision.update(result['precision'], n=len(raw_metrics))
+        recall.update(result['recall'], n=len(raw_metrics))
+        fmeasure_score = 2 * precision.val * recall.val / (precision.val + recall.val + 1e-8)
+        fmeasure.update(fmeasure_score)
+
+        return {
+            'precision': precision,
+            'recall': recall,
+            'fmeasure': fmeasure
+        }

+ 28 - 0
torchocr/metrics/RecMetric.py

@@ -0,0 +1,28 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/6/15 14:07
+# @Author  : zhoujun
+import Levenshtein
+
+
+class RecMetric:
+    def __init__(self, converter):
+        """
+        文本识别相关指标计算类
+
+        :param converter: 用于label转换的转换器
+        """
+        self.converter = converter
+
+    def __call__(self, predictions, labels):
+        n_correct = 0
+        norm_edit_dis = 0.0
+        predictions = predictions.softmax(dim=2).detach().cpu().numpy()
+        preds_str = self.converter.decode(predictions)
+        show_str = []
+        for (pred, pred_conf), target in zip(preds_str, labels):
+            norm_edit_dis += Levenshtein.distance(pred, target) / max(len(pred), len(target))
+            show_str.append(f'{pred} -> {target}')
+            print(f'{pred} -> {target}')
+            if pred == target:
+                n_correct += 1
+        return {'n_correct': n_correct, 'norm_edit_dis': norm_edit_dis, 'show_str': show_str}

+ 18 - 0
torchocr/metrics/__init__.py

@@ -0,0 +1,18 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/6/15 14:07
+# @Author  : zhoujun
+import copy
+from .RecMetric import RecMetric
+from .DetMetric import DetMetric
+from .distill_metric import DistillationMetric
+
+
+def build_metric(config):
+    support_dict = ["DistillationMetric"]
+
+    config = copy.deepcopy(config)
+    module_name = config.pop("name")
+    assert module_name in support_dict, Exception(
+        "metric only support {}".format(support_dict))
+    module_class = eval(module_name)(**config)
+    return module_class

+ 42 - 0
torchocr/metrics/distill_metric.py

@@ -0,0 +1,42 @@
+import importlib
+from .DetMetric import  DetMetric
+
+class DistillationMetric(object):
+    def __init__(self, key=None, base_metric_name=None, main_indicator=None, **kwargs):
+        self.main_indicator = main_indicator
+        self.key = key
+        self.main_indicator = main_indicator
+        self.base_metric_name = base_metric_name
+        self.kwargs = kwargs
+        self.metrics = None
+        self.out = dict()
+
+    def _init_metric(self, preds):
+        self.metrics = dict()
+        mod = importlib.import_module(__name__)
+        for key in preds:
+            self.metrics[key] = getattr(mod, self.base_metric_name)(**self.kwargs)
+
+    def __call__(self,batch, preds, **kwargs):
+        assert isinstance(preds, dict), f'preds should be dict,not {type(preds)}'
+        if self.metrics is None:
+            self._init_metric(preds)
+
+        for key in preds:
+            self.out.setdefault(key, []).append(self.metrics[key].__call__( batch,preds[key], **kwargs))
+
+    def get_metric(self):
+        output = dict()
+        for key, val in self.out.items():
+            metric = self.metrics[key].gather_measure(val)
+            if key == self.key:
+                output.update(metric)
+            else:
+                for sub_key in metric:
+                    output['{}_{}'.format(key, sub_key)] = metric[sub_key]
+        self.out.clear()
+        return output
+
+    def reset(self):
+        for key in self.metrics:
+            self.metrics[key].reset()

+ 256 - 0
torchocr/metrics/iou_utils.py

@@ -0,0 +1,256 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from collections import namedtuple
+import numpy as np
+from shapely.geometry import Polygon
+import cv2
+
+
+def iou_rotate(box_a, box_b, method='union'):
+    rect_a = cv2.minAreaRect(box_a)
+    rect_b = cv2.minAreaRect(box_b)
+    r1 = cv2.rotatedRectangleIntersection(rect_a, rect_b)
+    if r1[0] == 0:
+        return 0
+    else:
+        inter_area = cv2.contourArea(r1[1])
+        area_a = cv2.contourArea(box_a)
+        area_b = cv2.contourArea(box_b)
+        union_area = area_a + area_b - inter_area
+        if union_area == 0 or inter_area == 0:
+            return 0
+        if method == 'union':
+            iou = inter_area / union_area
+        elif method == 'intersection':
+            iou = inter_area / min(area_a, area_b)
+        else:
+            raise NotImplementedError
+        return iou
+
+
+class DetectionIoUEvaluator(object):
+    def __init__(self, is_output_polygon=False, iou_constraint=0.5, area_precision_constraint=0.5):
+        self.is_output_polygon = is_output_polygon
+        self.iou_constraint = iou_constraint
+        self.area_precision_constraint = area_precision_constraint
+
+    def evaluate_image(self, gt, pred):
+
+        def get_union(pD, pG):
+            return Polygon(pD).union(Polygon(pG)).area
+
+        def get_intersection_over_union(pD, pG):
+            return get_intersection(pD, pG) / get_union(pD, pG)
+
+        def get_intersection(pD, pG):
+            return Polygon(pD).intersection(Polygon(pG)).area
+
+        def compute_ap(confList, matchList, numGtCare):
+            correct = 0
+            AP = 0
+            if len(confList) > 0:
+                confList = np.array(confList)
+                matchList = np.array(matchList)
+                sorted_ind = np.argsort(-confList)
+                confList = confList[sorted_ind]
+                matchList = matchList[sorted_ind]
+                for n in range(len(confList)):
+                    match = matchList[n]
+                    if match:
+                        correct += 1
+                        AP += float(correct) / (n + 1)
+
+                if numGtCare > 0:
+                    AP /= numGtCare
+
+            return AP
+
+        perSampleMetrics = {}
+
+        matchedSum = 0
+
+        Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
+
+        numGlobalCareGt = 0
+        numGlobalCareDet = 0
+
+        arrGlobalConfidences = []
+        arrGlobalMatches = []
+
+        recall = 0
+        precision = 0
+        hmean = 0
+
+        detMatched = 0
+
+        iouMat = np.empty([1, 1])
+
+        gtPols = []
+        detPols = []
+
+        gtPolPoints = []
+        detPolPoints = []
+
+        # Array of Ground Truth Polygons' keys marked as don't Care
+        gtDontCarePolsNum = []
+        # Array of Detected Polygons' matched with a don't Care GT
+        detDontCarePolsNum = []
+
+        pairs = []
+        detMatchedNums = []
+
+        arrSampleConfidences = []
+        arrSampleMatch = []
+
+        evaluationLog = ""
+
+        for n in range(len(gt)):
+            points = gt[n]['points']
+            # transcription = gt[n]['text']
+            dontCare = gt[n]['ignore']
+
+            if not Polygon(points).is_valid or not Polygon(points).is_simple:
+                continue
+
+            gtPol = points
+            gtPols.append(gtPol)
+            gtPolPoints.append(points)
+            if dontCare:
+                gtDontCarePolsNum.append(len(gtPols) - 1)
+
+        evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(
+            gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n")
+
+        for n in range(len(pred)):
+            points = pred[n]['points']
+            if not Polygon(points).is_valid or not Polygon(points).is_simple:
+                continue
+
+            detPol = points
+            detPols.append(detPol)
+            detPolPoints.append(points)
+            if len(gtDontCarePolsNum) > 0:
+                for dontCarePol in gtDontCarePolsNum:
+                    dontCarePol = gtPols[dontCarePol]
+                    intersected_area = get_intersection(dontCarePol, detPol)
+                    pdDimensions = Polygon(detPol).area
+                    precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
+                    if (precision > self.area_precision_constraint):
+                        detDontCarePolsNum.append(len(detPols) - 1)
+                        break
+
+        evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(
+            detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n")
+
+        if len(gtPols) > 0 and len(detPols) > 0:
+            # Calculate IoU and precision matrixs
+            outputShape = [len(gtPols), len(detPols)]
+            iouMat = np.empty(outputShape)
+            gtRectMat = np.zeros(len(gtPols), np.int8)
+            detRectMat = np.zeros(len(detPols), np.int8)
+            # if self.is_output_polygon:
+            for gtNum in range(len(gtPols)):
+                for detNum in range(len(detPols)):
+                    pG = gtPols[gtNum]
+                    pD = detPols[detNum]
+                    iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
+            # else:
+            #     # gtPols = np.float32(gtPols)
+            #     # detPols = np.float32(detPols)
+            #     for gtNum in range(len(gtPols)):
+            #         for detNum in range(len(detPols)):
+            #             pG = np.float32(gtPols[gtNum])
+            #             pD = np.float32(detPols[detNum])
+            #             iouMat[gtNum, detNum] = iou_rotate(pD, pG)
+            for gtNum in range(len(gtPols)):
+                for detNum in range(len(detPols)):
+                    if gtRectMat[gtNum] == 0 and detRectMat[
+                        detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
+                        if iouMat[gtNum, detNum] > self.iou_constraint:
+                            gtRectMat[gtNum] = 1
+                            detRectMat[detNum] = 1
+                            detMatched += 1
+                            pairs.append({'gt': gtNum, 'det': detNum})
+                            detMatchedNums.append(detNum)
+                            evaluationLog += "Match GT #" + \
+                                             str(gtNum) + " with Det #" + str(detNum) + "\n"
+
+        numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
+        numDetCare = (len(detPols) - len(detDontCarePolsNum))
+        if numGtCare == 0:
+            recall = float(1)
+            precision = float(0) if numDetCare > 0 else float(1)
+        else:
+            recall = float(detMatched) / numGtCare
+            precision = 0 if numDetCare == 0 else float(
+                detMatched) / numDetCare
+
+        hmean = 0 if (precision + recall) == 0 else 2.0 * \
+                                                    precision * recall / (precision + recall)
+
+        matchedSum += detMatched
+        numGlobalCareGt += numGtCare
+        numGlobalCareDet += numDetCare
+
+        perSampleMetrics = {
+            'precision': precision,
+            'recall': recall,
+            'hmean': hmean,
+            'pairs': pairs,
+            'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
+            'gtPolPoints': gtPolPoints,
+            'detPolPoints': detPolPoints,
+            'gtCare': numGtCare,
+            'detCare': numDetCare,
+            'gtDontCare': gtDontCarePolsNum,
+            'detDontCare': detDontCarePolsNum,
+            'detMatched': detMatched,
+            'evaluationLog': evaluationLog
+        }
+
+        return perSampleMetrics
+
+    def combine_results(self, results):
+        numGlobalCareGt = 0
+        numGlobalCareDet = 0
+        matchedSum = 0
+        for result in results:
+            numGlobalCareGt += result['gtCare']
+            numGlobalCareDet += result['detCare']
+            matchedSum += result['detMatched']
+
+        methodRecall = 0 if numGlobalCareGt == 0 else float(
+            matchedSum) / numGlobalCareGt
+        methodPrecision = 0 if numGlobalCareDet == 0 else float(
+            matchedSum) / numGlobalCareDet
+        methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
+                                                                    methodRecall * methodPrecision / (
+                                                                            methodRecall + methodPrecision)
+
+        methodMetrics = {'precision': methodPrecision,
+                         'recall': methodRecall, 'hmean': methodHmean}
+
+        return methodMetrics
+
+
+if __name__ == '__main__':
+    evaluator = DetectionIoUEvaluator()
+    preds = [[{
+        'points': [(0.1, 0.1), (0.5, 0), (0.5, 1), (0, 1)],
+        'text': 1234,
+        'ignore': False,
+    }, {
+        'points': [(0.5, 0.1), (1, 0), (1, 1), (0.5, 1)],
+        'text': 5678,
+        'ignore': False,
+    }]]
+    gts = [[{
+        'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
+        'text': 123,
+        'ignore': False,
+    }]]
+    results = []
+    for gt, pred in zip(gts, preds):
+        results.append(evaluator.evaluate_image(gt, pred))
+    metrics = evaluator.combine_results(results)
+    print(metrics)

+ 303 - 0
torchocr/networks/CommonModules.py

@@ -0,0 +1,303 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+from collections import OrderedDict
+
+class HSwish(nn.Module):
+    def forward(self, x):
+        out = x * F.relu6(x + 3, inplace=True) / 6
+        return out
+
+
+class HardSigmoid(nn.Module):
+    def __init__(self, type):
+        super().__init__()
+        self.type = type
+
+    def forward(self, x):
+        if self.type == 'paddle':
+            x = (1.2 * x).add_(3.).clamp_(0., 6.).div_(6.)
+        else:
+            x = F.relu6(x + 3, inplace=True) / 6
+            F.hardsigmoid()
+        return x
+
+
+class HSigmoid(nn.Module):
+    def forward(self, x):
+        x = (1.2 * x).add_(3.).clamp_(0., 6.).div_(6.)
+        return x
+
+
+class ConvBNACT(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, act=None):
+        super().__init__()
+        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
+                              stride=stride, padding=padding, groups=groups,
+                              bias=False)
+        self.bn = nn.BatchNorm2d(out_channels)
+        if act == 'relu':
+            self.act = nn.ReLU()
+        elif act == 'hard_swish':
+            self.act = HSwish()
+        elif act is None:
+            self.act = None
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        if self.act is not None:
+            x = self.act(x)
+        return x
+
+
+class SEBlock(nn.Module):
+    def __init__(self, in_channels, ratio=4):
+        super().__init__()
+        num_mid_filter = in_channels // ratio
+        self.pool = nn.AdaptiveAvgPool2d(1)
+        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=num_mid_filter, kernel_size=1, bias=True)
+        self.relu1 = nn.ReLU(inplace=True)
+        self.conv2 = nn.Conv2d(in_channels=num_mid_filter, kernel_size=1, out_channels=in_channels, bias=True)
+        # self.relu2 = HardSigmoid(hsigmoid_type)
+        self.relu2 = HardSigmoid(type = 'paddle')
+
+    def forward(self, x):
+        attn = self.pool(x)
+        attn = self.conv1(attn)
+        attn = self.relu1(attn)
+        attn = self.conv2(attn)
+        attn = self.relu2(attn)
+        return x * attn
+
+
+def global_avg_pool(x: torch.Tensor) -> torch.Tensor:
+    N, C, H, W = x.shape
+    y = x.view([N, C, H * W]).contiguous()
+    y = y.sum(2)
+    y = torch.unsqueeze(y, 2)
+    y = torch.unsqueeze(y, 3)
+    y = y / (H * W)
+    return y
+
+
+def global_max_pool(x: torch.Tensor) -> torch.Tensor:
+    N, C, H, W = x.shape
+    y = x.view([N, C, H * W]).contiguous()
+    y = torch.max(y, 2).values
+    y = torch.unsqueeze(y, 2)
+    y = torch.unsqueeze(y, 3)
+    return y
+
+
+class ChannelAttention(nn.Module):
+    def __init__(self, channels, ratio=16):
+        super(ChannelAttention, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+        self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
+        self.fc = nn.Sequential(nn.Conv2d(channels, channels // ratio, 1, bias=False),
+                                nn.ReLU(),
+                                nn.Conv2d(channels // ratio, channels, 1, bias=False), )
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x):
+        y1 = self.avg_pool(x)
+        y1 = self.fc(y1)
+        y2 = self.max_pool(x)
+        y2 = self.fc(y2)
+        y = self.sigmoid(y1 + y2)
+        return y
+
+
+class SpatialAttention(nn.Module):
+    def __init__(self, kernel_size=7):
+        super(SpatialAttention, self).__init__()
+
+        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
+        padding = 3 if kernel_size == 7 else 1
+
+        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x):
+        avg_out = torch.mean(x, dim=1, keepdim=True)
+        max_out, _ = torch.max(x, dim=1, keepdim=True)
+        x = torch.cat([avg_out, max_out], dim=1)
+        x = self.conv1(x)
+        return self.sigmoid(x)
+
+
+class CBAM(nn.Module):
+    def __init__(self, in_channels, ratio=16):
+        super(CBAM, self).__init__()
+        self.cam = ChannelAttention(in_channels, ratio)
+        self.sam = SpatialAttention()
+
+    def forward(self, x):
+        x = x * self.cam(x)
+        x = x * self.sam(x)
+        return x
+
+
+class eca_layer(nn.Module):
+    def __init__(self, channel, k_size=3):
+        super(eca_layer, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x):
+        # feature descriptor on the global spatial information
+        y = self.avg_pool(x)
+
+        # Two different branches of ECA module
+        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
+
+        # Multi-scale information fusion
+        y = self.sigmoid(y)
+
+        return x * y.expand_as(x)
+
+
+
+class ScaleChannelAttention(nn.Module):
+    def __init__(self, in_planes, out_planes, num_features, init_weight=True):
+        super(ScaleChannelAttention, self).__init__()
+        self.avgpool = nn.AdaptiveAvgPool2d(1)
+        print(self.avgpool)
+        self.fc1 = nn.Conv2d(in_planes, out_planes, 1, bias=False)
+        self.bn = nn.BatchNorm2d(out_planes)
+        self.fc2 = nn.Conv2d(out_planes, num_features, 1, bias=False)
+        if init_weight:
+            self._initialize_weights()
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            if isinstance(m ,nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x):
+        global_x = self.avgpool(x)
+        global_x = self.fc1(global_x)
+        global_x = F.relu(self.bn(global_x))
+        global_x = self.fc2(global_x)
+        global_x = F.softmax(global_x, 1)
+        return global_x
+
+class ScaleChannelSpatialAttention(nn.Module):
+    def __init__(self, in_planes, out_planes, num_features, init_weight=True):
+        super(ScaleChannelSpatialAttention, self).__init__()
+        self.channel_wise = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            nn.Conv2d(in_planes, out_planes , 1, bias=False),
+            # nn.BatchNorm2d(out_planes),
+            nn.ReLU(),
+            nn.Conv2d(out_planes, in_planes, 1, bias=False)
+        )
+        self.spatial_wise = nn.Sequential(
+            #Nx1xHxW
+            nn.Conv2d(1, 1, 3, bias=False, padding=1),
+            nn.ReLU(),
+            nn.Conv2d(1, 1, 1, bias=False),
+            nn.Sigmoid()
+        )
+        self.attention_wise = nn.Sequential(
+            nn.Conv2d(in_planes, num_features, 1, bias=False),
+            nn.Sigmoid()
+        )
+        if init_weight:
+            self._initialize_weights()
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            if isinstance(m ,nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x):
+        # global_x = self.avgpool(x)
+        #shape Nx4x1x1
+        global_x = self.channel_wise(x).sigmoid()
+        #shape: NxCxHxW
+        global_x = global_x + x
+        #shape:Nx1xHxW
+        x = torch.mean(global_x, dim=1, keepdim=True)
+        global_x = self.spatial_wise(x) + global_x
+        global_x = self.attention_wise(global_x)
+        return global_x
+
+class ScaleSpatialAttention(nn.Module):
+    def __init__(self, in_planes, out_planes, num_features, init_weight=True):
+        super(ScaleSpatialAttention, self).__init__()
+        self.spatial_wise = nn.Sequential(
+            #Nx1xHxW
+            nn.Conv2d(1, 1, 3, bias=False, padding=1),
+            nn.ReLU(),
+            nn.Conv2d(1, 1, 1, bias=False),
+            nn.Sigmoid()
+        )
+        self.attention_wise = nn.Sequential(
+            nn.Conv2d(in_planes, num_features, 1, bias=False),
+            nn.Sigmoid()
+        )
+        if init_weight:
+            self._initialize_weights()
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            if isinstance(m ,nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x):
+        global_x = torch.mean(x, dim=1, keepdim=True)
+        global_x = self.spatial_wise(global_x) + x
+        global_x = self.attention_wise(global_x)
+        return global_x
+
+class ScaleFeatureSelection(nn.Module):
+    def __init__(self, in_channels, inter_channels , out_features_num=4, attention_type='scale_spatial'):
+        super(ScaleFeatureSelection, self).__init__()
+        self.in_channels=in_channels
+        self.inter_channels = inter_channels
+        self.out_features_num = out_features_num
+        self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1)
+        self.type = attention_type
+        if self.type == 'scale_spatial':
+            self.enhanced_attention = ScaleSpatialAttention(inter_channels, inter_channels//4, out_features_num)
+        elif self.type == 'scale_channel_spatial':
+            self.enhanced_attention = ScaleChannelSpatialAttention(inter_channels, inter_channels // 4, out_features_num)
+        elif self.type == 'scale_channel':
+            self.enhanced_attention = ScaleChannelAttention(inter_channels, inter_channels//2, out_features_num)
+
+    def _initialize_weights(self, m):
+        classname = m.__class__.__name__
+        if classname.find('Conv') != -1:
+            nn.init.kaiming_normal_(m.weight.data)
+        elif classname.find('BatchNorm') != -1:
+            m.weight.data.fill_(1.)
+            m.bias.data.fill_(1e-4)
+    def forward(self, concat_x, features_list):
+        concat_x = self.conv(concat_x)
+        score = self.enhanced_attention(concat_x)
+        assert len(features_list) == self.out_features_num
+        if self.type not in ['scale_channel_spatial', 'scale_spatial']:
+            shape = features_list[0].shape[2:]
+            score = F.interpolate(score, size=shape, mode='bilinear')
+        x = []
+        for i in range(self.out_features_num):
+            x.append(score[:, i:i+1] * features_list[i])
+        return torch.cat(x, dim=1)

+ 4 - 0
torchocr/networks/__init__.py

@@ -0,0 +1,4 @@
+from .architectures import build_model
+from .losses import build_loss
+
+__all__ = ['build_model', 'build_loss']

+ 82 - 0
torchocr/networks/architectures/DetModel.py

@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/21 14:23
+# @Author  : zhoujun
+from torch import nn
+from addict import Dict as AttrDict
+from torchocr.networks.backbones.DetMobilenetV3 import MobileNetV3
+from torchocr.networks.backbones.DetResNetvd import ResNet
+from torchocr.networks.necks.DB_fpn import DB_fpn, RSEFPN, LKPAN
+from torchocr.networks.necks.FCE_Fpn import FCEFPN
+from torchocr.networks.necks.pse_fpn import PSEFpn
+from torchocr.networks.necks.DB_ASF import DB_Asf
+from torchocr.networks.heads.DetDbHead import DBHead
+from torchocr.networks.heads.FCEHead import FCEHead
+from torchocr.networks.heads.DetPseHead import PseHead
+from torchocr.networks.backbones.DetGhostNet import GhostNet
+from torchocr.networks.backbones.Transformer import *
+from torchocr.networks.backbones.ConvNext import ConvNeXt
+
+backbone_dict = {'MobileNetV3': MobileNetV3,
+                 'ResNet': ResNet,
+                 'GhostNet': GhostNet,
+                 'SwinTransformer': SwinTransformer,
+                 'ConvNeXt': ConvNeXt
+                 }
+neck_dict = {'DB_fpn': DB_fpn,
+             'pse_fpn': PSEFpn,
+             'ASF': DB_Asf,
+             'RSEFPN': RSEFPN,
+             'LKPAN': LKPAN,
+             'FCEFPN': FCEFPN
+             }
+head_dict = {'DBHead': DBHead,
+             'PseHead': PseHead,
+             'FCEHead': FCEHead
+             }
+
+
+class DetModel(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        assert 'in_channels' in config, 'in_channels must in model config'
+        backbone_type = config.backbone.pop('type')
+        assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'
+        self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
+
+        neck_type = config.neck.pop('type')
+        assert neck_type in neck_dict, f'neck.type must in {neck_dict}'
+        self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
+
+        head_type = config.head.pop('type')
+        assert head_type in head_dict, f'head.type must in {head_dict}'
+        self.head = head_dict[head_type](self.neck.out_channels, **config.head)
+
+        self.name = f'DetModel_{backbone_type}_{neck_type}_{head_type}'
+
+    def forward(self, x):
+        x = self.backbone(x)
+        x = self.neck(x)
+        x = self.head(x)
+        return x
+
+
+if __name__ == '__main__':
+    import torch
+
+    # db_config = AttrDict(
+    #     in_channels=3,
+    #     backbone=AttrDict(type='MobileNetV3', layers=50, model_name='large',pretrained=True),
+    #     neck=AttrDict(type='FPN', out_channels=256),
+    #     head=AttrDict(type='DBHead')
+    # )
+    # x = torch.zeros(1, 3, 640, 640)
+    # model = DetModel(db_config)
+
+    db_config = AttrDict(
+        in_channels=3,
+        backbone=AttrDict(type='ResNet', layers=50, pretrained=True),
+        neck=AttrDict(type='pse_fpn', out_channels=256),
+        head=AttrDict(type='PseHead', H=640, W=640, scale=1)
+    )
+    x = torch.zeros(1, 3, 640, 640)
+    model = DetModel(db_config)

+ 58 - 0
torchocr/networks/architectures/DistillationDetModel.py

@@ -0,0 +1,58 @@
+import os
+import copy
+import torch
+from torch import nn
+from addict import Dict
+
+from .DetModel import DetModel
+from addict import Dict as AttrDict
+
+__all__ = ['DistillationModel']
+
+
+def load_pretrained_params(_model, _path):
+    if _path is None:
+        return False
+    if not os.path.exists(_path):
+        print(f'The pretrained_model {_path} does not exists')
+        return False
+    params = torch.load(_path)
+    state_dict = params['state_dict']
+    state_dict_no_module = {k.replace('module.', ''): v for k, v in state_dict.items()}
+    _model.load_state_dict(state_dict_no_module)
+    return _model
+
+
+class DistillationModel(nn.Module):
+    def __init__(self, config):
+        super(DistillationModel, self).__init__()
+        self.model_dict = nn.ModuleDict()
+        self.model_name_list = []
+
+        sub_model_cfgs = config['models']
+        for key in sub_model_cfgs:
+            sub_cfg = copy.deepcopy(sub_model_cfgs[key])
+            sub_cfg.pop('type')
+            freeze_params = False
+            pretrained = None
+
+            if 'freeze_params' in sub_cfg:
+                freeze_params = sub_cfg.pop('freeze_params')
+            if 'pretrained' in sub_cfg:
+                pretrained = sub_cfg.pop('pretrained')
+            model = DetModel(Dict(sub_cfg))
+            if pretrained is not None:
+                model = load_pretrained_params(model, pretrained)
+            if freeze_params:
+                for para in model.parameters():
+                    para.requires_grad = False
+                model.training = False
+
+            self.model_dict[key] = model
+            self.model_name_list.append(key)
+
+    def forward(self, x):
+        result_dict = dict()
+        for idx, model_name in enumerate(self.model_name_list):
+            result_dict[model_name] = self.model_dict[model_name](x)
+        return result_dict

+ 41 - 0
torchocr/networks/architectures/RecModel.py

@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/16 11:18
+# @Author  : zhoujun
+from torch import nn
+
+from torchocr.networks.backbones.RecMobileNetV3 import MobileNetV3
+from torchocr.networks.backbones.RecResNetvd import ResNet
+from torchocr.networks.necks.RNN import SequenceEncoder, Im2Seq
+
+from torchocr.networks.heads.RecCTCHead import CTC
+
+backbone_dict = {'MobileNetV3': MobileNetV3, 'ResNet': ResNet}
+neck_dict = {'PPaddleRNN': SequenceEncoder, 'None': Im2Seq}
+head_dict = {'CTC': CTC}
+
+
+class RecModel(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        assert 'in_channels' in config, 'in_channels must in model config'
+        backbone_type = config.backbone.pop('type')
+        assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'
+        self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
+
+        neck_type = config.neck.pop('type')
+        assert neck_type in neck_dict, f'neck.type must in {neck_dict}'
+        self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
+
+        head_type = config.head.pop('type')
+        assert head_type in head_dict, f'head.type must in {head_dict}'
+        self.head = head_dict[head_type](self.neck.out_channels, **config.head)
+
+        self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}'
+
+    def forward(self, x):
+        x = self.backbone(x)
+        x = self.neck(x)
+        features = x
+        x = self.head(x)
+        # return features,x
+        return x

+ 21 - 0
torchocr/networks/architectures/__init__.py

@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/15 17:42
+# @Author  : zhoujun
+from addict import Dict
+import copy
+from .RecModel import RecModel
+from .DetModel import DetModel
+from .DistillationDetModel import DistillationModel
+
+support_model = ['RecModel', 'DetModel','DistillationModel']
+
+
+def build_model(config):
+    """
+    get architecture model class
+    """
+    copy_config = copy.deepcopy(config)
+    arch_type = copy_config.pop('type')
+    assert arch_type in support_model, f'{arch_type} is not developed yet!, only {support_model} are support now'
+    arch_model = eval(arch_type)(Dict(copy_config))
+    return arch_model

+ 166 - 0
torchocr/networks/backbones/ConvNext.py

@@ -0,0 +1,166 @@
+from functools import partial
+import logging
+import os
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchocr.networks.backbones.Transformer import DropPath
+
+
+class Block(nn.Module):
+    r""" ConvNeXt Block. There are two equivalent implementations:
+    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+    We use (2) as we find it slightly faster in PyTorch
+
+    Args:
+        dim (int): Number of input channels.
+        drop_path (float): Stochastic depth rate. Default: 0.0
+        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+    """
+
+    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
+        super().__init__()
+        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
+        self.norm = LayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.pwconv2 = nn.Linear(4 * dim, dim)
+        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
+                                  requires_grad=True) if layer_scale_init_value > 0 else None
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+    def forward(self, x):
+        input = x
+        x = self.dwconv(x)
+        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
+        x = self.norm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+        if self.gamma is not None:
+            x = self.gamma * x
+        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
+
+        x = input + self.drop_path(x)
+        return x
+
+
+class ConvNeXt(nn.Module):
+    r""" ConvNeXt
+        A PyTorch impl of : `A ConvNet for the 2020s`  -
+          https://arxiv.org/pdf/2201.03545.pdf
+    Args:
+        in_chans (int): Number of input image channels. Default: 3
+        num_classes (int): Number of classes for classification head. Default: 1000
+        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+        drop_path_rate (float): Stochastic depth rate. Default: 0.
+        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+    """
+
+    def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
+                 drop_path_rate=0.4, layer_scale_init_value=1.0, out_indices=[0, 1, 2, 3], **kwargs
+                 ):
+        super().__init__()
+
+        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
+        stem = nn.Sequential(
+            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
+        )
+        self.downsample_layers.append(stem)
+        for i in range(3):
+            downsample_layer = nn.Sequential(
+                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
+            )
+            self.downsample_layers.append(downsample_layer)
+
+        self.stages = nn.ModuleList()  # 4 feature resolution stages, each consisting of multiple residual blocks
+        self.pretrained = kwargs.get('pretrained', True)
+        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+        self.out_channels = [96, 192, 384, 768]
+        cur = 0
+        for i in range(4):
+            stage = nn.Sequential(
+                *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
+                        layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
+            )
+            self.stages.append(stage)
+            cur += depths[i]
+
+        self.out_indices = out_indices
+
+        norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
+        for i_layer in range(4):
+            layer = norm_layer(dims[i_layer])
+            layer_name = f'norm{i_layer}'
+            self.add_module(layer_name, layer)
+
+        if self.pretrained:
+            ckpt_path = f'./weights/convnext_tiny_1k_512x512.pth'
+            logger = logging.getLogger('torchocr')
+            if os.path.exists(ckpt_path):
+                logger.info('load convnext weights')
+                self.load_state_dict(torch.load(ckpt_path), strict=True)
+            else:
+                logger.info(f'{ckpt_path} not exists')
+                self.apply(self._init_weights)
+        else:
+            self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            nn.init.trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def forward_features(self, x):
+        outs = []
+        for i in range(4):
+            x = self.downsample_layers[i](x)
+            x = self.stages[i](x)
+            if i in self.out_indices:
+                norm_layer = getattr(self, f'norm{i}')
+                x_out = norm_layer(x)
+                outs.append(x_out)
+
+        return tuple(outs)
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        return x
+
+
+class LayerNorm(nn.Module):
+    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
+    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+    with shape (batch_size, channels, height, width).
+    """
+
+    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(normalized_shape))
+        self.bias = nn.Parameter(torch.zeros(normalized_shape))
+        self.eps = eps
+        self.data_format = data_format
+        if self.data_format not in ["channels_last", "channels_first"]:
+            raise NotImplementedError
+        self.normalized_shape = (normalized_shape,)
+
+    def forward(self, x):
+        if self.data_format == "channels_last":
+            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+        elif self.data_format == "channels_first":
+            u = x.mean(1, keepdim=True)
+            s = (x - u).pow(2).mean(1, keepdim=True)
+            x = (x - u) / torch.sqrt(s + self.eps)
+            x = self.weight[:, None, None] * x + self.bias[:, None, None]
+            return x

+ 311 - 0
torchocr/networks/backbones/DetGhostNet.py

@@ -0,0 +1,311 @@
+# 2020.06.09-Changed for building GhostNet
+#            Huawei Technologies Co., Ltd. <foss@huawei.com>
+"""
+Creates a GhostNet Model as defined in:
+GhostNet: More Features from Cheap Operations By Kai Han, Yunhe Wang, Qi Tian, Jianyuan Guo, Chunjing Xu, Chang Xu.
+https://arxiv.org/abs/1911.11907
+Modified from https://github.com/d-li14/mobilenetv3.pytorch and https://github.com/rwightman/pytorch-image-models
+"""
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+import logging
+from collections import OrderedDict
+from torchocr.networks.CommonModules import CBAM
+
+
+
+def _make_divisible(v, divisor, min_value=None):
+    """
+    This function is taken from the original tf repo.
+    It ensures that all layers have a channel number that is divisible by 8
+    It can be seen here:
+    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+    """
+    if min_value is None:
+        min_value = divisor
+    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+    # Make sure that round down does not go down by more than 10%.
+    if new_v < 0.9 * v:
+        new_v += divisor
+    return new_v
+
+
+def hard_sigmoid(x, inplace: bool = False):
+    if inplace:
+        return x.add_(3.).clamp_(0., 6.).div_(6.)
+    else:
+        return F.relu6(x + 3.) / 6.
+
+
+class SqueezeExcite(nn.Module):
+    def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
+                 act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_):
+        super(SqueezeExcite, self).__init__()
+        self.gate_fn = gate_fn
+        reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
+        self.act1 = act_layer(inplace=True)
+        self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
+
+    def forward(self, x):
+        x_se = self.avg_pool(x)
+        x_se = self.conv_reduce(x_se)
+        x_se = self.act1(x_se)
+        x_se = self.conv_expand(x_se)
+        x = x * self.gate_fn(x_se)
+        return x
+
+
+class ConvBnAct(nn.Module):
+    def __init__(self, in_chs, out_chs, kernel_size,
+                 stride=1, act_layer=nn.ReLU):
+        super(ConvBnAct, self).__init__()
+        self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False)
+        self.bn1 = nn.BatchNorm2d(out_chs)
+        self.act1 = act_layer(inplace=True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn1(x)
+        x = self.act1(x)
+        return x
+
+
+class GhostModule(nn.Module):
+    def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True):
+        super(GhostModule, self).__init__()
+        self.oup = oup
+        init_channels = math.ceil(oup / ratio)
+        new_channels = init_channels * (ratio - 1)
+
+        self.primary_conv = nn.Sequential(
+            nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size // 2, bias=False),
+            nn.BatchNorm2d(init_channels),
+            nn.ReLU(inplace=True) if relu else nn.Sequential(),
+        )
+
+        self.cheap_operation = nn.Sequential(
+            nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size // 2, groups=init_channels, bias=False),
+            nn.BatchNorm2d(new_channels),
+            nn.ReLU(inplace=True) if relu else nn.Sequential(),
+        )
+
+    def forward(self, x):
+        x1 = self.primary_conv(x)
+        x2 = self.cheap_operation(x1)
+        out = torch.cat([x1, x2], dim=1)
+        return out[:, :self.oup, :, :]
+
+
+class GhostBottleneck(nn.Module):
+    """ Ghost bottleneck w/ optional SE"""
+
+    def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3,
+                 stride=1, act_layer=nn.ReLU, se_ratio=0.):
+        super(GhostBottleneck, self).__init__()
+        has_se = se_ratio is not None and se_ratio > 0.
+        self.stride = stride
+
+        # Point-wise expansion
+        self.ghost1 = GhostModule(in_chs, mid_chs, relu=True)
+
+        # Depth-wise convolution
+        if self.stride > 1:
+            self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride,
+                                     padding=(dw_kernel_size - 1) // 2,
+                                     groups=mid_chs, bias=False)
+            self.bn_dw = nn.BatchNorm2d(mid_chs)
+
+        # Squeeze-and-excitation
+        if has_se:
+            self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio)
+            # self.se = CBAM(mid_chs,mid_chs)
+        else:
+            self.se = None
+
+        # Point-wise linear projection
+        self.ghost2 = GhostModule(mid_chs, out_chs, relu=False)
+
+        # shortcut
+        if (in_chs == out_chs and self.stride == 1):
+            self.shortcut = nn.Sequential()
+        else:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride=stride,
+                          padding=(dw_kernel_size - 1) // 2, groups=in_chs, bias=False),
+                nn.BatchNorm2d(in_chs),
+                nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False),
+                nn.BatchNorm2d(out_chs),
+            )
+
+    def forward(self, x):
+        residual = x
+
+        # 1st ghost bottleneck
+        x = self.ghost1(x)
+
+        # Depth-wise convolution
+        if self.stride > 1:
+            x = self.conv_dw(x)
+            x = self.bn_dw(x)
+
+        # Squeeze-and-excitation
+        if self.se is not None:
+            x = self.se(x)
+
+        # 2nd ghost bottleneck
+        x = self.ghost2(x)
+
+        x += self.shortcut(residual)
+        return x
+
+
+class GhostNet(nn.Module):
+    def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, pretrained=True,**kwargs):
+        super(GhostNet, self).__init__()
+        # setting of inverted residual blocks
+        model_name = kwargs.get('model_name', 'default')
+        self.disable_se = kwargs.get('disable_se', False)
+        if model_name=='default':
+            self.cfgs= [
+        # k, t, c, SE, s
+        # stage1
+        [[3, 16, 16, 0, 1]],
+        # stage2
+        [[3, 48, 24, 0, 2]],
+        [[3, 72, 24, 0, 1]],
+        # stage3
+        [[5, 72, 40, 0.25, 2]],
+        [[5, 120, 40, 0.25, 1]],
+        # stage4
+        [[3, 240, 80, 0, 2]],
+        [[3, 200, 80, 0, 1],
+         [3, 184, 80, 0, 1],
+         [3, 184, 80, 0, 1],
+         [3, 480, 112, 0.25, 1],
+         [3, 672, 112, 0.25, 1]
+         ],
+        # stage5
+        [[5, 672, 160, 0.25, 2]],
+        [[5, 960, 160, 0, 1],
+         [5, 960, 160, 0.25, 1],
+         [5, 960, 160, 0, 1],
+         [5, 960, 160, 0.25, 1]
+         ]
+    ]
+
+        # self.cfgs = cfgs
+        # self.dropout = dropout
+
+        # building first layer
+        output_channel = _make_divisible(16 * width, 4)  # 16
+        self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(output_channel)
+        self.act1 = nn.ReLU(inplace=True)
+        input_channel = output_channel
+
+        # building inverted residual blocks
+        stages = []
+        block = GhostBottleneck
+        self.keep_stages = []
+        self.out_channels = []
+        i = 0
+        for cfg in self.cfgs:
+            layers = []
+            for k, exp_size, c, se_ratio, s in cfg:
+                if s == 2 and i > 2:
+                    self.out_channels.append(input_channel)
+                output_channel = _make_divisible(c * width, 4)
+                hidden_channel = _make_divisible(exp_size * width, 4)
+                layers.append(block(input_channel, hidden_channel, output_channel, k, s,
+                                    se_ratio=se_ratio))
+                input_channel = output_channel
+                i += 1
+            stages.append(nn.Sequential(*layers))
+
+        output_channel = _make_divisible(exp_size * width, 4)
+        stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1)))
+        input_channel = output_channel
+
+        self.out_channels.append(input_channel)
+
+        self.blocks = nn.Sequential(*stages)
+
+        # building last several layers
+        # output_channel = 1280
+        # self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
+        # self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True)
+        # self.act2 = nn.ReLU(inplace=True)
+        # self.classifier = nn.Linear(output_channel, num_classes)
+
+        if pretrained:
+            ckpt_path = f'./weights/state_dict_73.98.pth'
+            logger = logging.getLogger('torchocr')
+            if os.path.exists(ckpt_path):
+                logger.info('load imagenet weights')
+                dic_ckpt = torch.load(ckpt_path)
+                filtered_dict = OrderedDict()
+                for key in dic_ckpt.keys():
+                    flag = key.find('se') != -1
+                    if self.disable_se and flag:
+                        continue
+                    filtered_dict[key] = dic_ckpt[key]
+                self.load_state_dict(filtered_dict)
+            else:
+                logger.info(f'{ckpt_path} not exists')
+
+    def forward(self, x):
+        x = self.conv_stem(x)
+        x = self.bn1(x)
+        x = self.act1(x)
+        out = []
+        for stage in self.blocks:
+            x = stage(x)
+            out.append(x)
+        return [out[2], out[4], out[6], out[9]]
+
+
+def ghostnet(**kwargs):
+    """
+    Constructs a GhostNet model
+    """
+    cfgs = [
+        # k, t, c, SE, s
+        # stage1
+        [[3, 16, 16, 0, 1]],
+        # stage2
+        [[3, 48, 24, 0, 2]],
+        [[3, 72, 24, 0, 1]],
+        # stage3
+        [[5, 72, 40, 0.25, 2]],
+        [[5, 120, 40, 0.25, 1]],
+        # stage4
+        [[3, 240, 80, 0, 2]],
+        [[3, 200, 80, 0, 1],
+         [3, 184, 80, 0, 1],
+         [3, 184, 80, 0, 1],
+         [3, 480, 112, 0.25, 1],
+         [3, 672, 112, 0.25, 1]
+         ],
+        # stage5
+        [[5, 672, 160, 0.25, 2]],
+        [[5, 960, 160, 0, 1],
+         [5, 960, 160, 0.25, 1],
+         [5, 960, 160, 0, 1],
+         [5, 960, 160, 0.25, 1]
+         ]
+    ]
+    return GhostNet(cfgs, **kwargs)
+
+
+if __name__ == '__main__':
+    model = ghostnet()
+    model.eval()
+    # print(model)
+    input = torch.randn(32, 3, 320, 256)
+    y = model(input)
+    print(y.size())

+ 175 - 0
torchocr/networks/backbones/DetMobilenetV3.py

@@ -0,0 +1,175 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import logging
+import os
+import torch
+from torch import nn
+from torchocr.networks.CommonModules import ConvBNACT, SEBlock
+from collections import OrderedDict
+
+
+class ResidualUnit(nn.Module):
+    def __init__(self, num_in_filter, num_mid_filter, num_out_filter, stride, kernel_size, act=None, use_se=False):
+        super().__init__()
+        self.conv0 = ConvBNACT(in_channels=num_in_filter, out_channels=num_mid_filter, kernel_size=1, stride=1,
+                               padding=0, act=act)
+
+        self.conv1 = ConvBNACT(in_channels=num_mid_filter, out_channels=num_mid_filter, kernel_size=kernel_size,
+                               stride=stride,
+                               padding=int((kernel_size - 1) // 2), act=act, groups=num_mid_filter)
+        if use_se:
+            self.se = SEBlock(in_channels=num_mid_filter)
+        else:
+            self.se = None
+
+        self.conv2 = ConvBNACT(in_channels=num_mid_filter, out_channels=num_out_filter, kernel_size=1, stride=1,
+                               padding=0)
+        self.not_add = num_in_filter != num_out_filter or stride != 1
+
+    def forward(self, x):
+        y = self.conv0(x)
+        y = self.conv1(y)
+        if self.se is not None:
+            y = self.se(y)
+        y = self.conv2(y)
+        if not self.not_add:
+            y = x + y
+        return y
+
+
+class MobileNetV3(nn.Module):
+    def __init__(self, in_channels, pretrained=True, **kwargs):
+        """
+        the MobilenetV3 backbone network for detection module.
+        Args:
+            params(dict): the super parameters for build network
+        """
+        super().__init__()
+        self.scale = kwargs.get('scale', 0.5)
+        model_name = kwargs.get('model_name', 'large')
+        self.disable_se = kwargs.get('disable_se', True)
+        self.inplanes = 16
+        if model_name == "large":
+            self.cfg = [
+                # k, exp, c,  se,     nl,  s,
+                [3, 16, 16, False, 'relu', 1],
+                [3, 64, 24, False, 'relu', 2],
+                [3, 72, 24, False, 'relu', 1],
+                [5, 72, 40, True, 'relu', 2],
+                [5, 120, 40, True, 'relu', 1],
+                [5, 120, 40, True, 'relu', 1],
+                [3, 240, 80, False, 'hard_swish', 2],
+                [3, 200, 80, False, 'hard_swish', 1],
+                [3, 184, 80, False, 'hard_swish', 1],
+                [3, 184, 80, False, 'hard_swish', 1],
+                [3, 480, 112, True, 'hard_swish', 1],
+                [3, 672, 112, True, 'hard_swish', 1],
+                [5, 672, 160, True, 'hard_swish', 2],
+                [5, 960, 160, True, 'hard_swish', 1],
+                [5, 960, 160, True, 'hard_swish', 1],
+            ]
+            self.cls_ch_squeeze = 960
+            self.cls_ch_expand = 1280
+        elif model_name == "small":
+            self.cfg = [
+                # k, exp, c,  se,     nl,  s,
+                [3, 16, 16, True, 'relu', 2],
+                [3, 72, 24, False, 'relu', 2],
+                [3, 88, 24, False, 'relu', 1],
+                [5, 96, 40, True, 'hard_swish', 2],
+                [5, 240, 40, True, 'hard_swish', 1],
+                [5, 240, 40, True, 'hard_swish', 1],
+                [5, 120, 48, True, 'hard_swish', 1],
+                [5, 144, 48, True, 'hard_swish', 1],
+                [5, 288, 96, True, 'hard_swish', 2],
+                [5, 576, 96, True, 'hard_swish', 1],
+                [5, 576, 96, True, 'hard_swish', 1],
+            ]
+            self.cls_ch_squeeze = 576
+            self.cls_ch_expand = 1280
+        else:
+            raise NotImplementedError("mode[" + model_name +
+                                      "_model] is not implemented!")
+
+        supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
+        assert self.scale in supported_scale, \
+            "supported scale are {} but input scale is {}".format(supported_scale, self.scale)
+
+        scale = self.scale
+        inplanes = self.inplanes
+        cfg = self.cfg
+        cls_ch_squeeze = self.cls_ch_squeeze
+        # conv1
+        self.conv1 = ConvBNACT(in_channels=in_channels,
+                               out_channels=self.make_divisible(inplanes * scale),
+                               kernel_size=3,
+                               stride=2,
+                               padding=1,
+                               groups=1,
+                               act='hard_swish')
+        i = 0
+        inplanes = self.make_divisible(inplanes * scale)
+        self.stages = nn.ModuleList()
+        block_list = []
+        self.out_channels = []
+        for layer_cfg in cfg:
+            se = layer_cfg[3] and not self.disable_se
+            if layer_cfg[5] == 2 and i > 2:
+                self.out_channels.append(inplanes)
+                self.stages.append(nn.Sequential(*block_list))
+                block_list = []
+            block = ResidualUnit(num_in_filter=inplanes,
+                                 num_mid_filter=self.make_divisible(scale * layer_cfg[1]),
+                                 num_out_filter=self.make_divisible(scale * layer_cfg[2]),
+                                 act=layer_cfg[4],
+                                 stride=layer_cfg[5],
+                                 kernel_size=layer_cfg[0],
+                                 use_se=se)
+            block_list.append(block)
+            inplanes = self.make_divisible(scale * layer_cfg[2])
+            i += 1
+        block_list.append(ConvBNACT(
+            in_channels=inplanes,
+            out_channels=self.make_divisible(scale * cls_ch_squeeze),
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            groups=1,
+            act='hard_swish'))
+        self.stages.append(nn.Sequential(*block_list))
+        self.out_channels.append(self.make_divisible(scale * cls_ch_squeeze))
+
+        if pretrained:
+            ckpt_path = f'./weights/MobileNetV3_{model_name}_x{str(scale).replace(".", "_")}.pth'
+            logger = logging.getLogger('torchocr')
+            if os.path.exists(ckpt_path):
+                logger.info('load imagenet weights')
+                dic_ckpt = torch.load(ckpt_path)
+                filtered_dict = OrderedDict()
+                for key in dic_ckpt.keys():
+                    flag = key.find('se') != -1
+                    if self.disable_se and flag:
+                        continue
+                    filtered_dict[key] = dic_ckpt[key]
+
+                self.load_state_dict(filtered_dict)
+            else:
+                logger.info(f'{ckpt_path} not exists')
+
+    def make_divisible(self, v, divisor=8, min_value=None):
+        if min_value is None:
+            min_value = divisor
+        new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+        if new_v < 0.9 * v:
+            new_v += divisor
+        return new_v
+
+    def forward(self, x):
+        x = self.conv1(x)
+        out = []
+        for stage in self.stages:
+            x = stage(x)
+            out.append(x)
+
+        return out

+ 213 - 0
torchocr/networks/backbones/DetResNetvd.py

@@ -0,0 +1,213 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import os
+
+import torch
+from torch import nn
+
+from torchocr.networks.CommonModules import HSwish
+
+
+class ConvBNACT(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, act=None):
+        super().__init__()
+        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
+                              stride=stride, padding=padding, groups=groups,
+                              bias=False)
+        self.bn = nn.BatchNorm2d(out_channels)
+        if act == 'relu':
+            self.act = nn.ReLU(inplace=True)
+        elif act == 'hard_swish':
+            self.act = HSwish()
+        elif act is None:
+            self.act = None
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        if self.act is not None:
+            x = self.act(x)
+        return x
+
+
+class ConvBNACTWithPool(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, groups=1, act=None):
+        super().__init__()
+        # self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
+        self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
+
+        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1,
+                              padding=(kernel_size - 1) // 2,
+                              groups=groups,
+                              bias=False)
+        self.bn = nn.BatchNorm2d(out_channels)
+        if act is None:
+            self.act = None
+        else:
+            self.act = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = self.pool(x)
+        x = self.conv(x)
+        x = self.bn(x)
+        if self.act is not None:
+            x = self.act(x)
+        return x
+
+
+class ShortCut(nn.Module):
+    def __init__(self, in_channels, out_channels, stride, name, if_first=False):
+        super().__init__()
+        assert name is not None, 'shortcut must have name'
+
+        self.name = name
+        if in_channels != out_channels or stride != 1:
+            if if_first:
+                self.conv = ConvBNACT(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride,
+                                      padding=0, groups=1, act=None)
+            else:
+                self.conv = ConvBNACTWithPool(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
+                                              groups=1, act=None)
+        elif if_first:
+            self.conv = ConvBNACT(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride,
+                                  padding=0, groups=1, act=None)
+        else:
+            self.conv = None
+
+    def forward(self, x):
+        if self.conv is not None:
+            x = self.conv(x)
+        return x
+
+
+class BottleneckBlock(nn.Module):
+    def __init__(self, in_channels, out_channels, stride, if_first, name):
+        super().__init__()
+        assert name is not None, 'bottleneck must have name'
+        self.name = name
+        self.conv0 = ConvBNACT(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0,
+                               groups=1, act='relu')
+        self.conv1 = ConvBNACT(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=stride,
+                               padding=1, groups=1, act='relu')
+        self.conv2 = ConvBNACT(in_channels=out_channels, out_channels=out_channels * 4, kernel_size=1, stride=1,
+                               padding=0, groups=1, act=None)
+        self.shortcut = ShortCut(in_channels=in_channels, out_channels=out_channels * 4, stride=stride,
+                                 if_first=if_first, name=f'{name}_branch1')
+        self.relu = nn.ReLU(inplace=True)
+        self.output_channels = out_channels * 4
+
+    def forward(self, x):
+        y = self.conv0(x)
+        y = self.conv1(y)
+        y = self.conv2(y)
+        y = y + self.shortcut(x)
+        return self.relu(y)
+
+
+class BasicBlock(nn.Module):
+    def __init__(self, in_channels, out_channels, stride, if_first, name):
+        super().__init__()
+        assert name is not None, 'block must have name'
+        self.name = name
+
+        self.conv0 = ConvBNACT(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride,
+                               padding=1, groups=1, act='relu')
+        self.conv1 = ConvBNACT(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1,
+                               groups=1, act=None)
+        self.shortcut = ShortCut(in_channels=in_channels, out_channels=out_channels, stride=stride,
+                                 name=f'{name}_branch1', if_first=if_first, )
+        self.relu = nn.ReLU(inplace=True)
+        self.output_channels = out_channels
+
+    def forward(self, x):
+        y = self.conv0(x)
+        y = self.conv1(y)
+        y = y + self.shortcut(x)
+        return self.relu(y)
+
+
+class ResNet(nn.Module):
+    def __init__(self, in_channels, layers, out_indices=[0, 1, 2, 3], pretrained=True, **kwargs):
+        """
+        the Resnet backbone network for detection module.
+        Args:
+            params(dict): the super parameters for network build
+        """
+        super().__init__()
+        supported_layers = {
+            18: {'depth': [2, 2, 2, 2], 'block_class': BasicBlock},
+            34: {'depth': [3, 4, 6, 3], 'block_class': BasicBlock},
+            50: {'depth': [3, 4, 6, 3], 'block_class': BottleneckBlock},
+            101: {'depth': [3, 4, 23, 3], 'block_class': BottleneckBlock},
+            152: {'depth': [3, 8, 36, 3], 'block_class': BottleneckBlock},
+            200: {'depth': [3, 12, 48, 3], 'block_class': BottleneckBlock}
+        }
+        assert layers in supported_layers, \
+            "supported layers are {} but input layer is {}".format(supported_layers, layers)
+        depth = supported_layers[layers]['depth']
+        block_class = supported_layers[layers]['block_class']
+        self.use_supervised = kwargs.get('use_supervised', False)
+        self.out_indices = out_indices
+        num_filters = [64, 128, 256, 512]
+        self.conv1 = nn.Sequential(
+            ConvBNACT(in_channels=in_channels, out_channels=32, kernel_size=3, stride=2, padding=1, act='relu'),
+            ConvBNACT(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, act='relu'),
+            ConvBNACT(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, act='relu')
+        )
+        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+        self.stages = nn.ModuleList()
+        self.out_channels = []
+        tmp_channels = []
+        in_ch = 64
+        for block_index in range(len(depth)):
+            block_list = []
+            for i in range(depth[block_index]):
+                if layers >= 50:
+                    if layers in [101, 152, 200] and block_index == 2:
+                        if i == 0:
+                            conv_name = "res" + str(block_index + 2) + "a"
+                        else:
+                            conv_name = "res" + str(block_index + 2) + "b" + str(i)
+                    else:
+                        conv_name = "res" + str(block_index + 2) + chr(97 + i)
+                else:
+                    conv_name = f'res{str(block_index + 2)}{chr(97 + i)}'
+                block_list.append(block_class(in_channels=in_ch, out_channels=num_filters[block_index],
+                                              stride=2 if i == 0 and block_index != 0 else 1,
+                                              if_first=block_index == i == 0, name=conv_name))
+                in_ch = block_list[-1].output_channels
+            tmp_channels.append(in_ch)
+            self.stages.append(nn.Sequential(*block_list))
+        for idx, ch in enumerate(tmp_channels):
+            if idx in self.out_indices:
+                self.out_channels.append(ch)
+        if pretrained:
+            ckpt_path = f'./weights/resnet{layers}_vd.pth'
+            logger = logging.getLogger('torchocr')
+            if os.path.exists(ckpt_path):
+                logger.info('load imagenet weights')
+                self.load_state_dict(torch.load(ckpt_path))
+            else:
+                logger.info(f'{ckpt_path} not exists')
+        if self.use_supervised:
+            ckpt_path = f'./weights/res_supervised_140w_387e.pth'
+            logger = logging.getLogger('torchocr')
+            if os.path.exists(ckpt_path):
+                logger.info('load supervised weights')
+                self.load_state_dict(torch.load(ckpt_path))
+            else:
+                logger.info(f'{ckpt_path} not exists')
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.pool1(x)
+        out = []
+        for idx, stage in enumerate(self.stages):
+            x = stage(x)
+            if idx in self.out_indices:
+                out.append(x)
+        return out

+ 254 - 0
torchocr/networks/backbones/MobileViT.py

@@ -0,0 +1,254 @@
+import torch
+import torch.nn as nn
+
+from einops import rearrange
+
+
+def conv_1x1_bn(inp, oup):
+    return nn.Sequential(
+        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+        nn.BatchNorm2d(oup),
+        nn.SiLU()
+    )
+
+
+def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
+    return nn.Sequential(
+        nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
+        nn.BatchNorm2d(oup),
+        nn.SiLU()
+    )
+
+
+class PreNorm(nn.Module):
+    def __init__(self, dim, fn):
+        super().__init__()
+        self.norm = nn.LayerNorm(dim)
+        self.fn = fn
+
+    def forward(self, x, **kwargs):
+        return self.fn(self.norm(x), **kwargs)
+
+
+class FeedForward(nn.Module):
+    def __init__(self, dim, hidden_dim, dropout=0.):
+        super().__init__()
+        self.net = nn.Sequential(
+            nn.Linear(dim, hidden_dim),
+            nn.SiLU(),
+            nn.Dropout(dropout),
+            nn.Linear(hidden_dim, dim),
+            nn.Dropout(dropout)
+        )
+
+    def forward(self, x):
+        return self.net(x)
+
+
+class Attention(nn.Module):
+    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
+        super().__init__()
+        inner_dim = dim_head * heads
+        project_out = not (heads == 1 and dim_head == dim)
+
+        self.heads = heads
+        self.scale = dim_head ** -0.5
+
+        self.attend = nn.Softmax(dim=-1)
+        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
+
+        self.to_out = nn.Sequential(
+            nn.Linear(inner_dim, dim),
+            nn.Dropout(dropout)
+        ) if project_out else nn.Identity()
+
+    def forward(self, x):
+        qkv = self.to_qkv(x).chunk(3, dim=-1)
+        q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
+
+        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+        attn = self.attend(dots)
+        out = torch.matmul(attn, v)
+        out = rearrange(out, 'b p h n d -> b p n (h d)')
+        return self.to_out(out)
+
+
+class Transformer(nn.Module):
+    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
+        super().__init__()
+        self.layers = nn.ModuleList([])
+        for _ in range(depth):
+            self.layers.append(nn.ModuleList([
+                PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
+                PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
+            ]))
+
+    def forward(self, x):
+        for attn, ff in self.layers:
+            x = attn(x) + x
+            x = ff(x) + x
+        return x
+
+
+class MV2Block(nn.Module):
+    def __init__(self, inp, oup, stride=1, expansion=4):
+        super().__init__()
+        self.stride = stride
+        assert stride in [1, 2]
+
+        hidden_dim = int(inp * expansion)
+        self.use_res_connect = self.stride == 1 and inp == oup
+
+        if expansion == 1:
+            self.conv = nn.Sequential(
+                # dw
+                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+                nn.BatchNorm2d(hidden_dim),
+                nn.SiLU(),
+                # pw-linear
+                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+                nn.BatchNorm2d(oup),
+            )
+        else:
+            self.conv = nn.Sequential(
+                # pw
+                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
+                nn.BatchNorm2d(hidden_dim),
+                nn.SiLU(),
+                # dw
+                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+                nn.BatchNorm2d(hidden_dim),
+                nn.SiLU(),
+                # pw-linear
+                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+                nn.BatchNorm2d(oup),
+            )
+
+    def forward(self, x):
+        if self.use_res_connect:
+            return x + self.conv(x)
+        else:
+            return self.conv(x)
+
+
+class MobileViTBlock(nn.Module):
+    def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
+        super().__init__()
+        self.ph, self.pw = patch_size
+
+        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
+        self.conv2 = conv_1x1_bn(channel, dim)
+
+        self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
+
+        self.conv3 = conv_1x1_bn(dim, channel)
+        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
+
+    def forward(self, x):
+        y = x.clone()
+
+        # Local representations
+        x = self.conv1(x)
+        x = self.conv2(x)
+
+        # Global representations
+        _, _, h, w = x.shape
+        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
+        x = self.transformer(x)
+        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
+                      pw=self.pw)
+
+        # Fusion
+        x = self.conv3(x)
+        x = torch.cat((x, y), 1)
+        x = self.conv4(x)
+        return x
+
+
+class MobileViT(nn.Module):
+    def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2),**kwargs):
+        super().__init__()
+        ih, iw = image_size
+        ph, pw = patch_size
+        assert ih % ph == 0 and iw % pw == 0
+
+        L = [2, 4, 3]
+        self.out_channels = [channels[3], channels[5], channels[7], channels[9]]
+        self.conv1 = conv_nxn_bn(3, channels[0], stride=2)
+
+        self.mv2 = nn.ModuleList([])
+        self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
+        self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
+        self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
+        self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))  # Repeat
+        self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
+        self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
+        self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
+
+        self.mvit = nn.ModuleList([])
+        self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
+        self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
+        self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
+
+        self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
+
+        self.pool = nn.AvgPool2d(ih // 32, 1)
+        self.fc = nn.Linear(channels[-1], num_classes, bias=False)
+
+    def forward(self, x):
+        out = []
+        x = self.conv1(x)
+        x = self.mv2[0](x)
+
+        x = self.mv2[1](x)
+        x = self.mv2[2](x)
+        out.append(x)
+        x = self.mv2[3](x)  # Repeat
+
+        x = self.mv2[4](x)  # b*48*32*32
+        x = self.mvit[0](x)
+        out.append(x)
+
+        x = self.mv2[5](x)  # b*64*16*16
+        x = self.mvit[1](x)
+        out.append(x)
+
+        x = self.mv2[6](x)  # b*80*8*8
+        x = self.mvit[2](x)  # b*80*8*8
+        out.append(x)
+        return out
+
+        # x = self.conv2(x)
+        # x = self.pool(x).view(-1, x.shape[1])
+        # x = self.fc(x)
+        # return x
+
+
+def mobilevit_xxs():
+    dims = [64, 80, 96]
+    channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
+    return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)
+
+
+def mobilevit_xs(inchannel,**kwargs):
+    dims = [96, 120, 144]
+    channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
+    return MobileViT((512, 512), dims, channels, num_classes=1000)
+
+
+def mobilevit_s():
+    dims = [144, 192, 240]
+    channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
+    return MobileViT((256, 256), dims, channels, num_classes=1000)
+
+
+def count_parameters(model):
+    return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+if __name__ == '__main__':
+    img = torch.randn(5, 3, 256, 256)
+
+    vit = mobilevit_xs()
+    out = vit(img)
+    print(count_parameters(vit))

+ 141 - 0
torchocr/networks/backbones/RecMobileNetV3.py

@@ -0,0 +1,141 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from torch import nn
+
+from torchocr.networks.CommonModules import ConvBNACT, SEBlock
+
+
+class ResidualUnit(nn.Module):
+    def __init__(self, num_in_filter, num_mid_filter, num_out_filter, stride, kernel_size, act=None, use_se=False):
+        super().__init__()
+        self.expand_conv = ConvBNACT(in_channels=num_in_filter, out_channels=num_mid_filter, kernel_size=1, stride=1,
+                                     padding=0, act=act)
+
+        self.bottleneck_conv = ConvBNACT(in_channels=num_mid_filter, out_channels=num_mid_filter, kernel_size=kernel_size,
+                                         stride=stride,
+                                         padding=int((kernel_size - 1) // 2), act=act, groups=num_mid_filter)
+        if use_se:
+            self.se = SEBlock(in_channels=num_mid_filter)
+        else:
+            self.se = None
+
+        self.linear_conv = ConvBNACT(in_channels=num_mid_filter, out_channels=num_out_filter, kernel_size=1, stride=1,
+                                     padding=0)
+        self.not_add = num_in_filter != num_out_filter or stride != 1
+
+    def forward(self, x):
+        y = self.expand_conv(x)
+        y = self.bottleneck_conv(y)
+        if self.se is not None:
+            y = self.se(y)
+        y = self.linear_conv(y)
+        if not self.not_add:
+            y = x + y
+        return y
+
+
+class MobileNetV3(nn.Module):
+    def __init__(self, in_channels=3, **kwargs):
+        super().__init__()
+        self.scale = kwargs.get('scale', 0.5)
+        model_name = kwargs.get('model_name', 'small')
+        self.inplanes = 16
+        if model_name == "large":
+            self.cfg = [
+                # k, exp, c,  se,     nl,  s,
+                [3, 16, 16, False, 'relu', 1],
+                [3, 64, 24, False, 'relu', (2, 1)],
+                [3, 72, 24, False, 'relu', 1],
+                [5, 72, 40, True, 'relu', (2, 1)],
+                [5, 120, 40, True, 'relu', 1],
+                [5, 120, 40, True, 'relu', 1],
+                [3, 240, 80, False, 'hard_swish', 1],
+                [3, 200, 80, False, 'hard_swish', 1],
+                [3, 184, 80, False, 'hard_swish', 1],
+                [3, 184, 80, False, 'hard_swish', 1],
+                [3, 480, 112, True, 'hard_swish', 1],
+                [3, 672, 112, True, 'hard_swish', 1],
+                [5, 672, 160, True, 'hard_swish', (2, 1)],
+                [5, 960, 160, True, 'hard_swish', 1],
+                [5, 960, 160, True, 'hard_swish', 1],
+            ]
+            self.cls_ch_squeeze = 960
+            self.cls_ch_expand = 1280
+        elif model_name == "small":
+            self.cfg = [
+                # k, exp, c,  se,     nl,  s,
+                [3, 16, 16, True, 'relu', (1, 1)],
+                [3, 72, 24, False, 'relu', (2, 1)],
+                [3, 88, 24, False, 'relu', 1],
+                [5, 96, 40, True, 'hard_swish', (2, 1)],
+                [5, 240, 40, True, 'hard_swish', 1],
+                [5, 240, 40, True, 'hard_swish', 1],
+                [5, 120, 48, True, 'hard_swish', 1],
+                [5, 144, 48, True, 'hard_swish', 1],
+                [5, 288, 96, True, 'hard_swish', (2, 1)],
+                [5, 576, 96, True, 'hard_swish', 1],
+                [5, 576, 96, True, 'hard_swish', 1],
+            ]
+            self.cls_ch_squeeze = 576
+            self.cls_ch_expand = 1280
+        else:
+            raise NotImplementedError("mode[" + model_name +
+                                      "_model] is not implemented!")
+
+        supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
+        assert self.scale in supported_scale, "supported scale are {} but input scale is {}".format(supported_scale,
+                                                                                                    self.scale)
+
+        scale = self.scale
+        inplanes = self.inplanes
+        cfg = self.cfg
+        cls_ch_squeeze = self.cls_ch_squeeze
+        # conv1
+        self.conv1 = ConvBNACT(in_channels=in_channels,
+                               out_channels=self.make_divisible(inplanes * scale),
+                               kernel_size=3,
+                               stride=2,
+                               padding=1,
+                               groups=1,
+                               act='hard_swish')
+        inplanes = self.make_divisible(inplanes * scale)
+        block_list = []
+        for layer_cfg in cfg:
+            block = ResidualUnit(num_in_filter=inplanes,
+                                 num_mid_filter=self.make_divisible(scale * layer_cfg[1]),
+                                 num_out_filter=self.make_divisible(scale * layer_cfg[2]),
+                                 act=layer_cfg[4],
+                                 stride=layer_cfg[5],
+                                 kernel_size=layer_cfg[0],
+                                 use_se=layer_cfg[3])
+            block_list.append(block)
+            inplanes = self.make_divisible(scale * layer_cfg[2])
+
+        self.blocks = nn.Sequential(*block_list)
+        self.conv2 = ConvBNACT(in_channels=inplanes,
+                               out_channels=self.make_divisible(scale * cls_ch_squeeze),
+                               kernel_size=1,
+                               stride=1,
+                               padding=0,
+                               groups=1,
+                               act='hard_swish')
+
+        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+        self.out_channels = self.make_divisible(scale * cls_ch_squeeze)
+
+    def make_divisible(self, v, divisor=8, min_value=None):
+        if min_value is None:
+            min_value = divisor
+        new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+        if new_v < 0.9 * v:
+            new_v += divisor
+        return new_v
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.blocks(x)
+        x = self.conv2(x)
+        x = self.pool(x)
+        return x

+ 189 - 0
torchocr/networks/backbones/RecResNetvd.py

@@ -0,0 +1,189 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import OrderedDict
+import torch
+from torch import nn
+
+from torchocr.networks.CommonModules import HSwish
+
+
+class ConvBNACT(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, act=None):
+        super().__init__()
+        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
+                              stride=stride, padding=padding, groups=groups,
+                              bias=False)
+        self.bn = nn.BatchNorm2d(out_channels)
+        if act == 'relu':
+            self.act = nn.ReLU()
+        elif act == 'hard_swish':
+            self.act = HSwish()
+        elif act is None:
+            self.act = None
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        if self.act is not None:
+            x = self.act(x)
+        return x
+
+
+class ConvBNACTWithPool(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, groups=1, act=None):
+        super().__init__()
+        self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0, ceil_mode=True)
+
+        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1,
+                              padding=(kernel_size - 1) // 2,
+                              groups=groups,
+                              bias=False)
+        self.bn = nn.BatchNorm2d(out_channels)
+        if act is None:
+            self.act = None
+        else:
+            self.act = nn.ReLU()
+
+    def forward(self, x):
+        x = self.pool(x)
+        x = self.conv(x)
+        x = self.bn(x)
+        if self.act is not None:
+            x = self.act(x)
+        return x
+
+
+class ShortCut(nn.Module):
+    def __init__(self, in_channels, out_channels, stride, name, if_first=False):
+        super().__init__()
+        assert name is not None, 'shortcut must have name'
+
+        self.name = name
+        if in_channels != out_channels or stride[0] != 1:
+            if if_first:
+                self.conv = ConvBNACT(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride,
+                                      padding=0, groups=1, act=None)
+            else:
+                self.conv = ConvBNACTWithPool(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
+                                              stride=stride, groups=1, act=None)
+        elif if_first:
+            self.conv = ConvBNACT(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride,
+                                  padding=0, groups=1, act=None)
+        else:
+            self.conv = None
+
+
+    def forward(self, x):
+        if self.conv is not None:
+            x = self.conv(x)
+        return x
+
+
+class BasicBlock(nn.Module):
+    def __init__(self, in_channels, out_channels, stride, if_first, name):
+        super().__init__()
+        assert name is not None, 'block must have name'
+        self.name = name
+
+        self.conv0 = ConvBNACT(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride,
+                               padding=1, groups=1, act='relu')
+        self.conv1 = ConvBNACT(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1,
+                               groups=1, act=None)
+        self.shortcut = ShortCut(in_channels=in_channels, out_channels=out_channels, stride=stride,
+                                 name=f'{name}_branch1', if_first=if_first, )
+        self.relu = nn.ReLU()
+        self.output_channels = out_channels
+
+    def forward(self, x):
+        y = self.conv0(x)
+        y = self.conv1(y)
+        y = y + self.shortcut(x)
+        return self.relu(y)
+
+
+class BottleneckBlock(nn.Module):
+    def __init__(self, in_channels, out_channels, stride, if_first, name):
+        super().__init__()
+        assert name is not None, 'bottleneck must have name'
+        self.name = name
+        self.conv0 = ConvBNACT(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0,
+                               groups=1, act='relu')
+        self.conv1 = ConvBNACT(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=stride,
+                               padding=1, groups=1, act='relu')
+        self.conv2 = ConvBNACT(in_channels=out_channels, out_channels=out_channels * 4, kernel_size=1, stride=1,
+                               padding=0, groups=1, act=None)
+        self.shortcut = ShortCut(in_channels=in_channels, out_channels=out_channels * 4, stride=stride,
+                                 if_first=if_first, name=f'{name}_branch1')
+        self.relu = nn.ReLU()
+        self.output_channels = out_channels * 4
+
+    def forward(self, x):
+        y = self.conv0(x)
+        y = self.conv1(y)
+        y = self.conv2(y)
+        y = y + self.shortcut(x)
+        return self.relu(y)
+
+
+class ResNet(nn.Module):
+    def __init__(self, in_channels, layers, **kwargs):
+        super().__init__()
+        supported_layers = {
+            18: {'depth': [2, 2, 2, 2], 'block_class': BasicBlock},
+            34: {'depth': [3, 4, 6, 3], 'block_class': BasicBlock},
+            50: {'depth': [3, 4, 6, 3], 'block_class': BottleneckBlock},
+            101: {'depth': [3, 4, 23, 3], 'block_class': BottleneckBlock},
+            152: {'depth': [3, 8, 36, 3], 'block_class': BottleneckBlock},
+            200: {'depth': [3, 12, 48, 3], 'block_class': BottleneckBlock}
+        }
+        assert layers in supported_layers, "supported layers are {} but input layer is {}".format(supported_layers,
+                                                                                                  layers)
+
+        depth = supported_layers[layers]['depth']
+        block_class = supported_layers[layers]['block_class']
+
+        num_filters = [64, 128, 256, 512]
+        self.conv1 = nn.Sequential(
+            ConvBNACT(in_channels=in_channels, out_channels=32, kernel_size=3, stride=1, padding=1, act='relu'),
+            ConvBNACT(in_channels=32, out_channels=32, kernel_size=3, stride=1, act='relu', padding=1),
+            ConvBNACT(in_channels=32, out_channels=64, kernel_size=3, stride=1, act='relu', padding=1)
+        )
+
+        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+        self.stages = nn.ModuleList()
+        in_ch = 64
+        for block_index in range(len(depth)):
+            block_list = []
+            for i in range(depth[block_index]):
+                if layers >= 50:
+                    if layers in [101, 152, 200] and block_index == 2:
+                        if i == 0:
+                            conv_name = "res" + str(block_index + 2) + "a"
+                        else:
+                            conv_name = "res" + str(block_index + 2) + "b" + str(i)
+                    else:
+                        conv_name = "res" + str(block_index + 2) + chr(97 + i)
+                else:
+                    conv_name = f'res{str(block_index + 2)}{chr(97 + i)}'
+                if i == 0 and block_index != 0:
+                    stride = (2, 1)
+                else:
+                    stride = (1, 1)
+                block_list.append(block_class(in_channels=in_ch, out_channels=num_filters[block_index],
+                                              stride=stride,
+                                              if_first=block_index == i == 0, name=conv_name))
+                in_ch = block_list[-1].output_channels
+            self.stages.append(nn.Sequential(*block_list))
+        self.out_channels = in_ch
+        self.out = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.pool1(x)
+        for stage in self.stages:
+            x = stage(x)
+        x = self.out(x)
+        return x

+ 646 - 0
torchocr/networks/backbones/Transformer.py

@@ -0,0 +1,646 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from collections import OrderedDict
+
+
+
+class Mlp(nn.Module):
+    """ Multilayer perceptron."""
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+    return x
+
+
+class WindowAttention(nn.Module):
+    """ Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        # define a parameter table of relative position bias
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, mask=None):
+        """ Forward function.
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        B_, N, C = x.shape
+        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = (q @ k.transpose(-2, -1))
+
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+        attn = attn + relative_position_bias.unsqueeze(0)
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+            attn = attn.view(-1, self.num_heads, N, N)
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+def drop_path_f(x, drop_prob: float = 0., training: bool = False):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+    'survival rate' as the argument.
+    """
+    if drop_prob == 0. or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+    random_tensor.floor_()  # binarize
+    output = x.div(keep_prob) * random_tensor
+    return output
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path_f(x, self.drop_prob, self.training)
+
+
+class SwinTransformerBlock(nn.Module):
+    """ Swin Transformer Block.
+    Args:
+        dim (int): Number of input channels.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size
+        self.mlp_ratio = mlp_ratio
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
+            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        self.H = None
+        self.W = None
+
+    def forward(self, x, mask_matrix):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+            mask_matrix: Attention mask for cyclic shift.
+        """
+        B, L, C = x.shape
+        H, W = self.H, self.W
+        assert L == H * W, "input feature has wrong size"
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(B, H, W, C)
+
+        # pad feature maps to multiples of window size
+        pad_l = pad_t = 0
+        pad_r = (self.window_size - W % self.window_size) % self.window_size
+        pad_b = (self.window_size - H % self.window_size) % self.window_size
+        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+        _, Hp, Wp, _ = x.shape
+
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+            attn_mask = mask_matrix
+        else:
+            shifted_x = x
+            attn_mask = None
+
+        # partition windows
+        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
+        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
+
+        # W-MSA/SW-MSA
+        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            x = shifted_x
+
+        if pad_r > 0 or pad_b > 0:
+            x = x[:, :H, :W, :].contiguous()
+
+        x = x.view(B, H * W, C)
+
+        # FFN
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+        return x
+
+
+class PatchMerging(nn.Module):
+    """ Patch Merging Layer
+    Args:
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x, H, W):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        x = x.view(B, H, W, C)
+
+        # padding
+        pad_input = (H % 2 == 1) or (W % 2 == 1)
+        if pad_input:
+            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+
+class BasicLayer(nn.Module):
+    """ A basic Swin Transformer layer for one stage.
+    Args:
+        dim (int): Number of feature channels
+        depth (int): Depths of this stage.
+        num_heads (int): Number of attention head.
+        window_size (int): Local window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self,
+                 dim,
+                 depth,
+                 num_heads,
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 norm_layer=nn.LayerNorm,
+                 downsample=None,
+                 use_checkpoint=False):
+        super().__init__()
+        self.window_size = window_size
+        self.shift_size = window_size // 2
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock(
+                dim=dim,
+                num_heads=num_heads,
+                window_size=window_size,
+                shift_size=0 if (i % 2 == 0) else window_size // 2,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop,
+                attn_drop=attn_drop,
+                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                norm_layer=norm_layer)
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x, H, W):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+
+        # calculate attention mask for SW-MSA
+        Hp = int(np.ceil(H / self.window_size)) * self.window_size
+        Wp = int(np.ceil(W / self.window_size)) * self.window_size
+        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
+        h_slices = (slice(0, -self.window_size),
+                    slice(-self.window_size, -self.shift_size),
+                    slice(-self.shift_size, None))
+        w_slices = (slice(0, -self.window_size),
+                    slice(-self.window_size, -self.shift_size),
+                    slice(-self.shift_size, None))
+        cnt = 0
+        for h in h_slices:
+            for w in w_slices:
+                img_mask[:, h, w, :] = cnt
+                cnt += 1
+
+        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
+        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+        for blk in self.blocks:
+            blk.H, blk.W = H, W
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x, attn_mask)
+            else:
+                x = blk(x, attn_mask)
+        if self.downsample is not None:
+            x_down = self.downsample(x, H, W)
+            Wh, Ww = (H + 1) // 2, (W + 1) // 2
+            return x, H, W, x_down, Wh, Ww
+        else:
+            return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    Args:
+        patch_size (int): Patch token size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        norm_layer (nn.Module, optional): Normalization layer. Default: None
+    """
+
+    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        patch_size = (patch_size, patch_size)
+        self.patch_size = patch_size
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
+
+    def forward(self, x):
+        """Forward function."""
+        # padding
+        _, _, H, W = x.size()
+        if W % self.patch_size[1] != 0:
+            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+        if H % self.patch_size[0] != 0:
+            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+        x = self.proj(x)  # B C Wh Ww
+        if self.norm is not None:
+            Wh, Ww = x.size(2), x.size(3)
+            x = x.flatten(2).transpose(1, 2)
+            x = self.norm(x)
+            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+        return x
+
+
+class SwinTransformer(nn.Module):
+    """ Swin Transformer backbone.
+        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
+          https://arxiv.org/pdf/2103.14030
+    Args:
+        pretrain_img_size (int): Input image size for training the pretrained model,
+            used in absolute postion embedding. Default 224.
+        patch_size (int | tuple(int)): Patch size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        depths (tuple[int]): Depths of each Swin Transformer stage.
+        num_heads (tuple[int]): Number of attention head of each stage.
+        window_size (int): Window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+        drop_rate (float): Dropout rate.
+        attn_drop_rate (float): Attention dropout rate. Default: 0.
+        drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+        out_indices (Sequence[int]): Output from which stages.
+        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+            -1 means not freezing any parameters.
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self,
+                 pretrain_img_size=224,
+                 patch_size=4,
+                 in_chans=3,
+                 embed_dim=96,
+                 depths=[2, 2, 6, 2],
+                 num_heads=[3, 6, 12, 24],
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.2,
+                 norm_layer=nn.LayerNorm,
+                 ape=False,
+                 patch_norm=True,
+                 out_indices=(0, 1, 2, 3),
+                 frozen_stages=-1,
+                 use_checkpoint=False,**kwargs):
+        super().__init__()
+
+        self.pretrain_img_size = pretrain_img_size
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.ape = ape
+        self.patch_norm = patch_norm
+        self.out_indices = out_indices
+        self.frozen_stages = frozen_stages
+        self.out_channels = [96, 192, 384, 768]
+        self.pretrained = kwargs.get('pretrained', True)
+        # split image into non-overlapping patches
+        self.patch_embed = PatchEmbed(
+            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None)
+
+        # absolute position embedding
+        if self.ape:
+            pretrain_img_size = (pretrain_img_size, pretrain_img_size)
+            patch_size = (patch_size, patch_size)
+            patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
+
+            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
+            nn.init.trunc_normal_(self.absolute_pos_embed, std=.02)
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
+
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            layer = BasicLayer(
+                dim=int(embed_dim * 2 ** i_layer),
+                depth=depths[i_layer],
+                num_heads=num_heads[i_layer],
+                window_size=window_size,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+                norm_layer=norm_layer,
+                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+                use_checkpoint=use_checkpoint)
+            self.layers.append(layer)
+
+        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+        self.num_features = num_features
+
+        # add a norm layer for each output
+        for i_layer in out_indices:
+            layer = norm_layer(num_features[i_layer])
+            layer_name = f'norm{i_layer}'
+            self.add_module(layer_name, layer)
+
+        self._freeze_stages()
+
+
+        new_ckpt = OrderedDict()
+        ckpt_path = './weights/upernet_swin_tiny_patch4_window7_512x512.pth'
+        weights_dict = torch.load(ckpt_path)["state_dict"]
+        for k in list(weights_dict.keys()):
+            if k.find('backbone.') != -1:
+                new_key = k.replace('backbone.', '')
+                new_ckpt[new_key] = weights_dict[k]
+        self.load_state_dict(new_ckpt, strict=True)
+
+    def _freeze_stages(self):
+        if self.frozen_stages >= 0:
+            self.patch_embed.eval()
+            for param in self.patch_embed.parameters():
+                param.requires_grad = False
+
+        if self.frozen_stages >= 1 and self.ape:
+            self.absolute_pos_embed.requires_grad = False
+
+        if self.frozen_stages >= 2:
+            self.pos_drop.eval()
+            for i in range(0, self.frozen_stages - 1):
+                m = self.layers[i]
+                m.eval()
+                for param in m.parameters():
+                    param.requires_grad = False
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone.
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+
+        def _init_weights(m):
+            if isinstance(m, nn.Linear):
+                nn.init.trunc_normal_(m.weight, std=.02)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+        if isinstance(pretrained, str):
+            self.apply(_init_weights)
+        elif pretrained is None:
+            self.apply(_init_weights)
+        else:
+            raise TypeError('pretrained must be a str or None')
+
+    def forward(self, x):
+        """Forward function."""
+        x = self.patch_embed(x)
+
+        Wh, Ww = x.size(2), x.size(3)
+        if self.ape:
+            # interpolate the position embedding to the corresponding size
+            absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
+            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
+        else:
+            x = x.flatten(2).transpose(1, 2)
+        x = self.pos_drop(x)
+
+        outs = []
+        for i in range(self.num_layers):
+            layer = self.layers[i]
+            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+            if i in self.out_indices:
+                norm_layer = getattr(self, f'norm{i}')
+                x_out = norm_layer(x_out)
+
+                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+                outs.append(out)
+
+        return tuple(outs)
+
+    def train(self, mode=True):
+        """Convert the model into training mode while keep layers freezed."""
+        super(SwinTransformer, self).train(mode)
+        self._freeze_stages()

+ 3 - 0
torchocr/networks/backbones/__init__.py

@@ -0,0 +1,3 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/15 17:41
+# @Author  : zhoujun

+ 76 - 0
torchocr/networks/heads/DetDbHead.py

@@ -0,0 +1,76 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import torch
+from torch import nn
+
+
+class Head(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 4, kernel_size=3, padding=1,
+                               bias=False)
+        # self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 4, kernel_size=5, padding=2,
+        #                        bias=False)
+        self.conv_bn1 = nn.BatchNorm2d(in_channels // 4)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = nn.ConvTranspose2d(in_channels=in_channels // 4, out_channels=in_channels // 4, kernel_size=2,
+                                        stride=2)
+        self.conv_bn2 = nn.BatchNorm2d(in_channels // 4)
+        self.conv3 = nn.ConvTranspose2d(in_channels=in_channels // 4, out_channels=1, kernel_size=2, stride=2)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.conv_bn1(x)
+        x = self.relu(x)
+        x = self.conv2(x)
+        x = self.conv_bn2(x)
+        x = self.relu(x)
+        x = self.conv3(x)
+        x = torch.sigmoid(x)
+        return x
+
+
+def weights_init(m):
+    import torch.nn.init as init
+    if isinstance(m, nn.Conv2d):
+        init.kaiming_normal_(m.weight.data)
+        if m.bias is not None:
+            init.normal_(m.bias.data)
+    elif isinstance(m, nn.ConvTranspose2d):
+        init.kaiming_normal_(m.weight.data)
+        if m.bias is not None:
+            init.normal_(m.bias.data)
+    elif isinstance(m, nn.BatchNorm2d):
+        init.normal_(m.weight.data, mean=1, std=0.02)
+        init.constant_(m.bias.data, 0)
+
+
+class DBHead(nn.Module):
+    """
+    Differentiable Binarization (DB) for text detection:
+        see https://arxiv.org/abs/1911.08947
+    args:
+        params(dict): super parameters for build DB network
+    """
+
+    def __init__(self, in_channels, k=50):
+        super().__init__()
+        self.k = k
+        self.binarize = Head(in_channels)
+        self.thresh = Head(in_channels)
+        self.binarize.apply(weights_init)
+        self.thresh.apply(weights_init)
+
+    def step_function(self, x, y):
+        return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
+
+    def forward(self, x):
+        shrink_maps = self.binarize(x)
+        if not self.training:
+            return shrink_maps
+        threshold_maps = self.thresh(x)
+        binary_maps = self.step_function(shrink_maps, threshold_maps)
+        y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)
+        return y

+ 26 - 0
torchocr/networks/heads/DetPseHead.py

@@ -0,0 +1,26 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from torch import nn
+import torch.nn.functional as F
+from torchocr.networks.CommonModules import ConvBNACT
+
+
+class PseHead(nn.Module):
+    def __init__(self, in_channels, result_num=6, **kwargs):
+        super(PseHead, self).__init__()
+        self.H = kwargs.get('H', 640)
+        self.W = kwargs.get('W', 640)
+        self.scale = kwargs.get('scale', 1)
+        self.conv = ConvBNACT(in_channels, in_channels // 4, kernel_size=3, padding=1, stride=1, act='relu')
+        self.out_conv = nn.Conv2d(in_channels // 4, result_num, kernel_size=1, stride=1)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.out_conv(x)
+        if self.train:
+            x = F.interpolate(x, size=(self.H, self.W), mode='bilinear', align_corners=True)
+        else:
+            x = F.interpolate(x, size=(self.H // self.scale, self.W // self.scale), mode='bilinear', align_corners=True)
+        return x

+ 73 - 0
torchocr/networks/heads/FCEHead.py

@@ -0,0 +1,73 @@
+
+
+from torch import nn
+import torch.nn.functional as F
+import torch
+from functools import partial
+
+
+def multi_apply(func, *args, **kwargs):
+    pfunc = partial(func, **kwargs) if kwargs else func
+    map_results = map(pfunc, *args)
+    return tuple(map(list, zip(*map_results)))
+
+
+class FCEHead(nn.Module):
+    """The class for implementing FCENet head.
+    FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text
+    Detection.
+
+    [https://arxiv.org/abs/2104.10442]
+
+    Args:
+        in_channels (int): The number of input channels.
+        scales (list[int]) : The scale of each layer.
+        fourier_degree (int) : The maximum Fourier transform degree k.
+    """
+
+    def __init__(self, in_channels, fourier_degree=5):
+        super().__init__()
+        assert isinstance(in_channels, int)
+
+        self.downsample_ratio = 1.0
+        self.in_channels = in_channels
+        self.fourier_degree = fourier_degree
+        self.out_channels_cls = 4
+        self.out_channels_reg = (2 * self.fourier_degree + 1) * 2
+
+        self.out_conv_cls = nn.Conv2d(
+            in_channels=self.in_channels,
+            out_channels=self.out_channels_cls,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            groups=1,
+            bias=True)
+        self.out_conv_reg = nn.Conv2d(
+            in_channels=self.in_channels,
+            out_channels=self.out_channels_reg,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            groups=1,
+            bias=True)
+
+    def forward(self, feats, targets=None):
+        cls_res, reg_res = multi_apply(self.forward_single, feats)
+        level_num = len(cls_res)
+        outs = {}
+        if not self.training:
+            for i in range(level_num):
+                tr_pred = F.softmax(cls_res[i][:, 0:2, :, :], dim=1)
+                tcl_pred = F.softmax(cls_res[i][:, 2:, :, :], dim=1)
+                outs['level_{}'.format(i)] = torch.cat(
+                    [tr_pred, tcl_pred, reg_res[i]], dim=1)
+        else:
+            preds = [[cls_res[i], reg_res[i]] for i in range(level_num)]
+            outs['levels'] = preds
+        return outs
+
+    def forward_single(self, x):
+        cls_predict = self.out_conv_cls(x)
+        reg_predict = self.out_conv_reg(x)
+        return cls_predict, reg_predict

+ 18 - 0
torchocr/networks/heads/RecCTCHead.py

@@ -0,0 +1,18 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import OrderedDict
+
+import torch
+from torch import nn
+
+
+class CTC(nn.Module):
+    def __init__(self, in_channels, n_class, **kwargs):
+        super().__init__()
+        self.fc = nn.Linear(in_channels, n_class)
+        self.n_class = n_class
+
+    def forward(self, x):
+        return self.fc(x)

+ 3 - 0
torchocr/networks/heads/__init__.py

@@ -0,0 +1,3 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/15 17:42
+# @Author  : zhoujun

+ 143 - 0
torchocr/networks/losses/CTCLoss.py

@@ -0,0 +1,143 @@
+from torch import nn
+import torch
+
+
+class CTCLoss(nn.Module):
+
+    def __init__(self, loss_cfg, reduction='mean'):
+        super().__init__()
+        self.loss_func = torch.nn.CTCLoss(blank=loss_cfg['blank_idx'], reduction=reduction, zero_infinity=True)
+
+    def forward(self, pred, args):
+        # print(pred)
+        batch_size = pred.size(0)
+        label, label_length = args['targets'], args['targets_lengths']
+        pred = pred.log_softmax(2)
+        pred = pred.permute(1, 0, 2)
+        preds_lengths = torch.tensor([pred.size(0)] * batch_size, dtype=torch.long)
+        loss = self.loss_func(pred, label, preds_lengths, label_length)
+        return {'loss': loss}
+
+class EnhancedCTCLoss(nn.Module):
+
+    def __init__(self,
+                 # use_focal_loss=False,
+                 # use_ace_loss=False,
+                 # ace_loss_weight=0.1,
+                 loss_cfg,
+                 use_center_loss=True,
+                 center_loss_weight=0.05,
+                 num_classes=6625,
+                 feat_dim=96,
+                 init_center=False,
+                 center_file_path=None,
+                 **kwargs):
+        super(EnhancedCTCLoss, self).__init__()
+        self.ctc_loss_func = CTCLoss(loss_cfg)
+
+        # self.use_ace_loss = False
+        # if use_ace_loss:
+        #     self.use_ace_loss = use_ace_loss
+        #     self.ace_loss_func = ACELoss()
+        #     self.ace_loss_weight = ace_loss_weight
+
+        self.use_center_loss = False
+        if use_center_loss:
+            self.use_center_loss = use_center_loss
+            self.center_loss_func = CenterLoss(
+                num_classes=num_classes,
+                feat_dim=feat_dim,
+                init_center=init_center,
+                center_file_path=center_file_path)
+            self.center_loss_weight = center_loss_weight
+
+    def forward(self, predicts, batch):
+        loss = self.ctc_loss_func(predicts, batch)["loss"]
+
+        if self.use_center_loss:
+            center_loss = self.center_loss_func(
+                predicts, batch)["loss_center"] * self.center_loss_weight
+            loss = loss + center_loss
+
+        # if self.use_ace_loss:
+        #     ace_loss = self.ace_loss_func(
+        #         predicts, batch)["loss_ace"] * self.ace_loss_weight
+        #     loss = loss + ace_loss
+
+        return {'enhanced_ctc_loss': loss}
+
+class CenterLoss(nn.Module):
+    """
+    Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
+    """
+
+    def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
+        super().__init__()
+        self.num_classes = num_classes
+        self.feat_dim = feat_dim
+        self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
+
+        # if center_file_path is not None:
+        #     assert os.path.exists(
+        #         center_file_path
+        #     ), f"center path({center_file_path}) must exist when it is not None."
+        #     with open(center_file_path, 'rb') as f:
+        #         char_dict = pickle.load(f)
+        #         for key in char_dict.keys():
+        #             self.centers[key] = paddle.to_tensor(char_dict[key])
+
+    def forward(self, predicts, batch):
+        # assert isinstance(predicts, (list, tuple))
+        # features, predicts = predicts
+        predicts = predicts
+        features = batch
+
+        # feats_reshape = paddle.reshape(
+        #     features, [-1, features.shape[-1]]).astype("float64")
+        # label = paddle.argmax(predicts, axis=2)
+        # label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
+        #
+        # batch_size = feats_reshape.shape[0]
+        #
+        # #calc l2 distance between feats and centers
+        # square_feat = paddle.sum(paddle.square(feats_reshape),
+        #                          axis=1,
+        #                          keepdim=True)
+        # square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
+        #
+        # square_center = paddle.sum(paddle.square(self.centers),
+        #                            axis=1,
+        #                            keepdim=True)
+        # square_center = paddle.expand(
+        #     square_center, [self.num_classes, batch_size]).astype("float64")
+        # square_center = paddle.transpose(square_center, [1, 0])
+        #
+        # distmat = paddle.add(square_feat, square_center)
+        # feat_dot_center = paddle.matmul(feats_reshape,
+        #                                 paddle.transpose(self.centers, [1, 0]))
+        # distmat = distmat - 2.0 * feat_dot_center
+
+
+        x = predicts
+        labels = features
+        batch_size = x.size(0)
+        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
+                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
+        distmat.addmm_(1, -2, x, self.centers.t())
+
+        #generate the mask
+        # classes = torch.arange(self.num_classes).astype("int64")
+        classes = torch.arange(self.num_classes).long()
+        # label = paddle.expand(
+        #     paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
+        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
+        # mask = paddle.equal(
+        #     paddle.expand(classes, [batch_size, self.num_classes]),
+        #     label).astype("float64")
+        mask = labels.eq(classes.expand(batch_size, self.num_classes))
+        # dist = paddle.multiply(distmat, mask)
+        dist = distmat * mask.float()
+
+        # loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
+        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
+        return {'loss_center': loss}

+ 165 - 0
torchocr/networks/losses/CTCLoss_test.py

@@ -0,0 +1,165 @@
+from torch import nn
+import torch
+
+
+class CTCLoss(nn.Module):
+
+    def __init__(self, loss_cfg, reduction='mean',use_focal_loss=False):
+        super().__init__()
+        if use_focal_loss:
+            reduction = 'none'
+        self.loss_func = torch.nn.CTCLoss(blank=loss_cfg['blank_idx'], reduction=reduction, zero_infinity=True)
+        self.use_focal_loss = use_focal_loss
+
+    def forward(self, pred, args):
+        pred = pred[1]
+        batch_size = pred.size(0)
+        label, label_length = args['targets'], args['targets_lengths']
+        pred = pred.log_softmax(2)
+        pred = pred.permute(1, 0, 2)
+        preds_lengths = torch.tensor([pred.size(0)] * batch_size, dtype=torch.long)
+        loss = self.loss_func(pred, label, preds_lengths, label_length)
+        if self.use_focal_loss:
+            weight = torch.exp(-loss)
+            weight = torch.subtract(torch.as_tensor(1.0), weight)
+            weight = torch.square(weight)
+            loss = torch.multiply(loss , weight)
+            loss = loss.mean()
+
+        return {'loss': loss}
+
+class EnhancedCTCLoss(nn.Module):
+
+    def __init__(self,
+                 # use_focal_loss=False,
+                 # use_ace_loss=False,
+                 # ace_loss_weight=0.1,
+                 loss_cfg,
+                 use_center_loss=True,
+                 center_loss_weight=0.05,
+                 num_classes=5990,
+                 feat_dim=96,
+                 init_center=False,
+                 center_file_path=None,
+                 **kwargs):
+        super(EnhancedCTCLoss, self).__init__()
+        self.ctc_loss_func = CTCLoss(loss_cfg)
+
+        # self.use_ace_loss = False
+        # if use_ace_loss:
+        #     self.use_ace_loss = use_ace_loss
+        #     self.ace_loss_func = ACELoss()
+        #     self.ace_loss_weight = ace_loss_weight
+
+        self.use_center_loss = False
+        if use_center_loss:
+            self.use_center_loss = use_center_loss
+            self.center_loss_func = CenterLoss(
+                num_classes=num_classes,
+                feat_dim=feat_dim,
+                center_file_path=center_file_path)
+            self.center_loss_weight = center_loss_weight
+
+    def forward(self, predicts,args):
+        loss = self.ctc_loss_func(predicts, args)["loss"]
+
+        if self.use_center_loss:
+            center_loss = self.center_loss_func(
+                predicts)["loss_center"] * self.center_loss_weight
+            loss = loss + center_loss
+
+        # if self.use_ace_loss:
+        #     ace_loss = self.ace_loss_func(
+        #         predicts, batch)["loss_ace"] * self.ace_loss_weight
+        #     loss = loss + ace_loss
+
+        return {'loss': loss}
+
+class CenterLoss(nn.Module):
+    """
+    Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
+    """
+
+    def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None):
+        super().__init__()
+        self.num_classes = num_classes
+        self.feat_dim = feat_dim
+        self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
+        self.use_gpu = True if torch.cuda.is_available() else False
+
+        # if center_file_path is not None:
+        #     assert os.path.exists(
+        #         center_file_path
+        #     ), f"center path({center_file_path}) must exist when it is not None."
+        #     with open(center_file_path, 'rb') as f:
+        #         char_dict = pickle.load(f)
+        #         for key in char_dict.keys():
+        #             self.centers[key] = paddle.to_tensor(char_dict[key])
+
+    def forward(self, predicts):
+        # assert isinstance(predicts, (list, tuple))
+        features, predicts = predicts
+        # predicts = predicts
+        # features = batch
+        # batch_size = features.size(0)
+        label = predicts.argmax( axis=2)
+
+
+        # feats_reshape = paddle.reshape(
+        #     features, [-1, features.shape[-1]]).astype("float64")
+        # feats_reshape = features.reshape(-1, features.size[-1])
+        feats_reshape = torch.reshape(features,(-1, features.size(-1)))
+        # label = paddle.argmax(predicts, axis=2)
+        # label = features.argmax( axis=2)
+        label = label.reshape(label.size(0) * label.size(1))
+
+        batch_size = feats_reshape.size(0)
+
+        #calc l2 distance between feats and centers
+        # square_feat = paddle.sum(paddle.square(feats_reshape),
+        #                          axis=1,
+        #                          keepdim=True)
+        # square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
+        square_feat = torch.pow(feats_reshape, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes)
+
+        # square_center = paddle.sum(paddle.square(self.centers),
+        #                            axis=1,
+        #                            keepdim=True)
+        # square_center = paddle.expand(
+        #     square_center, [self.num_classes, batch_size]).astype("float64")
+        # square_center = paddle.transpose(square_center, [1, 0])
+        square_center = torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
+
+        # distmat = paddle.add(square_feat, square_center)
+        distmat = square_feat + square_center
+        # feat_dot_center = paddle.matmul(feats_reshape,
+        #                                 paddle.transpose(self.centers, [1, 0]))
+        # distmat = distmat - 2.0 * feat_dot_center
+        distmat.addmm_(1, -2, feats_reshape, self.centers.t())
+
+
+        # x = predicts
+        # labels = features
+        # batch_size = x.size(0)
+        # distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
+        #           torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
+        # distmat.addmm_(1, -2, x, self.centers.t())
+
+        #generate the mask
+        # classes = torch.arange(self.num_classes).astype("int64")
+        classes = torch.arange(self.num_classes).long()
+        if self.use_gpu:
+            classes = classes.cuda()
+        # label = paddle.expand(
+        #     paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
+        labels = label.unsqueeze(1).expand(batch_size, self.num_classes)
+        # mask = paddle.equal(
+        #     paddle.expand(classes, [batch_size, self.num_classes]),
+        #     label).astype("float64")
+        mask = labels.eq(classes.expand(batch_size, self.num_classes))
+        # dist = paddle.multiply(distmat, mask)
+        dist = distmat * mask.float()
+
+        # loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
+        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
+        return {'loss_center': loss}

+ 27 - 0
torchocr/networks/losses/CombinedLoss.py

@@ -0,0 +1,27 @@
+import torch
+import torch.nn as nn
+from .distillation_loss import DistillationDilaDBLoss,DistillationDBLoss,DistillationDMLLoss
+
+class CombinedLoss(nn.Module):
+    def __init__(self, _cfg_list=None):
+        super().__init__()
+        self.loss_func = []
+        self.loss_weight = []
+        for key, val in _cfg_list['combine_list'].items():
+            self.loss_weight.append(val.pop('weight'))
+            self.loss_func.append(eval(key)(**val))
+
+    def forward(self, input, batch, **kwargs):
+        loss_dict = {}
+        loss_all = 0.
+        for idx, loss_func in enumerate(self.loss_func):
+            loss = loss_func(input, batch, **kwargs)
+            weight = self.loss_weight[idx]
+            loss = {key: loss[key] * weight for key in loss}
+            if 'loss' in loss:
+                loss_all =torch.add(loss_all, loss['loss'])
+            else:
+                loss_all += torch.add(list(loss.values()))
+            loss_dict.update(loss)
+        loss_dict['loss'] = loss_all
+        return loss_dict

+ 59 - 0
torchocr/networks/losses/DBLoss.py

@@ -0,0 +1,59 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2019/8/23 21:56
+# @Author  : zhoujun
+from torch import nn
+
+from torchocr.networks.losses.DetBasicLoss import BalanceCrossEntropyLoss, MaskL1Loss, DiceLoss, BalanceLoss
+
+
+class DBLoss(nn.Module):
+    def __init__(self, balance_loss=True, main_loss_type='DiceLoss', alpha=1.0, beta=10, ohem_ratio=3, reduction='mean',
+                 eps=1e-6):
+        """
+        Implement PSE Loss.
+        :param alpha: binary_map loss 前面的系数
+        :param beta: threshold_map loss 前面的系数
+        :param ohem_ratio: OHEM的比例
+        :param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和
+        """
+        super().__init__()
+        assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
+        self.alpha = alpha
+        self.beta = beta
+        # self.bce_loss = BalanceCrossEntropyLoss(negative_ratio=ohem_ratio)
+        self.bce_loss = BalanceLoss(
+            balance_loss=balance_loss,
+            main_loss_type=main_loss_type,
+            negative_ratio=ohem_ratio)
+        self.dice_loss = DiceLoss(eps=eps)
+        self.l1_loss = MaskL1Loss(eps=eps)
+        self.reduction = reduction
+
+    def forward(self, pred, batch):
+        """
+
+        :param pred:
+        :param batch: bach为一个dict{
+                                    'shrink_map': 收缩图,b*c*h,w
+                                    'shrink_mask: 收缩图mask,b*c*h,w
+                                    'threshold_map: 二值化边界gt,b*c*h,w
+                                    'threshold_mask: 二值化边界gtmask,b*c*h,w
+                                    }
+        :return:
+        """
+        shrink_maps = pred[:, 0, :, :]
+        threshold_maps = pred[:, 1, :, :]
+        binary_maps = pred[:, 2, :, :]
+
+        loss_shrink_maps = self.alpha * self.bce_loss(shrink_maps, batch['shrink_map'], batch['shrink_mask'])
+        loss_threshold_maps = self.beta * self.l1_loss(threshold_maps, batch['threshold_map'], batch['threshold_mask'])
+        loss_dict = dict(loss_shrink_maps=loss_shrink_maps, loss_threshold_maps=loss_threshold_maps)
+        if pred.size()[1] > 2:
+            loss_binary_maps = self.dice_loss(binary_maps, batch['shrink_map'], batch['shrink_mask'])
+            loss_dict['loss_binary_maps'] = loss_binary_maps
+            loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps
+            loss_dict['loss'] = loss_all
+        else:
+            loss_dict['loss'] = loss_shrink_maps
+
+        return loss_dict

+ 182 - 0
torchocr/networks/losses/DetBasicLoss.py

@@ -0,0 +1,182 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2019/12/4 14:39
+# @Author  : zhoujun
+import torch
+import torch.nn as nn
+
+
+class BalanceCrossEntropyLoss(nn.Module):
+    '''
+    Balanced cross entropy loss.
+    Shape:
+        - Input: :math:`(N, 1, H, W)`
+        - GT: :math:`(N, 1, H, W)`, same shape as the input
+        - Mask: :math:`(N, H, W)`, same spatial shape as the input
+        - Output: scalar.
+
+    Examples::
+
+        >>> m = nn.Sigmoid()
+        >>> loss = nn.BCELoss()
+        >>> input = torch.randn(3, requires_grad=True)
+        >>> target = torch.empty(3).random_(2)
+        >>> output = loss(m(input), target)
+        >>> output.backward()
+    '''
+
+    def __init__(self, negative_ratio=3.0, eps=1e-6):
+        super(BalanceCrossEntropyLoss, self).__init__()
+        self.negative_ratio = negative_ratio
+        self.eps = eps
+
+    def forward(self,
+                pred: torch.Tensor,
+                gt: torch.Tensor,
+                mask: torch.Tensor,
+                return_origin=False):
+        '''
+        Args:
+            pred: shape :math:`(N, 1, H, W)`, the prediction of network
+            gt: shape :math:`(N, 1, H, W)`, the target
+            mask: shape :math:`(N, H, W)`, the mask indicates positive regions
+        '''
+
+        positive = (gt * mask).byte()
+        negative = ((1 - gt) * mask).byte()
+        positive_count = int(positive.float().sum())
+        negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio))
+        loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none')
+        positive_loss = loss * positive.float()
+        negative_loss = loss * negative.float()
+
+        negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
+
+        balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + self.eps)
+
+        if return_origin:
+            return balance_loss, loss
+        return balance_loss
+
+
+class DiceLoss(nn.Module):
+    '''
+    Loss function from https://arxiv.org/abs/1707.03237,
+    where iou computation is introduced heatmap manner to measure the
+    diversity bwtween tow heatmaps.
+    '''
+
+    def __init__(self, eps=1e-6):
+        super(DiceLoss, self).__init__()
+        self.eps = eps
+
+    def forward(self, pred: torch.Tensor, gt, mask, weights=None):
+        '''
+        pred: one or two heatmaps of shape (N, 1, H, W),
+            the losses of tow heatmaps are added together.
+        gt: (N, 1, H, W)
+        mask: (N, H, W)
+        '''
+        return self._compute(pred, gt, mask, weights)
+
+    def _compute(self, pred, gt, mask, weights):
+        if pred.dim() == 4:
+            pred = pred[:, 0, :, :]
+            gt = gt[:, 0, :, :]
+        assert pred.shape == gt.shape
+        assert pred.shape == mask.shape
+        if weights is not None:
+            assert weights.shape == mask.shape
+            mask = weights * mask
+        intersection = (pred * gt * mask).sum()
+
+        union = (pred * mask).sum() + (gt * mask).sum() + self.eps
+        loss = 1 - 2.0 * intersection / union
+        assert loss <= 1
+        return loss
+
+
+class MaskL1Loss(nn.Module):
+    def __init__(self, eps=1e-6):
+        super(MaskL1Loss, self).__init__()
+        self.eps = eps
+
+    def forward(self, pred: torch.Tensor, gt, mask):
+        loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
+        return loss
+
+
+class BCELoss(nn.Module):
+    def __init__(self, reduction='mean'):
+        super(BCELoss, self).__init__()
+        self.reduction = reduction
+
+    def forward(self, input, label, mask=None, weight=None, name=None):
+        loss = nn.functional.binary_cross_entropy(input, label, reduction=self.reduction)
+        return loss
+
+
+class BalanceLoss(nn.Module):
+    def __init__(self,
+                 balance_loss=True,
+                 main_loss_type='DiceLoss',
+                 negative_ratio=3,
+                 return_origin=False,
+                 eps=1e-6,
+                 **kwargs):
+        super(BalanceLoss, self).__init__()
+        self.balance_loss = balance_loss
+        self.main_loss_type = main_loss_type
+        self.negative_ratio = negative_ratio
+        self.return_origin = return_origin
+        self.eps = eps
+
+        if self.main_loss_type == "CrossEntropy":
+            self.loss = nn.CrossEntropyLoss()
+        elif self.main_loss_type == "Euclidean":
+            self.loss = nn.MSELoss()
+        elif self.main_loss_type == "DiceLoss":
+            self.loss = DiceLoss(self.eps)
+        elif self.main_loss_type == "BCELoss":
+            self.loss = BCELoss(reduction='none')
+        elif self.main_loss_type == "MaskL1Loss":
+            self.loss = MaskL1Loss(self.eps)
+        else:
+            loss_type = [
+                'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
+            ]
+            raise Exception(
+                "main_loss_type in BalanceLoss() can only be one of {}".format(
+                    loss_type))
+
+    def forward(self, pred, gt, mask=None):
+        """
+        The BalanceLoss for Differentiable Binarization text detection
+        args:
+            pred (variable): predicted feature maps.
+            gt (variable): ground truth feature maps.
+            mask (variable): masked maps.
+        return: (variable) balanced loss
+        """
+        positive = (gt * mask).byte()
+        negative = ((1 - gt) * mask).byte()
+
+        positive_count = int(positive.float().sum())
+        negative_count = int(min(negative.float().sum(), positive_count * self.negative_ratio))
+
+        loss = self.loss(pred, gt, mask=mask)
+
+        if not self.balance_loss:
+            return loss
+
+        positive_loss = positive.float() * loss
+        negative_loss = negative.float() * loss
+        if negative_count > 0:
+            negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
+            balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
+                    positive_count + negative_count + self.eps)
+        else:
+            balance_loss = positive_loss.sum() / (positive_count + self.eps)
+        if self.return_origin:
+            return balance_loss, loss
+
+        return balance_loss

+ 209 - 0
torchocr/networks/losses/FCELoss.py

@@ -0,0 +1,209 @@
+import numpy as np
+from torch import nn
+import torch
+import torch.nn.functional as F
+from functools import partial
+
+
+def multi_apply(func, *args, **kwargs):
+    pfunc = partial(func, **kwargs) if kwargs else func
+    map_results = map(pfunc, *args)
+    return tuple(map(list, zip(*map_results)))
+
+
+class FCELoss(nn.Module):
+    """The class for implementing FCENet loss
+    FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped
+        Text Detection
+
+    [https://arxiv.org/abs/2104.10442]
+
+    Args:
+        fourier_degree (int) : The maximum Fourier transform degree k.
+        num_sample (int) : The sampling points number of regression
+            loss. If it is too small, fcenet tends to be overfitting.
+        ohem_ratio (float): the negative/positive ratio in OHEM.
+    """
+
+    def __init__(self, fourier_degree, num_sample, ohem_ratio=3.):
+        super().__init__()
+        self.fourier_degree = fourier_degree
+        self.num_sample = num_sample
+        self.ohem_ratio = ohem_ratio
+
+    def forward(self, preds, labels):
+        assert isinstance(preds, dict)
+        preds = preds['levels']
+
+        p3_maps, p4_maps, p5_maps = labels['p3_maps'], labels['p4_maps'], labels['p5_maps']
+        assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5, \
+            'fourier degree not equal in FCEhead and FCEtarget'
+
+        # to tensor
+        gts = [p3_maps, p4_maps, p5_maps]
+        # for idx, maps in enumerate(gts):
+        #     gts[idx] = torch.tensor(np.stack(maps.cpu().detach().numpy()))
+        #     torch.stack(maps)
+        losses = multi_apply(self.forward_single, preds, gts)
+
+        loss_tr = torch.tensor(0.).cuda().float()
+        loss_tcl = torch.tensor(0.).cuda().float()
+        loss_reg_x = torch.tensor(0.).cuda().float()
+        loss_reg_y = torch.tensor(0.).cuda().float()
+        loss_all = torch.tensor(0.).cuda().float()
+
+        for idx, loss in enumerate(losses):
+            loss_all += sum(loss)
+            if idx == 0:
+                loss_tr += sum(loss)
+            elif idx == 1:
+                loss_tcl += sum(loss)
+            elif idx == 2:
+                loss_reg_x += sum(loss)
+            else:
+                loss_reg_y += sum(loss)
+
+        results = dict(
+            loss=loss_all,
+            loss_text=loss_tr,
+            loss_center=loss_tcl,
+            loss_reg_x=loss_reg_x,
+            loss_reg_y=loss_reg_y, )
+        return results
+
+    def forward_single(self, pred, gt):
+        cls_pred = pred[0].permute(0, 2, 3, 1)
+        reg_pred = pred[1].permute(0, 2, 3, 1)
+        gt = gt.permute(0, 2, 3, 1)
+
+        k = 2 * self.fourier_degree + 1
+        tr_pred = torch.reshape(cls_pred[:, :, :, :2], (-1, 2))
+        tcl_pred = torch.reshape(cls_pred[:, :, :, 2:], (-1, 2))
+        x_pred = torch.reshape(reg_pred[:, :, :, 0:k], (-1, k))
+        y_pred = torch.reshape(reg_pred[:, :, :, k:2 * k], (-1, k))
+
+        tr_mask = gt[:, :, :, :1].reshape([-1])
+        tcl_mask = gt[:, :, :, 1:2].reshape([-1])
+        train_mask = gt[:, :, :, 2:3].reshape([-1])
+        x_map = torch.reshape(gt[:, :, :, 3:3 + k], (-1, k))
+        y_map = torch.reshape(gt[:, :, :, 3 + k:], (-1, k))
+
+        tr_train_mask = (train_mask * tr_mask).bool()
+        tr_train_mask2 = torch.cat(
+            [tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], dim=1)
+        # tr loss
+        loss_tr = self.ohem(tr_pred, tr_mask, train_mask)
+        # tcl loss
+        loss_tcl = torch.tensor((0.), dtype=torch.float32)
+        tr_neg_mask = tr_train_mask.logical_not()
+        tr_neg_mask2 = torch.cat(
+            [tr_neg_mask.unsqueeze(1), tr_neg_mask.unsqueeze(1)], dim=1)
+        if tr_train_mask.sum().item() > 0:
+            loss_tcl_pos = F.cross_entropy(
+                tcl_pred.masked_select(tr_train_mask2).reshape([-1, 2]),
+                tcl_mask.masked_select(tr_train_mask).long())
+            loss_tcl_neg = F.cross_entropy(
+                tcl_pred.masked_select(tr_neg_mask2).reshape([-1, 2]),
+                tcl_mask.masked_select(tr_neg_mask).long())
+            loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
+
+        # regression loss
+        loss_reg_x = torch.tensor(0.).float()
+        loss_reg_y = torch.tensor(0.).float()
+        if tr_train_mask.sum().item() > 0:
+            weight = (tr_mask.masked_select(tr_train_mask.bool())
+                      .float() + tcl_mask.masked_select(
+                tr_train_mask.bool()).float()) / 2
+            weight = weight.reshape([-1, 1])
+
+            ft_x, ft_y = self.fourier2poly(x_map, y_map)
+            ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
+
+            dim = ft_x.shape[1]
+
+            tr_train_mask3 = torch.cat(
+                [tr_train_mask.unsqueeze(1) for i in range(dim)], dim=1)
+
+            loss_reg_x = torch.mean(weight * F.smooth_l1_loss(
+                ft_x_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
+                ft_x.masked_select(tr_train_mask3).reshape([-1, dim]),
+                reduction='none'))
+            loss_reg_y = torch.mean(weight * F.smooth_l1_loss(
+                ft_y_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
+                ft_y.masked_select(tr_train_mask3).reshape([-1, dim]),
+                reduction='none'))
+
+        return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
+
+    def ohem(self, predict, target, train_mask):
+
+        pos = (target * train_mask).bool()
+        neg = ((1 - target) * train_mask).bool()
+
+        pos2 = torch.cat([pos.unsqueeze(1), pos.unsqueeze(1)], dim=1)
+        neg2 = torch.cat([neg.unsqueeze(1), neg.unsqueeze(1)], dim=1)
+
+        n_pos = pos.float().sum()
+
+        if n_pos.item() > 0:
+            loss_pos = F.cross_entropy(
+                predict.masked_select(pos2).reshape([-1, 2]),
+                target.masked_select(pos).long(),
+                reduction='sum')
+            loss_neg = F.cross_entropy(
+                predict.masked_select(neg2).reshape([-1, 2]),
+                target.masked_select(neg).long(),
+                reduction='none')
+            n_neg = min(
+                int(neg.float().sum().item()),
+                int(self.ohem_ratio * n_pos.float()))
+        else:
+            loss_pos = torch.tensor(0.)
+            loss_neg = F.cross_entropy(
+                predict.masked_select(neg2).reshape([-1, 2]),
+                target.masked_select(neg).long(),
+                reduction='none')
+            n_neg = 100
+        if len(loss_neg) > n_neg:
+            loss_neg, _ = torch.topk(loss_neg, n_neg)
+
+        return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
+
+    def fourier2poly(self, real_maps, imag_maps):
+        """Transform Fourier coefficient maps to polygon maps.
+
+        Args:
+            real_maps (tensor): A map composed of the real parts of the
+                Fourier coefficients, whose shape is (-1, 2k+1)
+            imag_maps (tensor):A map composed of the imag parts of the
+                Fourier coefficients, whose shape is (-1, 2k+1)
+
+        Returns
+            x_maps (tensor): A map composed of the x value of the polygon
+                represented by n sample points (xn, yn), whose shape is (-1, n)
+            y_maps (tensor): A map composed of the y value of the polygon
+                represented by n sample points (xn, yn), whose shape is (-1, n)
+        """
+
+        k_vect = torch.arange(
+            -self.fourier_degree, self.fourier_degree + 1,
+            dtype=torch.float32).reshape([-1, 1])
+        i_vect = torch.arange(
+            0, self.num_sample, dtype=torch.float32).reshape([1, -1])
+
+        transform_matrix = 2 * np.pi / self.num_sample * torch.matmul(k_vect,
+                                                                      i_vect)
+
+        x1 = torch.einsum('ak, kn-> an', real_maps,
+                          torch.cos(transform_matrix).cuda())
+        x2 = torch.einsum('ak, kn-> an', imag_maps,
+                          torch.sin(transform_matrix).cuda())
+        y1 = torch.einsum('ak, kn-> an', real_maps,
+                          torch.sin(transform_matrix).cuda())
+        y2 = torch.einsum('ak, kn-> an', imag_maps,
+                          torch.cos(transform_matrix).cuda())
+
+        x_maps = x1 - x2
+        y_maps = y1 + y2
+
+        return x_maps, y_maps

+ 104 - 0
torchocr/networks/losses/PSELoss.py

@@ -0,0 +1,104 @@
+import torch
+from torch import nn
+import numpy as np
+
+
+class PSELoss(nn.Module):
+    def __init__(self, Lambda, ratio=3, reduction='mean'):
+        """Implement PSE Loss.
+        """
+        super(PSELoss, self).__init__()
+        assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
+        self.Lambda = Lambda
+        self.ratio = ratio
+        self.reduction = reduction
+
+    def forward(self, outputs, labels, training_masks):
+        texts = outputs[:, -1, :, :]
+        kernels = outputs[:, :-1, :, :]
+        gt_texts = labels[:, -1, :, :]
+        gt_kernels = labels[:, :-1, :, :]
+
+        selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
+        selected_masks = selected_masks.to(outputs.device)
+
+        loss_text = self.dice_loss(texts, gt_texts, selected_masks)
+
+        loss_kernels = []
+        mask0 = torch.sigmoid(texts).data.cpu().numpy()
+        mask1 = training_masks.data.cpu().numpy()
+        selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32')
+        selected_masks = torch.from_numpy(selected_masks).float()
+        selected_masks = selected_masks.to(outputs.device)
+        kernels_num = gt_kernels.size()[1]
+        for i in range(kernels_num):
+            kernel_i = kernels[:, i, :, :]
+            gt_kernel_i = gt_kernels[:, i, :, :]
+            loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks)
+            loss_kernels.append(loss_kernel_i)
+        loss_kernels = torch.stack(loss_kernels).mean(0)
+        if self.reduction == 'mean':
+            loss_text = loss_text.mean()
+            loss_kernels = loss_kernels.mean()
+        elif self.reduction == 'sum':
+            loss_text = loss_text.sum()
+            loss_kernels = loss_kernels.sum()
+
+        loss = self.Lambda * loss_text + (1 - self.Lambda) * loss_kernels
+        return loss_text, loss_kernels, loss
+
+    def dice_loss(self, input, target, mask):
+        input = torch.sigmoid(input)
+
+        input = input.contiguous().view(input.size()[0], -1)
+        target = target.contiguous().view(target.size()[0], -1)
+        mask = mask.contiguous().view(mask.size()[0], -1)
+
+        input = input * mask
+        target = target * mask
+
+        a = torch.sum(input * target, 1)
+        b = torch.sum(input * input, 1) + 0.001
+        c = torch.sum(target * target, 1) + 0.001
+        d = (2 * a) / (b + c)
+        return 1 - d
+
+    def ohem_single(self, score, gt_text, training_mask):
+        pos_num = (int)(np.sum(gt_text > 0.5)) - (int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
+
+        if pos_num == 0:
+            # selected_mask = gt_text.copy() * 0 # may be not good
+            selected_mask = training_mask
+            selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
+            return selected_mask
+
+        neg_num = (int)(np.sum(gt_text <= 0.5))
+        neg_num = (int)(min(pos_num * 3, neg_num))
+
+        if neg_num == 0:
+            selected_mask = training_mask
+            selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
+            return selected_mask
+
+        neg_score = score[gt_text <= 0.5]
+        # 将负样本得分从高到低排序
+        neg_score_sorted = np.sort(-neg_score)
+        threshold = -neg_score_sorted[neg_num - 1]
+        # 选出 得分高的 负样本 和正样本 的 mask
+        selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5)
+        selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
+        return selected_mask
+
+    def ohem_batch(self, scores, gt_texts, training_masks):
+        scores = scores.data.cpu().numpy()
+        gt_texts = gt_texts.data.cpu().numpy()
+        training_masks = training_masks.data.cpu().numpy()
+
+        selected_masks = []
+        for i in range(scores.shape[0]):
+            selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :]))
+
+        selected_masks = np.concatenate(selected_masks, 0)
+        selected_masks = torch.from_numpy(selected_masks).float()
+
+        return selected_masks

+ 23 - 0
torchocr/networks/losses/__init__.py

@@ -0,0 +1,23 @@
+# -*- coding: utf-8 -*-
+# @Time    : 2020/5/15 17:43
+# @Author  : zhoujun
+import copy
+from addict import Dict
+from .DBLoss import DBLoss
+from .CTCLoss import CTCLoss,EnhancedCTCLoss
+from .PSELoss import PSELoss
+from .CombinedLoss import CombinedLoss
+from .FCELoss import FCELoss
+
+__all__ = ['build_loss']
+
+support_loss = ['DBLoss', 'CTCLoss','PSELoss','CombinedLoss','FCELoss','EnhancedCTCLoss']
+
+
+def build_loss(config):
+    copy_config = copy.deepcopy(config)
+    loss_type = copy_config.pop('type')
+    assert loss_type in support_loss, f'all support loss is {support_loss}'
+    criterion = eval(loss_type)(copy_config)
+    # criterion = eval(loss_type)(Dict(copy_config))
+    return criterion

+ 300 - 0
torchocr/networks/losses/distillation_loss.py

@@ -0,0 +1,300 @@
+import torch
+import cv2
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+from .DBLoss import DBLoss
+
+def _sum_loss(loss_dict):
+    if "loss" in loss_dict.keys():
+        return loss_dict
+    else:
+        loss_dict["loss"] = 0.
+        for k, value in loss_dict.items():
+            if k == "loss":
+                continue
+            else:
+                loss_dict["loss"] += value
+        return loss_dict
+
+
+class KLJSLoss(object):
+    def __init__(self, mode='kl'):
+        assert mode in ['kl', 'js', 'KL', 'JS'
+                        ], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
+        self.mode = mode
+
+    def __call__(self, p1, p2, reduction="mean"):
+
+        loss = torch.mul(p2, torch.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
+
+        if self.mode.lower() == "js":
+            loss += torch.mul(
+                p1, torch.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
+            loss *= 0.5
+        if reduction == "mean":
+            loss = torch.mean(loss)
+        elif reduction == "none" or reduction is None:
+            return loss
+        else:
+            loss = torch.sum(loss)
+
+        return loss
+
+
+class DMLLoss(nn.Module):
+    """
+    DMLLoss
+    """
+
+    def __init__(self, act=None, use_log=False):
+        super().__init__()
+        if act is not None:
+            assert act in ["softmax", "sigmoid"]
+        if act == "softmax":
+            self.act = nn.Softmax(axis=-1)
+        elif act == "sigmoid":
+            self.act = nn.Sigmoid()
+        else:
+            self.act = None
+
+        self.use_log = use_log
+
+        self.jskl_loss = KLJSLoss(mode="js")
+
+    def forward(self, out1, out2):
+        if self.act is not None:
+            out1 = self.act(out1)
+            out2 = self.act(out2)
+        if self.use_log:
+            # for recognition distillation, log is needed for feature map
+            log_out1 = torch.log(out1)
+            log_out2 = torch.log(out2)
+            loss = (F.kl_div(
+                log_out1, out2, reduction='batchmean') + F.kl_div(
+                log_out2, out1, reduction='batchmean')) / 2.0
+        else:
+            # for detection distillation log is not needed
+            loss = self.jskl_loss(out1, out2)
+        return loss
+
+class DistanceLoss(nn.Module):
+    """
+    DistanceLoss:
+        mode: loss mode
+    """
+
+    def __init__(self, mode="l2", **kargs):
+        super().__init__()
+        assert mode in ["l1", "l2", "smooth_l1"]
+        if mode == "l1":
+            self.loss_func = nn.L1Loss(**kargs)
+        elif mode == "l2":
+            self.loss_func = nn.MSELoss(**kargs)
+        elif mode == "smooth_l1":
+            self.loss_func = nn.SmoothL1Loss(**kargs)
+
+    def forward(self, x, y):
+        return self.loss_func(x, y)
+
+
+class DistillationDMLLoss(DMLLoss):
+    """
+    """
+
+    def __init__(self,
+                 model_name_pairs=[],
+                 act=None,
+                 use_log=False,
+                 key=None,
+                 maps_name=None,
+                 name="dml"):
+        super().__init__(act=act, use_log=use_log)
+        assert isinstance(model_name_pairs, list)
+        self.key = key
+        self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
+        self.name = name
+        self.maps_name = self._check_maps_name(maps_name)
+
+    def _check_model_name_pairs(self, model_name_pairs):
+        if not isinstance(model_name_pairs, list):
+            return []
+        elif isinstance(model_name_pairs[0], list) and isinstance(
+                model_name_pairs[0][0], str):
+            return model_name_pairs
+        else:
+            return [model_name_pairs]
+
+    def _check_maps_name(self, maps_name):
+        if maps_name is None:
+            return None
+        elif type(maps_name) == str:
+            return [maps_name]
+        elif type(maps_name) == list:
+            return [maps_name]
+        else:
+            return None
+
+    def _slice_out(self, outs):
+        new_outs = {}
+        for k in self.maps_name:
+            if k == "thrink_maps":
+                new_outs[k] = outs[:, 0, :, :]
+            elif k == "threshold_maps":
+                new_outs[k] = outs[:, 1, :, :]
+            elif k == "binary_maps":
+                new_outs[k] = outs[:, 2, :, :]
+            else:
+                continue
+        return new_outs
+
+    def forward(self, predicts, batch):
+        loss_dict = dict()
+        for idx, pair in enumerate(self.model_name_pairs):
+            out1 = predicts[pair[0]]
+            out2 = predicts[pair[1]]
+            if self.maps_name is None:
+                loss = super().forward(out1, out2)
+                if isinstance(loss, dict):
+                    for key in loss:
+                        loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],idx)] = loss[key]
+                else:
+                    loss_dict["{}_{}".format(self.name, idx)] = loss
+            else:
+                outs1 = self._slice_out(out1)
+                outs2 = self._slice_out(out2)
+                for _c, k in enumerate(outs1.keys()):
+                    loss = super().forward(outs1[k], outs2[k])
+                    if isinstance(loss, dict):
+                        for key in loss:
+                            loss_dict["{}_{}_{}_{}_{}".format(key, pair[0], pair[1], self.maps_name[_c], idx)] = loss[key]
+                    else:
+                        loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c], idx)] = loss
+
+        loss_dict = _sum_loss(loss_dict)
+
+        return loss_dict
+
+
+class DistillationDBLoss(DBLoss):
+    def __init__(self,
+                 model_name_list=[],
+                 balance_loss=True,
+                 main_loss_type='DiceLoss',
+                 alpha=5,
+                 beta=10,
+                 ohem_ratio=3,
+                 eps=1e-6,
+                 name="db",
+                 **kwargs):
+        super().__init__()
+        self.model_name_list = model_name_list
+        self.name = name
+        self.key = None
+
+    def forward(self, predicts, batch):
+        loss_dict = {}
+        for idx, model_name in enumerate(self.model_name_list):
+            out = predicts[model_name]
+            loss = super().forward(out, batch)
+            if isinstance(loss, dict):
+                for key in loss.keys():
+                    if key == "loss":
+                        continue
+                    name = "{}_{}_{}".format(self.name, model_name, key)
+                    loss_dict[name] = loss[key]
+            else:
+                loss_dict["{}_{}".format(self.name, model_name)] = loss
+
+        loss_dict = _sum_loss(loss_dict)
+        return loss_dict
+
+
+class DistillationDilaDBLoss(DBLoss):
+    def __init__(self,
+                 model_name_pairs=[],
+                 key=None,
+                 balance_loss=True,
+                 main_loss_type='DiceLoss',
+                 alpha=5,
+                 beta=10,
+                 ohem_ratio=3,
+                 eps=1e-6,
+                 name="dila_dbloss"):
+        super().__init__()
+        self.model_name_pairs = model_name_pairs
+        self.name = name
+        self.key = key
+
+    def forward(self, predicts, batch):
+        loss_dict = dict()
+        for idx, pair in enumerate(self.model_name_pairs):
+            # stu_outs = predicts[pair[0]]
+            # tch_outs = predicts[pair[1]]
+            # if self.key is not None:
+            #     stu_preds = stu_outs[self.key]
+            #     tch_preds = tch_outs[self.key]
+            stu_preds = predicts[pair[0]]
+            tch_preds = predicts[pair[1]]
+            stu_shrink_maps = stu_preds[:, 0, :, :]
+            stu_binary_maps = stu_preds[:, 2, :, :]
+
+            # dilation to teacher prediction
+            dilation_w = np.array([[1, 1], [1, 1]])
+            th_shrink_maps = tch_preds[:, 0, :, :]
+            th_shrink_maps = th_shrink_maps.cpu().detach().numpy() > 0.3  # thresh = 0.3
+            dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
+            for i in range(th_shrink_maps.shape[0]):
+                dilate_maps[i] = cv2.dilate(
+                    th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
+            th_shrink_maps = torch.tensor(dilate_maps).cuda()
+
+            label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch['threshold_map'], batch['threshold_mask'], batch['shrink_map'], batch['shrink_mask']
+
+            # calculate the shrink map loss
+            bce_loss = self.alpha * self.bce_loss(
+                stu_shrink_maps, th_shrink_maps, label_shrink_mask)
+            loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
+                                              label_shrink_mask)
+
+            # k = f"{self.name}_{pair[0]}_{pair[1]}"
+            k = "{}_{}_{}".format(self.name, pair[0], pair[1])
+            loss_dict[k] = bce_loss + loss_binary_maps
+
+        loss_dict = _sum_loss(loss_dict)
+        return loss_dict
+
+
+class DistillationDistanceLoss(DistanceLoss):
+    """
+    """
+
+    def __init__(self,
+                 mode="l2",
+                 model_name_pairs=[],
+                 key=None,
+                 name="loss_distance",
+                 **kargs):
+        super().__init__(mode=mode, **kargs)
+        assert isinstance(model_name_pairs, list)
+        self.key = key
+        self.model_name_pairs = model_name_pairs
+        self.name = name + "_l2"
+
+    def forward(self, predicts, batch):
+        loss_dict = dict()
+        for idx, pair in enumerate(self.model_name_pairs):
+            out1 = predicts[pair[0]]
+            out2 = predicts[pair[1]]
+            # if self.key is not None:
+            #     out1 = out1[self.key]
+            #     out2 = out2[self.key]
+            loss = super().forward(out1, out2)
+            if isinstance(loss, dict):
+                for key in loss:
+                    loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
+                        key]
+            else:
+                loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
+                                               idx)] = loss
+        return loss_dict

Энэ ялгаанд хэт олон файл өөрчлөгдсөн тул зарим файлыг харуулаагүй болно