Python trainer.Trainer() Examples

The following are 30 code examples of trainer.Trainer(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module trainer , or try the search function .
Example #1
Source File: classifier.py    From ModelFeast with MIT License 6 votes vote down vote up
def train(self):
        assert callable(self.model), "model is not callable!!"
        assert callable(self.loss), "loss is not callable!!"
        assert all(callable(met) for met in self.metrics), "metrics is not callable!!"
        assert "trainer" in self.config, "trainer hasn't been configured!!"
        assert isinstance(self.data_loader, Iterable), "data_loader is not iterable!!"

        # the num of classes in dataset must bet the same as model's output
        if hasattr(self.data_loader, 'classes'):
            true_classes = len(self.data_loader.classes)
            model_output = self.config['arch']['args']['n_class']
            assert true_classes==model_output, "model分类数为{},可是实际上有{}个类".format(
                model_output, true_classes)

        if "name" not in self.config:
            self.config["name"] = "_".join(self.config["arch"]["type"], 
                self.config["data_loader"]["type"])
        self.trainer = Trainer(self.model, self.loss, self.metrics, self.optimizer, 
            resume=self.resume, config=self.config, data_loader=self.data_loader,
            valid_data_loader=self.valid_data_loader, lr_scheduler=self.lr_scheduler,
            train_logger=self.train_logger)        
        self.trainer.train() 
Example #2
Source File: main.py    From JointBERT with Apache License 2.0 6 votes vote down vote up
def main(args):
    init_logger()
    set_seed(args)
    tokenizer = load_tokenizer(args)

    train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
    dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
    test_dataset = load_and_cache_examples(args, tokenizer, mode="test")

    trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)

    if args.do_train:
        trainer.train()

    if args.do_eval:
        trainer.load_model()
        trainer.evaluate("test") 
Example #3
Source File: main.py    From EDSR-PyTorch with MIT License 6 votes vote down vote up
def main():
    global model
    if args.data_test == ['video']:
        from videotester import VideoTester
        model = model.Model(args, checkpoint)
        t = VideoTester(args, model, checkpoint)
        t.test()
    else:
        if checkpoint.ok:
            loader = data.Data(args)
            _model = model.Model(args, checkpoint)
            _loss = loss.Loss(args, checkpoint) if not args.test_only else None
            t = Trainer(args, loader, _model, _loss, checkpoint)
            while not t.terminate():
                t.train()
                t.test()

            checkpoint.done() 
Example #4
Source File: main.py    From KoBERT-NER with Apache License 2.0 6 votes vote down vote up
def main(args):
    init_logger()
    set_seed(args)
    
    tokenizer = load_tokenizer(args)

    train_dataset = None
    dev_dataset = None
    test_dataset = None

    if args.do_train or args.do_eval:
        test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
    if args.do_train:
        train_dataset = load_and_cache_examples(args, tokenizer, mode="train")

    trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)

    if args.do_train:
        trainer.train()

    if args.do_eval:
        trainer.load_model()
        trainer.evaluate("test", "eval") 
Example #5
Source File: main.py    From pointer-network-tensorflow with MIT License 6 votes vote down vote up
def main(_):
  prepare_dirs_and_logger(config)

  if not config.task.lower().startswith('tsp'):
    raise Exception("[!] Task should starts with TSP")

  if config.max_enc_length is None:
    config.max_enc_length = config.max_data_length
  if config.max_dec_length is None:
    config.max_dec_length = config.max_data_length

  rng = np.random.RandomState(config.random_seed)
  tf.set_random_seed(config.random_seed)

  trainer = Trainer(config, rng)
  save_config(config.model_dir, config)

  if config.is_train:
    trainer.train()
  else:
    if not config.load_path:
      raise Exception("[!] You should specify `load_path` to load a pretrained model")
    trainer.test()

  tf.logging.info("Run finished.") 
Example #6
Source File: main.py    From neural-combinatorial-rl-tensorflow with MIT License 6 votes vote down vote up
def main(_):
  prepare_dirs_and_logger(config)

  if not config.task.lower().startswith('tsp'):
    raise Exception("[!] Task should starts with TSP")

  if config.max_enc_length is None:
    config.max_enc_length = config.max_data_length
  if config.max_dec_length is None:
    config.max_dec_length = config.max_data_length

  rng = np.random.RandomState(config.random_seed)
  tf.set_random_seed(config.random_seed)

  trainer = Trainer(config, rng)
  save_config(config.model_dir, config)

  if config.is_train:
    trainer.train()
  else:
    if not config.load_path:
      raise Exception("[!] You should specify `load_path` to load a pretrained model")
    trainer.test()

  tf.logging.info("Run finished.") 
Example #7
Source File: train.py    From pytorch_segmentation with MIT License 6 votes vote down vote up
def main(config, resume):
    train_logger = Logger()

    # DATA LOADERS
    train_loader = get_instance(dataloaders, 'train_loader', config)
    val_loader = get_instance(dataloaders, 'val_loader', config)

    # MODEL
    model = get_instance(models, 'arch', config, train_loader.dataset.num_classes)
    print(f'\n{model}\n')

    # LOSS
    loss = getattr(losses, config['loss'])(ignore_index = config['ignore_index'])

    # TRAINING
    trainer = Trainer(
        model=model,
        loss=loss,
        resume=resume,
        config=config,
        train_loader=train_loader,
        val_loader=val_loader,
        train_logger=train_logger)

    trainer.train() 
Example #8
Source File: main.py    From EverybodyDanceNow_reproduce_pytorch with MIT License 6 votes vote down vote up
def main(is_debug):
    # configs
    import os

    dataset_dir = '../data/face'
    pose_name = '../data/target/pose.npy'
    ckpt_dir = '../checkpoints/face'
    log_dir = '../checkpoints/face/logs'
    batch_num = 10
    batch_size = 10

    image_folder = dataset.ImageFolderDataset(dataset_dir, cache=os.path.join(dataset_dir, 'local.db'))
    face_dataset = dataset.FaceCropDataset(image_folder, pose_name, image_transforms, crop_size=48)  # 48 for 512-frame, 96 for HD frame
    data_loader = DataLoader(face_dataset, batch_size=batch_size,
                             drop_last=True, num_workers=4, shuffle=True)

    generator, discriminator, batch_num = load_models(ckpt_dir, batch_num)

    if is_debug:
        trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader, log_every=1, save_every=1)
    else:
        trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader)
    trainer.train(generator, discriminator, batch_num) 
Example #9
Source File: main.py    From everybody_dance_now_pytorch with GNU Affero General Public License v3.0 6 votes vote down vote up
def main(is_debug):
    # configs
    dataset_dir = '../datasets/cardio_dance_512'
    pose_name = '../datasets/cardio_dance_512/poses.npy'
    ckpt_dir = './checkpoints/dance_test_new_down2_res6'
    log_dir = './logs/dance_test_new_down2_res6'
    batch_num = 0
    batch_size = 64

    image_folder = dataset.ImageFolderDataset(dataset_dir, cache=os.path.join(dataset_dir, 'local.db'))
    face_dataset = dataset.FaceCropDataset(image_folder, pose_name, image_transforms, crop_size=48)  # 48 for 512-frame, 96 for HD frame
    data_loader = DataLoader(face_dataset, batch_size=batch_size,
                             drop_last=True, num_workers=4, shuffle=True)

    generator, discriminator, batch_num = load_models(ckpt_dir, batch_num)

    if is_debug:
        trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader, log_every=1, save_every=1)
    else:
        trainer = Trainer(ckpt_dir, log_dir, face_dataset, data_loader)
    trainer.train(generator, discriminator, batch_num) 
Example #10
Source File: gcrn_main.py    From gconvRNN with MIT License 6 votes vote down vote up
def main(_):

    #Directory generating.. for saving
    prepare_dirs(config)

    #Random seed settings
    rng = np.random.RandomState(config.random_seed)
    tf.set_random_seed(config.random_seed)

    #Model training
    trainer = Trainer(config, rng)
    save_config(config.model_dir, config)
    if config.is_train:
        trainer.train()
    else:
        if not config.load_path:
            raise Exception(
                "[!] You should specify `load_path` to "
                "load a pretrained model")
        trainer.test() 
Example #11
Source File: main.py    From conditional-motion-propagation with MIT License 6 votes vote down vote up
def main(args):
    with open(args.config) as f:
        if version.parse(yaml.version >= "5.1"):
            config = yaml.load(f, Loader=yaml.FullLoader)
        else:
            config = yaml.load(f)

    for k, v in config.items():
        setattr(args, k, v)

    # exp path
    if not hasattr(args, 'exp_path'):
        args.exp_path = os.path.dirname(args.config)

    # dist init
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    dist_init(args.launcher, backend='nccl')

    # train
    trainer = Trainer(args)
    trainer.run() 
Example #12
Source File: train.py    From DDRNet with MIT License 6 votes vote down vote up
def main(args):
    logging.basicConfig(
      level=logging.DEBUG,
      format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s',
      filename=os.path.join(args.logdir, 'logging.txt'))
    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s')
    console.setFormatter(formatter)
    logging.getLogger().addHandler(console)

    filename = os.path.realpath(args.index_file)
    if not os.path.isfile(filename):
        raise ValueError('No such index_file: {}'.format(filename))
    else:
        print("Reading csv file: {}".format(filename))

    with open(filename, "r") as f:
        line = f.readline().strip()
        input_path = line.split(',')[0]
        if not os.path.exists(input_path):
            raise ValueError('Input path in csv not exist: {}'.format(input_path))

    t = trainer.Trainer(filename, args)
    t.fit() 
Example #13
Source File: main.py    From Graph-U-Nets with GNU General Public License v3.0 5 votes vote down vote up
def app_run(args, G_data, fold_idx):
    G_data.use_fold_data(fold_idx)
    net = GNet(G_data.feat_dim, G_data.num_class, args)
    trainer = Trainer(args, net, G_data)
    trainer.train() 
Example #14
Source File: train.py    From Real-time-Text-Detection with Apache License 2.0 5 votes vote down vote up
def main(config):
    train_loader = get_dataloader(config['data_loader']['type'], config['data_loader']['args'])

    criterion = get_loss(config).cuda()

    model = get_model(config)

    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader)
    trainer.train() 
Example #15
Source File: main.py    From neural-question-generation with MIT License 5 votes vote down vote up
def main(args):
    if args.train:
        trainer = Trainer(args)
        trainer.train()
    else:
        beamsearcher = BeamSearcher(args.model_path, args.output_dir)
        beamsearcher.decode() 
Example #16
Source File: main.py    From densenet with MIT License 5 votes vote down vote up
def main(config):

    # ensure directories are setup
    prepare_dirs(config)

    if config.num_gpu > 0:
        torch.cuda.manual_seed(config.random_seed)
        kwargs = {'num_workers': 1, 'pin_memory': True}
    else:
        torch.manual_seed(config.random_seed)
        kwargs = {}

    # instantiate data loaders
    if config.is_train:
        data_loader = get_train_valid_loader(config.data_dir,
            config.dataset, config.batch_size, config.augment, 
            config.random_seed, config.valid_size, config.shuffle, 
            config.show_sample, **kwargs)
    else:
        data_loader = get_test_loader(config.data_dir,
            config.dataset, config.batch_size, config.shuffle, 
            **kwargs)

    # instantiate trainer
    trainer = Trainer(config, data_loader)

    # either train
    if config.is_train:
        save_config(config)
        trainer.train()

    # or load a pretrained model and test
    else:
        trainer.test() 
Example #17
Source File: main.py    From Deep-Mutual-Learning with MIT License 5 votes vote down vote up
def main(config):

    # ensure directories are setup
    prepare_dirs(config)

    # ensure reproducibility
    #torch.manual_seed(config.random_seed)
    kwargs = {}
    if config.use_gpu:
        #torch.cuda.manual_seed_all(config.random_seed)
        kwargs = {'num_workers': config.num_workers, 'pin_memory': config.pin_memory}
        #torch.backends.cudnn.deterministic = True
        
    # instantiate data loaders
    test_data_loader = get_test_loader(
        config.data_dir, config.batch_size, **kwargs
    )
    
    if config.is_train:
        train_data_loader = get_train_loader(
            config.data_dir, config.batch_size,
            config.random_seed, config.shuffle, **kwargs
        )
        data_loader = (train_data_loader, test_data_loader)
    else:
        data_loader = test_data_loader

    # instantiate trainer
    trainer = Trainer(config, data_loader)

    # either train
    if config.is_train:
        save_config(config)
        trainer.train()

    # or load a pretrained model and test
    else:
        trainer.test() 
Example #18
Source File: main.py    From ENAS-pytorch with Apache License 2.0 5 votes vote down vote up
def main(args):  # pylint:disable=redefined-outer-name
    """main: Entry point."""
    utils.prepare_dirs(args)

    torch.manual_seed(args.random_seed)

    if args.num_gpu > 0:
        torch.cuda.manual_seed(args.random_seed)

    if args.network_type == 'rnn':
        dataset = data.text.Corpus(args.data_path)
    elif args.dataset == 'cifar':
        dataset = data.image.Image(args.data_path)
    else:
        raise NotImplementedError(f"{args.dataset} is not supported")

    trnr = trainer.Trainer(args, dataset)

    if args.mode == 'train':
        utils.save_args(args)
        trnr.train()
    elif args.mode == 'derive':
        assert args.load_path != "", ("`--load_path` should be given in "
                                      "`derive` mode")
        trnr.derive()
    elif args.mode == 'test':
        if not args.load_path:
            raise Exception("[!] You should specify `load_path` to load a "
                            "pretrained model")
        trnr.test()
    elif args.mode == 'single':
        if not args.dag_path:
            raise Exception("[!] You should specify `dag_path` to load a dag")
        utils.save_args(args)
        trnr.train(single=True)
    else:
        raise Exception(f"[!] Mode not found: {args.mode}") 
Example #19
Source File: train.py    From DBNet.pytorch with Apache License 2.0 5 votes vote down vote up
def main(config):
    import torch
    from models import build_model, build_loss
    from data_loader import get_dataloader
    from trainer import Trainer
    from post_processing import get_post_processing
    from utils import get_metric
    if torch.cuda.device_count() > 1:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method="env://", world_size=torch.cuda.device_count(), rank=args.local_rank)
        config['distributed'] = True
    else:
        config['distributed'] = False
    config['local_rank'] = args.local_rank

    train_loader = get_dataloader(config['dataset']['train'], config['distributed'])
    assert train_loader is not None
    if 'validate' in config['dataset']:
        validate_loader = get_dataloader(config['dataset']['validate'], False)
    else:
        validate_loader = None

    criterion = build_loss(config['loss']).cuda()

    config['arch']['backbone']['in_channels'] = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1
    model = build_model(config['arch'])

    post_p = get_post_processing(config['post_processing'])
    metric = get_metric(config['metric'])

    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader,
                      post_process=post_p,
                      metric_cls=metric,
                      validate_loader=validate_loader)
    trainer.train() 
Example #20
Source File: main.py    From BEGAN-tensorflow with Apache License 2.0 5 votes vote down vote up
def main(config):
    prepare_dirs_and_logger(config)

    rng = np.random.RandomState(config.random_seed)
    tf.set_random_seed(config.random_seed)

    if config.is_train:
        data_path = config.data_path
        batch_size = config.batch_size
        do_shuffle = True
    else:
        setattr(config, 'batch_size', 64)
        if config.test_data_path is None:
            data_path = config.data_path
        else:
            data_path = config.test_data_path
        batch_size = config.sample_per_image
        do_shuffle = False

    data_loader = get_loader(
            data_path, config.batch_size, config.input_scale_size,
            config.data_format, config.split)
    trainer = Trainer(config, data_loader)

    if config.is_train:
        save_config(config)
        trainer.train()
    else:
        if not config.load_path:
            raise Exception("[!] You should specify `load_path` to load a pretrained model")
        trainer.test() 
Example #21
Source File: main.py    From recurrent-visual-attention with MIT License 5 votes vote down vote up
def main(config):
    utils.prepare_dirs(config)

    # ensure reproducibility
    torch.manual_seed(config.random_seed)
    kwargs = {}
    if config.use_gpu:
        torch.cuda.manual_seed(config.random_seed)
        kwargs = {"num_workers": 1, "pin_memory": True}

    # instantiate data loaders
    if config.is_train:
        dloader = data_loader.get_train_valid_loader(
            config.data_dir,
            config.batch_size,
            config.random_seed,
            config.valid_size,
            config.shuffle,
            config.show_sample,
            **kwargs,
        )
    else:
        dloader = data_loader.get_test_loader(
            config.data_dir, config.batch_size, **kwargs,
        )

    trainer = Trainer(config, dloader)

    # either train
    if config.is_train:
        utils.save_config(config)
        trainer.train()
    # or load a pretrained model and test
    else:
        trainer.test() 
Example #22
Source File: main.py    From simulated-unsupervised-tensorflow with Apache License 2.0 5 votes vote down vote up
def main(_):
  prepare_dirs(config)

  rng = np.random.RandomState(config.random_seed)
  tf.set_random_seed(config.random_seed)

  trainer = Trainer(config, rng)
  save_config(config.model_dir, config)

  if config.is_train:
    trainer.train()
  else:
    if not config.load_path:
      raise Exception("[!] You should specify `load_path` to load a pretrained model")
    trainer.test() 
Example #23
Source File: main.py    From video_captioning_rl with MIT License 5 votes vote down vote up
def main(args):
    prepare_dirs(args)

    torch.manual_seed(args.random_seed)

    if args.num_gpu > 0:
        torch.cuda.manual_seed(args.random_seed)


    if args.network_type == 'seq2seq':
        vocab = data.common_loader.Vocab(args.vocab_file, args.max_vocab_size)
        dataset = {}
        if args.dataset == 'msrvtt':
            dataset['train'] = data.common_loader.MSRVTTBatcher(args, 'train', vocab)
            dataset['val'] = data.common_loader.MSRVTTBatcher(args, 'val', vocab)
            dataset['test'] = data.common_loader.MSRVTTBatcher(args, 'test', vocab)
        else:
            raise Exception(f"Unknown dataset: {args.dataset} for the corresponding network type: {args.network_type}")

    else:
        raise NotImplemented(f"{args.dataset} is not supported")

    trainer = Trainer(args, dataset)

    if args.mode == 'train':
        save_args(args)
        trainer.train()
    else:
        if not args.load_path:
            raise Exception("[!] You should specify `load_path` to load a pretrained model")
        else:
            trainer.test(args.mode) 
Example #24
Source File: main.py    From HistoGAN with GNU General Public License v3.0 5 votes vote down vote up
def main(config):
    prepare_dirs_and_logger(config)

    torch.manual_seed(config.random_seed)
    if config.num_gpu > 0:
        torch.cuda.manual_seed(config.random_seed)

    if config.is_train:
        data_path = config.data_path
        batch_size = config.batch_size
    else:
        if config.test_data_path is None:
            data_path = config.data_path
        else:
            data_path = config.test_data_path
        batch_size = config.sample_per_image

    a_data_loader, b_data_loader = get_loader(
            data_path, batch_size, config.input_scale_size,
            config.num_worker, config.skip_pix2pix_processing)

    trainer = Trainer(config, a_data_loader, b_data_loader)

    if config.is_train:
        save_config(config)
        trainer.train()
    else:
        if not config.load_path:
            raise Exception("[!] You should specify `load_path` to load a pretrained model")
        trainer.test() 
Example #25
Source File: main.py    From R-BERT with Apache License 2.0 5 votes vote down vote up
def main(args):
    init_logger()
    tokenizer = load_tokenizer(args)

    train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
    test_dataset = load_and_cache_examples(args, tokenizer, mode="test")

    trainer = Trainer(args, train_dataset=train_dataset, test_dataset=test_dataset)

    if args.do_train:
        trainer.train()

    if args.do_eval:
        trainer.load_model()
        trainer.evaluate('test') 
Example #26
Source File: convert.py    From ZeroSpeech-TTS-without-T with MIT License 5 votes vote down vote up
def get_trainer(hps_path, model_path, g_mode, enc_mode, clf_path):
	HPS = Hps(hps_path)
	hps = HPS.get_tuple()
	global MIN_LEN
	MIN_LEN = MIN_LEN if hps.enc_mode != 'gumbel_t' else hps.seg_len
	trainer = Trainer(hps, None, g_mode, enc_mode)
	trainer.load_model(model_path, load_model_list=hps.load_model_list, clf_path = clf_path)
	return trainer 
Example #27
Source File: train.py    From pytorch-template with MIT License 5 votes vote down vote up
def main(config):
    logger = config.get_logger('train')

    # setup data_loader instances
    data_loader = config.init_obj('data_loader', module_data)
    valid_data_loader = data_loader.split_validation()

    # build model architecture, then print to console
    model = config.init_obj('arch', module_arch)
    logger.info(model)

    # get function handles of loss and metrics
    criterion = getattr(module_loss, config['loss'])
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = config.init_obj('optimizer', torch.optim, trainable_params)

    lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

    trainer = Trainer(model, criterion, metrics, optimizer,
                      config=config,
                      data_loader=data_loader,
                      valid_data_loader=valid_data_loader,
                      lr_scheduler=lr_scheduler)

    trainer.train() 
Example #28
Source File: main.py    From BigGAN-pytorch with Apache License 2.0 5 votes vote down vote up
def main(config):
    # For fast training
    cudnn.benchmark = True


    config.n_class = len(glob.glob(os.path.join(config.image_path, '*/')))
    print('number class:', config.n_class)
    # Data loader
    data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
                             config.batch_size, shuf=config.train)

    # Create directories if not exist
    make_folder(config.model_save_path, config.version)
    make_folder(config.sample_path, config.version)
    make_folder(config.log_path, config.version)
    make_folder(config.attn_path, config.version)


    print('config data_loader and build logs folder')

    if config.train:
        if config.model=='sagan':
            trainer = Trainer(data_loader.loader(), config)
        elif config.model == 'qgan':
            trainer = qgan_trainer(data_loader.loader(), config)
        trainer.train()
    else:
        tester = Tester(data_loader.loader(), config)
        tester.test() 
Example #29
Source File: main.py    From CausalGAN with MIT License 5 votes vote down vote up
def get_trainer(config):
    print('tf: resetting default graph!')
    tf.reset_default_graph()

    #tf.set_random_seed(config.random_seed)
    #np.random.seed(22)

    print('Using data_type ',config.data_type)
    trainer=Trainer(config,config.data_type)
    print('built trainer successfully')

    tf.logging.set_verbosity(tf.logging.ERROR)

    return trainer 
Example #30
Source File: train.py    From crnn.gluon with Apache License 2.0 4 votes vote down vote up
def main(config):
    from mxnet import nd
    from mxnet.gluon.loss import CTCLoss

    from models import get_model
    from data_loader import get_dataloader
    from trainer import Trainer
    from utils import get_ctx, load

    if os.path.isfile(config['dataset']['alphabet']):
        config['dataset']['alphabet'] = ''.join(load(config['dataset']['alphabet']))

    prediction_type = config['arch']['args']['prediction']['type']
    num_class = len(config['dataset']['alphabet'])

    # loss 设置
    if prediction_type == 'CTC':
        criterion = CTCLoss()
    else:
        raise NotImplementedError

    ctx = get_ctx(config['trainer']['gpus'])
    model = get_model(num_class, ctx, config['arch']['args'])
    model.hybridize()
    model.initialize(ctx=ctx)

    img_h, img_w = 32, 100
    for process in config['dataset']['train']['dataset']['args']['pre_processes']:
        if process['type'] == "Resize":
            img_h = process['args']['img_h']
            img_w = process['args']['img_w']
            break
    img_channel = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1
    sample_input = nd.zeros((2, img_channel, img_h, img_w), ctx[0])
    num_label = model.get_batch_max_length(sample_input)

    train_loader = get_dataloader(config['dataset']['train'], num_label, config['dataset']['alphabet'])
    assert train_loader is not None
    if 'validate' in config['dataset']:
        validate_loader = get_dataloader(config['dataset']['validate'], num_label, config['dataset']['alphabet'])
    else:
        validate_loader = None

    config['lr_scheduler']['args']['step'] *= len(train_loader)

    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader,
                      validate_loader=validate_loader,
                      sample_input=sample_input,
                      ctx=ctx)
    trainer.train()