Python torch.optim.lr_scheduler.ExponentialLR() Examples

The following are 25 code examples of torch.optim.lr_scheduler.ExponentialLR(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.optim.lr_scheduler , or try the search function .
Example #1
Source File: ssds_train.py    From ssds.pytorch with MIT License 6 votes vote down vote up
def configure_lr_scheduler(self, optimizer, cfg):
        if cfg.SCHEDULER == 'step':
            scheduler = lr_scheduler.StepLR(optimizer, step_size=cfg.STEPS[0], gamma=cfg.GAMMA)
        elif cfg.SCHEDULER == 'multi_step':
            scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=cfg.STEPS, gamma=cfg.GAMMA)
        elif cfg.SCHEDULER == 'exponential':
            scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=cfg.GAMMA)
        elif cfg.SCHEDULER == 'SGDR':
            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.MAX_EPOCHS)
        else:
            AssertionError('scheduler can not be recognized.')
        return scheduler 
Example #2
Source File: load_opt_sched.py    From amortized-variational-filtering with MIT License 6 votes vote down vote up
def load_opt_sched(train_config, model):

    inf_params = model.inference_parameters()
    gen_params = model.generative_parameters()

    inf_opt = Optimizer(train_config['optimizer'], inf_params,
                        lr=train_config['inference_learning_rate'],
                        clip_grad_norm=train_config['clip_grad_norm'])
    inf_sched = ExponentialLR(inf_opt.opt, 0.999)

    gen_opt = Optimizer(train_config['optimizer'], gen_params,
                        lr=train_config['generation_learning_rate'],
                        clip_grad_norm=train_config['clip_grad_norm'])
    gen_sched = ExponentialLR(gen_opt.opt, 0.999)

    return (inf_opt, gen_opt), (inf_sched, gen_sched) 
Example #3
Source File: test_lr_scheduler_selector.py    From Auto-PyTorch with Apache License 2.0 5 votes vote down vote up
def test_lr_scheduler_selector(self):
        pipeline = Pipeline([
            NetworkSelector(),
            OptimizerSelector(),
            LearningrateSchedulerSelector(),
        ])

        net_selector = pipeline[NetworkSelector.get_name()]
        net_selector.add_network("mlpnet", MlpNet)
        net_selector.add_network("shapedmlpnet", ShapedMlpNet)
        net_selector.add_final_activation('none', nn.Sequential())

        opt_selector = pipeline[OptimizerSelector.get_name()]
        opt_selector.add_optimizer("adam", AdamOptimizer)
        opt_selector.add_optimizer("sgd", SgdOptimizer)

        lr_scheduler_selector = pipeline[LearningrateSchedulerSelector.get_name()]
        lr_scheduler_selector.add_lr_scheduler("step", SchedulerStepLR)
        lr_scheduler_selector.add_lr_scheduler("exp", SchedulerExponentialLR)


        pipeline_config = pipeline.get_pipeline_config()
        pipeline_config["random_seed"] = 42
        hyper_config = pipeline.get_hyperparameter_search_space().sample_configuration()

        pipeline.fit_pipeline(hyperparameter_config=hyper_config, pipeline_config=pipeline_config,
                                X=torch.rand(3,3), Y=torch.rand(3, 2), embedding=nn.Sequential(), training_techniques=[], train_indices=np.array([0, 1, 2]))

        sampled_lr_scheduler = pipeline[lr_scheduler_selector.get_name()].fit_output['training_techniques'][0].training_components['lr_scheduler']

        self.assertIn(type(sampled_lr_scheduler), [lr_scheduler.ExponentialLR, lr_scheduler.StepLR]) 
Example #4
Source File: callbacks.py    From steppy-toolkit with MIT License 5 votes vote down vote up
def set_params(self, transformer, validation_datagen):
        self.validation_datagen = validation_datagen
        self.model = transformer.model
        self.optimizer = transformer.optimizer
        self.loss_function = transformer.loss_function
        self.lr_scheduler = ExponentialLR(self.optimizer, self.gamma, last_epoch=-1) 
Example #5
Source File: lr_schedulers.py    From Auto-PyTorch with Apache License 2.0 5 votes vote down vote up
def _get_scheduler(self, optimizer, config):
        return lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=config['gamma'], last_epoch=-1) 
Example #6
Source File: lr_scheduler.py    From XenonPy with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, *, gamma, last_epoch=-1):
        """Decays the learning rate of each parameter group by gamma every epoch.
        When last_epoch=-1, sets initial lr as lr.

        Args:
            gamma (float): Multiplicative factor of learning rate decay.
            last_epoch (int): The index of last epoch. Default: -1.
        """
        super().__init__(lr_scheduler.ExponentialLR, gamma=gamma, last_epoch=last_epoch) 
Example #7
Source File: lr_scheduler.py    From homura with Apache License 2.0 5 votes vote down vote up
def ExponentialLR(T_max,
                  eta_min=0,
                  last_epoch=-1):
    return partial(_lr_scheduler.ExponentialLR, **locals()) 
Example #8
Source File: utils.py    From chemprop with MIT License 5 votes vote down vote up
def build_lr_scheduler(optimizer: Optimizer, args: Namespace, total_epochs: List[int] = None) -> _LRScheduler:
    """
    Builds a learning rate scheduler.

    :param optimizer: The Optimizer whose learning rate will be scheduled.
    :param args: Arguments.
    :return: An initialized learning rate scheduler.
    """
    # Learning rate scheduler
    if args.scheduler == 'noam':
        return NoamLR(
            optimizer=optimizer,
            warmup_epochs=args.warmup_epochs,
            total_epochs=total_epochs or [args.epochs] * args.num_lrs,
            steps_per_epoch=args.train_data_size // args.batch_size,
            init_lr=args.init_lr,
            max_lr=args.max_lr,
            final_lr=args.final_lr
        )

    if args.scheduler == 'none':
        return MockLR(optimizer=optimizer, lr=args.init_lr)

    if args.scheduler == 'decay':
        return ExponentialLR(optimizer, args.lr_decay_rate)

    raise ValueError(f'Learning rate scheduler "{args.scheduler}" not supported.') 
Example #9
Source File: train_siamfc.py    From SiamDW with MIT License 5 votes vote down vote up
def lr_decay(cfg, optimizer):
    if cfg.SIAMFC.TRAIN.LR_POLICY == 'exp':
        scheduler = ExponentialLR(optimizer, gamma=0.8685)
    elif cfg.SIAMFC.TRAIN.LR_POLICY == 'cos':
        scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)
    elif cfg.SIAMFC.TRAIN.LR_POLICY == 'Reduce':
        scheduler = ReduceLROnPlateau(optimizer, patience=5)
    elif cfg.SIAMFC.TRAIN.LR_POLICY == 'log':
        scheduler = np.logspace(math.log10(cfg.SIAMFC.TRAIN.LR), math.log10(cfg.SIAMFC.TRAIN.LR_END), cfg.SIAMFC.TRAIN.END_EPOCH)
    else:
        raise ValueError('unsupported learing rate scheduler')

    return scheduler 
Example #10
Source File: train_siamrpn.py    From SiamDW with MIT License 5 votes vote down vote up
def lr_decay(cfg, optimizer):
    if cfg.SIAMRPN.TRAIN.LR_POLICY == 'exp':
        scheduler = ExponentialLR(optimizer, gamma=0.8685)
    elif cfg.SIAMRPN.TRAIN.LR_POLICY == 'cos':
        scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)
    elif cfg.SIAMRPN.TRAIN.LR_POLICY == 'Reduce':
        scheduler = ReduceLROnPlateau(optimizer, patience=5)
    elif cfg.SIAMRPN.TRAIN.LR_POLICY == 'log':
        scheduler = np.logspace(math.log10(cfg.SIAMRPN.TRAIN.LR), math.log10(cfg.SIAMRPN.TRAIN.LR_END), cfg.SIAMRPN.TRAIN.END_EPOCH)
    else:
        raise ValueError('unsupported learing rate scheduler')

    return scheduler 
Example #11
Source File: callbacks.py    From open-solution-mapping-challenge with MIT License 5 votes vote down vote up
def set_params(self, transformer, validation_datagen, *args, **kwargs):
        self.validation_datagen = validation_datagen
        self.model = transformer.model
        self.optimizer = transformer.optimizer
        self.loss_function = transformer.loss_function
        self.lr_scheduler = ExponentialLR(self.optimizer, self.gamma, last_epoch=-1) 
Example #12
Source File: scheduler_factory.py    From kaggle-hpa with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def exponential(optimizer, last_epoch, gamma=0.995, **_):
  return lr_scheduler.ExponentialLR(optimizer, gamma=gamma, last_epoch=last_epoch) 
Example #13
Source File: imagenet.py    From pytorch-lightning with Apache License 2.0 5 votes vote down vote up
def configure_optimizers(self):
        optimizer = optim.SGD(
            self.parameters(),
            lr=self.lr,
            momentum=self.momentum,
            weight_decay=self.weight_decay
        )
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
        return [optimizer], [scheduler] 
Example #14
Source File: callbacks.py    From open-solution-data-science-bowl-2018 with MIT License 5 votes vote down vote up
def set_params(self, transformer, validation_datagen, *args, **kwargs):
        self.validation_datagen = validation_datagen
        self.model = transformer.model
        self.optimizer = transformer.optimizer
        self.loss_function = transformer.loss_function
        self.lr_scheduler = ExponentialLR(self.optimizer, self.gamma, last_epoch=-1) 
Example #15
Source File: load_opt_sched.py    From amortized-variational-filtering with MIT License 5 votes vote down vote up
def load_sched(optimizers, last_epoch):
    inf_opt, gen_opt = optimizers
    inf_sched = ExponentialLR(inf_opt.opt, 0.999, last_epoch=last_epoch)
    gen_sched = ExponentialLR(gen_opt.opt, 0.999, last_epoch=last_epoch)
    return (inf_sched, gen_sched) 
Example #16
Source File: siamfc.py    From siamfc-pytorch with MIT License 5 votes vote down vote up
def __init__(self, net_path=None, **kwargs):
        super(TrackerSiamFC, self).__init__('SiamFC', True)
        self.cfg = self.parse_args(**kwargs)

        # setup GPU device if available
        self.cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.cuda else 'cpu')

        # setup model
        self.net = Net(
            backbone=AlexNetV1(),
            head=SiamFC(self.cfg.out_scale))
        ops.init_weights(self.net)
        
        # load checkpoint if provided
        if net_path is not None:
            self.net.load_state_dict(torch.load(
                net_path, map_location=lambda storage, loc: storage))
        self.net = self.net.to(self.device)

        # setup criterion
        self.criterion = BalancedLoss()

        # setup optimizer
        self.optimizer = optim.SGD(
            self.net.parameters(),
            lr=self.cfg.initial_lr,
            weight_decay=self.cfg.weight_decay,
            momentum=self.cfg.momentum)
        
        # setup lr scheduler
        gamma = np.power(
            self.cfg.ultimate_lr / self.cfg.initial_lr,
            1.0 / self.cfg.epoch_num)
        self.lr_scheduler = ExponentialLR(self.optimizer, gamma) 
Example #17
Source File: lr_schedulers.py    From argus with MIT License 5 votes vote down vote up
def __init__(self, gamma, step_on_iteration=False):
        super().__init__(
            lambda opt: _scheduler.ExponentialLR(opt,
                                                 gamma),
            step_on_iteration=step_on_iteration
        ) 
Example #18
Source File: callbacks.py    From open-solution-ship-detection with MIT License 5 votes vote down vote up
def set_params(self, transformer, validation_datagen, *args, **kwargs):
        self.validation_datagen = validation_datagen
        self.model = transformer.model
        self.optimizer = transformer.optimizer
        self.loss_function = transformer.loss_function
        self.lr_scheduler = ExponentialLR(self.optimizer, self.gamma, last_epoch=-1) 
Example #19
Source File: callbacks.py    From open-solution-salt-identification with MIT License 5 votes vote down vote up
def set_params(self, transformer, validation_datagen, *args, **kwargs):
        self.validation_datagen = validation_datagen
        self.model = transformer.model
        self.optimizer = transformer.optimizer
        self.loss_function = transformer.loss_function
        self.lr_scheduler = ExponentialLR(self.optimizer, self.gamma, last_epoch=-1) 
Example #20
Source File: scheduler_factory.py    From kaggle-humpback with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def exponential(optimizer, last_epoch, gamma=0.995, **_):
  return lr_scheduler.ExponentialLR(optimizer, gamma=gamma, last_epoch=last_epoch) 
Example #21
Source File: train.py    From dsb2018_topcoders with MIT License 4 votes vote down vote up
def train(ds, val_ds, fold, train_idx, val_idx, config, num_workers=0, transforms=None, val_transforms=None, num_channels_changed=False, final_changed=False, cycle=False):
    os.makedirs(os.path.join('..', 'weights'), exist_ok=True)
    os.makedirs(os.path.join('..', 'logs'), exist_ok=True)

    save_path = os.path.join('..', 'weights', config.folder)
    model = models[config.network](num_classes=config.num_classes, num_channels=config.num_channels)
    estimator = Estimator(model, optimizers[config.optimizer], save_path,
                          config=config, num_channels_changed=num_channels_changed, final_changed=final_changed)

    estimator.lr_scheduler = ExponentialLR(estimator.optimizer, config.lr_gamma)#LRStepScheduler(estimator.optimizer, config.lr_steps)
    callbacks = [
        ModelSaver(1, ("fold"+str(fold)+"_best.pth"), best_only=True),
        ModelSaver(1, ("fold"+str(fold)+"_last.pth"), best_only=False),
        CheckpointSaver(1, ("fold"+str(fold)+"_checkpoint.pth")),
        # LRDropCheckpointSaver(("fold"+str(fold)+"_checkpoint_e{epoch}.pth")),
        ModelFreezer(),
        # EarlyStopper(10),
        TensorBoard(os.path.join('..', 'logs', config.folder, 'fold{}'.format(fold)))
    ]
    # if not num_channels_changed:
    #     callbacks.append(LastCheckpointSaver("fold"+str(fold)+"_checkpoint_rgb.pth", config.nb_epoch))

    hard_neg_miner = None#HardNegativeMiner(rate=10)
    # metrics = [('dr', dice_round)]

    trainer = PytorchTrain(estimator,
                           fold=fold,
                           callbacks=callbacks,
                           hard_negative_miner=hard_neg_miner)

    train_loader = PytorchDataLoader(TrainDataset(ds, train_idx, config, transforms=transforms),
                                     batch_size=config.batch_size,
                                     shuffle=True,
                                     drop_last=True,
                                     num_workers=num_workers,
                                     pin_memory=True)
    val_loader = PytorchDataLoader(ValDataset(val_ds, val_idx, config, transforms=val_transforms),
                                   batch_size=1,
                                   shuffle=False,
                                   drop_last=False,
                                   num_workers=num_workers,
                                   pin_memory=True)

    trainer.fit(train_loader, val_loader, config.nb_epoch) 
Example #22
Source File: main.py    From TuckER with MIT License 4 votes vote down vote up
def train_and_eval(self):
        print("Training the TuckER model...")
        self.entity_idxs = {d.entities[i]:i for i in range(len(d.entities))}
        self.relation_idxs = {d.relations[i]:i for i in range(len(d.relations))}

        train_data_idxs = self.get_data_idxs(d.train_data)
        print("Number of training data points: %d" % len(train_data_idxs))

        model = TuckER(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs)
        if self.cuda:
            model.cuda()
        model.init()
        opt = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
        if self.decay_rate:
            scheduler = ExponentialLR(opt, self.decay_rate)

        er_vocab = self.get_er_vocab(train_data_idxs)
        er_vocab_pairs = list(er_vocab.keys())

        print("Starting training...")
        for it in range(1, self.num_iterations+1):
            start_train = time.time()
            model.train()    
            losses = []
            np.random.shuffle(er_vocab_pairs)
            for j in range(0, len(er_vocab_pairs), self.batch_size):
                data_batch, targets = self.get_batch(er_vocab, er_vocab_pairs, j)
                opt.zero_grad()
                e1_idx = torch.tensor(data_batch[:,0])
                r_idx = torch.tensor(data_batch[:,1])  
                if self.cuda:
                    e1_idx = e1_idx.cuda()
                    r_idx = r_idx.cuda()
                predictions = model.forward(e1_idx, r_idx)
                if self.label_smoothing:
                    targets = ((1.0-self.label_smoothing)*targets) + (1.0/targets.size(1))           
                loss = model.loss(predictions, targets)
                loss.backward()
                opt.step()
                losses.append(loss.item())
            if self.decay_rate:
                scheduler.step()
            print(it)
            print(time.time()-start_train)    
            print(np.mean(losses))
            model.eval()
            with torch.no_grad():
                print("Validation:")
                self.evaluate(model, d.valid_data)
                if not it%2:
                    print("Test:")
                    start_test = time.time()
                    self.evaluate(model, d.test_data)
                    print(time.time()-start_test) 
Example #23
Source File: test_param_scheduler.py    From ignite with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def test_scheduler_with_param_groups():
    def _test(lr_scheduler, optimizer):
        num_iterations = 10
        max_epochs = 20

        state_dict = lr_scheduler.state_dict()

        trainer = Engine(lambda engine, batch: None)

        @trainer.on(Events.ITERATION_COMPLETED)
        def save_lr():
            lrs.append((optimizer.param_groups[0]["lr"], optimizer.param_groups[1]["lr"]))

        trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

        data = [0] * num_iterations

        for _ in range(2):
            lrs = []
            trainer.run(data, max_epochs=max_epochs)
            assert [lr[0] for lr in lrs] == pytest.approx([lr[1] for lr in lrs])
            lr_scheduler.load_state_dict(state_dict)

    t1 = torch.zeros([1], requires_grad=True)
    t2 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([{"params": t1, "lr": 0.1}, {"params": t2, "lr": 0.1}])

    lr_scheduler = LinearCyclicalScheduler(optimizer, "lr", start_value=1.0, end_value=0.0, cycle_size=10)
    _test(lr_scheduler, optimizer)

    lr_scheduler = PiecewiseLinear(
        optimizer, "lr", milestones_values=[(5, 0.5), (15, 1.0), (25, 0.0), (35, 1.0), (40, 0.5)]
    )
    _test(lr_scheduler, optimizer)

    lr_scheduler = CosineAnnealingScheduler(optimizer, "lr", start_value=0.0, end_value=1.0, cycle_size=10)
    _test(lr_scheduler, optimizer)

    torch_lr_scheduler = ExponentialLR(optimizer, gamma=0.98)
    _test(LRScheduler(torch_lr_scheduler), optimizer)

    torch_lr_scheduler = StepLR(optimizer, step_size=50, gamma=0.5)
    _test(LRScheduler(torch_lr_scheduler), optimizer) 
Example #24
Source File: utils.py    From inplace_abn with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def create_optimizer(optimizer_config, model):
    """Creates optimizer and schedule from configuration

    Parameters
    ----------
    optimizer_config : dict
        Dictionary containing the configuration options for the optimizer.
    model : Model
        The network model.

    Returns
    -------
    optimizer : Optimizer
        The optimizer.
    scheduler : LRScheduler
        The learning rate scheduler.
    """
    if optimizer_config["classifier_lr"] != -1:
        # Separate classifier parameters from all others
        net_params = []
        classifier_params = []
        for k, v in model.named_parameters():
            if k.find("fc") != -1:
                classifier_params.append(v)
            else:
                net_params.append(v)
        params = [
            {"params": net_params},
            {"params": classifier_params, "lr": optimizer_config["classifier_lr"]},
        ]
    else:
        params = model.parameters()

    if optimizer_config["type"] == "SGD":
        optimizer = optim.SGD(params,
                              lr=optimizer_config["learning_rate"],
                              momentum=optimizer_config["momentum"],
                              weight_decay=optimizer_config["weight_decay"],
                              nesterov=optimizer_config["nesterov"])
    elif optimizer_config["type"] == "Adam":
        optimizer = optim.Adam(params,
                               lr=optimizer_config["learning_rate"],
                               weight_decay=optimizer_config["weight_decay"])
    else:
        raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"]))

    if optimizer_config["schedule"]["type"] == "step":
        scheduler = lr_scheduler.StepLR(optimizer, **optimizer_config["schedule"]["params"])
    elif optimizer_config["schedule"]["type"] == "multistep":
        scheduler = lr_scheduler.MultiStepLR(optimizer, **optimizer_config["schedule"]["params"])
    elif optimizer_config["schedule"]["type"] == "exponential":
        scheduler = lr_scheduler.ExponentialLR(optimizer, **optimizer_config["schedule"]["params"])
    elif optimizer_config["schedule"]["type"] == "constant":
        scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0)
    elif optimizer_config["schedule"]["type"] == "linear":
        def linear_lr(it):
            return it * optimizer_config["schedule"]["params"]["alpha"] + optimizer_config["schedule"]["params"]["beta"]

        scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr)

    return optimizer, scheduler 
Example #25
Source File: capsulenet.py    From CapsNet-Pytorch with MIT License 4 votes vote down vote up
def train(model, train_loader, test_loader, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param train_loader: torch.utils.data.DataLoader for training data
    :param test_loader: torch.utils.data.DataLoader for test data
    :param args: arguments
    :return: The trained model
    """
    print('Begin Training' + '-'*70)
    from time import time
    import csv
    logfile = open(args.save_dir + '/log.csv', 'w')
    logwriter = csv.DictWriter(logfile, fieldnames=['epoch', 'loss', 'val_loss', 'val_acc'])
    logwriter.writeheader()

    t0 = time()
    optimizer = Adam(model.parameters(), lr=args.lr)
    lr_decay = lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay)
    best_val_acc = 0.
    for epoch in range(args.epochs):
        model.train()  # set to training mode
        lr_decay.step()  # decrease the learning rate by multiplying a factor `gamma`
        ti = time()
        training_loss = 0.0
        for i, (x, y) in enumerate(train_loader):  # batch training
            y = torch.zeros(y.size(0), 10).scatter_(1, y.view(-1, 1), 1.)  # change to one-hot coding
            x, y = Variable(x.cuda()), Variable(y.cuda())  # convert input data to GPU Variable

            optimizer.zero_grad()  # set gradients of optimizer to zero
            y_pred, x_recon = model(x, y)  # forward
            loss = caps_loss(y, y_pred, x, x_recon, args.lam_recon)  # compute loss
            loss.backward()  # backward, compute all gradients of loss w.r.t all Variables
            training_loss += loss.data[0] * x.size(0)  # record the batch loss
            optimizer.step()  # update the trainable parameters with computed gradients

        # compute validation loss and acc
        val_loss, val_acc = test(model, test_loader, args)
        logwriter.writerow(dict(epoch=epoch, loss=training_loss / len(train_loader.dataset),
                                val_loss=val_loss, val_acc=val_acc))
        print("==> Epoch %02d: loss=%.5f, val_loss=%.5f, val_acc=%.4f, time=%ds"
              % (epoch, training_loss / len(train_loader.dataset),
                 val_loss, val_acc, time() - ti))
        if val_acc > best_val_acc:  # update best validation acc and save model
            best_val_acc = val_acc
            torch.save(model.state_dict(), args.save_dir + '/epoch%d.pkl' % epoch)
            print("best val_acc increased to %.4f" % best_val_acc)
    logfile.close()
    torch.save(model.state_dict(), args.save_dir + '/trained_model.pkl')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)
    print("Total time = %ds" % (time() - t0))
    print('End Training' + '-' * 70)
    return model