Python apex.amp.load_state_dict() Examples
The following are 8
code examples of apex.amp.load_state_dict().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
apex.amp
, or try the search function
.
Example #1
Source File: solver.py From End-to-end-ASR-Pytorch with MIT License | 5 votes |
def load_ckpt(self): ''' Load ckpt if --load option is specified ''' if self.paras.load: # Load weights ckpt = torch.load( self.paras.load, map_location=self.device if self.mode == 'train' else 'cpu') self.model.load_state_dict(ckpt['model']) if self.emb_decoder is not None: self.emb_decoder.load_state_dict(ckpt['emb_decoder']) # if self.amp: # amp.load_state_dict(ckpt['amp']) # Load task-dependent items metric = "None" score = 0.0 for k, v in ckpt.items(): if type(v) is float: metric, score = k, v if self.mode == 'train': self.step = ckpt['global_step'] self.optimizer.load_opt_state_dict(ckpt['optimizer']) self.verbose('Load ckpt from {}, restarting at step {} (recorded {} = {:.2f} %)'.format( self.paras.load, self.step, metric, score)) else: self.model.eval() if self.emb_decoder is not None: self.emb_decoder.eval() self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(self.paras.load, metric, score))
Example #2
Source File: run_seq2seq.py From unilm with MIT License | 5 votes |
def prepare_for_training(args, model, checkpoint_state_dict, amp): no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) if amp: model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) if checkpoint_state_dict: amp.load_state_dict(checkpoint_state_dict['amp']) if checkpoint_state_dict: optimizer.load_state_dict(checkpoint_state_dict['optimizer']) model.load_state_dict(checkpoint_state_dict['model']) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) return model, optimizer
Example #3
Source File: abs_task.py From espnet with Apache License 2.0 | 5 votes |
def resume( checkpoint: Union[str, Path], model: torch.nn.Module, reporter: Reporter, optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], ngpu: int = 0, use_apex: bool = False, ): states = torch.load( checkpoint, map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu", ) model.load_state_dict(states["model"]) reporter.load_state_dict(states["reporter"]) for optimizer, state in zip(optimizers, states["optimizers"]): optimizer.load_state_dict(state) for scheduler, state in zip(schedulers, states["schedulers"]): if scheduler is not None: scheduler.load_state_dict(state) if use_apex and states["amp"] is not None: try: from apex import amp except ImportError: logging.error( "You need to install apex. " "See https://github.com/NVIDIA/apex#linux" ) amp.load_state_dict(states["amp"]) logging.info(f"The training was resumed using {checkpoint}")
Example #4
Source File: abs_task.py From espnet with Apache License 2.0 | 5 votes |
def build_model_from_file( cls, config_file: Union[Path, str], model_file: Union[Path, str] = None, device: str = "cpu", ) -> Tuple[AbsESPnetModel, argparse.Namespace]: """This method is used for inference or fine-tuning. Args: config_file: The yaml file saved when training. model_file: The model file saved when training. device: """ assert check_argument_types() config_file = Path(config_file) with config_file.open("r", encoding="utf-8") as f: args = yaml.safe_load(f) args = argparse.Namespace(**args) model = cls.build_model(args) if not isinstance(model, AbsESPnetModel): raise RuntimeError( f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" ) model.to(device) if model_file is not None: if device == "cuda": # NOTE(kamo): "cuda" for torch.load always indicates cuda:0 # in PyTorch<=1.4 device = f"cuda:{torch.cuda.current_device()}" model.load_state_dict(torch.load(model_file, map_location=device)) return model, args
Example #5
Source File: torch_runner.py From ray with Apache License 2.0 | 5 votes |
def load_state_dict(self, state): """Sets the state of the model.""" for model, state_dict in zip(self.models, state["models"]): model.load_state_dict(state_dict) for optimizer, state_dict in zip(self.optimizers, state["optimizers"]): optimizer.load_state_dict(state_dict) if self.schedulers: for scheduler, state_dict in zip(self.schedulers, state["schedulers"]): scheduler.load_state_dict(state_dict) if self.use_fp16 and "amp" in state and amp: amp.load_state_dict(state["amp"]) self.epochs = state["epoch"] self.training_operator.load_state_dict(state_dict)
Example #6
Source File: torch_runner.py From ray with Apache License 2.0 | 5 votes |
def load_state_stream(self, byte_obj): """Loads a bytes object the training state dict.""" _buffer = io.BytesIO(byte_obj) state_dict = torch.load(_buffer) return self.load_state_dict(state_dict)
Example #7
Source File: training.py From tape with BSD 3-Clause "New" or "Revised" License | 5 votes |
def resume_from_checkpoint(self, checkpoint_dir: str) -> int: checkpoint = torch.load( os.path.join(checkpoint_dir, 'checkpoint.bin'), map_location=self.device) self.optimizer.load_state_dict(checkpoint['optimizer']) if self.fp16: self.optimizer._lazy_init_maybe_master_weights() self.optimizer._amp_stash.lazy_init_called = True self.optimizer.load_state_dict(checkpoint['optimizer']) for param, saved in zip( amp.master_params(self.optimizer), checkpoint['master params']): param.data.copy_(saved.data) amp.load_state_dict(checkpoint['amp']) self.scheduler.load_state_dict(checkpoint['scheduler']) start_epoch = checkpoint['epoch'] + 1 return start_epoch
Example #8
Source File: train.py From OpenTransformer with MIT License | 5 votes |
def load_model(self, checkpoint): state_dict = torch.load(checkpoint) self.model.load_state_dict(state_dict['model']) if self.mixed_precision: import apex.amp as amp amp.load_state_dict(state_dict['amp'])