Python torch.optim.lr_scheduler._LRScheduler() Examples

The following are 30 code examples of torch.optim.lr_scheduler._LRScheduler(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.optim.lr_scheduler , or try the search function .
Example #1
Source File: model.py    From SlowFast-Network-pytorch with MIT License 6 votes vote down vote up
def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
        checkpoint = torch.load(path_to_checkpoint)
        self.load_state_dict(checkpoint['state_dict'])

        # model_dict = self.state_dict()
        # pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict}  # filter out unnecessary keys
        # model_dict.update(pretrained_dict)
        # self.load_state_dict(model_dict)
        # torch.nn.DataParallel(self).cuda()
        #step = checkpoint['step']
        step=0
        # if optimizer is not None:
        #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # if scheduler is not None:
        #     scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return step 
Example #2
Source File: logging_manager.py    From emmental with MIT License 6 votes vote down vote up
def checkpoint_model(
        self,
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> None:
        """Checkpoint the model.

        Args:
          model: The model to checkpoint.
          optimizer: The optimizer used during training process.
          lr_scheduler: Learning rate scheduler.
          metric_dict: the metric dict.
        """
        self.checkpointer.checkpoint(
            self.unit_total, model, optimizer, lr_scheduler, metric_dict
        ) 
Example #3
Source File: optimizer.py    From kge with MIT License 6 votes vote down vote up
def __init__(self, config: Config, optimizer):
        super().__init__(config)
        name = config.get("train.lr_scheduler")
        args = config.get("train.lr_scheduler_args")
        self._lr_scheduler: _LRScheduler = None
        if name != "":
            try:
                self._lr_scheduler = getattr(torch.optim.lr_scheduler, name)(
                    optimizer, **args
                )
            except Exception as e:
                raise ValueError(
                    (
                        "Invalid LR scheduler {} or scheduler arguments {}. "
                        "Error: {}"
                    ).format(name, args, e)
                )

        self._metric_based = name in ["ReduceLROnPlateau"] 
Example #4
Source File: param_scheduler.py    From LaSO with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, lr_scheduler, save_history=False, **kwds):

        if not isinstance(lr_scheduler, _LRScheduler):
            raise TypeError("Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
                            "but given {}".format(type(lr_scheduler)))

        if len(lr_scheduler.optimizer.param_groups) > 1:
            raise ValueError("Optimizer passed to lr_scheduler should have a single param group, "
                             "but currently there are {} param groups".format(len(lr_scheduler.optimizer.param_groups)))

        self.lr_scheduler = lr_scheduler
        super(LRScheduler, self).__init__(
            optimizer=self.lr_scheduler.optimizer,
            param_name='lr',
            save_history=save_history
        ) 
Example #5
Source File: param_scheduler.py    From LaSO with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def simulate_values(cls, num_events, lr_scheduler, **kwargs):
        """Method to simulate scheduled values during num_events events.

        Args:
            num_events (int): number of events during the simulation.
            lr_scheduler (subclass of `torch.optim.lr_scheduler._LRScheduler`): lr_scheduler object to wrap.

        Returns:
            list of pairs: [event_index, value]

        """
        copy_lr_scheduler = LRScheduler._copy_lr_scheduler(lr_scheduler)
        values = []
        scheduler = cls(save_history=False, lr_scheduler=copy_lr_scheduler)
        for i in range(num_events):
            scheduler(engine=None)
            values.append([i, scheduler.optimizer_param_groups[0][scheduler.param_name]])

        return values 
Example #6
Source File: nnUNetTrainerV2_SGD_ReduceOnPlateau.py    From nnUNet with Apache License 2.0 5 votes vote down vote up
def maybe_update_lr(self, epoch=None):
        # maybe update learning rate
        if self.lr_scheduler is not None:
            assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))

            if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
                # lr scheduler is updated with moving average val loss. should be more robust
                if self.epoch > 0:  # otherwise self.train_loss_MA is None
                    self.lr_scheduler.step(self.train_loss_MA)
            else:
                self.lr_scheduler.step(self.epoch + 1)
        self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) 
Example #7
Source File: model.py    From easy-faster-rcnn.pytorch with MIT License 5 votes vote down vote up
def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
        checkpoint = torch.load(path_to_checkpoint)
        self.load_state_dict(checkpoint['state_dict'])
        step = checkpoint['step']
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler is not None:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return step 
Example #8
Source File: nnUNetTrainerV2_Adam_ReduceOnPlateau.py    From nnUNet with Apache License 2.0 5 votes vote down vote up
def maybe_update_lr(self, epoch=None):
        # maybe update learning rate
        if self.lr_scheduler is not None:
            assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))

            if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
                # lr scheduler is updated with moving average val loss. should be more robust
                if self.epoch > 0 and self.train_loss_MA is not None:  # otherwise self.train_loss_MA is None
                    self.lr_scheduler.step(self.train_loss_MA)
            else:
                self.lr_scheduler.step(self.epoch + 1)
        self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) 
Example #9
Source File: param_scheduler.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, lr_scheduler, save_history=False, **kwargs):

        if not isinstance(lr_scheduler, _LRScheduler):
            raise TypeError(
                "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
                "but given {}".format(type(lr_scheduler))
            )

        self.lr_scheduler = lr_scheduler
        super(LRScheduler, self).__init__(
            optimizer=self.lr_scheduler.optimizer, param_name="lr", save_history=save_history
        )
        self._state_attrs += [
            "lr_scheduler",
        ] 
Example #10
Source File: param_scheduler.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def simulate_values(cls, num_events, lr_scheduler, **kwargs):
        """Method to simulate scheduled values during num_events events.

        Args:
            num_events (int): number of events during the simulation.
            lr_scheduler (subclass of `torch.optim.lr_scheduler._LRScheduler`): lr_scheduler object to wrap.

        Returns:
            list of pairs: [event_index, value]

        """

        if not isinstance(lr_scheduler, _LRScheduler):
            raise TypeError(
                "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
                "but given {}".format(type(lr_scheduler))
            )

        # This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which
        # should be replicated in order to simulate LR values and
        # not perturb original scheduler.
        with tempfile.TemporaryDirectory() as tmpdirname:
            cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
            obj = {
                "lr_scheduler": lr_scheduler.state_dict(),
                "optimizer": lr_scheduler.optimizer.state_dict(),
            }
            torch.save(obj, cache_filepath.as_posix())

            values = []
            scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs)
            for i in range(num_events):
                params = [p[scheduler.param_name] for p in scheduler.optimizer_param_groups]
                values.append([i] + params)
                scheduler(engine=None)

            obj = torch.load(cache_filepath.as_posix())
            lr_scheduler.load_state_dict(obj["lr_scheduler"])
            lr_scheduler.optimizer.load_state_dict(obj["optimizer"])

            return values 
Example #11
Source File: param_scheduler.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def simulate_values(cls, num_events, schedulers, **kwargs):
        """Method to simulate scheduled values during num_events events.

        Args:
            num_events (int): number of events during the simulation.
            lr_schedulers (subclass of `torch.optim.lr_scheduler._LRScheduler`): lr_scheduler object to wrap.

        Returns:
            list of pairs: [event_index, value]

        """

        # This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which
        # should be replicated in order to simulate LR values and
        # not perturb original scheduler.
        with tempfile.TemporaryDirectory() as tmpdirname:
            cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
            objs = {"lr_scheduler_{}".format(i): s.state_dict() for i, s in enumerate(schedulers)}
            # all schedulers should be related to the same optimizer
            objs["optimizer"] = schedulers[0].optimizer.state_dict()

            torch.save(objs, cache_filepath.as_posix())

            values = []
            scheduler = cls(schedulers=schedulers, **kwargs)
            for i in range(num_events):
                params = scheduler.get_param()
                values.append([i] + params)
                scheduler(engine=None)

            objs = torch.load(cache_filepath.as_posix())
            for i, s in enumerate(schedulers):
                s.load_state_dict(objs["lr_scheduler_{}".format(i)])
                s.optimizer.load_state_dict(objs["optimizer"])

            return values 
Example #12
Source File: utils.py    From chemprop with MIT License 5 votes vote down vote up
def build_lr_scheduler(optimizer: Optimizer, args: Namespace, total_epochs: List[int] = None) -> _LRScheduler:
    """
    Builds a learning rate scheduler.

    :param optimizer: The Optimizer whose learning rate will be scheduled.
    :param args: Arguments.
    :return: An initialized learning rate scheduler.
    """
    # Learning rate scheduler
    if args.scheduler == 'noam':
        return NoamLR(
            optimizer=optimizer,
            warmup_epochs=args.warmup_epochs,
            total_epochs=total_epochs or [args.epochs] * args.num_lrs,
            steps_per_epoch=args.train_data_size // args.batch_size,
            init_lr=args.init_lr,
            max_lr=args.max_lr,
            final_lr=args.final_lr
        )

    if args.scheduler == 'none':
        return MockLR(optimizer=optimizer, lr=args.init_lr)

    if args.scheduler == 'decay':
        return ExponentialLR(optimizer, args.lr_decay_rate)

    raise ValueError(f'Learning rate scheduler "{args.scheduler}" not supported.') 
Example #13
Source File: trainers.py    From homura with Apache License 2.0 5 votes vote down vote up
def set_scheduler(self):
        """ Set scheduler(s) for optimizer(s). You can override as ::

            class YourTrainer(TrainerBase):
                def set_scheduler(self):
                    self.scheduler = torch.optim.lr_scheduler.Foo(self.optimizer)

        :return:
        """

        scheduler = self.scheduler
        if scheduler is not None and self.optimizer is None:
            raise TypeError("Optimizer is not set, so scheduler cannot be set")

        if isinstance(scheduler, Scheduler) or scheduler is None:
            self.scheduler = scheduler

        elif isinstance(scheduler, Partial):
            if not issubclass(scheduler.func, Scheduler):
                raise TypeError(f"`scheduler.func` is expected to be subclass of `_LRScheduler`"
                                f" but got {type(scheduler.func)}")
            self.scheduler = scheduler(self.optimizer)

        elif isinstance(scheduler, dict):
            if not isinstance(self.optimizer, StepDict):
                raise TypeError("When `scheduler` is `dict`, `optimizer` is also needs to be `dict`")
            _scheduler = {}
            for k, v in scheduler.items():
                if isinstance(v, Partial):
                    v = v(self.optimizer[k])
                _scheduler[k] = v
            self.scheduler = StepDict(Scheduler, **_scheduler)

        else:
            raise TypeError(f"Unexpected type {type(scheduler)} for `scheduler`") 
Example #14
Source File: trainers.py    From homura with Apache License 2.0 5 votes vote down vote up
def __init__(self,
                 model: nn.Module,
                 optimizer: Optimizer,
                 loss_f: Callable,
                 *,
                 callbacks: Optional[Callback or Iterable[Callable]] = None,
                 scheduler: Optional[Scheduler] = None,
                 verb=True,
                 use_cudnn_benchmark=True,
                 data_parallel=False,
                 use_amp=False,
                 **kwargs):
        if isinstance(model, dict):
            raise TypeError(f"{type(self)} does not support dict model")
        super(SupervisedTrainer, self).__init__(model, optimizer, loss_f, callbacks=callbacks, scheduler=scheduler,
                                                verb=verb, use_cudnn_benchmark=use_cudnn_benchmark, **kwargs)
        if data_parallel and not isinstance(self.model, nn.DataParallel) and torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
            self.model.to(self.device)

        self._use_amp = use_amp
        if use_amp:
            if not hasattr(torch.cuda.amp, 'autocast'):
                warnings.warn('amp is not available')
                self._use_amp = False
            else:
                self.scaler = torch.cuda.amp.GradScaler() 
Example #15
Source File: callbacks.py    From kekas with MIT License 5 votes vote down vote up
def __init__(self,
                 sched: Union[_LRScheduler, ReduceLROnPlateau],
                 metric: str = None) -> None:
        self.metric = "loss" or metric
        self.is_reduce = isinstance(sched, ReduceLROnPlateau) 
Example #16
Source File: trainer.py    From XenonPy with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def lr_scheduler(self, scheduler):
        if scheduler is not None:
            self._scheduler = scheduler

        if self._scheduler is not None and self._optimizer is not None:
            self._lr_scheduler: Union[_LRScheduler, None] = self._scheduler(self._optimizer) 
Example #17
Source File: train.py    From solaris with Apache License 2.0 5 votes vote down vote up
def initialize_model(self):
        """Load in and create all model training elements."""
        if not self.pretrained:
            self.model = reset_weights(self.model, self.framework)

        if self.framework == 'keras':
            self.model = self.model.compile(optimizer=self.optimizer,
                                            loss=self.loss,
                                            metrics=self.metrics['train'])

        elif self.framework == 'torch':
            if self.gpu_available:
                self.model = self.model.cuda()
                if self.gpu_count > 1:
                    self.model = torch.nn.DataParallel(self.model)
            # create optimizer
            if self.config['training']['opt_args'] is not None:
                self.optimizer = self.optimizer(
                    self.model.parameters(), lr=self.lr,
                    **self.config['training']['opt_args']
                )
            else:
                self.optimizer = self.optimizer(
                    self.model.parameters(), lr=self.lr
                )
            # wrap in lr_scheduler if one was created
            for cb in self.callbacks:
                if isinstance(cb, _LRScheduler):
                    self.optimizer = cb(
                        self.optimizer,
                        **self.config['training']['callbacks'][
                            'lr_schedule'].get(['schedule_dict'], {})
                        )
                    # drop the LRScheduler callback from the list
                    self.callbacks = [i for i in self.callbacks if i != cb]

        self.is_initialized = True 
Example #18
Source File: network_trainer.py    From nnUNet with Apache License 2.0 5 votes vote down vote up
def maybe_update_lr(self):
        # maybe update learning rate
        if self.lr_scheduler is not None:
            assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))

            if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
                # lr scheduler is updated with moving average val loss. should be more robust
                self.lr_scheduler.step(self.train_loss_MA)
            else:
                self.lr_scheduler.step(self.epoch + 1)
        self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) 
Example #19
Source File: schedulers.py    From incremental_learning.pytorch with MIT License 5 votes vote down vote up
def get_lr(self):
        """Get updated learning rate."""
        # HACK: We need to check if this is the first time ``self.get_lr()`` was called,
        # since ``torch.optim.lr_scheduler._LRScheduler`` will call ``self.get_lr()``
        # when first initialized, but the learning rate should remain unchanged
        # for the first epoch.
        if not self._initialized:
            self._initialized = True
            return self.base_lrs

        step = self.last_epoch + 1
        self._cycle_counter = step - self._last_restart

        lrs = [
            self.eta_min + ((lr - self.eta_min) / 2) * (
                np.cos(
                    np.pi *
                    (self._cycle_counter % self._updated_cycle_len) / self._updated_cycle_len
                ) + 1
            ) for lr in self.base_lrs
        ]

        if self._cycle_counter % self._updated_cycle_len == 0:
            # Adjust the cycle length.
            self._cycle_factor *= self.factor
            self._cycle_counter = 0
            self._updated_cycle_len = int(self._cycle_factor * self.t_max)
            self._last_restart = step

        return lrs 
Example #20
Source File: model.py    From easy-fpn.pytorch with MIT License 5 votes vote down vote up
def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
        checkpoint = torch.load(path_to_checkpoint)
        self.load_state_dict(checkpoint['state_dict'])
        step = checkpoint['step']
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler is not None:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return step 
Example #21
Source File: model.py    From easy-fpn.pytorch with MIT License 5 votes vote down vote up
def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str:
        path_to_checkpoint = os.path.join(path_to_checkpoints_dir, f'model-{step}.pth')
        checkpoint = {
            'state_dict': self.state_dict(),
            'step': step,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, path_to_checkpoint)
        return path_to_checkpoint 
Example #22
Source File: checkpointer.py    From emmental with MIT License 5 votes vote down vote up
def collect_state_dict(
        self,
        iteration: Union[float, int],
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> Dict[str, Any]:
        """Collect the state dict of the model.

        Args:
          iteration: The current iteration.
          model: The model to checkpoint.
          optimizer: The optimizer used during training process.
          lr_scheduler: Learning rate scheduler.
          metric_dict: the metric dict.

        Returns:
          The state dict.
        """
        model_params = {
            "name": model.name,
            "module_pool": model.collect_state_dict(),
            # "task_names": model.task_names,
            # "task_flows": model.task_flows,
            # "loss_funcs": model.loss_funcs,
            # "output_funcs": model.output_funcs,
            # "scorers": model.scorers,
        }

        state_dict = {
            "iteration": iteration,
            "model": model_params,
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict() if lr_scheduler else None,
            "metric_dict": metric_dict,
        }

        return state_dict 
Example #23
Source File: model.py    From SlowFast-Network-pytorch with MIT License 5 votes vote down vote up
def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str:
        path_to_checkpoint = os.path.join(path_to_checkpoints_dir, 'model-{}.pth'.format(step))
        checkpoint = {
            'state_dict': self.state_dict(),
            'step': step,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, path_to_checkpoint)
        return path_to_checkpoint
    # 
Example #24
Source File: model.py    From easy-faster-rcnn.pytorch with MIT License 5 votes vote down vote up
def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str:
        path_to_checkpoint = os.path.join(path_to_checkpoints_dir, f'model-{step}.pth')
        checkpoint = {
            'state_dict': self.state_dict(),
            'step': step,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, path_to_checkpoint)
        return path_to_checkpoint 
Example #25
Source File: network_trainer.py    From nnUNet with Apache License 2.0 4 votes vote down vote up
def load_checkpoint_ram(self, saved_model, train=True):
        """
        used for if the checkpoint is already in ram
        :param saved_model:
        :param train:
        :return:
        """
        if not self.was_initialized:
            self.initialize(train)

        new_state_dict = OrderedDict()
        curr_state_dict_keys = list(self.network.state_dict().keys())
        # if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not
        # match. Use heuristic to make it match
        for k, value in saved_model['state_dict'].items():
            key = k
            if key not in curr_state_dict_keys:
                print("duh")
                key = key[7:]
            new_state_dict[key] = value

        # if we are fp16, then we need to reinitialize the network and the optimizer. Otherwise amp will throw an error
        if self.fp16:
            self.network, self.optimizer, self.lr_scheduler = None, None, None
            self.initialize_network()
            self.initialize_optimizer_and_scheduler()

        self.network.load_state_dict(new_state_dict)
        self.epoch = saved_model['epoch']
        if train:
            optimizer_state_dict = saved_model['optimizer_state_dict']
            if optimizer_state_dict is not None:
                self.optimizer.load_state_dict(optimizer_state_dict)

            if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and saved_model[
                'lr_scheduler_state_dict'] is not None:
                self.lr_scheduler.load_state_dict(saved_model['lr_scheduler_state_dict'])

            if issubclass(self.lr_scheduler.__class__, _LRScheduler):
                self.lr_scheduler.step(self.epoch)

        self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = saved_model[
            'plot_stuff']

        # after the training is done, the epoch is incremented one more time in my old code. This results in
        # self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because
        # len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here
        if self.epoch != len(self.all_tr_losses):
            self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is "
                                   "due to an old bug and should only appear when you are loading old models. New "
                                   "models should have this fixed! self.epoch is now set to len(self.all_tr_losses)")
            self.epoch = len(self.all_tr_losses)
            self.all_tr_losses = self.all_tr_losses[:self.epoch]
            self.all_val_losses = self.all_val_losses[:self.epoch]
            self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch]
            self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch]

        self.amp_initialized = False
        self._maybe_init_amp() 
Example #26
Source File: optimization.py    From texar-pytorch with Apache License 2.0 4 votes vote down vote up
def get_train_op(params: Optional[Iterable[Union[torch.Tensor,
                                                 Dict[str, Any]]]] = None,
                 optimizer: Optional[Optimizer] = None,
                 scheduler: Optional[_LRScheduler] = None,
                 hparams: Optional[Union[HParams, Dict[str, Any]]] = None) -> \
        Callable[[], None]:
    r"""Creates a training op.

    Args:
        params: an iterable of :class:`torch.Tensor` or
            :class:`dict`. Specifies what Tensors should be optimized.
        optimizer: A :torch_docs:`torch.optim.Optimizer
            <optim.html#torch.optim.Optimizer>` instance.
        scheduler: A :torch_docs:`torch.optim.lr_scheduler._LRScheduler
            <optim.html#how-to-adjust-learning-rate>` instance.
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically. See
            :func:`~texar.torch.core.default_optimization_hparams` for
            all hyperparameters and default values.

    Returns:
        The callable used for variable optimization.
    """
    hparams = HParams(hparams, default_optimization_hparams())

    if params is None and optimizer is None and scheduler is None:
        raise ValueError("'params', 'optimizer' and 'scheduler' must not be "
                         "None simultaneously.")

    if scheduler is None:
        if optimizer is None and params is not None:
            optimizer = get_optimizer(params, hparams)
        if optimizer is not None:
            scheduler = get_scheduler(optimizer, hparams)
    else:
        optimizer = scheduler.optimizer  # type: ignore

    grad_clip_fn = get_grad_clip_fn(hparams)

    # TODO: Support per-parameter options in the future.
    params_list: List[nn.Parameter] = []
    for param_group in optimizer.param_groups:  # type: ignore
        params = param_group["params"]
        if isinstance(params, torch.Tensor):
            params_list.append(params)
        elif isinstance(params, list):
            params_list += params

    def _train_op():
        if grad_clip_fn is not None:
            grad_clip_fn(parameters=params_list)
        optimizer.step()
        # TODO: Ideally, scheduler should be used in the epoch level.
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    return _train_op 
Example #27
Source File: nnUNetTrainerV2_DDP.py    From nnUNet with Apache License 2.0 4 votes vote down vote up
def load_checkpoint_ram(self, saved_model, train=True):
        """
        used for if the checkpoint is already in ram
        :param saved_model:
        :param train:
        :return:
        """
        if not self.was_initialized:
            self.initialize(train)

        new_state_dict = OrderedDict()
        curr_state_dict_keys = list(self.network.state_dict().keys())
        # if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not
        # match. Use heuristic to make it match
        for k, value in saved_model['state_dict'].items():
            key = k
            if key not in curr_state_dict_keys:
                print("duh")
                key = key[7:]
            new_state_dict[key] = value

        # if we are fp16, then we need to reinitialize the network and the optimizer. Otherwise amp will throw an error
        if self.fp16:
            self.network, self.optimizer, self.lr_scheduler = None, None, None
            self.initialize_network()
            self.initialize_optimizer_and_scheduler()
            # we need to reinitialize DDP here
            self.network = DDP(self.network)

        self.network.load_state_dict(new_state_dict)
        self.epoch = saved_model['epoch']
        if train:
            optimizer_state_dict = saved_model['optimizer_state_dict']
            if optimizer_state_dict is not None:
                self.optimizer.load_state_dict(optimizer_state_dict)

            if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and saved_model[
                'lr_scheduler_state_dict'] is not None:
                self.lr_scheduler.load_state_dict(saved_model['lr_scheduler_state_dict'])

            if issubclass(self.lr_scheduler.__class__, _LRScheduler):
                self.lr_scheduler.step(self.epoch)

        self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = saved_model[
            'plot_stuff']

        # after the training is done, the epoch is incremented one more time in my old code. This results in
        # self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because
        # len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here
        if self.epoch != len(self.all_tr_losses):
            self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is "
                                   "due to an old bug and should only appear when you are loading old models. New "
                                   "models should have this fixed! self.epoch is now set to len(self.all_tr_losses)")
            self.epoch = len(self.all_tr_losses)
            self.all_tr_losses = self.all_tr_losses[:self.epoch]
            self.all_val_losses = self.all_val_losses[:self.epoch]
            self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch]
            self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch]

        self.amp_initialized = False
        self._maybe_init_amp() 
Example #28
Source File: optimization.py    From texar-pytorch with Apache License 2.0 4 votes vote down vote up
def get_scheduler(optimizer: Optimizer,
                  hparams: Optional[Union[HParams, Dict[str, Any]]] = None) -> \
        Optional[_LRScheduler]:
    r"""Creates a scheduler instance.

    Args:
        optimizer: A :torch_docs:`torch.optim.Optimizer
            <optim.html#torch.optim.Optimizer>` instance.
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically. See
            :func:`~texar.torch.core.default_optimization_hparams` for
            all hyperparameters and default values.

    :return:
        A :torch_docs:`torch.optim.lr_scheduler._LRScheduler
        <optim.html#how-to-adjust-learning-rate>` instance.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(hparams, default_optimization_hparams())

    hparams_scheduler = hparams["learning_rate_decay"]

    scheduler_type = hparams_scheduler["type"]
    if scheduler_type == "" or scheduler_type is None:
        scheduler = None
    else:
        if isinstance(scheduler_type, _LRScheduler):
            scheduler_class = scheduler_type
        else:
            scheduler_modules = ['torch.optim.lr_scheduler',
                                 'texar.torch.custom']
            try:
                scheduler_class = utils.check_or_get_class(  # type: ignore
                    scheduler_type, scheduler_modules, _LRScheduler)
            except TypeError:
                raise ValueError(
                    "Unrecognized lr_scheduler. Must be string name of the "
                    "lr_scheduler class, or the class which is a subclass of "
                    "torch.optim._LRScheduler.")

        scheduler_kwargs = hparams_scheduler["kwargs"].todict()
        scheduler_kwargs.update({"optimizer": optimizer})
        scheduler = scheduler_class(**scheduler_kwargs)  # type: ignore

    return scheduler 
Example #29
Source File: param_scheduler.py    From LaSO with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def create_lr_scheduler_with_warmup(lr_scheduler, warmup_start_value, warmup_end_value, warmup_duration,
                                    save_history=False,
                                    output_simulated_values=None):
    """
    Helper method to create a LR scheduler with a linear warm-up.

    Args:
        lr_scheduler (ParamScheduler or subclass of `torch.optim.lr_scheduler._LRScheduler`): LR scheduler after
            the warm-up.
        warmup_start_value (float): LR start value of the warm-up phase.
        warmup_end_value (float): LR end value of the warm-up phase.
        warmup_duration (int): warm-up phase duration, number of events.
        save_history (bool, optional): whether to log the parameter values to
            `engine.state.param_history`, (default=False).
        output_simulated_values (list or tuple, optional): optional output of simulated LR values.
            If output_simulated_values is set to an empty list, after the execution it will be filled
            by simulated LR values.

    Returns:
        ConcatScheduler: LR scheduler with linear warm-up.


    .. code-block:: python

        torch_lr_scheduler = ExponentialLR(optimizer=optimizer, gamma=0.98)
        lr_values = []
        scheduler = create_lr_scheduler_with_warmup(torch_lr_scheduler,
                                                    warmup_start_value=0.0, warmup_end_value=0.1, warmup_duration=10,
                                                    output_simulated_values=lr_values)
        lr_values = np.array(lr_values)
        # Plot simulated values
        plt.plot(lr_values[:, 0], lr_values[:, 1], label="learning rate")

        # Attach to the trainer
        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    """
    if not isinstance(lr_scheduler, (ParamScheduler, _LRScheduler)):
        raise TypeError("Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler or "
                        "ParamScheduler, but given {}".format(type(lr_scheduler)))

    if isinstance(lr_scheduler, _LRScheduler):
        lr_scheduler = LRScheduler(lr_scheduler)

    dummy_optimizer = {}
    warmup_scheduler = LinearCyclicalScheduler(dummy_optimizer, param_name="lr",
                                               start_value=warmup_start_value,
                                               end_value=warmup_end_value,
                                               cycle_size=warmup_duration * 2)

    warmup_scheduler.optimizer_param_groups = lr_scheduler.optimizer_param_groups

    schedulers = [warmup_scheduler, lr_scheduler]
    durations = [warmup_duration, ]
    combined_scheduler = ConcatScheduler(schedulers, durations=durations,
                                         save_history=save_history)
    if output_simulated_values is not None:
        output_simulated_values.extend(ConcatScheduler.simulate_values(num_events=warmup_duration * 20,
                                                                       schedulers=schedulers, durations=durations))
    return combined_scheduler 
Example #30
Source File: checkpointer.py    From emmental with MIT License 4 votes vote down vote up
def checkpoint(
        self,
        iteration: Union[float, int],
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> None:
        """Checkpointing the checkpoint.

        Args:
          iteration: The current iteration.
          model: The model to checkpoint.
          optimizer: The optimizer used during training process.
          lr_scheduler: Learning rate scheduler.
          metric_dict: The metric dict.
        """
        # Check the checkpoint_runway condition is met
        if iteration < self.checkpoint_runway:
            return
        elif not self.checkpoint_condition_met and iteration >= self.checkpoint_runway:
            self.checkpoint_condition_met = True
            logger.info("checkpoint_runway condition has been met. Start checkpoining.")

        state_dict = self.collect_state_dict(
            iteration, model, optimizer, lr_scheduler, metric_dict
        )
        checkpoint_path = f"{self.checkpoint_path}/checkpoint_{iteration}.pth"
        torch.save(state_dict, checkpoint_path)
        logger.info(
            f"Save checkpoint of {iteration} {self.checkpoint_unit} "
            f"at {checkpoint_path}."
        )

        if self.checkpoint_all is False:
            for path in self.checkpoint_paths:
                if os.path.exists(path):
                    os.remove(path)

        self.checkpoint_paths.append(checkpoint_path)

        if not set(self.checkpoint_all_metrics.keys()).isdisjoint(
            set(metric_dict.keys())
        ):
            new_best_metrics = self.is_new_best(metric_dict)
            for metric in new_best_metrics:
                copyfile(
                    checkpoint_path,
                    f"{self.checkpoint_path}/best_model_"
                    f"{metric.replace('/', '_')}.pth",
                )

                logger.info(
                    f"Save best model of metric {metric} at {self.checkpoint_path}"
                    f"/best_model_{metric.replace('/', '_')}.pth"
                )