Python chainer.training.Trainer() Examples

The following are 21 code examples of chainer.training.Trainer(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module chainer.training , or try the search function .
Example #1
Source File: plot_chainer_MLP.py    From soft-dtw with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def train(network, loss, X_tr, Y_tr, X_te, Y_te, n_epochs=30, gamma=1):
    model= Objective(network, loss=loss, gamma=gamma)

    #optimizer = optimizers.SGD()
    optimizer = optimizers.Adam()
    optimizer.setup(model)

    train = tuple_dataset.TupleDataset(X_tr, Y_tr)
    test = tuple_dataset.TupleDataset(X_te, Y_te)

    train_iter = iterators.SerialIterator(train, batch_size=1, shuffle=True)
    test_iter = iterators.SerialIterator(test, batch_size=1, repeat=False,
                                         shuffle=False)
    updater = training.StandardUpdater(train_iter, optimizer)
    trainer = training.Trainer(updater, (n_epochs, 'epoch'))

    trainer.run() 
Example #2
Source File: gen_mnist_mlp.py    From chainer-compiler with MIT License 6 votes vote down vote up
def create_trainer(args, model):
    # Setup an optimizer
    # optimizer = chainer.optimizers.Adam()
    optimizer = chainer.optimizers.SGD()
    optimizer.setup(model)

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    train_iter = MyIterator(train, args.batchsize, shuffle=False)
    test_iter = MyIterator(test, args.batchsize, repeat=False, shuffle=False)

    # Set up a trainer
    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    return trainer 
Example #3
Source File: train.py    From voxelnet_chainer with MIT License 6 votes vote down vote up
def train_voxelnet():
    """Training VoxelNet."""
    config = parse_args()
    model = get_model(config["model"])
    devices = parse_devices(config['gpus'], config['updater']['name'])
    train_data, test_data = load_dataset(config["dataset"])
    train_iter, test_iter = create_iterator(train_data, test_data,
                                            config['iterator'], devices,
                                            config['updater']['name'])
    class_weight = get_class_weight(config)
    optimizer = create_optimizer(config['optimizer'], model)
    updater = create_updater(train_iter, optimizer, config['updater'], devices)
    trainer = training.Trainer(updater, config['end_trigger'], out=config['results'])
    trainer = create_extension(trainer, test_iter,  model,
                               config['extension'], devices=devices)
    trainer.run()
    chainer.serializers.save_npz(os.path.join(config['results'], 'model.npz'),
                                 model) 
Example #4
Source File: test_parameter_statistics.py    From chainer with MIT License 6 votes vote down vote up
def _get_mocked_trainer(links, stop_trigger=(10, 'iteration')):
    updater = mock.Mock()
    optimizer = mock.Mock()
    target = mock.Mock()
    target.namedlinks.return_value = [
        (str(i), link) for i, link in enumerate(links)]

    optimizer.target = target
    updater.get_all_optimizers.return_value = {'optimizer_name': optimizer}
    updater.iteration = 0
    updater.epoch = 0
    updater.epoch_detail = 0
    updater.is_new_epoch = True
    iter_per_epoch = 10

    def update():
        time.sleep(0.001)
        updater.iteration += 1
        updater.epoch = updater.iteration // iter_per_epoch
        updater.epoch_detail = updater.iteration / iter_per_epoch
        updater.is_new_epoch = updater.epoch == updater.epoch_detail

    updater.update = update

    return training.Trainer(updater, stop_trigger) 
Example #5
Source File: test_computational_graph.py    From chainer with MIT License 5 votes vote down vote up
def _run_test(self, tempdir, initial_flag):
        n_data = 4
        n_epochs = 3
        outdir = os.path.join(tempdir, 'testresult')

        # Prepare
        model = Model()
        classifier = links.Classifier(model)
        optimizer = chainer.optimizers.Adam()
        optimizer.setup(classifier)

        dataset = Dataset([i for i in range(n_data)])
        iterator = chainer.iterators.SerialIterator(dataset, 1, shuffle=False)
        updater = training.updaters.StandardUpdater(iterator, optimizer)
        trainer = training.Trainer(updater, (n_epochs, 'epoch'), out=outdir)

        extension = c.DumpGraph('main/loss', filename='test.dot')
        trainer.extend(extension)

        # Run
        with chainer.using_config('keep_graph_on_report', initial_flag):
            trainer.run()

        # Check flag history
        self.assertEqual(model.flag_history,
                         [True] + [initial_flag] * (n_data * n_epochs - 1))

        # Check the dumped graph
        graph_path = os.path.join(outdir, 'test.dot')
        with open(graph_path) as f:
            graph_dot = f.read()

        # Check that only the first iteration is dumped
        self.assertIn('Function1', graph_dot)
        self.assertNotIn('Function2', graph_dot)

        if c.is_graphviz_available():
            self.assertTrue(os.path.exists(os.path.join(outdir, 'test.png'))) 
Example #6
Source File: NNet.py    From alpha-zero-general with MIT License 5 votes vote down vote up
def _train_trainer(self, examples):
        """Training with chainer trainer module"""
        train_iter = SerialIterator(examples, args.batch_size)
        optimizer = optimizers.Adam(alpha=args.lr)
        optimizer.setup(self.nnet)

        def loss_func(boards, target_pis, target_vs):
            out_pi, out_v = self.nnet(boards)
            l_pi = self.loss_pi(target_pis, out_pi)
            l_v = self.loss_v(target_vs, out_v)
            total_loss = l_pi + l_v
            chainer.reporter.report({
                'loss': total_loss,
                'loss_pi': l_pi,
                'loss_v': l_v,
            }, observer=self.nnet)
            return total_loss

        updater = training.StandardUpdater(
            train_iter, optimizer, device=args.device, loss_func=loss_func, converter=converter)
        # Set up the trainer.
        trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.out)
        # trainer.extend(extensions.snapshot(), trigger=(args.epochs, 'epoch'))
        trainer.extend(extensions.LogReport())
        trainer.extend(extensions.PrintReport([
            'epoch', 'main/loss', 'main/loss_pi', 'main/loss_v', 'elapsed_time']))
        trainer.extend(extensions.ProgressBar(update_interval=10))
        trainer.run() 
Example #7
Source File: train.py    From chainer-wasserstein-gan with MIT License 5 votes vote down vote up
def train(args):
    nz = args.nz
    batch_size = args.batch_size
    epochs = args.epochs
    gpu = args.gpu

    # CIFAR-10 images in range [-1, 1] (tanh generator outputs)
    train, _ = datasets.get_cifar10(withlabel=False, ndim=3, scale=2)
    train -= 1.0
    train_iter = iterators.SerialIterator(train, batch_size)

    z_iter = RandomNoiseIterator(GaussianNoiseGenerator(0, 1, args.nz),
                                 batch_size)

    optimizer_generator = optimizers.RMSprop(lr=0.00005)
    optimizer_critic = optimizers.RMSprop(lr=0.00005)
    optimizer_generator.setup(Generator())
    optimizer_critic.setup(Critic())

    updater = WassersteinGANUpdater(
        iterator=train_iter,
        noise_iterator=z_iter,
        optimizer_generator=optimizer_generator,
        optimizer_critic=optimizer_critic,
        device=gpu)

    trainer = training.Trainer(updater, stop_trigger=(epochs, 'epoch'))
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.LogReport(trigger=(1, 'iteration')))
    trainer.extend(GeneratorSample(), trigger=(1, 'epoch'))
    trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'critic/loss',
            'critic/loss/real', 'critic/loss/fake', 'generator/loss']))
    trainer.run() 
Example #8
Source File: test_multi_node_snapshot.py    From chainer with MIT License 5 votes vote down vote up
def _prepare_multinode_snapshot(n, result):
    n_units = 100
    batchsize = 10
    comm = create_communicator('naive')
    model = L.Classifier(MLP(n_units, 10))
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.Adam(), comm)
    optimizer.setup(model)

    if comm.rank == 0:
        train, _ = chainer.datasets.get_mnist()
    else:
        train, _ = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    train_iter = chainer.iterators.SerialIterator(train, batchsize)

    updater = StandardUpdater(train_iter, optimizer)
    trainer = Trainer(updater, out=result)

    snapshot = extensions.snapshot(target=updater, autoload=True)
    replica_sets = []
    mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
    mn_snapshot.initialize(trainer)
    for _ in range(n):
        updater.update()

    return updater, mn_snapshot, trainer 
Example #9
Source File: test_fail_on_nonnumber.py    From chainer with MIT License 5 votes vote down vote up
def prepare(self, dirname='test', device=None):
        outdir = os.path.join(self.temp_dir, dirname)
        self.updater = training.updaters.StandardUpdater(
            self.iterator, self.optimizer, device=device)
        self.trainer = training.Trainer(
            self.updater, (self.n_epochs, 'epoch'), out=outdir)
        self.trainer.extend(training.extensions.FailOnNonNumber()) 
Example #10
Source File: train.py    From qb with MIT License 5 votes vote down vote up
def main(model):
    train = read_data(fold=BUZZER_TRAIN_FOLD)
    valid = read_data(fold=BUZZER_DEV_FOLD)
    print('# train data: {}'.format(len(train)))
    print('# valid data: {}'.format(len(valid)))

    train_iter = chainer.iterators.SerialIterator(train, 64)
    valid_iter = chainer.iterators.SerialIterator(valid, 64, repeat=False, shuffle=False)

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(1e-4))

    updater = training.updaters.StandardUpdater(train_iter, optimizer, converter=convert_seq, device=0)
    trainer = training.Trainer(updater, (20, 'epoch'), out=model.model_dir)

    trainer.extend(extensions.Evaluator(valid_iter, model, converter=convert_seq, device=0))

    record_trigger = training.triggers.MaxValueTrigger('validation/main/accuracy', (1, 'epoch'))
    trainer.extend(extensions.snapshot_object(model, 'buzzer.npz'), trigger=record_trigger)

    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.PrintReport([
        'epoch', 'main/loss', 'validation/main/loss',
        'main/accuracy', 'validation/main/accuracy', 'elapsed_time'
    ]))

    if not os.path.isdir(model.model_dir):
        os.mkdir(model.model_dir)

    trainer.run() 
Example #11
Source File: test_commands_extension.py    From chainerui with MIT License 5 votes vote down vote up
def _get_mock_trainer(self, out_path, trigger=None, updater=None):
        class _MockTrainer(Trainer):
            def __init__(
                self, out_path, stop_trigger=IntervalTrigger(100, 'epoch'),
                    updater=None):

                self.out = out_path
                self.stop_trigger = stop_trigger

                hyperparam = Hyperparameter()
                hyperparam.lr = 0.005
                optimizer = MagicMock()
                optimizer.__class__.__name__ = 'MomentumSGD'
                optimizer.hyperparam = hyperparam

                if updater is None:
                    updater = MagicMock()
                    updater.epoch = 0
                    updater.iteration = 0
                    updater.get_optimizer.return_value = optimizer
                self.updater = updater

            @property
            def elapsed_time(self):
                return 0

            def serialize(self, serializer):
                pass

        return _MockTrainer(out_path, trigger, updater) 
Example #12
Source File: test_linear_network.py    From shoelace with MIT License 5 votes vote down vote up
def test_linear_network():

    # To ensure repeatability of experiments
    np.random.seed(1042)

    # Load data set
    dataset = get_dataset(True)
    iterator = LtrIterator(dataset, repeat=True, shuffle=True)
    eval_iterator = LtrIterator(dataset, repeat=False, shuffle=False)

    # Create neural network with chainer and apply our loss function
    predictor = links.Linear(None, 1)
    loss = Ranker(predictor, listnet)

    # Build optimizer, updater and trainer
    optimizer = optimizers.Adam(alpha=0.2)
    optimizer.setup(loss)
    updater = training.StandardUpdater(iterator, optimizer)
    trainer = training.Trainer(updater, (10, 'epoch'))

    # Evaluate loss before training
    before_loss = eval(loss, eval_iterator)

    # Train neural network
    trainer.run()

    # Evaluate loss after training
    after_loss = eval(loss, eval_iterator)

    # Assert precomputed values
    assert_almost_equal(before_loss, 0.26958397)
    assert_almost_equal(after_loss, 0.2326711) 
Example #13
Source File: gen_resnet50.py    From chainer-compiler with MIT License 5 votes vote down vote up
def create_trainer(args, model):
    # Setup an optimizer
    #optimizer = chainer.optimizers.Adam()
    optimizer = chainer.optimizers.SGD()
    optimizer.setup(model)

    # Load the datasets and mean file
    mean = np.load(args.mean)

    train = PreprocessedDataset(args.train, args.root, mean, insize)
    val = PreprocessedDataset(args.val, args.root, mean, insize, False)

    # These iterators load the images with subprocesses running in parallel to
    # the training/validation.
    train_iter = MyIterator(
        train, args.batchsize, n_processes=args.loaderjob)
    val_iter = MyIterator(
        val, args.val_batchsize, repeat=False, n_processes=args.loaderjob)

    # Set up a trainer
    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(val_iter, model, device=args.gpu))
    return trainer 
Example #14
Source File: train_ch_in1k.py    From imgclsmob with MIT License 4 votes vote down vote up
def prepare_trainer(net,
                    optimizer_name,
                    lr,
                    momentum,
                    num_epochs,
                    train_iter,
                    val_iter,
                    logging_dir_path,
                    num_gpus=0):
    if optimizer_name == "sgd":
        optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum)
    elif optimizer_name == "nag":
        optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum)
    else:
        raise Exception('Unsupported optimizer: {}'.format(optimizer_name))
    optimizer.setup(net)

    # devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, )
    devices = (0,) if num_gpus > 0 else (-1,)

    updater = training.updaters.StandardUpdater(
        iterator=train_iter,
        optimizer=optimizer,
        device=devices[0])
    trainer = training.Trainer(
        updater=updater,
        stop_trigger=(num_epochs, 'epoch'),
        out=logging_dir_path)

    val_interval = 100000, 'iteration'
    log_interval = 1000, 'iteration'

    trainer.extend(
        extension=extensions.Evaluator(
            val_iter,
            net,
            device=devices[0]),
        trigger=val_interval)
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(), trigger=val_interval)
    trainer.extend(
        extensions.snapshot_object(
            net,
            'model_iter_{.updater.iteration}'),
        trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy',
            'lr']),
        trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    return trainer 
Example #15
Source File: train_ch_cifar.py    From imgclsmob with MIT License 4 votes vote down vote up
def prepare_trainer(net,
                    optimizer_name,
                    lr,
                    momentum,
                    num_epochs,
                    train_iter,
                    val_iter,
                    logging_dir_path,
                    num_gpus=0):
    if optimizer_name == "sgd":
        optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum)
    elif optimizer_name == "nag":
        optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum)
    else:
        raise Exception('Unsupported optimizer: {}'.format(optimizer_name))
    optimizer.setup(net)

    # devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, )
    devices = (0,) if num_gpus > 0 else (-1,)

    updater = training.updaters.StandardUpdater(
        iterator=train_iter,
        optimizer=optimizer,
        device=devices[0])
    trainer = training.Trainer(
        updater=updater,
        stop_trigger=(num_epochs, 'epoch'),
        out=logging_dir_path)

    val_interval = 100000, 'iteration'
    log_interval = 1000, 'iteration'

    trainer.extend(
        extension=extensions.Evaluator(
            val_iter,
            net,
            device=devices[0]),
        trigger=val_interval)
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(), trigger=val_interval)
    trainer.extend(
        extensions.snapshot_object(
            net,
            'model_iter_{.updater.iteration}'),
        trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy',
            'lr']),
        trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    return trainer 
Example #16
Source File: training.py    From chainer with MIT License 4 votes vote down vote up
def get_trainer_with_mock_updater(
        stop_trigger=(10, 'iteration'), iter_per_epoch=10, extensions=None):
    """Returns a :class:`~chainer.training.Trainer` object with mock updater.

    The returned trainer can be used for testing the trainer itself and the
    extensions. A mock object is used as its updater. The update function set
    to the mock correctly increments the iteration counts (
    ``updater.iteration``), and thus you can write a test relying on it.

    Args:
        stop_trigger: Stop trigger of the trainer.
        iter_per_epoch: The number of iterations per epoch.
        extensions: Extensions registered to the trainer.

    Returns:
        Trainer object with a mock updater.

    """
    if extensions is None:
        extensions = []
    check_available()
    updater = mock.Mock()
    updater.get_all_optimizers.return_value = {}
    updater.iteration = 0
    updater.epoch = 0
    updater.epoch_detail = 0
    updater.is_new_epoch = True
    updater.previous_epoch_detail = None

    def update():
        updater.update_core()
        updater.iteration += 1
        updater.epoch = updater.iteration // iter_per_epoch
        updater.epoch_detail = updater.iteration / iter_per_epoch
        updater.is_new_epoch = (updater.iteration - 1) // \
            iter_per_epoch != updater.epoch
        updater.previous_epoch_detail = (updater.iteration - 1) \
            / iter_per_epoch

    updater.update = update
    trainer = training.Trainer(updater, stop_trigger, extensions=extensions)
    return trainer 
Example #17
Source File: train.py    From models with MIT License 4 votes vote down vote up
def train_one_epoch(model, train_data, lr, gpu, batchsize, out):
    train_model = PixelwiseSoftmaxClassifier(model)
    if gpu >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(gpu).use()
        train_model.to_gpu()  # Copy the model to the GPU
    log_trigger = (0.1, 'epoch')
    validation_trigger = (1, 'epoch')
    end_trigger = (1, 'epoch')

    train_data = TransformDataset(
        train_data, ('img', 'label_map'), SimpleDoesItTransform(model.mean))
    val = VOCSemanticSegmentationWithBboxDataset(
        split='val').slice[:, ['img', 'label_map']]

    # Iterator
    train_iter = iterators.MultiprocessIterator(train_data, batchsize)
    val_iter = iterators.MultiprocessIterator(
        val, 1, shuffle=False, repeat=False, shared_mem=100000000)

    # Optimizer
    optimizer = optimizers.MomentumSGD(lr=lr, momentum=0.9)
    optimizer.setup(train_model)
    optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(rate=0.0001))

    # Updater
    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, device=gpu)

    # Trainer
    trainer = training.Trainer(updater, end_trigger, out=out)

    trainer.extend(extensions.LogReport(trigger=log_trigger))
    trainer.extend(extensions.observe_lr(), trigger=log_trigger)
    trainer.extend(extensions.dump_graph('main/loss'))

    if extensions.PlotReport.available():
        trainer.extend(extensions.PlotReport(
            ['main/loss'], x_key='iteration',
            file_name='loss.png'))
        trainer.extend(extensions.PlotReport(
            ['validation/main/miou'], x_key='iteration',
            file_name='miou.png'))

    trainer.extend(extensions.snapshot_object(
        model, filename='snapshot.npy'),
        trigger=end_trigger)
    trainer.extend(extensions.PrintReport(
        ['epoch', 'iteration', 'elapsed_time', 'lr',
         'main/loss', 'validation/main/miou',
         'validation/main/mean_class_accuracy',
         'validation/main/pixel_accuracy']),
        trigger=log_trigger)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.extend(
        SemanticSegmentationEvaluator(
            val_iter, model,
            voc_semantic_segmentation_label_names),
        trigger=validation_trigger)
    trainer.run() 
Example #18
Source File: train_ch.py    From imgclsmob with MIT License 4 votes vote down vote up
def prepare_trainer(net,
                    optimizer_name,
                    lr,
                    momentum,
                    num_epochs,
                    train_data,
                    val_data,
                    logging_dir_path,
                    use_gpus):
    if optimizer_name == "sgd":
        optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum)
    elif optimizer_name == "nag":
        optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum)
    else:
        raise Exception("Unsupported optimizer: {}".format(optimizer_name))
    optimizer.setup(net)

    # devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, )
    devices = (0,) if use_gpus else (-1,)

    updater = training.updaters.StandardUpdater(
        iterator=train_data["iterator"],
        optimizer=optimizer,
        device=devices[0])
    trainer = training.Trainer(
        updater=updater,
        stop_trigger=(num_epochs, "epoch"),
        out=logging_dir_path)

    val_interval = 100000, "iteration"
    log_interval = 1000, "iteration"

    trainer.extend(
        extension=extensions.Evaluator(
            iterator=val_data["iterator"],
            target=net,
            device=devices[0]),
        trigger=val_interval)
    trainer.extend(extensions.dump_graph("main/loss"))
    trainer.extend(extensions.snapshot(), trigger=val_interval)
    trainer.extend(
        extensions.snapshot_object(
            net,
            "model_iter_{.updater.iteration}"),
        trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    trainer.extend(
        extensions.PrintReport([
            "epoch", "iteration", "main/loss", "validation/main/loss", "main/accuracy", "validation/main/accuracy",
            "lr"]),
        trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    return trainer 
Example #19
Source File: train.py    From portrait_matting with GNU General Public License v3.0 4 votes vote down vote up
def main(argv):
    # Argument
    args = parse_arguments(argv)

    # Load config
    config.load(args.config)

    # Setup dataset
    train, test = setup_dataset(args.mode, config.img_crop_dir,
                                config.img_mask_dir, config.img_mean_mask_dir,
                                config.img_mean_grid_dir,
                                config.img_trimap_dir, config.img_alpha_dir,
                                config.img_alpha_weight_dir)

    # Setup iterators
    train_iter, test_iter = setup_iterators(args.gpus, args.batchsize, train,
                                            test)

    # Setup model
    model = setup_model(args.mode, args.pretrained_fcn8s,
                        args.pretrained_n_input_ch,
                        args.pretrained_n_output_ch,
                        args.mat_scale)

    # Setup an optimizer
    optimizer = setup_optimizer(model, args.lr, args.momentum,
                                args.weight_decay)

    # Set up a trainer
    updater = setup_updater(args.mode, args.gpus, train_iter, optimizer)
    trainer = training.Trainer(updater, (args.max_iteration, 'iteration'),
                               out=args.out)

    # Register extensions for portrait segmentation / matting
    register_extensions(trainer, model, test_iter, args)

    # Resume from a snapshot
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run() 
Example #20
Source File: train.py    From ConvLSTM with MIT License 4 votes vote down vote up
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', '-g', type=int, default=-1)
    parser.add_argument('--model', '-m', type=str, default=None)
    parser.add_argument('--opt', type=str, default=None)
    parser.add_argument('--epoch', '-e', type=int, default=3)
    parser.add_argument('--lr', '-l', type=float, default=0.001)
    parser.add_argument('--inf', type=int, default=10)
    parser.add_argument('--outf', type=int, default=10)
    parser.add_argument('--batch', '-b', type=int, default=8)
    args = parser.parse_args()

    train = dataset.MovingMnistDataset(0, 7000, args.inf, args.outf)
    train_iter = iterators.SerialIterator(train, batch_size=args.batch, shuffle=True)
    test = dataset.MovingMnistDataset(7000, 10000, args.inf, args.outf)
    test_iter = iterators.SerialIterator(test, batch_size=args.batch, repeat=False, shuffle=False)

    model = network.MovingMnistNetwork(sz=[128,64,64], n=2)

    if args.model != None:
        print( "loading model from " + args.model )
        serializers.load_npz(args.model, model)

    if args.gpu >= 0:
        cuda.get_device_from_id(0).use()
        model.to_gpu()

    opt = optimizers.Adam(alpha=args.lr)
    opt.setup(model)

    if args.opt != None:
        print( "loading opt from " + args.opt )
        serializers.load_npz(args.opt, opt)

    updater = training.StandardUpdater(train_iter, opt, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out='results')

    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    trainer.extend(extensions.LogReport(trigger=(10, 'iteration')))
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss']))
    trainer.extend(extensions.ProgressBar(update_interval=1))

    trainer.run()

    modelname = "./results/model"
    print( "saving model to " + modelname )
    serializers.save_npz(modelname, model)

    optname = "./results/opt"
    print( "saving opt to " + optname )
    serializers.save_npz(optname, opt) 
Example #21
Source File: mnist.py    From cloudml-samples with Apache License 2.0 4 votes vote down vote up
def main():
  # Training settings
  args = get_args()

  # Set up a neural network to train
  model = L.Classifier(Net())

  if args.gpu >= 0:
    # Make a specified GPU current
    chainer.backends.cuda.get_device_from_id(args.gpu).use()
    model.to_gpu() # Copy the model to the GPU

  # Setup an optimizer
  optimizer = chainer.optimizers.MomentumSGD(lr=args.lr, momentum=args.momentum)
  optimizer.setup(model)

  # Load the MNIST dataset
  train, test = chainer.datasets.get_mnist(ndim=3)
  train_iter = chainer.iterators.SerialIterator(train, args.batch_size)
  test_iter = chainer.iterators.SerialIterator(test, args.test_batch_size,
                                               repeat=False, shuffle=False)

  # Set up a trainer
  updater = training.updaters.StandardUpdater(
      train_iter, optimizer, device=args.gpu)
  trainer = training.Trainer(updater, (args.epochs, 'epoch'))

  # Evaluate the model with the test dataset for each epoch
  trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

  # Write a log of evaluation statistics for each epoch
  trainer.extend(extensions.LogReport())

  # Print selected entries of the log to stdout
  trainer.extend(extensions.PrintReport(
      ['epoch', 'main/loss', 'validation/main/loss',
       'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

  # Send selected entries of the log to CMLE HP tuning system
  trainer.extend(
    HpReport(hp_metric_val='validation/main/loss', hp_metric_tag='my_loss'))

  if args.resume:
    # Resume from a snapshot
    tmp_model_file = os.path.join('/tmp', MODEL_FILE_NAME)
    if not os.path.exists(tmp_model_file):
      subprocess.check_call([
        'gsutil', 'cp', os.path.join(args.model_dir, MODEL_FILE_NAME),
        tmp_model_file])
    if os.path.exists(tmp_model_file):
      chainer.serializers.load_npz(tmp_model_file, trainer)
  
  trainer.run()

  if args.model_dir:
    tmp_model_file = os.path.join('/tmp', MODEL_FILE_NAME)
    serializers.save_npz(tmp_model_file, model)
    subprocess.check_call([
        'gsutil', 'cp', tmp_model_file,
        os.path.join(args.model_dir, MODEL_FILE_NAME)])