Python blocks.main_loop.MainLoop() Examples

The following are 18 code examples of blocks.main_loop.MainLoop(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module blocks.main_loop , or try the search function .
Example #1
Source File: main.py    From attention-lvcsr with MIT License 6 votes vote down vote up
def train(config, save_path, bokeh_name,
          params, bokeh_server, bokeh, test_tag, use_load_ext,
          load_log, fast_start):

    model, algorithm, data, extensions = initialize_all(
        config, save_path, bokeh_name,
        params, bokeh_server, bokeh, test_tag, use_load_ext,
        load_log, fast_start)

    # Save the config into the status
    log = NDarrayLog()
    log.status['_config'] = repr(config)
    main_loop = MainLoop(
        model=model, log=log, algorithm=algorithm,
        data_stream=data.get_stream("train"),
        extensions=extensions)
    main_loop.run() 
Example #2
Source File: test_plot.py    From blocks-extras with MIT License 6 votes vote down vote up
def test_load_log():
    log = TrainingLog()
    log[0]['channel0'] = 0

    # test simple TrainingLog pickles
    with tempfile.NamedTemporaryFile() as f:
        dump(log, f)
        f.flush()

        log2 = plot.load_log(f.name)
        assert log2[0]['channel0'] == 0

    # test MainLoop pickles
    main_loop = MainLoop(model=None, data_stream=None,
                         algorithm=None, log=log)

    with tempfile.NamedTemporaryFile() as f:
        dump(main_loop, f)
        f.flush()

        log2 = plot.load_log(f.name)
        assert log2[0]['channel0'] == 0 
Example #3
Source File: plot.py    From blocks-extras with MIT License 6 votes vote down vote up
def load_log(fname):
    """Load a :class:`TrainingLog` object from disk.

    This function automatically handles various file formats that contain
    an instance of an :class:`TrainingLog`. This includes a pickled
    Log object, a pickled :class:`MainLoop` or an experiment dump (TODO).

    """
    with change_recursion_limit(config.recursion_limit):
        with open(fname, 'rb') as f:
            from_disk = load(f)
        # TODO: Load "dumped" experiments

    if isinstance(from_disk, TrainingLog):
        log = from_disk
    elif isinstance(from_disk, MainLoop):
        log = from_disk.log
        del from_disk
    else:
        raise ValueError("Could not load '{}': Unrecognized content.")

    return log 
Example #4
Source File: test_plot.py    From attention-lvcsr with MIT License 6 votes vote down vote up
def test_load_log():
    log = TrainingLog()
    log[0]['channel0'] = 0

    # test simple TrainingLog pickles
    with tempfile.NamedTemporaryFile() as f:
        dump(log, f)
        f.flush()

        log2 = plot.load_log(f.name)
        assert log2[0]['channel0'] == 0

    # test MainLoop pickles
    main_loop = MainLoop(model=None, data_stream=None,
                         algorithm=None, log=log)

    with tempfile.NamedTemporaryFile() as f:
        dump(main_loop, f)
        f.flush()

        log2 = plot.load_log(f.name)
        assert log2[0]['channel0'] == 0 
Example #5
Source File: plot.py    From attention-lvcsr with MIT License 6 votes vote down vote up
def load_log(fname):
    """Load a :class:`TrainingLog` object from disk.

    This function automatically handles various file formats that contain
    an instance of an :class:`TrainingLog`. This includes a pickled
    Log object, a pickled :class:`MainLoop` or an experiment dump (TODO).

    """
    with change_recursion_limit(config.recursion_limit):
        with open(fname, 'rb') as f:
            from_disk = load(f)
        # TODO: Load "dumped" experiments

    if isinstance(from_disk, TrainingLog):
        log = from_disk
    elif isinstance(from_disk, MainLoop):
        log = from_disk.log
        del from_disk
    else:
        raise ValueError("Could not load '{}': Unrecognized content.")

    return log 
Example #6
Source File: test_main_loop.py    From attention-lvcsr with MIT License 6 votes vote down vote up
def test_main_loop():
    old_config_profile_value = config.profile
    config.profile = True

    main_loop = MainLoop(
        MockAlgorithm(), IterableDataset(range(10)).get_example_stream(),
        extensions=[WriteBatchExtension(), FinishAfter(after_n_epochs=2)])
    main_loop.run()
    assert_raises(AttributeError, getattr, main_loop, 'model')

    assert main_loop.log.status['iterations_done'] == 20
    assert main_loop.log.status['_epoch_ends'] == [10, 20]
    assert len(main_loop.log) == 20
    for i in range(20):
        assert main_loop.log[i + 1]['batch'] == {'data': i % 10}

    config.profile = old_config_profile_value 
Example #7
Source File: test_training.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_shared_variable_modifier_two_parameters():
    weights = numpy.array([-1, 1], dtype=theano.config.floatX)
    features = [numpy.array(f, dtype=theano.config.floatX)
                for f in [[1, 2], [3, 4], [5, 6]]]
    targets = [(weights * f).sum() for f in features]
    n_batches = 3
    dataset = IterableDataset(dict(features=features, targets=targets))

    x = tensor.vector('features')
    y = tensor.scalar('targets')
    W = shared_floatx([0, 0], name='W')
    cost = ((x * W).sum() - y) ** 2
    cost.name = 'cost'

    step_rule = Scale(0.001)
    sgd = GradientDescent(cost=cost, parameters=[W],
                          step_rule=step_rule)
    modifier = SharedVariableModifier(
        step_rule.learning_rate,
        lambda _, val: numpy.cast[theano.config.floatX](val * 0.2))
    main_loop = MainLoop(
        model=None, data_stream=dataset.get_example_stream(),
        algorithm=sgd,
        extensions=[FinishAfter(after_n_epochs=1), modifier])

    main_loop.run()

    new_value = step_rule.learning_rate.get_value()
    assert_allclose(new_value,
                    0.001 * 0.2 ** n_batches,
                    atol=1e-5) 
Example #8
Source File: test_training.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_shared_variable_modifier():
    weights = numpy.array([-1, 1], dtype=theano.config.floatX)
    features = [numpy.array(f, dtype=theano.config.floatX)
                for f in [[1, 2], [3, 4], [5, 6]]]
    targets = [(weights * f).sum() for f in features]
    n_batches = 3
    dataset = IterableDataset(dict(features=features, targets=targets))

    x = tensor.vector('features')
    y = tensor.scalar('targets')
    W = shared_floatx([0, 0], name='W')
    cost = ((x * W).sum() - y) ** 2
    cost.name = 'cost'

    step_rule = Scale(0.001)
    sgd = GradientDescent(cost=cost, parameters=[W],
                          step_rule=step_rule)
    main_loop = MainLoop(
        model=None, data_stream=dataset.get_example_stream(),
        algorithm=sgd,
        extensions=[
            FinishAfter(after_n_epochs=1),
            SharedVariableModifier(
                step_rule.learning_rate,
                lambda n: numpy.cast[theano.config.floatX](10. / n)
            )])

    main_loop.run()

    assert_allclose(step_rule.learning_rate.get_value(),
                    numpy.cast[theano.config.floatX](10. / n_batches)) 
Example #9
Source File: test_main_loop.py    From attention-lvcsr with MIT License 5 votes vote down vote up
def test_training_resumption():
    def do_test(with_serialization):
        data_stream = IterableDataset(range(10)).get_example_stream()
        main_loop = MainLoop(
            MockAlgorithm(), data_stream,
            extensions=[WriteBatchExtension(),
                        FinishAfter(after_n_batches=14)])
        main_loop.run()
        assert main_loop.log.status['iterations_done'] == 14

        if with_serialization:
            main_loop = cPickle.loads(cPickle.dumps(main_loop))

        finish_after = unpack(
            [ext for ext in main_loop.extensions
             if isinstance(ext, FinishAfter)], singleton=True)
        finish_after.add_condition(
            ["after_batch"],
            predicate=lambda log: log.status['iterations_done'] == 27)
        main_loop.run()
        assert main_loop.log.status['iterations_done'] == 27
        assert main_loop.log.status['epochs_done'] == 2
        for i in range(27):
            assert main_loop.log[i + 1]['batch'] == {"data": i % 10}

    do_test(False)
    do_test(True) 
Example #10
Source File: utils.py    From blocks-char-rnn with MIT License 5 votes vote down vote up
def __init__(self, **kwargs):
        super(MainLoop, self).__init__(**kwargs) 
Example #11
Source File: __init__.py    From blocks-examples with MIT License 5 votes vote down vote up
def main(save_to, num_batches):
    mlp = MLP([Tanh(), Identity()], [1, 10, 1],
              weights_init=IsotropicGaussian(0.01),
              biases_init=Constant(0), seed=1)
    mlp.initialize()
    x = tensor.vector('numbers')
    y = tensor.vector('roots')
    cost = SquaredError().apply(y[:, None], mlp.apply(x[:, None]))
    cost.name = "cost"

    main_loop = MainLoop(
        GradientDescent(
            cost=cost, parameters=ComputationGraph(cost).parameters,
            step_rule=Scale(learning_rate=0.001)),
        get_data_stream(range(100)),
        model=Model(cost),
        extensions=[
            Timing(),
            FinishAfter(after_n_batches=num_batches),
            DataStreamMonitoring(
                [cost], get_data_stream(range(100, 200)),
                prefix="test"),
            TrainingDataMonitoring([cost], after_epoch=True),
            Checkpoint(save_to),
            Printing()])
    main_loop.run()
    return main_loop 
Example #12
Source File: test_monitoring.py    From attention-lvcsr with MIT License 4 votes vote down vote up
def test_training_data_monitoring():
    weights = numpy.array([-1, 1], dtype=theano.config.floatX)
    features = [numpy.array(f, dtype=theano.config.floatX)
                for f in [[1, 2], [3, 4], [5, 6]]]
    targets = [(weights * f).sum() for f in features]
    n_batches = 3
    dataset = IterableDataset(dict(features=features, targets=targets))

    x = tensor.vector('features')
    y = tensor.scalar('targets')
    W = shared_floatx([0, 0], name='W')
    V = shared_floatx(7, name='V')
    W_sum = W.sum().copy(name='W_sum')
    cost = ((x * W).sum() - y) ** 2
    cost.name = 'cost'

    class TrueCostExtension(TrainingExtension):

        def before_batch(self, data):
            self.main_loop.log.current_row['true_cost'] = (
                ((W.get_value() * data["features"]).sum() -
                 data["targets"]) ** 2)

    main_loop = MainLoop(
        model=None, data_stream=dataset.get_example_stream(),
        algorithm=GradientDescent(cost=cost, parameters=[W],
                                  step_rule=Scale(0.001)),
        extensions=[
            FinishAfter(after_n_epochs=1),
            TrainingDataMonitoring([W_sum, cost, V], prefix="train1",
                                   after_batch=True),
            TrainingDataMonitoring([aggregation.mean(W_sum), cost],
                                   prefix="train2", after_epoch=True),
            TrueCostExtension()])

    main_loop.run()

    # Check monitoring of a shared varible
    assert_allclose(main_loop.log.current_row['train1_V'], 7.0)

    for i in range(n_batches):
        # The ground truth is written to the log before the batch is
        # processed, where as the extension writes after the batch is
        # processed. This is why the iteration numbers differs here.
        assert_allclose(main_loop.log[i]['true_cost'],
                        main_loop.log[i + 1]['train1_cost'])
    assert_allclose(
        main_loop.log[n_batches]['train2_cost'],
        sum([main_loop.log[i]['true_cost']
             for i in range(n_batches)]) / n_batches)
    assert_allclose(
        main_loop.log[n_batches]['train2_W_sum'],
        sum([main_loop.log[i]['train1_W_sum']
             for i in range(1, n_batches + 1)]) / n_batches) 
Example #13
Source File: test_saveload.py    From attention-lvcsr with MIT License 4 votes vote down vote up
def test_checkpointing():
    # Create a main loop and checkpoint it
    mlp = MLP(activations=[None], dims=[10, 10], weights_init=Constant(1.),
              use_bias=False)
    mlp.initialize()
    W = mlp.linear_transformations[0].W
    x = tensor.vector('data')
    cost = mlp.apply(x).mean()
    data = numpy.random.rand(10, 10).astype(theano.config.floatX)
    data_stream = IterableDataset(data).get_example_stream()

    main_loop = MainLoop(
        data_stream=data_stream,
        algorithm=GradientDescent(cost=cost, parameters=[W]),
        extensions=[FinishAfter(after_n_batches=5),
                    Checkpoint('myweirdmodel.tar', parameters=[W])]
    )
    main_loop.run()

    # Load it again
    old_value = W.get_value()
    W.set_value(old_value * 2)
    main_loop = MainLoop(
        model=Model(cost),
        data_stream=data_stream,
        algorithm=GradientDescent(cost=cost, parameters=[W]),
        extensions=[Load('myweirdmodel.tar')]
    )
    main_loop.extensions[0].main_loop = main_loop
    main_loop._run_extensions('before_training')
    assert_allclose(W.get_value(), old_value)

    # Make sure things work too if the model was never saved before
    main_loop = MainLoop(
        model=Model(cost),
        data_stream=data_stream,
        algorithm=GradientDescent(cost=cost, parameters=[W]),
        extensions=[Load('mynonexisting.tar')]
    )
    main_loop.extensions[0].main_loop = main_loop
    main_loop._run_extensions('before_training')

    # Cleaning
    if os.path.exists('myweirdmodel.tar'):
        os.remove('myweirdmodel.tar') 
Example #14
Source File: __init__.py    From blocks-examples with MIT License 4 votes vote down vote up
def main(save_to, num_epochs):
    mlp = MLP([Tanh(), Softmax()], [784, 100, 10],
              weights_init=IsotropicGaussian(0.01),
              biases_init=Constant(0))
    mlp.initialize()
    x = tensor.matrix('features')
    y = tensor.lmatrix('targets')
    probs = mlp.apply(x)
    cost = CategoricalCrossEntropy().apply(y.flatten(), probs)
    error_rate = MisclassificationRate().apply(y.flatten(), probs)

    cg = ComputationGraph([cost])
    W1, W2 = VariableFilter(roles=[WEIGHT])(cg.variables)
    cost = cost + .00005 * (W1 ** 2).sum() + .00005 * (W2 ** 2).sum()
    cost.name = 'final_cost'

    mnist_train = MNIST(("train",))
    mnist_test = MNIST(("test",))

    algorithm = GradientDescent(
        cost=cost, parameters=cg.parameters,
        step_rule=Scale(learning_rate=0.1))
    extensions = [Timing(),
                  FinishAfter(after_n_epochs=num_epochs),
                  DataStreamMonitoring(
                      [cost, error_rate],
                      Flatten(
                          DataStream.default_stream(
                              mnist_test,
                              iteration_scheme=SequentialScheme(
                                  mnist_test.num_examples, 500)),
                          which_sources=('features',)),
                      prefix="test"),
                  TrainingDataMonitoring(
                      [cost, error_rate,
                       aggregation.mean(algorithm.total_gradient_norm)],
                      prefix="train",
                      after_epoch=True),
                  Checkpoint(save_to),
                  Printing()]

    if BLOCKS_EXTRAS_AVAILABLE:
        extensions.append(Plot(
            'MNIST example',
            channels=[
                ['test_final_cost',
                 'test_misclassificationrate_apply_error_rate'],
                ['train_total_gradient_norm']]))

    main_loop = MainLoop(
        algorithm,
        Flatten(
            DataStream.default_stream(
                mnist_train,
                iteration_scheme=SequentialScheme(
                    mnist_train.num_examples, 50)),
            which_sources=('features',)),
        model=Model(cost),
        extensions=extensions)

    main_loop.run() 
Example #15
Source File: train_celeba_classifier.py    From discgen with MIT License 4 votes vote down vote up
def run():
    streams = create_celeba_streams(training_batch_size=100,
                                    monitoring_batch_size=500,
                                    include_targets=True)
    main_loop_stream = streams[0]
    train_monitor_stream = streams[1]
    valid_monitor_stream = streams[2]

    cg, bn_dropout_cg = create_training_computation_graphs()

    # Compute parameter updates for the batch normalization population
    # statistics. They are updated following an exponential moving average.
    pop_updates = get_batch_normalization_updates(bn_dropout_cg)
    decay_rate = 0.05
    extra_updates = [(p, m * decay_rate + p * (1 - decay_rate))
                     for p, m in pop_updates]

    # Prepare algorithm
    step_rule = Adam()
    algorithm = GradientDescent(cost=bn_dropout_cg.outputs[0],
                                parameters=bn_dropout_cg.parameters,
                                step_rule=step_rule)
    algorithm.add_updates(extra_updates)

    # Prepare monitoring
    cost = bn_dropout_cg.outputs[0]
    cost.name = 'cost'
    train_monitoring = DataStreamMonitoring(
        [cost], train_monitor_stream, prefix="train",
        before_first_epoch=False, after_epoch=False, after_training=True,
        updates=extra_updates)

    cost, accuracy = cg.outputs
    cost.name = 'cost'
    accuracy.name = 'accuracy'
    monitored_quantities = [cost, accuracy]
    valid_monitoring = DataStreamMonitoring(
        monitored_quantities, valid_monitor_stream, prefix="valid",
        before_first_epoch=False, after_epoch=False, every_n_epochs=5)

    # Prepare checkpoint
    checkpoint = Checkpoint(
        'celeba_classifier.zip', every_n_epochs=5, use_cpickle=True)

    extensions = [Timing(), FinishAfter(after_n_epochs=50), train_monitoring,
                  valid_monitoring, checkpoint, Printing(), ProgressBar()]
    main_loop = MainLoop(data_stream=main_loop_stream, algorithm=algorithm,
                         extensions=extensions)
    main_loop.run() 
Example #16
Source File: train_celeba_vae.py    From discgen with MIT License 4 votes vote down vote up
def run(discriminative_regularization=True):
    streams = create_celeba_streams(training_batch_size=100,
                                    monitoring_batch_size=500,
                                    include_targets=False)
    main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3]

    # Compute parameter updates for the batch normalization population
    # statistics. They are updated following an exponential moving average.
    rval = create_training_computation_graphs(discriminative_regularization)
    cg, bn_cg, variance_parameters = rval
    pop_updates = list(
        set(get_batch_normalization_updates(bn_cg, allow_duplicates=True)))
    decay_rate = 0.05
    extra_updates = [(p, m * decay_rate + p * (1 - decay_rate))
                     for p, m in pop_updates]

    model = Model(bn_cg.outputs[0])
    selector = Selector(
        find_bricks(
            model.top_bricks,
            lambda brick: brick.name in ('encoder_convnet', 'encoder_mlp',
                                         'decoder_convnet', 'decoder_mlp')))
    parameters = list(selector.get_parameters().values()) + variance_parameters

    # Prepare algorithm
    step_rule = Adam()
    algorithm = GradientDescent(cost=bn_cg.outputs[0],
                                parameters=parameters,
                                step_rule=step_rule)
    algorithm.add_updates(extra_updates)

    # Prepare monitoring
    monitored_quantities_list = []
    for graph in [bn_cg, cg]:
        cost, kl_term, reconstruction_term = graph.outputs
        cost.name = 'nll_upper_bound'
        avg_kl_term = kl_term.mean(axis=0)
        avg_kl_term.name = 'avg_kl_term'
        avg_reconstruction_term = -reconstruction_term.mean(axis=0)
        avg_reconstruction_term.name = 'avg_reconstruction_term'
        monitored_quantities_list.append(
            [cost, avg_kl_term, avg_reconstruction_term])
    train_monitoring = DataStreamMonitoring(
        monitored_quantities_list[0], train_monitor_stream, prefix="train",
        updates=extra_updates, after_epoch=False, before_first_epoch=False,
        every_n_epochs=5)
    valid_monitoring = DataStreamMonitoring(
        monitored_quantities_list[1], valid_monitor_stream, prefix="valid",
        after_epoch=False, before_first_epoch=False, every_n_epochs=5)

    # Prepare checkpoint
    save_path = 'celeba_vae_{}regularization.zip'.format(
        '' if discriminative_regularization else 'no_')
    checkpoint = Checkpoint(save_path, every_n_epochs=5, use_cpickle=True)

    extensions = [Timing(), FinishAfter(after_n_epochs=75), train_monitoring,
                  valid_monitoring, checkpoint, Printing(), ProgressBar()]
    main_loop = MainLoop(data_stream=main_loop_stream,
                         algorithm=algorithm, extensions=extensions)
    main_loop.run() 
Example #17
Source File: pacgan_task.py    From PacGAN with MIT License 4 votes vote down vote up
def create_main_loop(self):
        model, bn_model, bn_updates = self.create_models()
        gan, = bn_model.top_bricks
        discriminator_loss, generator_loss = bn_model.outputs
        step_rule = Adam(learning_rate=self._config["learning_rate"], beta1=self._config["beta1"])
        algorithm = ali_algorithm(discriminator_loss, gan.discriminator_parameters, step_rule, generator_loss, gan.generator_parameters, step_rule)
        algorithm.add_updates(bn_updates)
        streams = create_packing_gaussian_mixture_data_streams(
            num_packings=self._config["num_packing"], 
            batch_size=self._config["batch_size"], 
            monitoring_batch_size=self._config["monitoring_batch_size"], 
            means=self._config["x_mode_means"], 
            variances=self._config["x_mode_variances"], 
            priors=self._config["x_mode_priors"],
            num_examples=self._config["num_sample"])
        main_loop_stream, train_monitor_stream, valid_monitor_stream = streams
        bn_monitored_variables = (
            [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] +
            bn_model.outputs)
        monitored_variables = (
            [v for v in model.auxiliary_variables if 'norm' not in v.name] +
            model.outputs)
        extensions = [
            Timing(),
            FinishAfter(after_n_epochs=self._config["num_epoch"]),
            DataStreamMonitoring(
                bn_monitored_variables, train_monitor_stream, prefix="train",
                updates=bn_updates),
            DataStreamMonitoring(
                monitored_variables, valid_monitor_stream, prefix="valid"),
            Checkpoint(os.path.join(self._work_dir, self._config["main_loop_file"]), after_epoch=True, after_training=True, use_cpickle=True),
            ProgressBar(),
            Printing(),
        ]
        if self._config["log_models"]:
            extensions.append(ModelLogger(folder=self._work_dir, after_epoch=True))
        if self._config["log_figures"]:
            extensions.append(GraphLogger(num_modes=self._config["num_zmode"], num_samples=self._config["num_log_figure_sample"], dimension=self._config["num_zdim"], r=self._config["z_mode_r"], std=self._config["z_mode_std"], folder=self._work_dir, after_epoch=True, after_training=True))
        if self._config["log_metrics"]:
            extensions.append(MetricLogger(means=self._config["x_mode_means"], variances=self._config["x_mode_variances"], folder=self._work_dir, after_epoch=True))
        main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream, algorithm=algorithm, extensions=extensions)
        return main_loop 
Example #18
Source File: pacgan_task.py    From PacGAN with MIT License 4 votes vote down vote up
def create_main_loop(self):
        model, bn_model, bn_updates = self.create_models()
        gan, = bn_model.top_bricks
        discriminator_loss, generator_loss = bn_model.outputs
        step_rule = Adam(learning_rate=self._config["learning_rate"], beta1=self._config["beta1"])
        algorithm = ali_algorithm(discriminator_loss, gan.discriminator_parameters, step_rule, generator_loss, gan.generator_parameters, step_rule)
        algorithm.add_updates(bn_updates)
        streams = create_packing_gaussian_mixture_data_streams(
            num_packings=self._config["num_packing"], 
            batch_size=self._config["batch_size"], 
            monitoring_batch_size=self._config["monitoring_batch_size"], 
            means=self._config["x_mode_means"], 
            variances=self._config["x_mode_variances"], 
            priors=self._config["x_mode_priors"],
            num_examples=self._config["num_sample"])
        main_loop_stream, train_monitor_stream, valid_monitor_stream = streams
        bn_monitored_variables = (
            [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] +
            bn_model.outputs)
        monitored_variables = (
            [v for v in model.auxiliary_variables if 'norm' not in v.name] +
            model.outputs)
        extensions = [
            Timing(),
            FinishAfter(after_n_epochs=self._config["num_epoch"]),
            DataStreamMonitoring(
                bn_monitored_variables, train_monitor_stream, prefix="train",
                updates=bn_updates),
            DataStreamMonitoring(
                monitored_variables, valid_monitor_stream, prefix="valid"),
            Checkpoint(os.path.join(self._work_dir, self._config["main_loop_file"]), after_epoch=True, after_training=True, use_cpickle=True),
            ProgressBar(),
            Printing(),
        ]
        if self._config["log_models"]:
            extensions.append(ModelLogger(folder=self._work_dir, after_epoch=True))
        if self._config["log_figures"]:
            extensions.append(GraphLogger(num_modes=self._config["num_zmode"], num_samples=self._config["num_log_figure_sample"], dimension=self._config["num_zdim"], r=self._config["z_mode_r"], std=self._config["z_mode_std"], folder=self._work_dir, after_epoch=True, after_training=True))
        if self._config["log_metrics"]:
            extensions.append(MetricLogger(means=self._config["x_mode_means"], variances=self._config["x_mode_variances"], folder=self._work_dir, after_epoch=True))
        main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream, algorithm=algorithm, extensions=extensions)
        return main_loop