Python chainer.training.extensions.snapshot() Examples

The following are 30 code examples of chainer.training.extensions.snapshot(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module chainer.training.extensions , or try the search function .
Example #1
Source File: train_utils.py    From see with GNU General Public License v3.0 6 votes vote down vote up
def add_default_arguments(parser):
    parser.add_argument("log_dir", help='directory where generated models and logs shall be stored')
    parser.add_argument('-b', '--batch-size', dest='batch_size', type=int, required=True,
                        help="Number of images per training batch")
    parser.add_argument('-g', '--gpus', type=int, nargs="*", default=[], help="Ids of GPU to use [default: (use cpu)]")
    parser.add_argument('-e', '--epochs', type=int, default=20, help="Number of epochs to train [default: 20]")
    parser.add_argument('-r', '--resume', help="path to previously saved state of trained model from which training shall resume")
    parser.add_argument('-si', '--snapshot-interval', dest='snapshot_interval', type=int, default=20000,
                        help="number of iterations after which a snapshot shall be taken [default: 20000]")
    parser.add_argument('-ln', '--log-name', dest='log_name', default='training', help="name of the log folder")
    parser.add_argument('-lr', '--learning-rate', dest='learning_rate', type=float, default=0.01,
                        help="initial learning rate [default: 0.01]")
    parser.add_argument('-li', '--log-interval', dest='log_interval', type=int, default=100,
                        help="number of iterations after which an update shall be logged [default: 100]")
    parser.add_argument('--lr-step', dest='learning_rate_step_size', type=float, default=0.1,
                        help="Step size for decreasing learning rate [default: 0.1]")
    parser.add_argument('-t', '--test-interval', dest='test_interval', type=int, default=1000,
                        help="number of iterations after which testing should be performed [default: 1000]")
    parser.add_argument('--test-iterations', dest='test_iterations', type=int, default=200,
                        help="number of test iterations [default: 200]")
    parser.add_argument("-dr", "--dropout-ratio", dest='dropout_ratio', default=0.5, type=float,
                        help="ratio for dropout layers")

    return parser 
Example #2
Source File: test_snapshot.py    From pfio with MIT License 6 votes vote down vote up
def test_snapshot_hdfs():
    trainer = chainer.testing.get_trainer_with_mock_updater()
    trainer.out = '.'
    trainer._done = True

    with pfio.create_handler('hdfs') as fs:
        tmpdir = "some-pfio-tmp-dir"
        fs.makedirs(tmpdir, exist_ok=True)
        file_list = list(fs.list(tmpdir))
        assert len(file_list) == 0

        writer = SimpleWriter(tmpdir, fs=fs)
        snapshot = extensions.snapshot(writer=writer)
        snapshot(trainer)

        assert 'snapshot_iter_0' in fs.list(tmpdir)

        trainer2 = chainer.testing.get_trainer_with_mock_updater()
        load_snapshot(trainer2, tmpdir, fs=fs, fail_on_no_file=True)

        # Cleanup
        fs.remove(tmpdir, recursive=True) 
Example #3
Source File: test_multi_node_snapshot.py    From chainer with MIT License 6 votes vote down vote up
def test_smoke_wrapper():
    rs = [[0, 1], ]
    comm = create_communicator('naive')
    if comm.size < 2:
        pytest.skip()

    snapshot = extensions.snapshot()
    filename = '{}.{}'.format(snapshot.filename, comm.rank)

    replica_sets = rs
    mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
    if comm.rank == 0:
        assert mn_snapshot.is_master
        assert filename == mn_snapshot.snapshot.filename
    elif comm.rank == 1:
        assert not mn_snapshot.is_master
    elif comm.rank == 2:
        assert mn_snapshot.is_master
        assert filename == mn_snapshot.snapshot.filename
    else:
        assert not mn_snapshot.is_master

    comm.finalize() 
Example #4
Source File: test_snapshot.py    From chainer with MIT License 6 votes vote down vote up
def test_on_error(self):

        class TheOnlyError(Exception):
            pass

        @training.make_extension(trigger=(1, 'iteration'), priority=100)
        def exception_raiser(trainer):
            raise TheOnlyError()
        self.trainer.extend(exception_raiser)

        snapshot = extensions.snapshot_object(self.trainer, self.filename,
                                              snapshot_on_error=True)
        self.trainer.extend(snapshot)

        self.assertFalse(os.path.exists(self.filename))

        with self.assertRaises(TheOnlyError):
            self.trainer.run()

        self.assertTrue(os.path.exists(self.filename)) 
Example #5
Source File: NNet.py    From alpha-zero-general with MIT License 5 votes vote down vote up
def _train_trainer(self, examples):
        """Training with chainer trainer module"""
        train_iter = SerialIterator(examples, args.batch_size)
        optimizer = optimizers.Adam(alpha=args.lr)
        optimizer.setup(self.nnet)

        def loss_func(boards, target_pis, target_vs):
            out_pi, out_v = self.nnet(boards)
            l_pi = self.loss_pi(target_pis, out_pi)
            l_v = self.loss_v(target_vs, out_v)
            total_loss = l_pi + l_v
            chainer.reporter.report({
                'loss': total_loss,
                'loss_pi': l_pi,
                'loss_v': l_v,
            }, observer=self.nnet)
            return total_loss

        updater = training.StandardUpdater(
            train_iter, optimizer, device=args.device, loss_func=loss_func, converter=converter)
        # Set up the trainer.
        trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.out)
        # trainer.extend(extensions.snapshot(), trigger=(args.epochs, 'epoch'))
        trainer.extend(extensions.LogReport())
        trainer.extend(extensions.PrintReport([
            'epoch', 'main/loss', 'main/loss_pi', 'main/loss_v', 'elapsed_time']))
        trainer.extend(extensions.ProgressBar(update_interval=10))
        trainer.run() 
Example #6
Source File: main.py    From qb with MIT License 5 votes vote down vote up
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', '-b', type=int, default=32,
                        help='Number of examples in each mini-batch')
    parser.add_argument('--bproplen', '-l', type=int, default=35,
                        help='Number of words in each mini-batch '
                             '(= length of truncated BPTT)')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--gradclip', '-c', type=float, default=5,
                        help='Gradient norm threshold to clip')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--test', action='store_true',
                        help='Use tiny datasets for quick tests')
    parser.set_defaults(test=False)
    parser.add_argument('--hidden_size', type=int, default=300,
                        help='Number of LSTM units in each layer')
    parser.add_argument('--embed_size', type=int, default=300,
                        help='Size of embeddings')
    parser.add_argument('--model', '-m', default='model.npz',
                        help='Model file name to serialize')
    parser.add_argument('--glove', default='data/glove.6B.300d.txt',
                        help='Path to glove embedding file.')
    args = parser.parse_args()
    return args 
Example #7
Source File: train.py    From portrait_matting with GNU General Public License v3.0 5 votes vote down vote up
def parse_arguments(argv):
    parser = argparse.ArgumentParser(description='Training Script')
    parser.add_argument('--config', '-c', default='config.json',
                        help='Configure json filepath')
    parser.add_argument('--batchsize', '-b', type=int, default=1,
                        help='Number of images in each mini-batch')
    parser.add_argument('--max_iteration', '-e', type=int, default=30000,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpus', '-g', type=int, default=[-1], nargs='*',
                        help='GPU IDs (negative value indicates CPU)')
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='Initial learning rate')
    parser.add_argument('--momentum', default=0.99, help='Momentum for SGD')
    parser.add_argument('--weight_decay', default=0.0005, help='Weight decay')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--mode', choices=['seg', 'seg+', 'seg_tri', 'mat'],
                        help='Training mode', required=True)
    parser.add_argument('--pretrained_fcn8s', default=None,
                        help='Pretrained model path of FCN8s')
    parser.add_argument('--pretrained_n_input_ch', default=3, type=int,
                        help='Input channel number of Pretrained model')
    parser.add_argument('--pretrained_n_output_ch', default=21, type=int,
                        help='Output channel number of Pretrained model')
    parser.add_argument('--mat_scale', default=4, type=int,
                        help='Matting scale for speed up')
    args = parser.parse_args(argv)
    return args 
Example #8
Source File: test_snapshot.py    From pfio with MIT License 5 votes vote down vote up
def test_snapshot():
    trainer = testing.get_trainer_with_mock_updater()
    trainer.out = '.'
    trainer._done = True

    with tempfile.TemporaryDirectory() as td:
        writer = SimpleWriter(td)
        snapshot = extensions.snapshot(writer=writer)
        snapshot(trainer)
        assert 'snapshot_iter_0' in os.listdir(td)

        trainer2 = chainer.testing.get_trainer_with_mock_updater()
        load_snapshot(trainer2, td, fail_on_no_file=True) 
Example #9
Source File: test_snapshot.py    From pfio with MIT License 5 votes vote down vote up
def test_scan_directory():
    from pfio.chainer_extensions.snapshot import _scan_directory
    with tempfile.TemporaryDirectory() as td:
        files = ['tmpfoobar_10', 'foobar_10', 'foobar_123', 'tmpfoobar_10234']
        for file in files:
            pathlib.Path(os.path.join(td, file)).touch()

        latest = _scan_directory(pfio, td)
        assert latest is not None
        assert 'foobar_123' == latest 
Example #10
Source File: test_multi_node_snapshot.py    From chainer with MIT License 5 votes vote down vote up
def _prepare_multinode_snapshot(n, result):
    n_units = 100
    batchsize = 10
    comm = create_communicator('naive')
    model = L.Classifier(MLP(n_units, 10))
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.Adam(), comm)
    optimizer.setup(model)

    if comm.rank == 0:
        train, _ = chainer.datasets.get_mnist()
    else:
        train, _ = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    train_iter = chainer.iterators.SerialIterator(train, batchsize)

    updater = StandardUpdater(train_iter, optimizer)
    trainer = Trainer(updater, out=result)

    snapshot = extensions.snapshot(target=updater, autoload=True)
    replica_sets = []
    mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
    mn_snapshot.initialize(trainer)
    for _ in range(n):
        updater.update()

    return updater, mn_snapshot, trainer 
Example #11
Source File: test_multi_node_snapshot.py    From chainer with MIT License 5 votes vote down vote up
def test_smoke_multinode_snapshot():
    t = mock.MagicMock()
    c = mock.MagicMock(side_effect=[True, False])
    w = mock.MagicMock()
    snapshot = extensions.snapshot(target=t, condition=c, writer=w)
    trainer = mock.MagicMock()

    comm = create_communicator('naive')
    replica_sets = []
    mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)

    mn_snapshot.initialize(trainer)
    mn_snapshot(trainer)
    mn_snapshot(trainer)
    mn_snapshot.finalize()

    if comm.rank == 0:
        assert mn_snapshot.is_master
        assert c.call_count == 2
        assert w.call_count == 1
    else:
        assert not mn_snapshot.is_master
        assert c.call_count == 0
        assert w.call_count == 0

    comm.finalize() 
Example #12
Source File: test_multi_node_snapshot.py    From chainer with MIT License 5 votes vote down vote up
def test_callable_filename():
    rs = [[0, 1], ]
    comm = create_communicator('naive')
    if comm.size < 2:
        pytest.skip()

    def filename_fun(t):
        return 'deadbeef-{.updater.iteration}'.format(t)

    snapshot = extensions.snapshot(filename=filename_fun)

    trainer = mock.MagicMock()
    filename = '{}.{}'.format(filename_fun(trainer), comm.rank)

    replica_sets = rs
    mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
    if comm.rank == 0:
        assert mn_snapshot.is_master
        assert filename == mn_snapshot.snapshot.filename(trainer)
    elif comm.rank == 1:
        assert not mn_snapshot.is_master
    elif comm.rank == 2:
        assert mn_snapshot.is_master
        assert filename == mn_snapshot.snapshot.filename(trainer)
    else:
        assert not mn_snapshot.is_master

    comm.finalize() 
Example #13
Source File: test_snapshot.py    From chainer with MIT License 5 votes vote down vote up
def test_remove_stale_snapshots(self):
        fmt = 'snapshot_iter_{.updater.iteration}'
        retain = 3
        snapshot = extensions.snapshot(filename=fmt, n_retains=retain,
                                       autoload=False)

        trainer = testing.get_trainer_with_mock_updater()
        trainer.out = self.path
        trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2)

        class TimeStampUpdater():
            t = time.time() - 100
            name = 'ts_updater'
            priority = 1  # This must be called after snapshot taken

            def __call__(self, _trainer):
                filename = os.path.join(_trainer.out, fmt.format(_trainer))
                self.t += 1
                # For filesystems that does low timestamp precision
                os.utime(filename, (self.t, self.t))

        trainer.extend(TimeStampUpdater(), trigger=(1, 'iteration'))
        trainer.run()
        assert 10 == trainer.updater.iteration
        assert trainer._done

        pattern = os.path.join(trainer.out, "snapshot_iter_*")
        found = [os.path.basename(path) for path in glob.glob(pattern)]
        assert retain == len(found)
        found.sort()
        # snapshot_iter_(8, 9, 10) expected
        expected = ['snapshot_iter_{}'.format(i) for i in range(8, 11)]
        expected.sort()
        assert expected == found

        trainer2 = testing.get_trainer_with_mock_updater()
        trainer2.out = self.path
        assert not trainer2._done
        snapshot2 = extensions.snapshot(filename=fmt, autoload=True)
        # Just making sure no error occurs
        snapshot2.initialize(trainer2) 
Example #14
Source File: gen_mnist_mlp.py    From chainer-compiler with MIT License 5 votes vote down vote up
def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=7,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency', '-f', type=int, default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    parser.add_argument('--noplot', dest='plot', action='store_false',
                        help='Disable PlotReport extension')
    parser.add_argument('--onnx', default='',
                        help='Export ONNX model')
    parser.add_argument('--model', '-m', default='model.npz',
                        help='Model file name to serialize')
    parser.add_argument('--timeout', type=int, default=0,
                        help='Enable timeout')
    parser.add_argument('--trace', default='',
                        help='Enable tracing')
    parser.add_argument('--run_training', action='store_true',
                        help='Run training')
    args = parser.parse_args()

    main_impl(args) 
Example #15
Source File: 02-train.py    From Semantic-Segmentation-using-Adversarial-Networks with MIT License 5 votes vote down vote up
def parse_args(generators, discriminators, updaters):
    parser = argparse.ArgumentParser(description='Semantic Segmentation using Adversarial Networks')
    parser.add_argument('--generator', choices=generators.keys(), default='fcn32s',
                        help='Generator(segmentor) architecture')
    parser.add_argument('--discriminator', choices=discriminators.keys(), default='largefov',
                        help='Discriminator architecture')
    parser.add_argument('--updater', choices=updaters.keys(), default='gan',
                        help='Updater')
    parser.add_argument('--initgen_path', default='pretrained_model/vgg16.npz',
                        help='Pretrained model of generator')
    parser.add_argument('--initdis_path', default=None,
                        help='Pretrained model of discriminator')
    parser.add_argument('--batchsize', '-b', type=int, default=1,
                        help='Number of images in each mini-batch')
    parser.add_argument('--iteration', '-i', type=int, default=100000,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='snapshot',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--evaluate_interval', type=int, default=1000,
                        help='Interval of evaluation')
    parser.add_argument('--snapshot_interval', type=int, default=10000,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval', type=int, default=10,
                        help='Interval of displaying log to console')
    return parser.parse_args() 
Example #16
Source File: test_snapshot.py    From chainer with MIT License 5 votes vote down vote up
def test_clean_up_tempdir(self):
        snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat')
        snapshot(self.trainer)

        left_tmps = [fn for fn in os.listdir('.')
                     if fn.startswith('tmpmyfile.dat')]
        self.assertEqual(len(left_tmps), 0) 
Example #17
Source File: test_snapshot.py    From chainer with MIT License 5 votes vote down vote up
def test_call(self):
        t = mock.MagicMock()
        c = mock.MagicMock(side_effect=[True, False])
        w = mock.MagicMock()
        snapshot = extensions.snapshot(target=t, condition=c, writer=w)
        trainer = mock.MagicMock()
        snapshot(trainer)
        snapshot(trainer)

        assert c.call_count == 2
        assert w.call_count == 1 
Example #18
Source File: test_snapshot.py    From chainer with MIT License 5 votes vote down vote up
def test_savefun_and_writer_exclusive(self):
        # savefun and writer arguments cannot be specified together.
        def savefun(*args, **kwargs):
            assert False
        writer = extensions.snapshot_writers.SimpleWriter()
        with pytest.raises(TypeError):
            extensions.snapshot(savefun=savefun, writer=writer)

        trainer = mock.MagicMock()
        with pytest.raises(TypeError):
            extensions.snapshot_object(trainer, savefun=savefun, writer=writer) 
Example #19
Source File: test_snapshot.py    From chainer with MIT License 5 votes vote down vote up
def test_save_file(self):
        w = extensions.snapshot_writers.SimpleWriter()
        snapshot = extensions.snapshot_object(self.trainer, 'myfile.dat',
                                              writer=w)
        snapshot(self.trainer)

        self.assertTrue(os.path.exists('myfile.dat')) 
Example #20
Source File: train_ch_in1k.py    From imgclsmob with MIT License 4 votes vote down vote up
def prepare_trainer(net,
                    optimizer_name,
                    lr,
                    momentum,
                    num_epochs,
                    train_iter,
                    val_iter,
                    logging_dir_path,
                    num_gpus=0):
    if optimizer_name == "sgd":
        optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum)
    elif optimizer_name == "nag":
        optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum)
    else:
        raise Exception('Unsupported optimizer: {}'.format(optimizer_name))
    optimizer.setup(net)

    # devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, )
    devices = (0,) if num_gpus > 0 else (-1,)

    updater = training.updaters.StandardUpdater(
        iterator=train_iter,
        optimizer=optimizer,
        device=devices[0])
    trainer = training.Trainer(
        updater=updater,
        stop_trigger=(num_epochs, 'epoch'),
        out=logging_dir_path)

    val_interval = 100000, 'iteration'
    log_interval = 1000, 'iteration'

    trainer.extend(
        extension=extensions.Evaluator(
            val_iter,
            net,
            device=devices[0]),
        trigger=val_interval)
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(), trigger=val_interval)
    trainer.extend(
        extensions.snapshot_object(
            net,
            'model_iter_{.updater.iteration}'),
        trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy',
            'lr']),
        trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    return trainer 
Example #21
Source File: train_ch_cifar.py    From imgclsmob with MIT License 4 votes vote down vote up
def prepare_trainer(net,
                    optimizer_name,
                    lr,
                    momentum,
                    num_epochs,
                    train_iter,
                    val_iter,
                    logging_dir_path,
                    num_gpus=0):
    if optimizer_name == "sgd":
        optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum)
    elif optimizer_name == "nag":
        optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum)
    else:
        raise Exception('Unsupported optimizer: {}'.format(optimizer_name))
    optimizer.setup(net)

    # devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, )
    devices = (0,) if num_gpus > 0 else (-1,)

    updater = training.updaters.StandardUpdater(
        iterator=train_iter,
        optimizer=optimizer,
        device=devices[0])
    trainer = training.Trainer(
        updater=updater,
        stop_trigger=(num_epochs, 'epoch'),
        out=logging_dir_path)

    val_interval = 100000, 'iteration'
    log_interval = 1000, 'iteration'

    trainer.extend(
        extension=extensions.Evaluator(
            val_iter,
            net,
            device=devices[0]),
        trigger=val_interval)
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(), trigger=val_interval)
    trainer.extend(
        extensions.snapshot_object(
            net,
            'model_iter_{.updater.iteration}'),
        trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy',
            'lr']),
        trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    return trainer 
Example #22
Source File: train_ch.py    From imgclsmob with MIT License 4 votes vote down vote up
def prepare_trainer(net,
                    optimizer_name,
                    lr,
                    momentum,
                    num_epochs,
                    train_data,
                    val_data,
                    logging_dir_path,
                    use_gpus):
    if optimizer_name == "sgd":
        optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum)
    elif optimizer_name == "nag":
        optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum)
    else:
        raise Exception("Unsupported optimizer: {}".format(optimizer_name))
    optimizer.setup(net)

    # devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, )
    devices = (0,) if use_gpus else (-1,)

    updater = training.updaters.StandardUpdater(
        iterator=train_data["iterator"],
        optimizer=optimizer,
        device=devices[0])
    trainer = training.Trainer(
        updater=updater,
        stop_trigger=(num_epochs, "epoch"),
        out=logging_dir_path)

    val_interval = 100000, "iteration"
    log_interval = 1000, "iteration"

    trainer.extend(
        extension=extensions.Evaluator(
            iterator=val_data["iterator"],
            target=net,
            device=devices[0]),
        trigger=val_interval)
    trainer.extend(extensions.dump_graph("main/loss"))
    trainer.extend(extensions.snapshot(), trigger=val_interval)
    trainer.extend(
        extensions.snapshot_object(
            net,
            "model_iter_{.updater.iteration}"),
        trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    trainer.extend(
        extensions.PrintReport([
            "epoch", "iteration", "main/loss", "validation/main/loss", "main/accuracy", "validation/main/accuracy",
            "lr"]),
        trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    return trainer 
Example #23
Source File: gen_mnist_mlp.py    From chainer-compiler with MIT License 4 votes vote down vote up
def run_training(args, model):
    trainer = create_trainer(args, model)

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Save two plot images to the result dir
    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch', file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch', file_name='accuracy.png'))

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run() 
Example #24
Source File: train_utils.py    From see with GNU General Public License v3.0 4 votes vote down vote up
def get_trainer(net, updater, log_dir, print_fields, curriculum=None, extra_extensions=(), epochs=10, snapshot_interval=20000, print_interval=100, postprocess=None, do_logging=True, model_files=()):
    if curriculum is None:
        trainer = chainer.training.Trainer(
            updater,
            (epochs, 'epoch'),
            out=log_dir,
        )
    else:
        trainer = chainer.training.Trainer(
            updater,
            EarlyStopIntervalTrigger(epochs, 'epoch', curriculum),
            out=log_dir,
        )

    # dump computational graph
    trainer.extend(extensions.dump_graph('main/loss'))

    # also observe learning rate
    observe_lr_extension = chainer.training.extensions.observe_lr()
    observe_lr_extension.trigger = (print_interval, 'iteration')
    trainer.extend(observe_lr_extension)

    # Take snapshots
    trainer.extend(
        extensions.snapshot(filename="trainer_snapshot"),
        trigger=lambda trainer:
        trainer.updater.is_new_epoch or
        (trainer.updater.iteration > 0 and trainer.updater.iteration % snapshot_interval == 0)
    )

    if do_logging:
        # write all statistics to a file
        trainer.extend(Logger(model_files, log_dir, keys=print_fields, trigger=(print_interval, 'iteration'), postprocess=postprocess))

        # print some interesting statistics
        trainer.extend(extensions.PrintReport(
            print_fields,
            log_report='Logger',
        ))

    # Progressbar!!
    trainer.extend(extensions.ProgressBar(update_interval=1))

    for extra_extension, trigger in extra_extensions:
        trainer.extend(extra_extension, trigger=trigger)

    return trainer 
Example #25
Source File: gen_resnet50.py    From chainer-compiler with MIT License 4 votes vote down vote up
def run_training(args, model):
    trainer = create_trainer(args, model)

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Save two plot images to the result dir
    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch', file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch', file_name='accuracy.png'))

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run() 
Example #26
Source File: gen_resnet50.py    From chainer-compiler with MIT License 4 votes vote down vote up
def main():
    archs = {
        'alex': alex.Alex,
        'nin': nin.NIN,
        'resnet50': resnet50.ResNet50,
    }
    parser = argparse.ArgumentParser(
        description='Learning convnet from ILSVRC2012 dataset')
    parser.add_argument('--arch', '-a', choices=archs.keys(),
                        default='resnet50',
                        help='Convnet architecture')
    parser.add_argument('--train', default='',
                        help='Path to training image-label list file')
    parser.add_argument('--val', default='',
                        help='Path to validation image-label list file')
    parser.add_argument('--batchsize', '-B', type=int, default=32,
                        help='Learning minibatch size')
    parser.add_argument('--epoch', '-E', type=int, default=10,
                        help='Number of epochs to train')
    parser.add_argument('--frequency', '-f', type=int, default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU')
    parser.add_argument('--initmodel',
                        help='Initialize the model from given file')
    parser.add_argument('--loaderjob', '-j', type=int,
                        help='Number of parallel data loading processes')
    parser.add_argument('--mean', '-m', default='mean.npy',
                        help='Mean file (computed by compute_mean.py)')
    parser.add_argument('--noplot', dest='plot', action='store_false',
                        help='Disable PlotReport extension')
    parser.add_argument('--resume', '-r', default='',
                        help='Initialize the trainer from given file')
    parser.add_argument('--out', '-o', default='result',
                        help='Output directory')
    parser.add_argument('--root', '-R', default='.',
                        help='Root directory path of image files')
    parser.add_argument('--val_batchsize', '-b', type=int, default=250,
                        help='Validation minibatch size')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--run_training', action='store_true',
                        help='Run training')
    parser.set_defaults(test=False)
    args = parser.parse_args()

    model_cls = archs[args.arch]
    main_impl(args, model_cls)

    # TODO(hamaji): Stop writing a file to scripts.
    with open('scripts/%s_stamp' % args.arch, 'w'): pass 
Example #27
Source File: train.py    From portrait_matting with GNU General Public License v3.0 4 votes vote down vote up
def register_extensions(trainer, model, test_iter, args):
    if args.mode.startswith('seg'):
        # Max accuracy
        best_trigger = training.triggers.BestValueTrigger(
            'validation/main/accuracy', lambda a, b: a < b, (1, 'epoch'))
    elif args.mode.startswith('mat'):
        # Min loss
        best_trigger = training.triggers.BestValueTrigger(
            'validation/main/loss', lambda a, b: a > b, (1, 'epoch'))
    else:
        logger.error('Invalid training mode')

    # Segmentation extensions
    trainer.extend(
        custom_extensions.PortraitVisEvaluator(
            test_iter, model, device=args.gpus[0],
            converter=select_converter(args.mode),
            filename='vis_epoch={epoch}_idx={index}.jpg',
            mode=args.mode
        ), trigger=(1, 'epoch'))

    # Basic extensions
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.LogReport(trigger=(200, 'iteration')))
    trainer.extend(extensions.ProgressBar(update_interval=20))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
         'validation/main/accuracy', 'lr', 'elapsed_time']))
    trainer.extend(extensions.observe_lr(), trigger=(200, 'iteration'))

    # Snapshots
    trainer.extend(extensions.snapshot(
        filename='snapshot_epoch_{.updater.epoch}'
    ), trigger=(5, 'epoch'))
    trainer.extend(extensions.snapshot_object(
        model, filename='model_best'
    ), trigger=best_trigger)

    # ChainerUI extensions
    trainer.extend(chainerui.extensions.CommandsExtension())
    chainerui.utils.save_args(args, args.out)

    # Plotting extensions
    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(
                ['main/loss', 'validation/main/loss'],
                'epoch', file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch', file_name='accuracy.png')) 
Example #28
Source File: train.py    From portrait_matting with GNU General Public License v3.0 4 votes vote down vote up
def main(argv):
    # Argument
    args = parse_arguments(argv)

    # Load config
    config.load(args.config)

    # Setup dataset
    train, test = setup_dataset(args.mode, config.img_crop_dir,
                                config.img_mask_dir, config.img_mean_mask_dir,
                                config.img_mean_grid_dir,
                                config.img_trimap_dir, config.img_alpha_dir,
                                config.img_alpha_weight_dir)

    # Setup iterators
    train_iter, test_iter = setup_iterators(args.gpus, args.batchsize, train,
                                            test)

    # Setup model
    model = setup_model(args.mode, args.pretrained_fcn8s,
                        args.pretrained_n_input_ch,
                        args.pretrained_n_output_ch,
                        args.mat_scale)

    # Setup an optimizer
    optimizer = setup_optimizer(model, args.lr, args.momentum,
                                args.weight_decay)

    # Set up a trainer
    updater = setup_updater(args.mode, args.gpus, train_iter, optimizer)
    trainer = training.Trainer(updater, (args.max_iteration, 'iteration'),
                               out=args.out)

    # Register extensions for portrait segmentation / matting
    register_extensions(trainer, model, test_iter, args)

    # Resume from a snapshot
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run() 
Example #29
Source File: train_utils.py    From chainer-chemistry with MIT License 4 votes vote down vote up
def run_node_classification_train(model, data,
                                  train_mask, valid_mask,
                                  epoch=10,
                                  optimizer=None,
                                  out='result',
                                  extensions_list=None,
                                  device=-1,
                                  converter=None,
                                  use_default_extensions=True,
                                  resume_path=None):
    if optimizer is None:
        # Use Adam optimizer as default
        optimizer = optimizers.Adam()
    elif not isinstance(optimizer, Optimizer):
        raise ValueError("[ERROR] optimizer must be instance of Optimizer, "
                         "but passed {}".format(type(Optimizer)))

    optimizer.setup(model)

    def one_batch_converter(batch, device):
        if not isinstance(device, Device):
            device = chainer.get_device(device)
        data, train_mask, valid_mask = batch[0]
        return (data.to_device(device),
                device.send(train_mask), device.send(valid_mask))

    data_iter = SerialIterator([(data, train_mask, valid_mask)], batch_size=1)
    updater = training.StandardUpdater(
        data_iter, optimizer, device=device,
        converter=one_batch_converter)
    trainer = training.Trainer(updater, (epoch, 'epoch'), out=out)
    if use_default_extensions:
        trainer.extend(extensions.LogReport())
        trainer.extend(AutoPrintReport())
        trainer.extend(extensions.ProgressBar(update_interval=10))
        # TODO(nakago): consider to include snapshot as default extension.
        # trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    if extensions_list is not None:
        for e in extensions_list:
            trainer.extend(e)

    if resume_path:
        chainer.serializers.load_npz(resume_path, trainer)
    trainer.run()

    return 
Example #30
Source File: train.py    From Video-frame-prediction-by-multi-scale-GAN with MIT License 4 votes vote down vote up
def main(resume, gpu, load_path, data_path):
	dataset = Dataset(data_path)


	GenNetwork = MultiScaleGenerator(c.SCALE_FMS_G, c.SCALE_KERNEL_SIZES_G)
	DisNetwork = MultiScaleDiscriminator(c.SCALE_CONV_FMS_D, c.SCALE_KERNEL_SIZES_D, c.SCALE_FC_LAYER_SIZES_D)

	optimizers = {}
	optimizers["GeneratorNetwork"] = chainer.optimizers.SGD(c.LRATE_G)
	optimizers["DiscriminatorNetwork"] = chainer.optimizers.SGD(c.LRATE_D)

	iterator = chainer.iterators.SerialIterator(dataset, 1)
	params = {'LAM_ADV': 0.05, 'LAM_LP': 1, 'LAM_GDL': .1}
	updater = Updater(iterators=iterator, optimizers=optimizers,
	                  GeneratorNetwork=GenNetwork,
	                  DiscriminatorNetwork=DisNetwork,
	                  params=params,
	                  device=gpu
	                  )
	if gpu>=0:
		updater.GenNetwork.to_gpu()
		updater.DisNetwork.to_gpu()

	trainer = chainer.training.Trainer(updater, (500000, 'iteration'), out='result')
	trainer.extend(extensions.snapshot(filename='snapshot'), trigger=(1, 'iteration'))
	trainer.extend(extensions.snapshot_object(trainer.updater.GenNetwork, "GEN"))
	trainer.extend(saveGen)

	log_keys = ['epoch', 'iteration', 'GeneratorNetwork/L2Loss', 'GeneratorNetwork/GDL',
	            'DiscriminatorNetwork/DisLoss', 'GeneratorNetwork/CompositeGenLoss']
	print_keys = ['GeneratorNetwork/CompositeGenLoss','DiscriminatorNetwork/DisLoss']
	trainer.extend(extensions.LogReport(keys=log_keys, trigger=(10, 'iteration')))
	trainer.extend(extensions.PrintReport(print_keys), trigger=(10, 'iteration'))
	trainer.extend(extensions.PlotReport(['DiscriminatorNetwork/DisLoss'], 'iteration', (10, 'iteration'), file_name="DisLoss.png"))
	trainer.extend(extensions.PlotReport(['GeneratorNetwork/CompositeGenLoss'], 'iteration', (10, 'iteration'), file_name="GenLoss.png"))
	trainer.extend(extensions.PlotReport(['GeneratorNetwork/AdvLoss'], 'iteration', (10, 'iteration'), file_name="AdvGenLoss.png"))
	trainer.extend(extensions.PlotReport(['GeneratorNetwork/AdvLoss','DiscriminatorNetwork/DisLoss'], 'iteration', (10, 'iteration'), file_name="AdversarialLosses.png"))
	trainer.extend(extensions.PlotReport(['GeneratorNetwork/L2Loss'], 'iteration', (10, 'iteration'),file_name="L2Loss.png"))
	trainer.extend(extensions.PlotReport(['GeneratorNetwork/GDL'], 'iteration', (10, 'iteration'),file_name="GDL.png"))

	trainer.extend(extensions.ProgressBar(update_interval=10))
	if resume:
		# Resume from a snapshot
		chainer.serializers.load_npz(load_path, trainer)
	print(trainer.updater.__dict__)
	trainer.run()