Python chainer.Optimizer() Examples

The following are 16 code examples of chainer.Optimizer(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module chainer , or try the search function .
Example #1
Source File: optimizer.py    From chainer with MIT License 6 votes vote down vote up
def serialize(self, serializer):
        """Serializes or deserializes the optimizer.

        It only saves or loads the following things:

        - Optimizer states
        - Global states (:attr:`t` and :attr:`epoch`)

        **It does not saves nor loads the parameters of the target link.** They
        should be separately saved or loaded.

        Args:
            serializer (~chainer.AbstractSerializer): Serializer or
                deserializer object.

        """
        self.t = serializer('t', self.t)
        self.epoch = serializer('epoch', self.epoch)
        for name, param in self.target.namedparams():
            rule = getattr(param, 'update_rule', None)
            if rule is not None:
                rule.serialize(serializer[name]) 
Example #2
Source File: data_parallel.py    From delira with GNU Affero General Public License v3.0 6 votes vote down vote up
def __init__(self, optimizer):
        """

        Parameters
        ----------
        optimizer : :class:`chainer.Optimizer`
            the optimizer to wrap

        """
        if isinstance(optimizer, chainer.Optimizer):
            self._optimizer = optimizer

        else:
            raise RuntimeError("Invalid optimizer class given: Expected "
                               "instance of chainer.Optimizer, but got %s"
                               % optimizer.__class__.__name__) 
Example #3
Source File: data_parallel.py    From delira with GNU Affero General Public License v3.0 6 votes vote down vote up
def from_optimizer_class(cls, optim_cls, *args, **kwargs):
        """

        Parameters
        ----------
        optim_cls : subclass of :class:`chainer.Optimizer`
            the optimizer to use internally
        *args :
            arbitrary positional arguments (will be used for
            initialization of internally used optimizer)
        **kwargs :
            arbitrary keyword arguments (will be used for initialization
            of internally used optimizer)

        """
        if optim_cls is not None and issubclass(optim_cls,
                                                chainer.Optimizer):
            _optim = optim_cls(*args, **kwargs)
        else:
            raise RuntimeError("Invalid optimizer class given: Expected "
                               "Subclass of chainer.Optimizer, but got %s"
                               % optim_cls.__name__)
        return cls(_optim) 
Example #4
Source File: utils.py    From deep_metric_learning with MIT License 6 votes vote down vote up
def save(self, dir_name):
        dir_path = os.path.join(self._root_dir_path, dir_name)
        if not os.path.exists(dir_path):
            os.mkdir(dir_path)

        others = []
        for key, value in self.items():
            if key.startswith('_'):
                continue

            if isinstance(value, (np.ndarray, list)):
                np.save(os.path.join(dir_path, key + ".npy"), value)
            elif isinstance(value, (chainer.Chain, chainer.ChainList)):
                model_path = os.path.join(dir_path, "model.npz")
                chainer.serializers.save_npz(model_path, value)
            elif isinstance(value, chainer.Optimizer):
                optimizer_path = os.path.join(dir_path, "optimizer.npz")
                chainer.serializers.save_npz(optimizer_path, value)
            else:
                others.append("{}: {}".format(key, value))

        with open(os.path.join(dir_path, "log.txt"), "a") as f:
            text = "\n".join(others) + "\n"
            f.write(text) 
Example #5
Source File: async_.py    From chainerrl with MIT License 5 votes vote down vote up
def set_shared_states(a, b):
    assert isinstance(a, chainer.Optimizer)
    assert hasattr(a, 'target'), 'Optimizer.setup must be called first'
    for param_name, param in a.target.namedparams():
        ensure_initialized_update_rule(param)
        state = param.update_rule.state
        for state_name, state_val in b[param_name].items():
            s = state[state_name]
            state[state_name] = np.frombuffer(
                state_val,
                dtype=s.dtype).reshape(s.shape) 
Example #6
Source File: async_.py    From chainerrl with MIT License 5 votes vote down vote up
def extract_states_as_shared_arrays(optimizer):
    assert isinstance(optimizer, chainer.Optimizer)
    assert hasattr(optimizer, 'target'), 'Optimizer.setup must be called first'
    shared_arrays = {}
    for param_name, param in optimizer.target.namedparams():
        shared_arrays[param_name] = {}
        ensure_initialized_update_rule(param)
        state = param.update_rule.state
        for state_name, state_val in state.items():
            shared_arrays[param_name][
                state_name] = mp.RawArray('f', state_val.ravel())
    return shared_arrays 
Example #7
Source File: async_.py    From chainerrl with MIT License 5 votes vote down vote up
def as_shared_objects(obj):
    if isinstance(obj, tuple):
        return tuple(as_shared_objects(x) for x in obj)
    elif isinstance(obj, chainer.Link):
        return share_params_as_shared_arrays(obj)
    elif isinstance(obj, chainer.Optimizer):
        return share_states_as_shared_arrays(obj)
    elif isinstance(obj, mp.sharedctypes.Synchronized):
        return obj
    else:
        raise ValueError('') 
Example #8
Source File: async_.py    From chainerrl with MIT License 5 votes vote down vote up
def synchronize_to_shared_objects(obj, shared_memory):
    if isinstance(obj, tuple):
        return tuple(synchronize_to_shared_objects(o, s)
                     for o, s in zip(obj, shared_memory))
    elif isinstance(obj, chainer.Link):
        set_shared_params(obj, shared_memory)
        return obj
    elif isinstance(obj, chainer.Optimizer):
        set_shared_states(obj, shared_memory)
        return obj
    elif isinstance(obj, mp.sharedctypes.Synchronized):
        return shared_memory
    else:
        raise ValueError('') 
Example #9
Source File: async.py    From async-rl with MIT License 5 votes vote down vote up
def set_shared_states(a, b):
    assert isinstance(a, chainer.Optimizer)
    assert hasattr(a, 'target'), 'Optimizer.setup must be called first'
    for state_name, shared_state in b.items():
        for param_name, param in shared_state.items():
            old_param = a._states[state_name][param_name]
            a._states[state_name][param_name] = np.frombuffer(
                param,
                dtype=old_param.dtype).reshape(old_param.shape) 
Example #10
Source File: async.py    From async-rl with MIT License 5 votes vote down vote up
def extract_states_as_shared_arrays(optimizer):
    assert isinstance(optimizer, chainer.Optimizer)
    assert hasattr(optimizer, 'target'), 'Optimizer.setup must be called first'
    shared_arrays = {}
    for state_name, state in optimizer._states.items():
        shared_arrays[state_name] = {}
        for param_name, param in state.items():
            shared_arrays[state_name][
                param_name] = mp.RawArray('f', param.ravel())
    return shared_arrays 
Example #11
Source File: test_optimizers.py    From chainer with MIT License 5 votes vote down vote up
def test_all_optimizers_coverage(self):
        module = chainer.optimizers
        module_optimizers = []
        for name in dir(module):
            obj = getattr(module, name)
            if (isinstance(obj, type) and issubclass(obj, chainer.Optimizer)):
                module_optimizers.append(name)

        assert sorted(_all_optimizers) == sorted(module_optimizers) 
Example #12
Source File: optimizer.py    From chainer with MIT License 5 votes vote down vote up
def _check_set_up(self):
        if self._hookable is None:
            raise RuntimeError('Optimizer is not set up. Call `setup` method.') 
Example #13
Source File: standard_updater.py    From chainer with MIT License 5 votes vote down vote up
def get_optimizer(self, name):
        """Gets the optimizer of given name.

        Args:
            name (str): Name of the optimizer.

        Returns:
            ~chainer.Optimizer: Corresponding optimizer.

        """
        return self._optimizers[name] 
Example #14
Source File: data_parallel.py    From delira with GNU Affero General Public License v3.0 5 votes vote down vote up
def __call__(self, optimizer: chainer.Optimizer):
        """
        Summing up all parameters if the target is an instance of
        ``DataParallel``

        Parameters
        ----------
        optimizer : chainer.Optimizer
            the optimizer holding the target, whoose gradients should be
            summed across the replications

        """
        if isinstance(optimizer.target, DataParallelChainerNetwork):
            for module in optimizer.target.modules[1:]:
                optimizer.target.modules[0].addgrads(module) 
Example #15
Source File: data_parallel.py    From delira with GNU Affero General Public License v3.0 5 votes vote down vote up
def __call__(self, optimizer: chainer.Optimizer):
        if isinstance(optimizer.target, DataParallelChainerNetwork):
            for module in optimizer.target.modules[1:]:
                module.copyparams(optimizer.target.modules[0]) 
Example #16
Source File: multi_updater.py    From Comicolorization with MIT License 5 votes vote down vote up
def __init__(
            self,
            args,
            loss_maker,
            main_optimizer,
            main_lossfun,
            reinput_optimizer=None,
            reinput_lossfun=None,
            discriminator_optimizer=None,
            discriminator_lossfun=None,
            *_args, **kwargs
    ):
        # type: (any, comicolorization.loss.LossMaker, any, typing.Callable[[typing.Dict], any], typing.List[chainer.Optimizer], typing.Callable[[int, typing.Dict], any], any, typing.Callable[[typing.Dict], any], *any, **any) -> None
        optimizers = {'main': main_optimizer}
        if reinput_optimizer is not None:
            for i_reinput, optimizer in enumerate(reinput_optimizer):
                optimizers['reinput{}'.format(i_reinput)] = optimizer

        if discriminator_optimizer is not None:
            optimizers['discriminator'] = discriminator_optimizer

        super().__init__(optimizer=optimizers, *_args, **kwargs)

        # chainer.reporter cannot work on some optimizer focus same model
        if args.separate_backward_reinput and reinput_optimizer is None:
            reinput_optimizer = [main_optimizer for _ in range(len(args.loss_blend_ratio_reinput))]

        self.args = args
        self.loss_maker = loss_maker
        self.main_optimizer = main_optimizer
        self.main_lossfun = main_lossfun
        self.reinput_optimizer = reinput_optimizer
        self.reinput_lossfun = reinput_lossfun
        self.discriminator_optimizer = discriminator_optimizer
        self.discriminator_lossfun = discriminator_lossfun