generate_multi_language_configs.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import yaml
  15. from argparse import ArgumentParser, RawDescriptionHelpFormatter
  16. import os.path
  17. import logging
  18. logging.basicConfig(level=logging.INFO)
  19. support_list = {
  20. 'it':'italian', 'xi':'spanish', 'pu':'portuguese', 'ru':'russian', 'ar':'arabic',
  21. 'ta':'tamil', 'ug':'uyghur', 'fa':'persian', 'ur':'urdu', 'rs':'serbian latin',
  22. 'oc':'occitan', 'rsc':'serbian cyrillic', 'bg':'bulgarian', 'uk':'ukranian', 'be':'belarusian',
  23. 'te':'telugu', 'ka':'kannada', 'chinese_cht':'chinese tradition','hi':'hindi','mr':'marathi',
  24. 'ne':'nepali',
  25. }
  26. assert(
  27. os.path.isfile("./rec_multi_language_lite_train.yml")
  28. ),"Loss basic configuration file rec_multi_language_lite_train.yml.\
  29. You can download it from \
  30. https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
  31. global_config = yaml.load(open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader)
  32. project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
  33. class ArgsParser(ArgumentParser):
  34. def __init__(self):
  35. super(ArgsParser, self).__init__(
  36. formatter_class=RawDescriptionHelpFormatter)
  37. self.add_argument(
  38. "-o", "--opt", nargs='+', help="set configuration options")
  39. self.add_argument(
  40. "-l", "--language", nargs='+', help="set language type, support {}".format(support_list))
  41. self.add_argument(
  42. "--train",type=str,help="you can use this command to change the train dataset default path")
  43. self.add_argument(
  44. "--val",type=str,help="you can use this command to change the eval dataset default path")
  45. self.add_argument(
  46. "--dict",type=str,help="you can use this command to change the dictionary default path")
  47. self.add_argument(
  48. "--data_dir",type=str,help="you can use this command to change the dataset default root path")
  49. def parse_args(self, argv=None):
  50. args = super(ArgsParser, self).parse_args(argv)
  51. args.opt = self._parse_opt(args.opt)
  52. args.language = self._set_language(args.language)
  53. return args
  54. def _parse_opt(self, opts):
  55. config = {}
  56. if not opts:
  57. return config
  58. for s in opts:
  59. s = s.strip()
  60. k, v = s.split('=')
  61. config[k] = yaml.load(v, Loader=yaml.Loader)
  62. return config
  63. def _set_language(self, type):
  64. assert(type),"please use -l or --language to choose language type"
  65. assert(
  66. type[0] in support_list.keys()
  67. ),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \
  68. "please check your running command".format(support_list, type)
  69. global_config['Global']['character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(type[0])
  70. global_config['Global']['save_model_dir'] = './output/rec_{}_lite'.format(type[0])
  71. global_config['Train']['dataset']['label_file_list'] = ["train_data/{}_train.txt".format(type[0])]
  72. global_config['Eval']['dataset']['label_file_list'] = ["train_data/{}_val.txt".format(type[0])]
  73. global_config['Global']['character_type'] = type[0]
  74. assert(
  75. os.path.isfile(os.path.join(project_path,global_config['Global']['character_dict_path']))
  76. ),"Loss default dictionary file {}_dict.txt.You can download it from \
  77. https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(type[0])
  78. return type[0]
  79. def merge_config(config):
  80. """
  81. Merge config into global config.
  82. Args:
  83. config (dict): Config to be merged.
  84. Returns: global config
  85. """
  86. for key, value in config.items():
  87. if "." not in key:
  88. if isinstance(value, dict) and key in global_config:
  89. global_config[key].update(value)
  90. else:
  91. global_config[key] = value
  92. else:
  93. sub_keys = key.split('.')
  94. assert (
  95. sub_keys[0] in global_config
  96. ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
  97. global_config.keys(), sub_keys[0])
  98. cur = global_config[sub_keys[0]]
  99. for idx, sub_key in enumerate(sub_keys[1:]):
  100. if idx == len(sub_keys) - 2:
  101. cur[sub_key] = value
  102. else:
  103. cur = cur[sub_key]
  104. def loss_file(path):
  105. assert(
  106. os.path.exists(path)
  107. ),"There is no such file:{},Please do not forget to put in the specified file".format(path)
  108. if __name__ == '__main__':
  109. FLAGS = ArgsParser().parse_args()
  110. merge_config(FLAGS.opt)
  111. save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language)
  112. if os.path.isfile(save_file_path):
  113. os.remove(save_file_path)
  114. if FLAGS.train:
  115. global_config['Train']['dataset']['label_file_list'] = [FLAGS.train]
  116. train_label_path = os.path.join(project_path,FLAGS.train)
  117. loss_file(train_label_path)
  118. if FLAGS.val:
  119. global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
  120. eval_label_path = os.path.join(project_path,FLAGS.val)
  121. loss_file(Eval_label_path)
  122. if FLAGS.dict:
  123. global_config['Global']['character_dict_path'] = FLAGS.dict
  124. dict_path = os.path.join(project_path,FLAGS.dict)
  125. loss_file(dict_path)
  126. if FLAGS.data_dir:
  127. global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir
  128. global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir
  129. data_dir = os.path.join(project_path,FLAGS.data_dir)
  130. loss_file(data_dir)
  131. with open(save_file_path, 'w') as f:
  132. yaml.dump(dict(global_config), f, default_flow_style=False, sort_keys=False)
  133. logging.info("Project path is :{}".format(project_path))
  134. logging.info("Train list path set to :{}".format(global_config['Train']['dataset']['label_file_list'][0]))
  135. logging.info("Eval list path set to :{}".format(global_config['Eval']['dataset']['label_file_list'][0]))
  136. logging.info("Dataset root path set to :{}".format(global_config['Eval']['dataset']['data_dir']))
  137. logging.info("Dict path set to :{}".format(global_config['Global']['character_dict_path']))
  138. logging.info("Config file set to :configs/rec/multi_language/{}".format(save_file_path))