Python pytorch_lightning.Trainer() Examples
The following are 30
code examples of pytorch_lightning.Trainer().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
pytorch_lightning
, or try the search function
.
Example #1
Source File: mnist_pytorch_lightning.py From ray with Apache License 2.0 | 7 votes |
def train_mnist_tune_checkpoint(config, checkpoint=None): trainer = pl.Trainer( max_epochs=10, progress_bar_refresh_rate=0, callbacks=[CheckpointCallback(), TuneReportCallback()]) if checkpoint: # Currently, this leads to errors: # model = LightningMNISTClassifier.load_from_checkpoint( # os.path.join(checkpoint, "checkpoint")) # Workaround: ckpt = pl_load( os.path.join(checkpoint, "checkpoint"), map_location=lambda storage, loc: storage) model = LightningMNISTClassifier._load_model_state(ckpt, config=config) trainer.current_epoch = ckpt["epoch"] else: model = LightningMNISTClassifier( config=config, data_dir=config["data_dir"]) trainer.fit(model) # __tune_train_checkpoint_end__ # __tune_asha_begin__
Example #2
Source File: test_tpu.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_dataloaders_passed_to_fit(tmpdir): """Test if dataloaders passed to trainer works on TPU""" model = EvalModelTemplate() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, tpu_cores=8, ) result = trainer.fit( model, train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader(), ) assert result, "TPU doesn't work with dataloaders passed to fit()."
Example #3
Source File: test_base.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_custom_logger(tmpdir): hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) logger = CustomLogger() trainer = Trainer( max_epochs=1, limit_train_batches=0.05, logger=logger, default_root_dir=tmpdir, ) result = trainer.fit(model) assert result == 1, "Training failed" assert logger.hparams_logged == hparams assert logger.metrics_logged != {} assert logger.finalized_status == "success"
Example #4
Source File: test_model_checkpoint.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_model_checkpoint_path(tmpdir, logger_version, expected): """Test that "version_" prefix is only added when logger's version is an integer""" tutils.reset_seed() model = EvalModelTemplate() logger = TensorBoardLogger(str(tmpdir), version=logger_version) trainer = Trainer( default_root_dir=tmpdir, overfit_pct=0.2, max_epochs=5, logger=logger, ) trainer.fit(model) ckpt_version = Path(trainer.ckpt_path).parent.name assert ckpt_version == expected
Example #5
Source File: computer_vision_fine_tuning.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def main(args: argparse.Namespace) -> None: """Train the model. Args: args: Model hyper-parameters Note: For the sake of the example, the images dataset will be downloaded to a temporary directory. """ with TemporaryDirectory(dir=args.root_data_path) as tmp_dir: model = TransferLearningModel(dl_path=tmp_dir, **vars(args)) trainer = pl.Trainer( weights_summary=None, show_progress_bar=True, num_sanity_val_steps=0, gpus=args.gpus, min_epochs=args.nb_epochs, max_epochs=args.nb_epochs) trainer.fit(model)
Example #6
Source File: test_base.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_multiple_loggers(tmpdir): hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) logger1 = CustomLogger() logger2 = CustomLogger() trainer = Trainer( max_epochs=1, limit_train_batches=0.05, logger=[logger1, logger2], default_root_dir=tmpdir, ) result = trainer.fit(model) assert result == 1, "Training failed" assert logger1.hparams_logged == hparams assert logger1.metrics_logged != {} assert logger1.finalized_status == "success" assert logger2.hparams_logged == hparams assert logger2.metrics_logged != {} assert logger2.finalized_status == "success"
Example #7
Source File: test_base.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_multiple_loggers_pickle(tmpdir): """Verify that pickling trainer with multiple loggers works.""" logger1 = CustomLogger() logger2 = CustomLogger() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=[logger1, logger2], ) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0}, 0) assert logger1.metrics_logged != {} assert logger2.metrics_logged != {}
Example #8
Source File: test_model_checkpoint.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ tutils.reset_seed() model = EvalModelTemplate() checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint, overfit_pct=0.20, max_epochs=(save_top_k + 2), ) trainer.fit(model) # These should be different if the dirpath has be overridden assert trainer.ckpt_path != trainer.default_root_dir
Example #9
Source File: test_lr_finder.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_model_reset_correctly(tmpdir): """ Check that model weights are correctly reset after lr_find() """ model = EvalModelTemplate() # logger file to get meta trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, ) before_state_dict = model.state_dict() _ = trainer.lr_find(model, num_training=5) after_state_dict = model.state_dict() for key in before_state_dict.keys(): assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), \ 'Model was not reset correctly after learning rate finder'
Example #10
Source File: test_lr_finder.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_trainer_arg_str(tmpdir): """ Test that setting trainer arg to string works """ model = EvalModelTemplate() model.my_fancy_lr = 1.0 # update with non-standard field before_lr = model.my_fancy_lr # logger file to get meta trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, auto_lr_find='my_fancy_lr', ) trainer.fit(model) after_lr = model.my_fancy_lr assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder'
Example #11
Source File: test_lr_finder.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_call_to_trainer_method(tmpdir): """ Test that directly calling the trainer method works """ hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) before_lr = hparams.get('learning_rate') # logger file to get meta trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, ) lrfinder = trainer.lr_find(model, mode='linear') after_lr = lrfinder.suggestion() model.learning_rate = after_lr trainer.fit(model) assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder'
Example #12
Source File: test_neptune.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_neptune_leave_open_experiment_after_fit(tmpdir): """Verify that neptune experiment was closed after training""" model = EvalModelTemplate() def _run_training(logger): logger._experiment = MagicMock() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=0.05, logger=logger, ) trainer.fit(model) return logger logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True)) assert logger_close_after_fit._experiment.stop.call_count == 1 logger_open_after_fit = _run_training(NeptuneLogger(offline_mode=True, close_after_fit=False)) assert logger_open_after_fit._experiment.stop.call_count == 0
Example #13
Source File: test_lr_logger.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_lr_logger_multi_lrs(tmpdir): """ Test that learning rates are extracted and logged for multi lr schedulers. """ tutils.reset_seed() model = EvalModelTemplate() model.configure_optimizers = model.configure_optimizers__multiple_schedulers lr_logger = LearningRateLogger() trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_logger], ) result = trainer.fit(model) assert result assert lr_logger.lrs, 'No learning rates logged' assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ 'Number of learning rates logged does not match number of lr schedulers' assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \ 'Names of learning rates not set correctly' assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \ 'Length of logged learning rates exceeds the number of epochs'
Example #14
Source File: test_hparams.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_load_past_checkpoint(tmpdir, past_key): model = EvalModelTemplate() # verify we can train trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) trainer.fit(model) # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path) raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] raw_checkpoint['hparams_type'] = 'Namespace' raw_checkpoint[past_key]['batch_size'] = -17 del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] # save back the checkpoint torch.save(raw_checkpoint, raw_checkpoint_path) # verify that model loads correctly model2 = EvalModelTemplate.load_from_checkpoint(raw_checkpoint_path) assert model2.hparams.batch_size == -17
Example #15
Source File: test_lr_logger.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_lr_logger_single_lr(tmpdir): """ Test that learning rates are extracted and logged for single lr scheduler. """ tutils.reset_seed() model = EvalModelTemplate() model.configure_optimizers = model.configure_optimizers__single_scheduler lr_logger = LearningRateLogger() trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_logger], ) result = trainer.fit(model) assert result assert lr_logger.lrs, 'No learning rates logged' assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ 'Number of learning rates logged does not match number of lr schedulers' assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \ 'Names of learning rates not set correctly'
Example #16
Source File: test_gpu.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_multi_gpu_model(tmpdir, backend): """Make sure DDP works.""" tutils.set_random_master_port() trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, gpus=[0, 1], distributed_backend=backend, ) model = EvalModelTemplate() # tutils.run_model_test(trainer_options, model) trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result # test memory helper functions memory.get_memory_profile('min_max')
Example #17
Source File: test_gpu.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_multi_gpu_early_stop(tmpdir, backend): """Make sure DDP works. with early stopping""" tutils.set_random_master_port() trainer_options = dict( default_root_dir=tmpdir, early_stop_callback=True, max_epochs=50, limit_train_batches=10, limit_val_batches=10, gpus=[0, 1], distributed_backend=backend, ) model = EvalModelTemplate() # tutils.run_model_test(trainer_options, model) trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result
Example #18
Source File: test_gpu.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_ddp_all_dataloaders_passed_to_fit(tmpdir): """Make sure DDP works with dataloaders passed to fit()""" tutils.set_random_master_port() trainer_options = dict( default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=1, limit_train_batches=0.1, limit_val_batches=0.1, gpus=[0, 1], distributed_backend='ddp' ) model = EvalModelTemplate() fit_options = dict(train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader()) trainer = Trainer(**trainer_options) result = trainer.fit(model, **fit_options) assert result == 1, "DDP doesn't work with dataloaders passed to fit()."
Example #19
Source File: test_early_stopping.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_epoch): """Test to ensure that early stopping is not triggered before patience is exhausted.""" class ModelOverrideValidationReturn(EvalModelTemplate): validation_return_values = torch.Tensor(loss_values) count = 0 def validation_epoch_end(self, outputs): loss = self.validation_return_values[self.count] self.count += 1 return {"test_val_loss": loss} model = ModelOverrideValidationReturn() early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=early_stop_callback, val_check_interval=1.0, num_sanity_val_steps=0, max_epochs=10, ) trainer.fit(model) assert trainer.current_epoch == expected_stop_epoch
Example #20
Source File: test_early_stopping.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_early_stopping_no_extraneous_invocations(tmpdir): """Test to ensure that callback methods aren't being invoked outside of the callback handler.""" class EarlyStoppingTestInvocations(EarlyStopping): def __init__(self, expected_count): super().__init__() self.count = 0 self.expected_count = expected_count def on_validation_end(self, trainer, pl_module): self.count += 1 def on_train_end(self, trainer, pl_module): assert self.count == self.expected_count model = EvalModelTemplate() expected_count = 4 early_stop_callback = EarlyStoppingTestInvocations(expected_count) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=early_stop_callback, val_check_interval=1.0, max_epochs=expected_count, ) trainer.fit(model)
Example #21
Source File: test_callbacks.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_model_checkpoint_path(tmpdir, logger_version, expected): """Test that "version_" prefix is only added when logger's version is an integer""" tutils.reset_seed() model = EvalModelTemplate() logger = TensorBoardLogger(str(tmpdir), version=logger_version) trainer = Trainer( default_root_dir=tmpdir, overfit_batches=0.2, max_epochs=2, logger=logger, ) trainer.fit(model) ckpt_version = Path(trainer.ckpt_path).parent.name assert ckpt_version == expected
Example #22
Source File: test_callbacks.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ tutils.reset_seed() model = EvalModelTemplate() checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint, overfit_batches=0.20, max_epochs=2, ) trainer.fit(model) # These should be different if the dirpath has be overridden assert trainer.ckpt_path != trainer.default_root_dir
Example #23
Source File: test_callbacks.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_early_stopping_no_val_step(tmpdir): """Test that early stopping callback falls back to training metrics when no validation defined.""" class CurrentModel(EvalModelTemplate): def training_step(self, *args, **kwargs): output = super().training_step(*args, **kwargs) output.update({'my_train_metric': output['loss']}) # could be anything else return output model = CurrentModel() model.validation_step = None model.val_dataloader = None stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=stopping, overfit_batches=0.20, max_epochs=2, ) result = trainer.fit(model) assert result == 1, 'training failed to complete' assert trainer.current_epoch < trainer.max_epochs
Example #24
Source File: test_callbacks.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_early_stopping_functionality(tmpdir): class CurrentModel(EvalModelTemplate): def validation_epoch_end(self, outputs): losses = [8, 4, 2, 3, 4, 5, 8, 10] val_loss = losses[self.current_epoch] return {'val_loss': torch.tensor(val_loss)} model = CurrentModel() trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=True, overfit_batches=0.20, max_epochs=20, ) result = trainer.fit(model) print(trainer.current_epoch) assert trainer.current_epoch == 5, 'early_stopping failed'
Example #25
Source File: test_amp.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_amp_single_gpu(tmpdir, backend): """Make sure DP/DDP + AMP work.""" tutils.reset_seed() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, gpus=1, distributed_backend=backend, precision=16, ) model = EvalModelTemplate() # tutils.run_model_test(trainer_options, model) result = trainer.fit(model) assert result == 1
Example #26
Source File: test_amp.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_amp_multi_gpu(tmpdir, backend): """Make sure DP/DDP + AMP work.""" tutils.set_random_master_port() model = EvalModelTemplate() trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, # gpus=2, gpus='0, 1', # test init with gpu string distributed_backend=backend, precision=16, ) # tutils.run_model_test(trainer_options, model) trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result
Example #27
Source File: test_amp.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_multi_gpu_wandb(tmpdir, backend): """Make sure DP/DDP + AMP work.""" from pytorch_lightning.loggers import WandbLogger tutils.set_random_master_port() model = EvalModelTemplate() logger = WandbLogger(name='utest') trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, gpus=2, distributed_backend=backend, precision=16, logger=logger, ) # tutils.run_model_test(trainer_options, model) trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result trainer.test(model)
Example #28
Source File: test_hooks.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_on_before_zero_grad_called(tmpdir, max_steps): class CurrentTestModel(EvalModelTemplate): on_before_zero_grad_called = 0 def on_before_zero_grad(self, optimizer): self.on_before_zero_grad_called += 1 model = CurrentTestModel() trainer = Trainer( default_root_dir=tmpdir, max_steps=max_steps, max_epochs=2, num_sanity_val_steps=5, ) assert 0 == model.on_before_zero_grad_called trainer.fit(model) assert max_steps == model.on_before_zero_grad_called model.on_before_zero_grad_called = 0 trainer.test(model) assert 0 == model.on_before_zero_grad_called
Example #29
Source File: test_lr_logger.py From pytorch-lightning with Apache License 2.0 | 6 votes |
def test_lr_logger_no_lr(tmpdir): tutils.reset_seed() model = EvalModelTemplate() lr_logger = LearningRateLogger() trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_logger], ) with pytest.warns(RuntimeWarning): result = trainer.fit(model) assert result
Example #30
Source File: test_progress_bar.py From pytorch-lightning with Apache License 2.0 | 5 votes |
def test_progress_bar_off(tmpdir, callbacks, refresh_rate): """Test different ways the progress bar can be turned off.""" trainer = Trainer( default_root_dir=tmpdir, callbacks=callbacks, progress_bar_refresh_rate=refresh_rate, ) progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)] assert 0 == len(progress_bars) assert not trainer.progress_bar_callback