Python chainer.training.make_extension() Examples

The following are 16 code examples of chainer.training.make_extension(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module chainer.training , or try the search function .
Example #1
Source File: train_word2vec.py    From vecto with Mozilla Public License 2.0 6 votes vote down vote up
def get_model(args, loss_func, vocab, vocab_ngram_tokens, current_utils=utils.word):
    model = None
    if args.subword == 'none':
        if args.model == 'skipgram':
            model = current_utils.SkipGram(vocab.cnt_words, args.dimensions, loss_func)
        if args.model == 'cbow':
            # todo only skipgram supported
            model = current_utils.ContinuousBoW(vocab.cnt_words, args.dimensions, loss_func)
    else:
        if args.model == 'skipgram':
            model = utils.subword.SkipGram(args.subword, vocab, vocab_ngram_tokens, args.dimensions, loss_func, )

    if model is None:
        raise Exception('Unknown model and word/subword type: {} "and" {}'.format(args.model, args.subword))
    return model


#@training.make_extension(trigger=(1, 'epoch'))
#def dump_embs(trainer):
#    print("dumping embeddings") 
Example #2
Source File: test_snapshot.py    From chainer with MIT License 6 votes vote down vote up
def test_on_error(self):

        class TheOnlyError(Exception):
            pass

        @training.make_extension(trigger=(1, 'iteration'), priority=100)
        def exception_raiser(trainer):
            raise TheOnlyError()
        self.trainer.extend(exception_raiser)

        snapshot = extensions.snapshot_object(self.trainer, self.filename,
                                              snapshot_on_error=True)
        self.trainer.extend(snapshot)

        self.assertFalse(os.path.exists(self.filename))

        with self.assertRaises(TheOnlyError):
            self.trainer.run()

        self.assertTrue(os.path.exists(self.filename)) 
Example #3
Source File: test_trainer.py    From chainer with MIT License 6 votes vote down vote up
def test_exception_in_exception_handler(self):

        ext = ErrorHandlingExtension()
        self.trainer.extend(ext, trigger=(1, 'iteration'), priority=1)
        self.assertFalse(ext.is_error_handled)

        def exception_handler(trainer, exp, tb):
            raise ValueError('hogehoge from exception handler')

        @training.make_extension(trigger=(1, 'iteration'), priority=100,
                                 on_error=exception_handler)
        def exception_raiser(trainer):
            raise TheOnlyError()
        self.trainer.extend(exception_raiser)

        dummy_extension = DummyExtension(self)
        self.trainer.extend(dummy_extension)

        with self.assertRaises(TheOnlyError):
            self.trainer.run()

        self.assertTrue(ext.is_error_handled)
        self.assertTrue(dummy_extension.is_finalized) 
Example #4
Source File: asr_utils.py    From espnet with Apache License 2.0 6 votes vote down vote up
def adadelta_eps_decay(eps_decay):
    """Extension to perform adadelta eps decay.

    Args:
        eps_decay (float): Decay rate of eps.

    Returns:
        An extension function.

    """

    @training.make_extension(trigger=(1, "epoch"))
    def adadelta_eps_decay(trainer):
        _adadelta_eps_decay(trainer, eps_decay)

    return adadelta_eps_decay 
Example #5
Source File: asr_utils.py    From espnet with Apache License 2.0 6 votes vote down vote up
def adam_lr_decay(eps_decay):
    """Extension to perform adam lr decay.

    Args:
        eps_decay (float): Decay rate of lr.

    Returns:
        An extension function.

    """

    @training.make_extension(trigger=(1, "epoch"))
    def adam_lr_decay(trainer):
        _adam_lr_decay(trainer, eps_decay)

    return adam_lr_decay 
Example #6
Source File: asr_utils.py    From espnet with Apache License 2.0 6 votes vote down vote up
def snapshot_object(target, filename):
    """Returns a trainer extension to take snapshots of a given object.

    Args:
        target (model): Object to serialize.
        filename (str): Name of the file into which the object is serialized.It can
            be a format string, where the trainer object is passed to
            the :meth: `str.format` method. For example,
            ``'snapshot_{.updater.iteration}'`` is converted to
            ``'snapshot_10000'`` at the 10,000th iteration.

    Returns:
        An extension function.

    """

    @extension.make_extension(trigger=(1, "epoch"), priority=-100)
    def snapshot_object(trainer):
        torch_save(os.path.join(trainer.out, filename.format(trainer)), target)

    return snapshot_object 
Example #7
Source File: test_trainer.py    From chainer with MIT License 5 votes vote down vote up
def test_add_make_extension(self):
        self.is_called = False

        @training.make_extension()
        def dummy_extension(trainer):
            self.is_called = True

        self.trainer.extend(dummy_extension)
        self.trainer.run()
        self.assertTrue(self.is_called) 
Example #8
Source File: test_trainer.py    From chainer with MIT License 5 votes vote down vote up
def test_add_make_extension_with_initializer(self):
        self.is_called = False

        def initializer(trainer):
            trainer.is_initialized = True

        @training.make_extension(initializer=initializer)
        def dummy_extension(trainer):
            self.assertTrue(trainer.is_initialized)
            self.is_called = True

        self.trainer.extend(dummy_extension)
        self.trainer.run()
        self.assertTrue(self.is_called) 
Example #9
Source File: test_trainer.py    From chainer with MIT License 5 votes vote down vote up
def test_add_two_extensions_default_priority(self):
        self.called_order = []

        @training.make_extension(trigger=(1, 'epoch'))
        def dummy_extension_1(trainer):
            self.called_order.append(1)

        @training.make_extension(trigger=(1, 'epoch'))
        def dummy_extension_2(trainer):
            self.called_order.append(2)

        self.trainer.extend(dummy_extension_1)
        self.trainer.extend(dummy_extension_2)
        self.trainer.run()
        self.assertEqual(self.called_order, [1, 2]) 
Example #10
Source File: test_trainer.py    From chainer with MIT License 5 votes vote down vote up
def test_add_two_extensions_specific_priority(self):
        self.called_order = []

        @training.make_extension(trigger=(1, 'epoch'), priority=50)
        def dummy_extension_1(trainer):
            self.called_order.append(1)

        @training.make_extension(trigger=(1, 'epoch'), priority=100)
        def dummy_extension_2(trainer):
            self.called_order.append(2)

        self.trainer.extend(dummy_extension_1)
        self.trainer.extend(dummy_extension_2)
        self.trainer.run()
        self.assertEqual(self.called_order, [2, 1]) 
Example #11
Source File: test_extension.py    From chainer with MIT License 5 votes vote down vote up
def test_make_extension(self):
        def initialize(trainer):
            pass

        @training.make_extension(trigger=(2, 'epoch'), default_name='my_ext',
                                 priority=50, initializer=initialize)
        def my_extension(trainer):
            pass

        self.assertEqual(my_extension.trigger, (2, 'epoch'))
        self.assertEqual(my_extension.default_name, 'my_ext')
        self.assertEqual(my_extension.priority, 50)
        self.assertIs(my_extension.initialize, initialize) 
Example #12
Source File: test_extension.py    From chainer with MIT License 5 votes vote down vote up
def test_make_extension_default_values(self):
        @training.make_extension()
        def my_extension(trainer):
            pass

        self.assertEqual(my_extension.trigger, (1, 'iteration'))
        self.assertEqual(my_extension.default_name, 'my_extension')
        self.assertEqual(my_extension.priority, training.PRIORITY_READER)
        self.assertIsNone(my_extension.initialize) 
Example #13
Source File: test_extension.py    From chainer with MIT License 5 votes vote down vote up
def test_make_extension_deleted_argument(self):
        with self.assertRaises(ValueError):
            @training.make_extension(invoke_before_training=False)
            def my_extension(_):
                pass 
Example #14
Source File: test_extension.py    From chainer with MIT License 5 votes vote down vote up
def test_make_extension_unexpected_kwargs(self):
        with self.assertRaises(TypeError):
            @training.make_extension(foo=1)
            def my_extension(_):
                pass 
Example #15
Source File: asr_utils.py    From espnet with Apache License 2.0 5 votes vote down vote up
def restore_snapshot(model, snapshot, load_fn=chainer.serializers.load_npz):
    """Extension to restore snapshot.

    Returns:
        An extension function.

    """

    @training.make_extension(trigger=(1, "epoch"))
    def restore_snapshot(trainer):
        _restore_snapshot(model, snapshot, load_fn)

    return restore_snapshot 
Example #16
Source File: asr_utils.py    From espnet with Apache License 2.0 5 votes vote down vote up
def torch_snapshot(savefun=torch.save, filename="snapshot.ep.{.updater.epoch}"):
    """Extension to take snapshot of the trainer for pytorch.

    Returns:
        An extension function.

    """

    @extension.make_extension(trigger=(1, "epoch"), priority=-100)
    def torch_snapshot(trainer):
        _torch_snapshot_object(trainer, trainer, filename.format(trainer), savefun)

    return torch_snapshot