1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- import os
- import pickle
- import sys
- import six
- import yaml
- import time
- import shutil
- import paddle
- import paddle.distributed as dist
- from paddle.fluid.dataloader import DistributedBatchSampler
- from paddle.fluid.reader import DataLoader
- from tqdm import tqdm
- from argparse import ArgumentParser, RawDescriptionHelpFormatter
- from ppocr.modeling.architectures import build_model
- from ppocr.utils.stats import TrainingStats
- from ppocr.utils.save_load import init_model
- from ppocr.utils.utility import print_dict
- from ppocr.utils.logging import get_logger
- from ppocr.data import build_dataloader
- from tools import program
- from tools.program import ArgsParser
- from ppocr.data.simple_dataset import SimpleDataSet
- if __name__ == "__main__":
- FLAGS = ArgsParser().parse_args()
- print(FLAGS)
- # file_path = "./configs/rec/my_chinese_lite.yml"
- # print(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
- config, device, logger, vdl_writer = program.preprocess(is_train=True)
- dataset = eval('SimpleDataSet')(config, 'Train', logger, None)
- # print(config)
- batch_sampler = DistributedBatchSampler(
- dataset=dataset,
- batch_size=16,
- shuffle=False,
- drop_last=True)
- data_loader = DataLoader(
- dataset=dataset,
- batch_sampler=batch_sampler,
- places=device,
- num_workers=4,
- return_list=True,
- use_shared_memory=True)
- print(len(data_loader))
- config['Architecture']['Head']['out_channels'] = 3
- model = build_model(config['Architecture'])
- if config['Global']['distributed']:
- model = paddle.DataParallel(model)
- import chardet
- file = "D:\\Project\\PaddleOCR-release-2.0\\pretrained_model\\ch_ppocr_mobile_v2.0_rec_pre\\best_accuracy"
- # f = open(file, "r")
- # data = f.read()
- # print(chardet.detect(data))
- pre_best_model_dict = init_model(config, model, logger, None)
- # pre_best_model_dict = paddle.load(file)
- # with open(file, 'rb') as f:
- # print(f.readline())
- # load_result = pickle.load(f, encoding='latin1')
- print(pre_best_model_dict)
|