Python chainer.training.make_extension() Examples
The following are 16
code examples of chainer.training.make_extension().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
chainer.training
, or try the search function
.
Example #1
Source File: train_word2vec.py From vecto with Mozilla Public License 2.0 | 6 votes |
def get_model(args, loss_func, vocab, vocab_ngram_tokens, current_utils=utils.word): model = None if args.subword == 'none': if args.model == 'skipgram': model = current_utils.SkipGram(vocab.cnt_words, args.dimensions, loss_func) if args.model == 'cbow': # todo only skipgram supported model = current_utils.ContinuousBoW(vocab.cnt_words, args.dimensions, loss_func) else: if args.model == 'skipgram': model = utils.subword.SkipGram(args.subword, vocab, vocab_ngram_tokens, args.dimensions, loss_func, ) if model is None: raise Exception('Unknown model and word/subword type: {} "and" {}'.format(args.model, args.subword)) return model #@training.make_extension(trigger=(1, 'epoch')) #def dump_embs(trainer): # print("dumping embeddings")
Example #2
Source File: test_snapshot.py From chainer with MIT License | 6 votes |
def test_on_error(self): class TheOnlyError(Exception): pass @training.make_extension(trigger=(1, 'iteration'), priority=100) def exception_raiser(trainer): raise TheOnlyError() self.trainer.extend(exception_raiser) snapshot = extensions.snapshot_object(self.trainer, self.filename, snapshot_on_error=True) self.trainer.extend(snapshot) self.assertFalse(os.path.exists(self.filename)) with self.assertRaises(TheOnlyError): self.trainer.run() self.assertTrue(os.path.exists(self.filename))
Example #3
Source File: test_trainer.py From chainer with MIT License | 6 votes |
def test_exception_in_exception_handler(self): ext = ErrorHandlingExtension() self.trainer.extend(ext, trigger=(1, 'iteration'), priority=1) self.assertFalse(ext.is_error_handled) def exception_handler(trainer, exp, tb): raise ValueError('hogehoge from exception handler') @training.make_extension(trigger=(1, 'iteration'), priority=100, on_error=exception_handler) def exception_raiser(trainer): raise TheOnlyError() self.trainer.extend(exception_raiser) dummy_extension = DummyExtension(self) self.trainer.extend(dummy_extension) with self.assertRaises(TheOnlyError): self.trainer.run() self.assertTrue(ext.is_error_handled) self.assertTrue(dummy_extension.is_finalized)
Example #4
Source File: asr_utils.py From espnet with Apache License 2.0 | 6 votes |
def adadelta_eps_decay(eps_decay): """Extension to perform adadelta eps decay. Args: eps_decay (float): Decay rate of eps. Returns: An extension function. """ @training.make_extension(trigger=(1, "epoch")) def adadelta_eps_decay(trainer): _adadelta_eps_decay(trainer, eps_decay) return adadelta_eps_decay
Example #5
Source File: asr_utils.py From espnet with Apache License 2.0 | 6 votes |
def adam_lr_decay(eps_decay): """Extension to perform adam lr decay. Args: eps_decay (float): Decay rate of lr. Returns: An extension function. """ @training.make_extension(trigger=(1, "epoch")) def adam_lr_decay(trainer): _adam_lr_decay(trainer, eps_decay) return adam_lr_decay
Example #6
Source File: asr_utils.py From espnet with Apache License 2.0 | 6 votes |
def snapshot_object(target, filename): """Returns a trainer extension to take snapshots of a given object. Args: target (model): Object to serialize. filename (str): Name of the file into which the object is serialized.It can be a format string, where the trainer object is passed to the :meth: `str.format` method. For example, ``'snapshot_{.updater.iteration}'`` is converted to ``'snapshot_10000'`` at the 10,000th iteration. Returns: An extension function. """ @extension.make_extension(trigger=(1, "epoch"), priority=-100) def snapshot_object(trainer): torch_save(os.path.join(trainer.out, filename.format(trainer)), target) return snapshot_object
Example #7
Source File: test_trainer.py From chainer with MIT License | 5 votes |
def test_add_make_extension(self): self.is_called = False @training.make_extension() def dummy_extension(trainer): self.is_called = True self.trainer.extend(dummy_extension) self.trainer.run() self.assertTrue(self.is_called)
Example #8
Source File: test_trainer.py From chainer with MIT License | 5 votes |
def test_add_make_extension_with_initializer(self): self.is_called = False def initializer(trainer): trainer.is_initialized = True @training.make_extension(initializer=initializer) def dummy_extension(trainer): self.assertTrue(trainer.is_initialized) self.is_called = True self.trainer.extend(dummy_extension) self.trainer.run() self.assertTrue(self.is_called)
Example #9
Source File: test_trainer.py From chainer with MIT License | 5 votes |
def test_add_two_extensions_default_priority(self): self.called_order = [] @training.make_extension(trigger=(1, 'epoch')) def dummy_extension_1(trainer): self.called_order.append(1) @training.make_extension(trigger=(1, 'epoch')) def dummy_extension_2(trainer): self.called_order.append(2) self.trainer.extend(dummy_extension_1) self.trainer.extend(dummy_extension_2) self.trainer.run() self.assertEqual(self.called_order, [1, 2])
Example #10
Source File: test_trainer.py From chainer with MIT License | 5 votes |
def test_add_two_extensions_specific_priority(self): self.called_order = [] @training.make_extension(trigger=(1, 'epoch'), priority=50) def dummy_extension_1(trainer): self.called_order.append(1) @training.make_extension(trigger=(1, 'epoch'), priority=100) def dummy_extension_2(trainer): self.called_order.append(2) self.trainer.extend(dummy_extension_1) self.trainer.extend(dummy_extension_2) self.trainer.run() self.assertEqual(self.called_order, [2, 1])
Example #11
Source File: test_extension.py From chainer with MIT License | 5 votes |
def test_make_extension(self): def initialize(trainer): pass @training.make_extension(trigger=(2, 'epoch'), default_name='my_ext', priority=50, initializer=initialize) def my_extension(trainer): pass self.assertEqual(my_extension.trigger, (2, 'epoch')) self.assertEqual(my_extension.default_name, 'my_ext') self.assertEqual(my_extension.priority, 50) self.assertIs(my_extension.initialize, initialize)
Example #12
Source File: test_extension.py From chainer with MIT License | 5 votes |
def test_make_extension_default_values(self): @training.make_extension() def my_extension(trainer): pass self.assertEqual(my_extension.trigger, (1, 'iteration')) self.assertEqual(my_extension.default_name, 'my_extension') self.assertEqual(my_extension.priority, training.PRIORITY_READER) self.assertIsNone(my_extension.initialize)
Example #13
Source File: test_extension.py From chainer with MIT License | 5 votes |
def test_make_extension_deleted_argument(self): with self.assertRaises(ValueError): @training.make_extension(invoke_before_training=False) def my_extension(_): pass
Example #14
Source File: test_extension.py From chainer with MIT License | 5 votes |
def test_make_extension_unexpected_kwargs(self): with self.assertRaises(TypeError): @training.make_extension(foo=1) def my_extension(_): pass
Example #15
Source File: asr_utils.py From espnet with Apache License 2.0 | 5 votes |
def restore_snapshot(model, snapshot, load_fn=chainer.serializers.load_npz): """Extension to restore snapshot. Returns: An extension function. """ @training.make_extension(trigger=(1, "epoch")) def restore_snapshot(trainer): _restore_snapshot(model, snapshot, load_fn) return restore_snapshot
Example #16
Source File: asr_utils.py From espnet with Apache License 2.0 | 5 votes |
def torch_snapshot(savefun=torch.save, filename="snapshot.ep.{.updater.epoch}"): """Extension to take snapshot of the trainer for pytorch. Returns: An extension function. """ @extension.make_extension(trigger=(1, "epoch"), priority=-100) def torch_snapshot(trainer): _torch_snapshot_object(trainer, trainer, filename.format(trainer), savefun) return torch_snapshot