Python pytorch_lightning.Trainer() Examples

The following are 30 code examples of pytorch_lightning.Trainer(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module pytorch_lightning , or try the search function .
Example #1
Source File: mnist_pytorch_lightning.py    From ray with Apache License 2.0 7 votes vote down vote up
def train_mnist_tune_checkpoint(config, checkpoint=None):
    trainer = pl.Trainer(
        max_epochs=10,
        progress_bar_refresh_rate=0,
        callbacks=[CheckpointCallback(),
                   TuneReportCallback()])
    if checkpoint:
        # Currently, this leads to errors:
        # model = LightningMNISTClassifier.load_from_checkpoint(
        #     os.path.join(checkpoint, "checkpoint"))
        # Workaround:
        ckpt = pl_load(
            os.path.join(checkpoint, "checkpoint"),
            map_location=lambda storage, loc: storage)
        model = LightningMNISTClassifier._load_model_state(ckpt, config=config)
        trainer.current_epoch = ckpt["epoch"]
    else:
        model = LightningMNISTClassifier(
            config=config, data_dir=config["data_dir"])

    trainer.fit(model)
# __tune_train_checkpoint_end__


# __tune_asha_begin__ 
Example #2
Source File: test_tpu.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_dataloaders_passed_to_fit(tmpdir):
    """Test if dataloaders passed to trainer works on TPU"""

    model = EvalModelTemplate()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        tpu_cores=8,
    )
    result = trainer.fit(
        model,
        train_dataloader=model.train_dataloader(),
        val_dataloaders=model.val_dataloader(),
    )
    assert result, "TPU doesn't work with dataloaders passed to fit()." 
Example #3
Source File: test_base.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_custom_logger(tmpdir):
    hparams = EvalModelTemplate.get_default_hparams()
    model = EvalModelTemplate(**hparams)

    logger = CustomLogger()

    trainer = Trainer(
        max_epochs=1,
        limit_train_batches=0.05,
        logger=logger,
        default_root_dir=tmpdir,
    )
    result = trainer.fit(model)
    assert result == 1, "Training failed"
    assert logger.hparams_logged == hparams
    assert logger.metrics_logged != {}
    assert logger.finalized_status == "success" 
Example #4
Source File: test_model_checkpoint.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_model_checkpoint_path(tmpdir, logger_version, expected):
    """Test that "version_" prefix is only added when logger's version is an integer"""
    tutils.reset_seed()
    model = EvalModelTemplate()
    logger = TensorBoardLogger(str(tmpdir), version=logger_version)

    trainer = Trainer(
        default_root_dir=tmpdir,
        overfit_pct=0.2,
        max_epochs=5,
        logger=logger,
    )
    trainer.fit(model)

    ckpt_version = Path(trainer.ckpt_path).parent.name
    assert ckpt_version == expected 
Example #5
Source File: computer_vision_fine_tuning.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def main(args: argparse.Namespace) -> None:
    """Train the model.

    Args:
        args: Model hyper-parameters

    Note:
        For the sake of the example, the images dataset will be downloaded
        to a temporary directory.
    """

    with TemporaryDirectory(dir=args.root_data_path) as tmp_dir:

        model = TransferLearningModel(dl_path=tmp_dir, **vars(args))

        trainer = pl.Trainer(
            weights_summary=None,
            show_progress_bar=True,
            num_sanity_val_steps=0,
            gpus=args.gpus,
            min_epochs=args.nb_epochs,
            max_epochs=args.nb_epochs)

        trainer.fit(model) 
Example #6
Source File: test_base.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_multiple_loggers(tmpdir):
    hparams = EvalModelTemplate.get_default_hparams()
    model = EvalModelTemplate(**hparams)

    logger1 = CustomLogger()
    logger2 = CustomLogger()

    trainer = Trainer(
        max_epochs=1,
        limit_train_batches=0.05,
        logger=[logger1, logger2],
        default_root_dir=tmpdir,
    )
    result = trainer.fit(model)
    assert result == 1, "Training failed"

    assert logger1.hparams_logged == hparams
    assert logger1.metrics_logged != {}
    assert logger1.finalized_status == "success"

    assert logger2.hparams_logged == hparams
    assert logger2.metrics_logged != {}
    assert logger2.finalized_status == "success" 
Example #7
Source File: test_base.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_multiple_loggers_pickle(tmpdir):
    """Verify that pickling trainer with multiple loggers works."""

    logger1 = CustomLogger()
    logger2 = CustomLogger()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        logger=[logger1, logger2],
    )
    pkl_bytes = pickle.dumps(trainer)
    trainer2 = pickle.loads(pkl_bytes)
    trainer2.logger.log_metrics({"acc": 1.0}, 0)

    assert logger1.metrics_logged != {}
    assert logger2.metrics_logged != {} 
Example #8
Source File: test_model_checkpoint.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
    """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """
    tutils.reset_seed()
    model = EvalModelTemplate()

    checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)

    trainer = Trainer(
        default_root_dir=tmpdir,
        checkpoint_callback=checkpoint,
        overfit_pct=0.20,
        max_epochs=(save_top_k + 2),
    )
    trainer.fit(model)

    # These should be different if the dirpath has be overridden
    assert trainer.ckpt_path != trainer.default_root_dir 
Example #9
Source File: test_lr_finder.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_model_reset_correctly(tmpdir):
    """ Check that model weights are correctly reset after lr_find() """

    model = EvalModelTemplate()

    # logger file to get meta
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
    )

    before_state_dict = model.state_dict()

    _ = trainer.lr_find(model, num_training=5)

    after_state_dict = model.state_dict()

    for key in before_state_dict.keys():
        assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), \
            'Model was not reset correctly after learning rate finder' 
Example #10
Source File: test_lr_finder.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_trainer_arg_str(tmpdir):
    """ Test that setting trainer arg to string works """
    model = EvalModelTemplate()
    model.my_fancy_lr = 1.0  # update with non-standard field

    before_lr = model.my_fancy_lr
    # logger file to get meta
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=2,
        auto_lr_find='my_fancy_lr',
    )

    trainer.fit(model)
    after_lr = model.my_fancy_lr
    assert before_lr != after_lr, \
        'Learning rate was not altered after running learning rate finder' 
Example #11
Source File: test_lr_finder.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_call_to_trainer_method(tmpdir):
    """ Test that directly calling the trainer method works """

    hparams = EvalModelTemplate.get_default_hparams()
    model = EvalModelTemplate(**hparams)

    before_lr = hparams.get('learning_rate')
    # logger file to get meta
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=2,
    )

    lrfinder = trainer.lr_find(model, mode='linear')
    after_lr = lrfinder.suggestion()
    model.learning_rate = after_lr
    trainer.fit(model)

    assert before_lr != after_lr, \
        'Learning rate was not altered after running learning rate finder' 
Example #12
Source File: test_neptune.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_neptune_leave_open_experiment_after_fit(tmpdir):
    """Verify that neptune experiment was closed after training"""
    model = EvalModelTemplate()

    def _run_training(logger):
        logger._experiment = MagicMock()
        trainer = Trainer(
            default_root_dir=tmpdir,
            max_epochs=1,
            limit_train_batches=0.05,
            logger=logger,
        )
        trainer.fit(model)
        return logger

    logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True))
    assert logger_close_after_fit._experiment.stop.call_count == 1

    logger_open_after_fit = _run_training(NeptuneLogger(offline_mode=True, close_after_fit=False))
    assert logger_open_after_fit._experiment.stop.call_count == 0 
Example #13
Source File: test_lr_logger.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_lr_logger_multi_lrs(tmpdir):
    """ Test that learning rates are extracted and logged for multi lr schedulers. """
    tutils.reset_seed()

    model = EvalModelTemplate()
    model.configure_optimizers = model.configure_optimizers__multiple_schedulers

    lr_logger = LearningRateLogger()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=2,
        limit_val_batches=0.1,
        limit_train_batches=0.5,
        callbacks=[lr_logger],
    )
    result = trainer.fit(model)
    assert result

    assert lr_logger.lrs, 'No learning rates logged'
    assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
        'Number of learning rates logged does not match number of lr schedulers'
    assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
        'Names of learning rates not set correctly'
    assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \
        'Length of logged learning rates exceeds the number of epochs' 
Example #14
Source File: test_hparams.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_load_past_checkpoint(tmpdir, past_key):
    model = EvalModelTemplate()

    # verify we can train
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
    trainer.fit(model)

    # make sure the raw checkpoint saved the properties
    raw_checkpoint_path = _raw_checkpoint_path(trainer)
    raw_checkpoint = torch.load(raw_checkpoint_path)
    raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
    raw_checkpoint['hparams_type'] = 'Namespace'
    raw_checkpoint[past_key]['batch_size'] = -17
    del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
    # save back the checkpoint
    torch.save(raw_checkpoint, raw_checkpoint_path)

    # verify that model loads correctly
    model2 = EvalModelTemplate.load_from_checkpoint(raw_checkpoint_path)
    assert model2.hparams.batch_size == -17 
Example #15
Source File: test_lr_logger.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_lr_logger_single_lr(tmpdir):
    """ Test that learning rates are extracted and logged for single lr scheduler. """
    tutils.reset_seed()

    model = EvalModelTemplate()
    model.configure_optimizers = model.configure_optimizers__single_scheduler

    lr_logger = LearningRateLogger()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=2,
        limit_val_batches=0.1,
        limit_train_batches=0.5,
        callbacks=[lr_logger],
    )
    result = trainer.fit(model)
    assert result

    assert lr_logger.lrs, 'No learning rates logged'
    assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
        'Number of learning rates logged does not match number of lr schedulers'
    assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
        'Names of learning rates not set correctly' 
Example #16
Source File: test_gpu.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_multi_gpu_model(tmpdir, backend):
    """Make sure DDP works."""
    tutils.set_random_master_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=0.4,
        limit_val_batches=0.2,
        gpus=[0, 1],
        distributed_backend=backend,
    )

    model = EvalModelTemplate()
    # tutils.run_model_test(trainer_options, model)
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    assert result

    # test memory helper functions
    memory.get_memory_profile('min_max') 
Example #17
Source File: test_gpu.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_multi_gpu_early_stop(tmpdir, backend):
    """Make sure DDP works. with early stopping"""
    tutils.set_random_master_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        early_stop_callback=True,
        max_epochs=50,
        limit_train_batches=10,
        limit_val_batches=10,
        gpus=[0, 1],
        distributed_backend=backend,
    )

    model = EvalModelTemplate()
    # tutils.run_model_test(trainer_options, model)
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    assert result 
Example #18
Source File: test_gpu.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
    """Make sure DDP works with dataloaders passed to fit()"""
    tutils.set_random_master_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=1,
        limit_train_batches=0.1,
        limit_val_batches=0.1,
        gpus=[0, 1],
        distributed_backend='ddp'
    )

    model = EvalModelTemplate()
    fit_options = dict(train_dataloader=model.train_dataloader(),
                       val_dataloaders=model.val_dataloader())

    trainer = Trainer(**trainer_options)
    result = trainer.fit(model, **fit_options)
    assert result == 1, "DDP doesn't work with dataloaders passed to fit()." 
Example #19
Source File: test_early_stopping.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_epoch):
    """Test to ensure that early stopping is not triggered before patience is exhausted."""

    class ModelOverrideValidationReturn(EvalModelTemplate):
        validation_return_values = torch.Tensor(loss_values)
        count = 0

        def validation_epoch_end(self, outputs):
            loss = self.validation_return_values[self.count]
            self.count += 1
            return {"test_val_loss": loss}

    model = ModelOverrideValidationReturn()
    early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
    trainer = Trainer(
        default_root_dir=tmpdir,
        early_stop_callback=early_stop_callback,
        val_check_interval=1.0,
        num_sanity_val_steps=0,
        max_epochs=10,
    )
    trainer.fit(model)
    assert trainer.current_epoch == expected_stop_epoch 
Example #20
Source File: test_early_stopping.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_early_stopping_no_extraneous_invocations(tmpdir):
    """Test to ensure that callback methods aren't being invoked outside of the callback handler."""
    class EarlyStoppingTestInvocations(EarlyStopping):
        def __init__(self, expected_count):
            super().__init__()
            self.count = 0
            self.expected_count = expected_count

        def on_validation_end(self, trainer, pl_module):
            self.count += 1

        def on_train_end(self, trainer, pl_module):
            assert self.count == self.expected_count

    model = EvalModelTemplate()
    expected_count = 4
    early_stop_callback = EarlyStoppingTestInvocations(expected_count)
    trainer = Trainer(
        default_root_dir=tmpdir,
        early_stop_callback=early_stop_callback,
        val_check_interval=1.0,
        max_epochs=expected_count,
    )
    trainer.fit(model) 
Example #21
Source File: test_callbacks.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_model_checkpoint_path(tmpdir, logger_version, expected):
    """Test that "version_" prefix is only added when logger's version is an integer"""
    tutils.reset_seed()
    model = EvalModelTemplate()
    logger = TensorBoardLogger(str(tmpdir), version=logger_version)

    trainer = Trainer(
        default_root_dir=tmpdir,
        overfit_batches=0.2,
        max_epochs=2,
        logger=logger,
    )
    trainer.fit(model)

    ckpt_version = Path(trainer.ckpt_path).parent.name
    assert ckpt_version == expected 
Example #22
Source File: test_callbacks.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
    """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """
    tutils.reset_seed()
    model = EvalModelTemplate()

    checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)

    trainer = Trainer(
        default_root_dir=tmpdir,
        checkpoint_callback=checkpoint,
        overfit_batches=0.20,
        max_epochs=2,
    )
    trainer.fit(model)

    # These should be different if the dirpath has be overridden
    assert trainer.ckpt_path != trainer.default_root_dir 
Example #23
Source File: test_callbacks.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_early_stopping_no_val_step(tmpdir):
    """Test that early stopping callback falls back to training metrics when no validation defined."""

    class CurrentModel(EvalModelTemplate):
        def training_step(self, *args, **kwargs):
            output = super().training_step(*args, **kwargs)
            output.update({'my_train_metric': output['loss']})  # could be anything else
            return output

    model = CurrentModel()
    model.validation_step = None
    model.val_dataloader = None

    stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
    trainer = Trainer(
        default_root_dir=tmpdir,
        early_stop_callback=stopping,
        overfit_batches=0.20,
        max_epochs=2,
    )
    result = trainer.fit(model)

    assert result == 1, 'training failed to complete'
    assert trainer.current_epoch < trainer.max_epochs 
Example #24
Source File: test_callbacks.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_early_stopping_functionality(tmpdir):

    class CurrentModel(EvalModelTemplate):
        def validation_epoch_end(self, outputs):
            losses = [8, 4, 2, 3, 4, 5, 8, 10]
            val_loss = losses[self.current_epoch]
            return {'val_loss': torch.tensor(val_loss)}

    model = CurrentModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        early_stop_callback=True,
        overfit_batches=0.20,
        max_epochs=20,
    )
    result = trainer.fit(model)
    print(trainer.current_epoch)

    assert trainer.current_epoch == 5, 'early_stopping failed' 
Example #25
Source File: test_amp.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_amp_single_gpu(tmpdir, backend):
    """Make sure DP/DDP + AMP work."""
    tutils.reset_seed()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        gpus=1,
        distributed_backend=backend,
        precision=16,
    )

    model = EvalModelTemplate()
    # tutils.run_model_test(trainer_options, model)
    result = trainer.fit(model)

    assert result == 1 
Example #26
Source File: test_amp.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_amp_multi_gpu(tmpdir, backend):
    """Make sure DP/DDP + AMP work."""
    tutils.set_random_master_port()

    model = EvalModelTemplate()

    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        # gpus=2,
        gpus='0, 1',  # test init with gpu string
        distributed_backend=backend,
        precision=16,
    )

    # tutils.run_model_test(trainer_options, model)
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    assert result 
Example #27
Source File: test_amp.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_multi_gpu_wandb(tmpdir, backend):
    """Make sure DP/DDP + AMP work."""
    from pytorch_lightning.loggers import WandbLogger
    tutils.set_random_master_port()

    model = EvalModelTemplate()
    logger = WandbLogger(name='utest')

    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        gpus=2,
        distributed_backend=backend,
        precision=16,
        logger=logger,

    )
    # tutils.run_model_test(trainer_options, model)
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    assert result
    trainer.test(model) 
Example #28
Source File: test_hooks.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_on_before_zero_grad_called(tmpdir, max_steps):

    class CurrentTestModel(EvalModelTemplate):
        on_before_zero_grad_called = 0

        def on_before_zero_grad(self, optimizer):
            self.on_before_zero_grad_called += 1

    model = CurrentTestModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_steps=max_steps,
        max_epochs=2,
        num_sanity_val_steps=5,
    )
    assert 0 == model.on_before_zero_grad_called
    trainer.fit(model)
    assert max_steps == model.on_before_zero_grad_called

    model.on_before_zero_grad_called = 0
    trainer.test(model)
    assert 0 == model.on_before_zero_grad_called 
Example #29
Source File: test_lr_logger.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_lr_logger_no_lr(tmpdir):
    tutils.reset_seed()

    model = EvalModelTemplate()

    lr_logger = LearningRateLogger()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=2,
        limit_val_batches=0.1,
        limit_train_batches=0.5,
        callbacks=[lr_logger],
    )

    with pytest.warns(RuntimeWarning):
        result = trainer.fit(model)
        assert result 
Example #30
Source File: test_progress_bar.py    From pytorch-lightning with Apache License 2.0 5 votes vote down vote up
def test_progress_bar_off(tmpdir, callbacks, refresh_rate):
    """Test different ways the progress bar can be turned off."""

    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=callbacks,
        progress_bar_refresh_rate=refresh_rate,
    )

    progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)]
    assert 0 == len(progress_bars)
    assert not trainer.progress_bar_callback