Python chainer.training.extensions.snapshot() Examples
The following are 30
code examples of chainer.training.extensions.snapshot().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
chainer.training.extensions
, or try the search function
.
Example #1
Source File: train_utils.py From see with GNU General Public License v3.0 | 6 votes |
def add_default_arguments(parser): parser.add_argument("log_dir", help='directory where generated models and logs shall be stored') parser.add_argument('-b', '--batch-size', dest='batch_size', type=int, required=True, help="Number of images per training batch") parser.add_argument('-g', '--gpus', type=int, nargs="*", default=[], help="Ids of GPU to use [default: (use cpu)]") parser.add_argument('-e', '--epochs', type=int, default=20, help="Number of epochs to train [default: 20]") parser.add_argument('-r', '--resume', help="path to previously saved state of trained model from which training shall resume") parser.add_argument('-si', '--snapshot-interval', dest='snapshot_interval', type=int, default=20000, help="number of iterations after which a snapshot shall be taken [default: 20000]") parser.add_argument('-ln', '--log-name', dest='log_name', default='training', help="name of the log folder") parser.add_argument('-lr', '--learning-rate', dest='learning_rate', type=float, default=0.01, help="initial learning rate [default: 0.01]") parser.add_argument('-li', '--log-interval', dest='log_interval', type=int, default=100, help="number of iterations after which an update shall be logged [default: 100]") parser.add_argument('--lr-step', dest='learning_rate_step_size', type=float, default=0.1, help="Step size for decreasing learning rate [default: 0.1]") parser.add_argument('-t', '--test-interval', dest='test_interval', type=int, default=1000, help="number of iterations after which testing should be performed [default: 1000]") parser.add_argument('--test-iterations', dest='test_iterations', type=int, default=200, help="number of test iterations [default: 200]") parser.add_argument("-dr", "--dropout-ratio", dest='dropout_ratio', default=0.5, type=float, help="ratio for dropout layers") return parser
Example #2
Source File: test_snapshot.py From pfio with MIT License | 6 votes |
def test_snapshot_hdfs(): trainer = chainer.testing.get_trainer_with_mock_updater() trainer.out = '.' trainer._done = True with pfio.create_handler('hdfs') as fs: tmpdir = "some-pfio-tmp-dir" fs.makedirs(tmpdir, exist_ok=True) file_list = list(fs.list(tmpdir)) assert len(file_list) == 0 writer = SimpleWriter(tmpdir, fs=fs) snapshot = extensions.snapshot(writer=writer) snapshot(trainer) assert 'snapshot_iter_0' in fs.list(tmpdir) trainer2 = chainer.testing.get_trainer_with_mock_updater() load_snapshot(trainer2, tmpdir, fs=fs, fail_on_no_file=True) # Cleanup fs.remove(tmpdir, recursive=True)
Example #3
Source File: test_multi_node_snapshot.py From chainer with MIT License | 6 votes |
def test_smoke_wrapper(): rs = [[0, 1], ] comm = create_communicator('naive') if comm.size < 2: pytest.skip() snapshot = extensions.snapshot() filename = '{}.{}'.format(snapshot.filename, comm.rank) replica_sets = rs mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets) if comm.rank == 0: assert mn_snapshot.is_master assert filename == mn_snapshot.snapshot.filename elif comm.rank == 1: assert not mn_snapshot.is_master elif comm.rank == 2: assert mn_snapshot.is_master assert filename == mn_snapshot.snapshot.filename else: assert not mn_snapshot.is_master comm.finalize()
Example #4
Source File: test_snapshot.py From chainer with MIT License | 6 votes |
def test_on_error(self): class TheOnlyError(Exception): pass @training.make_extension(trigger=(1, 'iteration'), priority=100) def exception_raiser(trainer): raise TheOnlyError() self.trainer.extend(exception_raiser) snapshot = extensions.snapshot_object(self.trainer, self.filename, snapshot_on_error=True) self.trainer.extend(snapshot) self.assertFalse(os.path.exists(self.filename)) with self.assertRaises(TheOnlyError): self.trainer.run() self.assertTrue(os.path.exists(self.filename))
Example #5
Source File: NNet.py From alpha-zero-general with MIT License | 5 votes |
def _train_trainer(self, examples): """Training with chainer trainer module""" train_iter = SerialIterator(examples, args.batch_size) optimizer = optimizers.Adam(alpha=args.lr) optimizer.setup(self.nnet) def loss_func(boards, target_pis, target_vs): out_pi, out_v = self.nnet(boards) l_pi = self.loss_pi(target_pis, out_pi) l_v = self.loss_v(target_vs, out_v) total_loss = l_pi + l_v chainer.reporter.report({ 'loss': total_loss, 'loss_pi': l_pi, 'loss_v': l_v, }, observer=self.nnet) return total_loss updater = training.StandardUpdater( train_iter, optimizer, device=args.device, loss_func=loss_func, converter=converter) # Set up the trainer. trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.out) # trainer.extend(extensions.snapshot(), trigger=(args.epochs, 'epoch')) trainer.extend(extensions.LogReport()) trainer.extend(extensions.PrintReport([ 'epoch', 'main/loss', 'main/loss_pi', 'main/loss_v', 'elapsed_time'])) trainer.extend(extensions.ProgressBar(update_interval=10)) trainer.run()
Example #6
Source File: main.py From qb with MIT License | 5 votes |
def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--batch_size', '-b', type=int, default=32, help='Number of examples in each mini-batch') parser.add_argument('--bproplen', '-l', type=int, default=35, help='Number of words in each mini-batch ' '(= length of truncated BPTT)') parser.add_argument('--epoch', '-e', type=int, default=20, help='Number of sweeps over the dataset to train') parser.add_argument('--gpu', '-g', type=int, default=0, help='GPU ID (negative value indicates CPU)') parser.add_argument('--gradclip', '-c', type=float, default=5, help='Gradient norm threshold to clip') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--resume', '-r', default='', help='Resume the training from snapshot') parser.add_argument('--test', action='store_true', help='Use tiny datasets for quick tests') parser.set_defaults(test=False) parser.add_argument('--hidden_size', type=int, default=300, help='Number of LSTM units in each layer') parser.add_argument('--embed_size', type=int, default=300, help='Size of embeddings') parser.add_argument('--model', '-m', default='model.npz', help='Model file name to serialize') parser.add_argument('--glove', default='data/glove.6B.300d.txt', help='Path to glove embedding file.') args = parser.parse_args() return args
Example #7
Source File: train.py From portrait_matting with GNU General Public License v3.0 | 5 votes |
def parse_arguments(argv): parser = argparse.ArgumentParser(description='Training Script') parser.add_argument('--config', '-c', default='config.json', help='Configure json filepath') parser.add_argument('--batchsize', '-b', type=int, default=1, help='Number of images in each mini-batch') parser.add_argument('--max_iteration', '-e', type=int, default=30000, help='Number of sweeps over the dataset to train') parser.add_argument('--gpus', '-g', type=int, default=[-1], nargs='*', help='GPU IDs (negative value indicates CPU)') parser.add_argument('--lr', type=float, default=1e-4, help='Initial learning rate') parser.add_argument('--momentum', default=0.99, help='Momentum for SGD') parser.add_argument('--weight_decay', default=0.0005, help='Weight decay') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--resume', '-r', default='', help='Resume the training from snapshot') parser.add_argument('--mode', choices=['seg', 'seg+', 'seg_tri', 'mat'], help='Training mode', required=True) parser.add_argument('--pretrained_fcn8s', default=None, help='Pretrained model path of FCN8s') parser.add_argument('--pretrained_n_input_ch', default=3, type=int, help='Input channel number of Pretrained model') parser.add_argument('--pretrained_n_output_ch', default=21, type=int, help='Output channel number of Pretrained model') parser.add_argument('--mat_scale', default=4, type=int, help='Matting scale for speed up') args = parser.parse_args(argv) return args
Example #8
Source File: test_snapshot.py From pfio with MIT License | 5 votes |
def test_snapshot(): trainer = testing.get_trainer_with_mock_updater() trainer.out = '.' trainer._done = True with tempfile.TemporaryDirectory() as td: writer = SimpleWriter(td) snapshot = extensions.snapshot(writer=writer) snapshot(trainer) assert 'snapshot_iter_0' in os.listdir(td) trainer2 = chainer.testing.get_trainer_with_mock_updater() load_snapshot(trainer2, td, fail_on_no_file=True)
Example #9
Source File: test_snapshot.py From pfio with MIT License | 5 votes |
def test_scan_directory(): from pfio.chainer_extensions.snapshot import _scan_directory with tempfile.TemporaryDirectory() as td: files = ['tmpfoobar_10', 'foobar_10', 'foobar_123', 'tmpfoobar_10234'] for file in files: pathlib.Path(os.path.join(td, file)).touch() latest = _scan_directory(pfio, td) assert latest is not None assert 'foobar_123' == latest
Example #10
Source File: test_multi_node_snapshot.py From chainer with MIT License | 5 votes |
def _prepare_multinode_snapshot(n, result): n_units = 100 batchsize = 10 comm = create_communicator('naive') model = L.Classifier(MLP(n_units, 10)) optimizer = chainermn.create_multi_node_optimizer( chainer.optimizers.Adam(), comm) optimizer.setup(model) if comm.rank == 0: train, _ = chainer.datasets.get_mnist() else: train, _ = None, None train = chainermn.scatter_dataset(train, comm, shuffle=True) train_iter = chainer.iterators.SerialIterator(train, batchsize) updater = StandardUpdater(train_iter, optimizer) trainer = Trainer(updater, out=result) snapshot = extensions.snapshot(target=updater, autoload=True) replica_sets = [] mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets) mn_snapshot.initialize(trainer) for _ in range(n): updater.update() return updater, mn_snapshot, trainer
Example #11
Source File: test_multi_node_snapshot.py From chainer with MIT License | 5 votes |
def test_smoke_multinode_snapshot(): t = mock.MagicMock() c = mock.MagicMock(side_effect=[True, False]) w = mock.MagicMock() snapshot = extensions.snapshot(target=t, condition=c, writer=w) trainer = mock.MagicMock() comm = create_communicator('naive') replica_sets = [] mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets) mn_snapshot.initialize(trainer) mn_snapshot(trainer) mn_snapshot(trainer) mn_snapshot.finalize() if comm.rank == 0: assert mn_snapshot.is_master assert c.call_count == 2 assert w.call_count == 1 else: assert not mn_snapshot.is_master assert c.call_count == 0 assert w.call_count == 0 comm.finalize()
Example #12
Source File: test_multi_node_snapshot.py From chainer with MIT License | 5 votes |
def test_callable_filename(): rs = [[0, 1], ] comm = create_communicator('naive') if comm.size < 2: pytest.skip() def filename_fun(t): return 'deadbeef-{.updater.iteration}'.format(t) snapshot = extensions.snapshot(filename=filename_fun) trainer = mock.MagicMock() filename = '{}.{}'.format(filename_fun(trainer), comm.rank) replica_sets = rs mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets) if comm.rank == 0: assert mn_snapshot.is_master assert filename == mn_snapshot.snapshot.filename(trainer) elif comm.rank == 1: assert not mn_snapshot.is_master elif comm.rank == 2: assert mn_snapshot.is_master assert filename == mn_snapshot.snapshot.filename(trainer) else: assert not mn_snapshot.is_master comm.finalize()
Example #13
Source File: test_snapshot.py From chainer with MIT License | 5 votes |
def test_remove_stale_snapshots(self): fmt = 'snapshot_iter_{.updater.iteration}' retain = 3 snapshot = extensions.snapshot(filename=fmt, n_retains=retain, autoload=False) trainer = testing.get_trainer_with_mock_updater() trainer.out = self.path trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2) class TimeStampUpdater(): t = time.time() - 100 name = 'ts_updater' priority = 1 # This must be called after snapshot taken def __call__(self, _trainer): filename = os.path.join(_trainer.out, fmt.format(_trainer)) self.t += 1 # For filesystems that does low timestamp precision os.utime(filename, (self.t, self.t)) trainer.extend(TimeStampUpdater(), trigger=(1, 'iteration')) trainer.run() assert 10 == trainer.updater.iteration assert trainer._done pattern = os.path.join(trainer.out, "snapshot_iter_*") found = [os.path.basename(path) for path in glob.glob(pattern)] assert retain == len(found) found.sort() # snapshot_iter_(8, 9, 10) expected expected = ['snapshot_iter_{}'.format(i) for i in range(8, 11)] expected.sort() assert expected == found trainer2 = testing.get_trainer_with_mock_updater() trainer2.out = self.path assert not trainer2._done snapshot2 = extensions.snapshot(filename=fmt, autoload=True) # Just making sure no error occurs snapshot2.initialize(trainer2)
Example #14
Source File: gen_mnist_mlp.py From chainer-compiler with MIT License | 5 votes |
def main(): parser = argparse.ArgumentParser(description='Chainer example: MNIST') parser.add_argument('--batchsize', '-b', type=int, default=7, help='Number of images in each mini-batch') parser.add_argument('--epoch', '-e', type=int, default=20, help='Number of sweeps over the dataset to train') parser.add_argument('--frequency', '-f', type=int, default=-1, help='Frequency of taking a snapshot') parser.add_argument('--gpu', '-g', type=int, default=-1, help='GPU ID (negative value indicates CPU)') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--resume', '-r', default='', help='Resume the training from snapshot') parser.add_argument('--unit', '-u', type=int, default=1000, help='Number of units') parser.add_argument('--noplot', dest='plot', action='store_false', help='Disable PlotReport extension') parser.add_argument('--onnx', default='', help='Export ONNX model') parser.add_argument('--model', '-m', default='model.npz', help='Model file name to serialize') parser.add_argument('--timeout', type=int, default=0, help='Enable timeout') parser.add_argument('--trace', default='', help='Enable tracing') parser.add_argument('--run_training', action='store_true', help='Run training') args = parser.parse_args() main_impl(args)
Example #15
Source File: 02-train.py From Semantic-Segmentation-using-Adversarial-Networks with MIT License | 5 votes |
def parse_args(generators, discriminators, updaters): parser = argparse.ArgumentParser(description='Semantic Segmentation using Adversarial Networks') parser.add_argument('--generator', choices=generators.keys(), default='fcn32s', help='Generator(segmentor) architecture') parser.add_argument('--discriminator', choices=discriminators.keys(), default='largefov', help='Discriminator architecture') parser.add_argument('--updater', choices=updaters.keys(), default='gan', help='Updater') parser.add_argument('--initgen_path', default='pretrained_model/vgg16.npz', help='Pretrained model of generator') parser.add_argument('--initdis_path', default=None, help='Pretrained model of discriminator') parser.add_argument('--batchsize', '-b', type=int, default=1, help='Number of images in each mini-batch') parser.add_argument('--iteration', '-i', type=int, default=100000, help='Number of sweeps over the dataset to train') parser.add_argument('--gpu', '-g', type=int, default=-1, help='GPU ID (negative value indicates CPU)') parser.add_argument('--out', '-o', default='snapshot', help='Directory to output the result') parser.add_argument('--resume', '-r', default='', help='Resume the training from snapshot') parser.add_argument('--evaluate_interval', type=int, default=1000, help='Interval of evaluation') parser.add_argument('--snapshot_interval', type=int, default=10000, help='Interval of snapshot') parser.add_argument('--display_interval', type=int, default=10, help='Interval of displaying log to console') return parser.parse_args()
Example #16
Source File: test_snapshot.py From chainer with MIT License | 5 votes |
def test_clean_up_tempdir(self): snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat') snapshot(self.trainer) left_tmps = [fn for fn in os.listdir('.') if fn.startswith('tmpmyfile.dat')] self.assertEqual(len(left_tmps), 0)
Example #17
Source File: test_snapshot.py From chainer with MIT License | 5 votes |
def test_call(self): t = mock.MagicMock() c = mock.MagicMock(side_effect=[True, False]) w = mock.MagicMock() snapshot = extensions.snapshot(target=t, condition=c, writer=w) trainer = mock.MagicMock() snapshot(trainer) snapshot(trainer) assert c.call_count == 2 assert w.call_count == 1
Example #18
Source File: test_snapshot.py From chainer with MIT License | 5 votes |
def test_savefun_and_writer_exclusive(self): # savefun and writer arguments cannot be specified together. def savefun(*args, **kwargs): assert False writer = extensions.snapshot_writers.SimpleWriter() with pytest.raises(TypeError): extensions.snapshot(savefun=savefun, writer=writer) trainer = mock.MagicMock() with pytest.raises(TypeError): extensions.snapshot_object(trainer, savefun=savefun, writer=writer)
Example #19
Source File: test_snapshot.py From chainer with MIT License | 5 votes |
def test_save_file(self): w = extensions.snapshot_writers.SimpleWriter() snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat', writer=w) snapshot(self.trainer) self.assertTrue(os.path.exists('myfile.dat'))
Example #20
Source File: train_ch_in1k.py From imgclsmob with MIT License | 4 votes |
def prepare_trainer(net, optimizer_name, lr, momentum, num_epochs, train_iter, val_iter, logging_dir_path, num_gpus=0): if optimizer_name == "sgd": optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum) elif optimizer_name == "nag": optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum) else: raise Exception('Unsupported optimizer: {}'.format(optimizer_name)) optimizer.setup(net) # devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, ) devices = (0,) if num_gpus > 0 else (-1,) updater = training.updaters.StandardUpdater( iterator=train_iter, optimizer=optimizer, device=devices[0]) trainer = training.Trainer( updater=updater, stop_trigger=(num_epochs, 'epoch'), out=logging_dir_path) val_interval = 100000, 'iteration' log_interval = 1000, 'iteration' trainer.extend( extension=extensions.Evaluator( val_iter, net, device=devices[0]), trigger=val_interval) trainer.extend(extensions.dump_graph('main/loss')) trainer.extend(extensions.snapshot(), trigger=val_interval) trainer.extend( extensions.snapshot_object( net, 'model_iter_{.updater.iteration}'), trigger=val_interval) trainer.extend(extensions.LogReport(trigger=log_interval)) trainer.extend(extensions.observe_lr(), trigger=log_interval) trainer.extend( extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'lr']), trigger=log_interval) trainer.extend(extensions.ProgressBar(update_interval=10)) return trainer
Example #21
Source File: train_ch_cifar.py From imgclsmob with MIT License | 4 votes |
def prepare_trainer(net, optimizer_name, lr, momentum, num_epochs, train_iter, val_iter, logging_dir_path, num_gpus=0): if optimizer_name == "sgd": optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum) elif optimizer_name == "nag": optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum) else: raise Exception('Unsupported optimizer: {}'.format(optimizer_name)) optimizer.setup(net) # devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, ) devices = (0,) if num_gpus > 0 else (-1,) updater = training.updaters.StandardUpdater( iterator=train_iter, optimizer=optimizer, device=devices[0]) trainer = training.Trainer( updater=updater, stop_trigger=(num_epochs, 'epoch'), out=logging_dir_path) val_interval = 100000, 'iteration' log_interval = 1000, 'iteration' trainer.extend( extension=extensions.Evaluator( val_iter, net, device=devices[0]), trigger=val_interval) trainer.extend(extensions.dump_graph('main/loss')) trainer.extend(extensions.snapshot(), trigger=val_interval) trainer.extend( extensions.snapshot_object( net, 'model_iter_{.updater.iteration}'), trigger=val_interval) trainer.extend(extensions.LogReport(trigger=log_interval)) trainer.extend(extensions.observe_lr(), trigger=log_interval) trainer.extend( extensions.PrintReport([ 'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'lr']), trigger=log_interval) trainer.extend(extensions.ProgressBar(update_interval=10)) return trainer
Example #22
Source File: train_ch.py From imgclsmob with MIT License | 4 votes |
def prepare_trainer(net, optimizer_name, lr, momentum, num_epochs, train_data, val_data, logging_dir_path, use_gpus): if optimizer_name == "sgd": optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum) elif optimizer_name == "nag": optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum) else: raise Exception("Unsupported optimizer: {}".format(optimizer_name)) optimizer.setup(net) # devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, ) devices = (0,) if use_gpus else (-1,) updater = training.updaters.StandardUpdater( iterator=train_data["iterator"], optimizer=optimizer, device=devices[0]) trainer = training.Trainer( updater=updater, stop_trigger=(num_epochs, "epoch"), out=logging_dir_path) val_interval = 100000, "iteration" log_interval = 1000, "iteration" trainer.extend( extension=extensions.Evaluator( iterator=val_data["iterator"], target=net, device=devices[0]), trigger=val_interval) trainer.extend(extensions.dump_graph("main/loss")) trainer.extend(extensions.snapshot(), trigger=val_interval) trainer.extend( extensions.snapshot_object( net, "model_iter_{.updater.iteration}"), trigger=val_interval) trainer.extend(extensions.LogReport(trigger=log_interval)) trainer.extend(extensions.observe_lr(), trigger=log_interval) trainer.extend( extensions.PrintReport([ "epoch", "iteration", "main/loss", "validation/main/loss", "main/accuracy", "validation/main/accuracy", "lr"]), trigger=log_interval) trainer.extend(extensions.ProgressBar(update_interval=10)) return trainer
Example #23
Source File: gen_mnist_mlp.py From chainer-compiler with MIT License | 4 votes |
def run_training(args, model): trainer = create_trainer(args, model) # Dump a computational graph from 'loss' variable at the first iteration # The "main" refers to the target link of the "main" optimizer. trainer.extend(extensions.dump_graph('main/loss')) # Take a snapshot for each specified epoch frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport()) # Save two plot images to the result dir if args.plot and extensions.PlotReport.available(): trainer.extend( extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) # Print selected entries of the log to stdout # Here "main" refers to the target link of the "main" optimizer again, and # "validation" refers to the default name of the Evaluator extension. # Entries other than 'epoch' are reported by the Classifier link, called by # either the updater or the evaluator. trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) # Print a progress bar to stdout trainer.extend(extensions.ProgressBar()) if args.resume: # Resume from a snapshot chainer.serializers.load_npz(args.resume, trainer) # Run the training trainer.run()
Example #24
Source File: train_utils.py From see with GNU General Public License v3.0 | 4 votes |
def get_trainer(net, updater, log_dir, print_fields, curriculum=None, extra_extensions=(), epochs=10, snapshot_interval=20000, print_interval=100, postprocess=None, do_logging=True, model_files=()): if curriculum is None: trainer = chainer.training.Trainer( updater, (epochs, 'epoch'), out=log_dir, ) else: trainer = chainer.training.Trainer( updater, EarlyStopIntervalTrigger(epochs, 'epoch', curriculum), out=log_dir, ) # dump computational graph trainer.extend(extensions.dump_graph('main/loss')) # also observe learning rate observe_lr_extension = chainer.training.extensions.observe_lr() observe_lr_extension.trigger = (print_interval, 'iteration') trainer.extend(observe_lr_extension) # Take snapshots trainer.extend( extensions.snapshot(filename="trainer_snapshot"), trigger=lambda trainer: trainer.updater.is_new_epoch or (trainer.updater.iteration > 0 and trainer.updater.iteration % snapshot_interval == 0) ) if do_logging: # write all statistics to a file trainer.extend(Logger(model_files, log_dir, keys=print_fields, trigger=(print_interval, 'iteration'), postprocess=postprocess)) # print some interesting statistics trainer.extend(extensions.PrintReport( print_fields, log_report='Logger', )) # Progressbar!! trainer.extend(extensions.ProgressBar(update_interval=1)) for extra_extension, trigger in extra_extensions: trainer.extend(extra_extension, trigger=trigger) return trainer
Example #25
Source File: gen_resnet50.py From chainer-compiler with MIT License | 4 votes |
def run_training(args, model): trainer = create_trainer(args, model) # Dump a computational graph from 'loss' variable at the first iteration # The "main" refers to the target link of the "main" optimizer. trainer.extend(extensions.dump_graph('main/loss')) # Take a snapshot for each specified epoch frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport()) # Save two plot images to the result dir if args.plot and extensions.PlotReport.available(): trainer.extend( extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) # Print selected entries of the log to stdout # Here "main" refers to the target link of the "main" optimizer again, and # "validation" refers to the default name of the Evaluator extension. # Entries other than 'epoch' are reported by the Classifier link, called by # either the updater or the evaluator. trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) # Print a progress bar to stdout trainer.extend(extensions.ProgressBar()) if args.resume: # Resume from a snapshot chainer.serializers.load_npz(args.resume, trainer) # Run the training trainer.run()
Example #26
Source File: gen_resnet50.py From chainer-compiler with MIT License | 4 votes |
def main(): archs = { 'alex': alex.Alex, 'nin': nin.NIN, 'resnet50': resnet50.ResNet50, } parser = argparse.ArgumentParser( description='Learning convnet from ILSVRC2012 dataset') parser.add_argument('--arch', '-a', choices=archs.keys(), default='resnet50', help='Convnet architecture') parser.add_argument('--train', default='', help='Path to training image-label list file') parser.add_argument('--val', default='', help='Path to validation image-label list file') parser.add_argument('--batchsize', '-B', type=int, default=32, help='Learning minibatch size') parser.add_argument('--epoch', '-E', type=int, default=10, help='Number of epochs to train') parser.add_argument('--frequency', '-f', type=int, default=-1, help='Frequency of taking a snapshot') parser.add_argument('--gpu', '-g', type=int, default=-1, help='GPU ID (negative value indicates CPU') parser.add_argument('--initmodel', help='Initialize the model from given file') parser.add_argument('--loaderjob', '-j', type=int, help='Number of parallel data loading processes') parser.add_argument('--mean', '-m', default='mean.npy', help='Mean file (computed by compute_mean.py)') parser.add_argument('--noplot', dest='plot', action='store_false', help='Disable PlotReport extension') parser.add_argument('--resume', '-r', default='', help='Initialize the trainer from given file') parser.add_argument('--out', '-o', default='result', help='Output directory') parser.add_argument('--root', '-R', default='.', help='Root directory path of image files') parser.add_argument('--val_batchsize', '-b', type=int, default=250, help='Validation minibatch size') parser.add_argument('--test', action='store_true') parser.add_argument('--run_training', action='store_true', help='Run training') parser.set_defaults(test=False) args = parser.parse_args() model_cls = archs[args.arch] main_impl(args, model_cls) # TODO(hamaji): Stop writing a file to scripts. with open('scripts/%s_stamp' % args.arch, 'w'): pass
Example #27
Source File: train.py From portrait_matting with GNU General Public License v3.0 | 4 votes |
def register_extensions(trainer, model, test_iter, args): if args.mode.startswith('seg'): # Max accuracy best_trigger = training.triggers.BestValueTrigger( 'validation/main/accuracy', lambda a, b: a < b, (1, 'epoch')) elif args.mode.startswith('mat'): # Min loss best_trigger = training.triggers.BestValueTrigger( 'validation/main/loss', lambda a, b: a > b, (1, 'epoch')) else: logger.error('Invalid training mode') # Segmentation extensions trainer.extend( custom_extensions.PortraitVisEvaluator( test_iter, model, device=args.gpus[0], converter=select_converter(args.mode), filename='vis_epoch={epoch}_idx={index}.jpg', mode=args.mode ), trigger=(1, 'epoch')) # Basic extensions trainer.extend(extensions.dump_graph('main/loss')) trainer.extend(extensions.LogReport(trigger=(200, 'iteration'))) trainer.extend(extensions.ProgressBar(update_interval=20)) trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'lr', 'elapsed_time'])) trainer.extend(extensions.observe_lr(), trigger=(200, 'iteration')) # Snapshots trainer.extend(extensions.snapshot( filename='snapshot_epoch_{.updater.epoch}' ), trigger=(5, 'epoch')) trainer.extend(extensions.snapshot_object( model, filename='model_best' ), trigger=best_trigger) # ChainerUI extensions trainer.extend(chainerui.extensions.CommandsExtension()) chainerui.utils.save_args(args, args.out) # Plotting extensions if extensions.PlotReport.available(): trainer.extend( extensions.PlotReport( ['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png'))
Example #28
Source File: train.py From portrait_matting with GNU General Public License v3.0 | 4 votes |
def main(argv): # Argument args = parse_arguments(argv) # Load config config.load(args.config) # Setup dataset train, test = setup_dataset(args.mode, config.img_crop_dir, config.img_mask_dir, config.img_mean_mask_dir, config.img_mean_grid_dir, config.img_trimap_dir, config.img_alpha_dir, config.img_alpha_weight_dir) # Setup iterators train_iter, test_iter = setup_iterators(args.gpus, args.batchsize, train, test) # Setup model model = setup_model(args.mode, args.pretrained_fcn8s, args.pretrained_n_input_ch, args.pretrained_n_output_ch, args.mat_scale) # Setup an optimizer optimizer = setup_optimizer(model, args.lr, args.momentum, args.weight_decay) # Set up a trainer updater = setup_updater(args.mode, args.gpus, train_iter, optimizer) trainer = training.Trainer(updater, (args.max_iteration, 'iteration'), out=args.out) # Register extensions for portrait segmentation / matting register_extensions(trainer, model, test_iter, args) # Resume from a snapshot if args.resume: chainer.serializers.load_npz(args.resume, trainer) # Run the training trainer.run()
Example #29
Source File: train_utils.py From chainer-chemistry with MIT License | 4 votes |
def run_node_classification_train(model, data, train_mask, valid_mask, epoch=10, optimizer=None, out='result', extensions_list=None, device=-1, converter=None, use_default_extensions=True, resume_path=None): if optimizer is None: # Use Adam optimizer as default optimizer = optimizers.Adam() elif not isinstance(optimizer, Optimizer): raise ValueError("[ERROR] optimizer must be instance of Optimizer, " "but passed {}".format(type(Optimizer))) optimizer.setup(model) def one_batch_converter(batch, device): if not isinstance(device, Device): device = chainer.get_device(device) data, train_mask, valid_mask = batch[0] return (data.to_device(device), device.send(train_mask), device.send(valid_mask)) data_iter = SerialIterator([(data, train_mask, valid_mask)], batch_size=1) updater = training.StandardUpdater( data_iter, optimizer, device=device, converter=one_batch_converter) trainer = training.Trainer(updater, (epoch, 'epoch'), out=out) if use_default_extensions: trainer.extend(extensions.LogReport()) trainer.extend(AutoPrintReport()) trainer.extend(extensions.ProgressBar(update_interval=10)) # TODO(nakago): consider to include snapshot as default extension. # trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) if extensions_list is not None: for e in extensions_list: trainer.extend(e) if resume_path: chainer.serializers.load_npz(resume_path, trainer) trainer.run() return
Example #30
Source File: train.py From Video-frame-prediction-by-multi-scale-GAN with MIT License | 4 votes |
def main(resume, gpu, load_path, data_path): dataset = Dataset(data_path) GenNetwork = MultiScaleGenerator(c.SCALE_FMS_G, c.SCALE_KERNEL_SIZES_G) DisNetwork = MultiScaleDiscriminator(c.SCALE_CONV_FMS_D, c.SCALE_KERNEL_SIZES_D, c.SCALE_FC_LAYER_SIZES_D) optimizers = {} optimizers["GeneratorNetwork"] = chainer.optimizers.SGD(c.LRATE_G) optimizers["DiscriminatorNetwork"] = chainer.optimizers.SGD(c.LRATE_D) iterator = chainer.iterators.SerialIterator(dataset, 1) params = {'LAM_ADV': 0.05, 'LAM_LP': 1, 'LAM_GDL': .1} updater = Updater(iterators=iterator, optimizers=optimizers, GeneratorNetwork=GenNetwork, DiscriminatorNetwork=DisNetwork, params=params, device=gpu ) if gpu>=0: updater.GenNetwork.to_gpu() updater.DisNetwork.to_gpu() trainer = chainer.training.Trainer(updater, (500000, 'iteration'), out='result') trainer.extend(extensions.snapshot(filename='snapshot'), trigger=(1, 'iteration')) trainer.extend(extensions.snapshot_object(trainer.updater.GenNetwork, "GEN")) trainer.extend(saveGen) log_keys = ['epoch', 'iteration', 'GeneratorNetwork/L2Loss', 'GeneratorNetwork/GDL', 'DiscriminatorNetwork/DisLoss', 'GeneratorNetwork/CompositeGenLoss'] print_keys = ['GeneratorNetwork/CompositeGenLoss','DiscriminatorNetwork/DisLoss'] trainer.extend(extensions.LogReport(keys=log_keys, trigger=(10, 'iteration'))) trainer.extend(extensions.PrintReport(print_keys), trigger=(10, 'iteration')) trainer.extend(extensions.PlotReport(['DiscriminatorNetwork/DisLoss'], 'iteration', (10, 'iteration'), file_name="DisLoss.png")) trainer.extend(extensions.PlotReport(['GeneratorNetwork/CompositeGenLoss'], 'iteration', (10, 'iteration'), file_name="GenLoss.png")) trainer.extend(extensions.PlotReport(['GeneratorNetwork/AdvLoss'], 'iteration', (10, 'iteration'), file_name="AdvGenLoss.png")) trainer.extend(extensions.PlotReport(['GeneratorNetwork/AdvLoss','DiscriminatorNetwork/DisLoss'], 'iteration', (10, 'iteration'), file_name="AdversarialLosses.png")) trainer.extend(extensions.PlotReport(['GeneratorNetwork/L2Loss'], 'iteration', (10, 'iteration'),file_name="L2Loss.png")) trainer.extend(extensions.PlotReport(['GeneratorNetwork/GDL'], 'iteration', (10, 'iteration'),file_name="GDL.png")) trainer.extend(extensions.ProgressBar(update_interval=10)) if resume: # Resume from a snapshot chainer.serializers.load_npz(load_path, trainer) print(trainer.updater.__dict__) trainer.run()