Python apex.amp.state_dict() Examples

The following are 16 code examples of apex.amp.state_dict(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module apex.amp , or try the search function .
Example #1
Source File: utils.py    From pytorch-image-models with Apache License 2.0 6 votes vote down vote up
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
        save_state = {
            'epoch': epoch,
            'arch': args.model,
            'state_dict': get_state_dict(model),
            'optimizer': optimizer.state_dict(),
            'args': args,
            'version': 2,  # version < 2 increments epoch before save
        }
        if use_amp and 'state_dict' in amp.__dict__:
            save_state['amp'] = amp.state_dict()
        if model_ema is not None:
            save_state['state_dict_ema'] = get_state_dict(model_ema)
        if metric is not None:
            save_state['metric'] = metric
        torch.save(save_state, save_path) 
Example #2
Source File: torch_runner.py    From ray with Apache License 2.0 6 votes vote down vote up
def state_dict(self):
        """Returns the state of the runner."""
        state = {
            "epoch": self.epochs,
            "operator": self.training_operator.state_dict(),
            "models": [model.state_dict() for model in self.models],
            "optimizers": [opt.state_dict() for opt in self.optimizers]
        }
        if self.schedulers:
            state.update({
                "schedulers": [
                    scheduler.state_dict() for scheduler in self.schedulers
                ]
            })
        # Check if fp16 is True and if NVIDIA Apex is imported.
        if self.use_fp16 and amp:
            state.update({"amp": amp.state_dict()})
        return state 
Example #3
Source File: training.py    From tape with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def save_state(self, save_directory: typing.Union[str, Path], epoch_id: int):
        save_directory = Path(save_directory)
        if not save_directory.exists():
            save_directory.mkdir()
        else:
            assert save_directory.is_dir(), "Save path should be a directory"
        model_to_save = getattr(self.model, 'module', self.model)
        model_to_save.save_pretrained(save_directory)
        optimizer_state: typing.Dict[str, typing.Any] = {
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'epoch': epoch_id}
        if APEX_FOUND:
            optimizer_state['master params'] = list(amp.master_params(self.optimizer))
            try:
                optimizer_state['amp'] = amp.state_dict()
            except AttributeError:
                pass
        torch.save(optimizer_state, save_directory / 'checkpoint.bin') 
Example #4
Source File: train.py    From OpenTransformer with MIT License 6 votes vote down vote up
def save_model(self, epoch=None, save_name=None):
        if save_name is None:
            save_name = 'model.epoch.%d.pt' % epoch

        if self.mixed_precision:
            import apex.amp as amp
            amp_state_dict = amp.state_dict()
        else:
            amp_state_dict = None

        checkpoint = {
            'epoch': epoch,
            'params': self.params,
            'model': self.model.module.state_dict() if self.ngpu > 1 else self.model.state_dict(),
             #'optimizer': self.optimizer.state_dict(),
            'amp': amp_state_dict
        }

        torch.save(checkpoint, os.path.join(self.expdir, save_name)) 
Example #5
Source File: trainer.py    From allennlp with Apache License 2.0 5 votes vote down vote up
def get_checkpoint_state(self) -> Iterator[Tuple[Dict[str, Any], Dict[str, Any]]]:
        if self._moving_average is not None:
            # Assigning average value to model parameters.  The checkpointer will call
            # `restore_state_after_checkpointing` when it is done to put this back to what it was.
            self._moving_average.assign_average_value()

        model_state = self.model.state_dict()

        # These are the training states we need to persist.
        training_states = {
            "metric_tracker": self._metric_tracker.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "batch_num_total": self._batch_num_total,
        }

        # If we have a learning rate or momentum scheduler, we should persist them too.
        if self._learning_rate_scheduler is not None:
            training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict()
        if self._momentum_scheduler is not None:
            training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict()
        # If model was trained with amp, we should persist the amp state.
        if self._opt_level is not None:
            training_states["amp"] = amp.state_dict()

        try:
            yield model_state, training_states
        finally:
            if self._moving_average is not None:
                self._moving_average.restore() 
Example #6
Source File: utils.py    From pytorch-image-models with Apache License 2.0 5 votes vote down vote up
def get_state_dict(model):
    return unwrap_model(model).state_dict() 
Example #7
Source File: utils.py    From pytorch-image-models with Apache License 2.0 5 votes vote down vote up
def update(self, model):
        # correct a mismatch in state dict keys
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                if needs_module:
                    k = 'module.' + k
                model_v = msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) 
Example #8
Source File: torch_runner.py    From ray with Apache License 2.0 5 votes vote down vote up
def load_state_dict(self, state):
        """Sets the state of the model."""
        for model, state_dict in zip(self.models, state["models"]):
            model.load_state_dict(state_dict)
        for optimizer, state_dict in zip(self.optimizers, state["optimizers"]):
            optimizer.load_state_dict(state_dict)
        if self.schedulers:
            for scheduler, state_dict in zip(self.schedulers,
                                             state["schedulers"]):
                scheduler.load_state_dict(state_dict)

        if self.use_fp16 and "amp" in state and amp:
            amp.load_state_dict(state["amp"])
        self.epochs = state["epoch"]
        self.training_operator.load_state_dict(state_dict) 
Example #9
Source File: torch_runner.py    From ray with Apache License 2.0 5 votes vote down vote up
def state_stream(self):
        """Returns a bytes object for the state dict."""
        state_dict = self.state_dict()
        _buffer = io.BytesIO()
        torch.save(state_dict, _buffer)
        return _buffer.getvalue() 
Example #10
Source File: torch_runner.py    From ray with Apache License 2.0 5 votes vote down vote up
def load_state_stream(self, byte_obj):
        """Loads a bytes object the training state dict."""
        _buffer = io.BytesIO(byte_obj)
        state_dict = torch.load(_buffer)
        return self.load_state_dict(state_dict) 
Example #11
Source File: test_checkpointing.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def check_state_dict_fp32(self, state_dict):
        for key in state_dict:
            if 'num_batches_tracked' in key:
                continue
            param = state_dict[key]
            self.assertEqual(param.type(), FLOAT,
                             'Parameter in state_dict not FLOAT') 
Example #12
Source File: test_checkpointing.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def compare_models(self, modelA, modelB, test_setup=''):
        state_dictA = modelA.state_dict()
        state_dictB = modelB.state_dict()
        self.assertEqual(len(state_dictA), len(state_dictB),
                         'state_dicts have different lengths' + test_setup)
        for key in state_dictA:
            paramA = state_dictA[key]
            paramB = state_dictB[key]
            self.assertTrue((paramA==paramB).all(),
                msg='Parameters in state_dices not equal.' +
                    'key: {}\nparam: {}\nrestored: {}\ndiff: {} for {}'.format(
                        key, paramA, paramB, paramA - paramB, test_setup)) 
Example #13
Source File: test_checkpointing.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_state_dict(self):
        for opt_level in self.test_opt_levels:
            # Skip O3
            if opt_level == 'O3':
                continue

            model = MyModel().to('cuda')
            optimizer = optim.Adam(model.parameters(), lr=1e-3)
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=opt_level, verbosity=0)

            # Export state_dict and check for Half
            state_dict = model.state_dict()
            for key in state_dict:
                self.assertFalse('Half' in state_dict[key].type())

            # Check, if model is still trainable
            # Create dummy data
            data = torch.randn(10, 3, 4, 4, device='cuda')
            target = torch.randn(10, 6, 4, 4, device='cuda')
            
            # Get initnial loss
            optimizer.zero_grad()
            output = model(data)
            loss = F.mse_loss(output, target)
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
            last_loss = loss.item()

            # train for some epochs
            for epoch in range(10):
                optimizer.zero_grad()
                output = model(data)
                loss = F.mse_loss(output, target)
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()
                self.assertTrue(loss.item() < last_loss)
                last_loss = loss.item() 
Example #14
Source File: train.py    From OpenTransformer with MIT License 5 votes vote down vote up
def load_model(self, checkpoint):

        state_dict = torch.load(checkpoint)
        self.model.load_state_dict(state_dict['model'])
        if self.mixed_precision:
            import apex.amp as amp
            amp.load_state_dict(state_dict['amp']) 
Example #15
Source File: test_checkpointing.py    From apex with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def test_loss_scale_decrease(self):
        num_losses = 3
        nb_decrease_loss_scales = [0, 1, 2]
        for opt_level in self.test_opt_levels:
            #print('#' * 75 + f'\n opt_level {opt_level}\n')
            # Create new tmp copy for this run
            nb_decrease_loss_scales_tmp = list(nb_decrease_loss_scales)

            model = MyModel().to('cuda')
        
            optimizer = optim.SGD(model.parameters(),
                                  lr=self.initial_lr)
        
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=opt_level, num_losses=num_losses,
                verbosity=0)

            if amp._amp_state.opt_properties.loss_scale != 'dynamic':
                #print('Static loss scale set. Skipping opt_level.')
                continue
        
            # force to skip some updates to decrease the loss_scale
            initial_loss_scales = []
            for idx in range(num_losses):
                initial_loss_scales.append(
                    amp._amp_state.loss_scalers[idx].loss_scale())
            
            for _ in range(len(nb_decrease_loss_scales)):
                x = torch.randn(16, 3, 24, 24, device='cuda')
                for idx in range(num_losses):
                    while nb_decrease_loss_scales_tmp[idx] > 0:
                        optimizer.zero_grad()
                        output = model(x * 2**17)
                        loss = output.mean()            
                    
                        with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
                            scaled_loss.backward(retain_graph=True)
                        optimizer.step()
                        nb_decrease_loss_scales_tmp[idx] -= 1
                
            # Check loss scales afterwards
            updated_loss_scales = []
            for idx in range(num_losses):
                updated_loss_scales.append(
                    amp._amp_state.loss_scalers[idx].loss_scale())
            for factor, update_ls, init_ls in zip(nb_decrease_loss_scales,
                                                  updated_loss_scales,
                                                  initial_loss_scales):
                self.assertEqual(update_ls, init_ls / 2**factor)

            # Check state dict
            amp_state_dict = amp.state_dict()
            for scaler_idx, factor, init_ls in zip(amp_state_dict,
                                                   nb_decrease_loss_scales,
                                                   initial_loss_scales):
                scaler = amp_state_dict[scaler_idx]
                self.assertEqual(scaler['loss_scale'], init_ls / 2**factor)
                unskipped_target = 0
                self.assertEqual(scaler['unskipped'], unskipped_target) 
Example #16
Source File: train.py    From OpenTransformer with MIT License 4 votes vote down vote up
def __init__(self, params, model, optimizer, scheduler=None, is_visual=True, expdir='./',
                 ngpu=1, parallel_mode='dp', local_rank=0, mixed_precision=False, opt_level='O1'):

        self.params = params
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.expdir = expdir
        self.is_visual = is_visual

        self.ngpu = ngpu
        self.parallel_mode = parallel_mode
        self.local_rank = local_rank

        self.shuffle = params['train']['shuffle']
        self.accum_steps = params['train']['accum_steps']
        self.grad_noise = params['train']['grad_noise']
        self.grad_clip = params['train']['clip_grad']
        self.global_step = 0
        self.log_interval = 10
        self.mean_loss = MeanLoss()

        self.mixed_precision = mixed_precision
        self.opt_level = opt_level

        self.logger = init_logger(log_file=os.path.join(expdir, 'train.log'))
        if self.is_visual and local_rank == 0:
            self.visulizer = Visulizer(log_dir=os.path.join(expdir, 'visual'))

        if self.params['train']['load_model']:
            self.load_model(self.params['train']['load_model'])
            self.logger.info('Load the checkpoint from %s' % self.params['train']['load_model'])

        if self.mixed_precision:
            import apex.amp as amp
            self.model, self.optimizer.optimizer = amp.initialize(self.model, self.optimizer.optimizer, opt_level=self.opt_level)

        if self.ngpu > 1:
#             if self.parallel_mode == 'hvd':
#                 import horovod.torch as hvd
#                 hvd.broadcast_parameters(self.model.state_dict(), root_rank=0)
#                 self.logger.info('[Horovod] Use %d gpus for training!' % self.ngpu)

            if self.parallel_mode == 'ddp':
                import torch.distributed as dist
                dist.init_process_group(backend="nccl", init_method='env://',
                                        rank=local_rank, world_size=self.ngpu)
                self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank], output_device=local_rank)
                self.logger.info('[DDP] Use %d gpus for training!' % self.ngpu)

            elif self.parallel_mode == 'dp':
                self.model = torch.nn.DataParallel(self.model, device_ids=[i for i in range(self.ngpu)])
                self.logger.info('[DP] Use %d gpus for training!' % self.ngpu)

            else:
                self.logger.warning('Please chose one of dp, ddp and hvd for parallel computing!')
        elif self.ngpu == 1:
            self.logger.info('Use only 1 gpu for training!')
        else:
            self.logger.info('Train the model in CPU!')