MyFineTuning.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import os
  2. import pickle
  3. import sys
  4. import six
  5. import yaml
  6. import time
  7. import shutil
  8. import paddle
  9. import paddle.distributed as dist
  10. from paddle.fluid.dataloader import DistributedBatchSampler
  11. from paddle.fluid.reader import DataLoader
  12. from tqdm import tqdm
  13. from argparse import ArgumentParser, RawDescriptionHelpFormatter
  14. from ppocr.modeling.architectures import build_model
  15. from ppocr.utils.stats import TrainingStats
  16. from ppocr.utils.save_load import init_model
  17. from ppocr.utils.utility import print_dict
  18. from ppocr.utils.logging import get_logger
  19. from ppocr.data import build_dataloader
  20. from tools import program
  21. from tools.program import ArgsParser
  22. from ppocr.data.simple_dataset import SimpleDataSet
  23. if __name__ == "__main__":
  24. FLAGS = ArgsParser().parse_args()
  25. print(FLAGS)
  26. # file_path = "./configs/rec/my_chinese_lite.yml"
  27. # print(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
  28. config, device, logger, vdl_writer = program.preprocess(is_train=True)
  29. dataset = eval('SimpleDataSet')(config, 'Train', logger, None)
  30. # print(config)
  31. batch_sampler = DistributedBatchSampler(
  32. dataset=dataset,
  33. batch_size=16,
  34. shuffle=False,
  35. drop_last=True)
  36. data_loader = DataLoader(
  37. dataset=dataset,
  38. batch_sampler=batch_sampler,
  39. places=device,
  40. num_workers=4,
  41. return_list=True,
  42. use_shared_memory=True)
  43. print(len(data_loader))
  44. config['Architecture']['Head']['out_channels'] = 3
  45. model = build_model(config['Architecture'])
  46. if config['Global']['distributed']:
  47. model = paddle.DataParallel(model)
  48. import chardet
  49. file = "D:\\Project\\PaddleOCR-release-2.0\\pretrained_model\\ch_ppocr_mobile_v2.0_rec_pre\\best_accuracy"
  50. # f = open(file, "r")
  51. # data = f.read()
  52. # print(chardet.detect(data))
  53. pre_best_model_dict = init_model(config, model, logger, None)
  54. # pre_best_model_dict = paddle.load(file)
  55. # with open(file, 'rb') as f:
  56. # print(f.readline())
  57. # load_result = pickle.load(f, encoding='latin1')
  58. print(pre_best_model_dict)