Python apex.amp.state_dict() Examples
The following are 16
code examples of apex.amp.state_dict().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
apex.amp
, or try the search function
.
Example #1
Source File: utils.py From pytorch-image-models with Apache License 2.0 | 6 votes |
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False): save_state = { 'epoch': epoch, 'arch': args.model, 'state_dict': get_state_dict(model), 'optimizer': optimizer.state_dict(), 'args': args, 'version': 2, # version < 2 increments epoch before save } if use_amp and 'state_dict' in amp.__dict__: save_state['amp'] = amp.state_dict() if model_ema is not None: save_state['state_dict_ema'] = get_state_dict(model_ema) if metric is not None: save_state['metric'] = metric torch.save(save_state, save_path)
Example #2
Source File: torch_runner.py From ray with Apache License 2.0 | 6 votes |
def state_dict(self): """Returns the state of the runner.""" state = { "epoch": self.epochs, "operator": self.training_operator.state_dict(), "models": [model.state_dict() for model in self.models], "optimizers": [opt.state_dict() for opt in self.optimizers] } if self.schedulers: state.update({ "schedulers": [ scheduler.state_dict() for scheduler in self.schedulers ] }) # Check if fp16 is True and if NVIDIA Apex is imported. if self.use_fp16 and amp: state.update({"amp": amp.state_dict()}) return state
Example #3
Source File: training.py From tape with BSD 3-Clause "New" or "Revised" License | 6 votes |
def save_state(self, save_directory: typing.Union[str, Path], epoch_id: int): save_directory = Path(save_directory) if not save_directory.exists(): save_directory.mkdir() else: assert save_directory.is_dir(), "Save path should be a directory" model_to_save = getattr(self.model, 'module', self.model) model_to_save.save_pretrained(save_directory) optimizer_state: typing.Dict[str, typing.Any] = { 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), 'epoch': epoch_id} if APEX_FOUND: optimizer_state['master params'] = list(amp.master_params(self.optimizer)) try: optimizer_state['amp'] = amp.state_dict() except AttributeError: pass torch.save(optimizer_state, save_directory / 'checkpoint.bin')
Example #4
Source File: train.py From OpenTransformer with MIT License | 6 votes |
def save_model(self, epoch=None, save_name=None): if save_name is None: save_name = 'model.epoch.%d.pt' % epoch if self.mixed_precision: import apex.amp as amp amp_state_dict = amp.state_dict() else: amp_state_dict = None checkpoint = { 'epoch': epoch, 'params': self.params, 'model': self.model.module.state_dict() if self.ngpu > 1 else self.model.state_dict(), #'optimizer': self.optimizer.state_dict(), 'amp': amp_state_dict } torch.save(checkpoint, os.path.join(self.expdir, save_name))
Example #5
Source File: trainer.py From allennlp with Apache License 2.0 | 5 votes |
def get_checkpoint_state(self) -> Iterator[Tuple[Dict[str, Any], Dict[str, Any]]]: if self._moving_average is not None: # Assigning average value to model parameters. The checkpointer will call # `restore_state_after_checkpointing` when it is done to put this back to what it was. self._moving_average.assign_average_value() model_state = self.model.state_dict() # These are the training states we need to persist. training_states = { "metric_tracker": self._metric_tracker.state_dict(), "optimizer": self.optimizer.state_dict(), "batch_num_total": self._batch_num_total, } # If we have a learning rate or momentum scheduler, we should persist them too. if self._learning_rate_scheduler is not None: training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict() if self._momentum_scheduler is not None: training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict() # If model was trained with amp, we should persist the amp state. if self._opt_level is not None: training_states["amp"] = amp.state_dict() try: yield model_state, training_states finally: if self._moving_average is not None: self._moving_average.restore()
Example #6
Source File: utils.py From pytorch-image-models with Apache License 2.0 | 5 votes |
def get_state_dict(model): return unwrap_model(model).state_dict()
Example #7
Source File: utils.py From pytorch-image-models with Apache License 2.0 | 5 votes |
def update(self, model): # correct a mismatch in state dict keys needs_module = hasattr(model, 'module') and not self.ema_has_module with torch.no_grad(): msd = model.state_dict() for k, ema_v in self.ema.state_dict().items(): if needs_module: k = 'module.' + k model_v = msd[k].detach() if self.device: model_v = model_v.to(device=self.device) ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
Example #8
Source File: torch_runner.py From ray with Apache License 2.0 | 5 votes |
def load_state_dict(self, state): """Sets the state of the model.""" for model, state_dict in zip(self.models, state["models"]): model.load_state_dict(state_dict) for optimizer, state_dict in zip(self.optimizers, state["optimizers"]): optimizer.load_state_dict(state_dict) if self.schedulers: for scheduler, state_dict in zip(self.schedulers, state["schedulers"]): scheduler.load_state_dict(state_dict) if self.use_fp16 and "amp" in state and amp: amp.load_state_dict(state["amp"]) self.epochs = state["epoch"] self.training_operator.load_state_dict(state_dict)
Example #9
Source File: torch_runner.py From ray with Apache License 2.0 | 5 votes |
def state_stream(self): """Returns a bytes object for the state dict.""" state_dict = self.state_dict() _buffer = io.BytesIO() torch.save(state_dict, _buffer) return _buffer.getvalue()
Example #10
Source File: torch_runner.py From ray with Apache License 2.0 | 5 votes |
def load_state_stream(self, byte_obj): """Loads a bytes object the training state dict.""" _buffer = io.BytesIO(byte_obj) state_dict = torch.load(_buffer) return self.load_state_dict(state_dict)
Example #11
Source File: test_checkpointing.py From apex with BSD 3-Clause "New" or "Revised" License | 5 votes |
def check_state_dict_fp32(self, state_dict): for key in state_dict: if 'num_batches_tracked' in key: continue param = state_dict[key] self.assertEqual(param.type(), FLOAT, 'Parameter in state_dict not FLOAT')
Example #12
Source File: test_checkpointing.py From apex with BSD 3-Clause "New" or "Revised" License | 5 votes |
def compare_models(self, modelA, modelB, test_setup=''): state_dictA = modelA.state_dict() state_dictB = modelB.state_dict() self.assertEqual(len(state_dictA), len(state_dictB), 'state_dicts have different lengths' + test_setup) for key in state_dictA: paramA = state_dictA[key] paramB = state_dictB[key] self.assertTrue((paramA==paramB).all(), msg='Parameters in state_dices not equal.' + 'key: {}\nparam: {}\nrestored: {}\ndiff: {} for {}'.format( key, paramA, paramB, paramA - paramB, test_setup))
Example #13
Source File: test_checkpointing.py From apex with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_state_dict(self): for opt_level in self.test_opt_levels: # Skip O3 if opt_level == 'O3': continue model = MyModel().to('cuda') optimizer = optim.Adam(model.parameters(), lr=1e-3) model, optimizer = amp.initialize( model, optimizer, opt_level=opt_level, verbosity=0) # Export state_dict and check for Half state_dict = model.state_dict() for key in state_dict: self.assertFalse('Half' in state_dict[key].type()) # Check, if model is still trainable # Create dummy data data = torch.randn(10, 3, 4, 4, device='cuda') target = torch.randn(10, 6, 4, 4, device='cuda') # Get initnial loss optimizer.zero_grad() output = model(data) loss = F.mse_loss(output, target) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() last_loss = loss.item() # train for some epochs for epoch in range(10): optimizer.zero_grad() output = model(data) loss = F.mse_loss(output, target) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() self.assertTrue(loss.item() < last_loss) last_loss = loss.item()
Example #14
Source File: train.py From OpenTransformer with MIT License | 5 votes |
def load_model(self, checkpoint): state_dict = torch.load(checkpoint) self.model.load_state_dict(state_dict['model']) if self.mixed_precision: import apex.amp as amp amp.load_state_dict(state_dict['amp'])
Example #15
Source File: test_checkpointing.py From apex with BSD 3-Clause "New" or "Revised" License | 4 votes |
def test_loss_scale_decrease(self): num_losses = 3 nb_decrease_loss_scales = [0, 1, 2] for opt_level in self.test_opt_levels: #print('#' * 75 + f'\n opt_level {opt_level}\n') # Create new tmp copy for this run nb_decrease_loss_scales_tmp = list(nb_decrease_loss_scales) model = MyModel().to('cuda') optimizer = optim.SGD(model.parameters(), lr=self.initial_lr) model, optimizer = amp.initialize( model, optimizer, opt_level=opt_level, num_losses=num_losses, verbosity=0) if amp._amp_state.opt_properties.loss_scale != 'dynamic': #print('Static loss scale set. Skipping opt_level.') continue # force to skip some updates to decrease the loss_scale initial_loss_scales = [] for idx in range(num_losses): initial_loss_scales.append( amp._amp_state.loss_scalers[idx].loss_scale()) for _ in range(len(nb_decrease_loss_scales)): x = torch.randn(16, 3, 24, 24, device='cuda') for idx in range(num_losses): while nb_decrease_loss_scales_tmp[idx] > 0: optimizer.zero_grad() output = model(x * 2**17) loss = output.mean() with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss: scaled_loss.backward(retain_graph=True) optimizer.step() nb_decrease_loss_scales_tmp[idx] -= 1 # Check loss scales afterwards updated_loss_scales = [] for idx in range(num_losses): updated_loss_scales.append( amp._amp_state.loss_scalers[idx].loss_scale()) for factor, update_ls, init_ls in zip(nb_decrease_loss_scales, updated_loss_scales, initial_loss_scales): self.assertEqual(update_ls, init_ls / 2**factor) # Check state dict amp_state_dict = amp.state_dict() for scaler_idx, factor, init_ls in zip(amp_state_dict, nb_decrease_loss_scales, initial_loss_scales): scaler = amp_state_dict[scaler_idx] self.assertEqual(scaler['loss_scale'], init_ls / 2**factor) unskipped_target = 0 self.assertEqual(scaler['unskipped'], unskipped_target)
Example #16
Source File: train.py From OpenTransformer with MIT License | 4 votes |
def __init__(self, params, model, optimizer, scheduler=None, is_visual=True, expdir='./', ngpu=1, parallel_mode='dp', local_rank=0, mixed_precision=False, opt_level='O1'): self.params = params self.model = model self.optimizer = optimizer self.scheduler = scheduler self.expdir = expdir self.is_visual = is_visual self.ngpu = ngpu self.parallel_mode = parallel_mode self.local_rank = local_rank self.shuffle = params['train']['shuffle'] self.accum_steps = params['train']['accum_steps'] self.grad_noise = params['train']['grad_noise'] self.grad_clip = params['train']['clip_grad'] self.global_step = 0 self.log_interval = 10 self.mean_loss = MeanLoss() self.mixed_precision = mixed_precision self.opt_level = opt_level self.logger = init_logger(log_file=os.path.join(expdir, 'train.log')) if self.is_visual and local_rank == 0: self.visulizer = Visulizer(log_dir=os.path.join(expdir, 'visual')) if self.params['train']['load_model']: self.load_model(self.params['train']['load_model']) self.logger.info('Load the checkpoint from %s' % self.params['train']['load_model']) if self.mixed_precision: import apex.amp as amp self.model, self.optimizer.optimizer = amp.initialize(self.model, self.optimizer.optimizer, opt_level=self.opt_level) if self.ngpu > 1: # if self.parallel_mode == 'hvd': # import horovod.torch as hvd # hvd.broadcast_parameters(self.model.state_dict(), root_rank=0) # self.logger.info('[Horovod] Use %d gpus for training!' % self.ngpu) if self.parallel_mode == 'ddp': import torch.distributed as dist dist.init_process_group(backend="nccl", init_method='env://', rank=local_rank, world_size=self.ngpu) self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank], output_device=local_rank) self.logger.info('[DDP] Use %d gpus for training!' % self.ngpu) elif self.parallel_mode == 'dp': self.model = torch.nn.DataParallel(self.model, device_ids=[i for i in range(self.ngpu)]) self.logger.info('[DP] Use %d gpus for training!' % self.ngpu) else: self.logger.warning('Please chose one of dp, ddp and hvd for parallel computing!') elif self.ngpu == 1: self.logger.info('Use only 1 gpu for training!') else: self.logger.info('Train the model in CPU!')