Python torch.nn.parallel.DistributedDataParallel() Examples

The following are 30 code examples of torch.nn.parallel.DistributedDataParallel(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn.parallel , or try the search function .
Example #1
Source File: base_task.py    From Doc2EDAG with MIT License 6 votes vote down vote up
def _decorate_model(self, parallel_decorate=True):
        self.logging('='*20 + 'Decorate Model' + '='*20)

        if self.setting.fp16:
            self.model.half()

        self.model.to(self.device)
        self.logging('Set model device to {}'.format(str(self.device)))

        if parallel_decorate:
            if self.in_distributed_mode():
                self.model = para.DistributedDataParallel(self.model,
                                                          device_ids=[self.setting.local_rank],
                                                          output_device=self.setting.local_rank)
                self.logging('Wrap distributed data parallel')
                # self.logging('In Distributed Mode, but do not use DistributedDataParallel Wrapper')
            elif self.n_gpu > 1:
                self.model = para.DataParallel(self.model)
                self.logging('Wrap data parallel')
        else:
            self.logging('Do not wrap parallel layers') 
Example #2
Source File: plain_train_net.py    From detectron2 with Apache License 2.0 6 votes vote down vote up
def main(args):
    cfg = setup(args)

    model = build_model(cfg)
    logger.info("Model:\n{}".format(model))
    if args.eval_only:
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        return do_test(cfg, model)

    distributed = comm.get_world_size() > 1
    if distributed:
        model = DistributedDataParallel(
            model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
        )

    do_train(cfg, model, resume=args.resume)
    return do_test(cfg, model) 
Example #3
Source File: main.py    From elastic with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def initialize_model(
    arch: str, lr: float, momentum: float, weight_decay: float, device_id: int
):
    print(f"=> creating model: {arch}")
    model = models.__dict__[arch]()
    # For multiprocessing distributed, DistributedDataParallel constructor
    # should always set the single device scope, otherwise,
    # DistributedDataParallel will use all available devices.
    model.cuda(device_id)
    cudnn.benchmark = True
    model = DistributedDataParallel(model, device_ids=[device_id])
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(device_id)
    optimizer = SGD(
        model.parameters(), lr, momentum=momentum, weight_decay=weight_decay
    )
    return model, criterion, optimizer 
Example #4
Source File: trainer.py    From GCA-Matting with MIT License 6 votes vote down vote up
def build_model(self):

        self.G = networks.get_generator(encoder=self.model_config.arch.encoder, decoder=self.model_config.arch.decoder)
        self.G.cuda()

        if CONFIG.dist:
            self.logger.info("Using pytorch synced BN")
            self.G = SyncBatchNorm.convert_sync_batchnorm(self.G)

        self.G_optimizer = torch.optim.Adam(self.G.parameters(),
                                            lr = self.train_config.G_lr,
                                            betas = [self.train_config.beta1, self.train_config.beta2])

        if CONFIG.dist:
            # SyncBatchNorm only supports DistributedDataParallel with single GPU per process
            self.G = DistributedDataParallel(self.G, device_ids=[CONFIG.local_rank], output_device=CONFIG.local_rank)
        else:
            self.G = nn.DataParallel(self.G)

        self.build_lr_scheduler() 
Example #5
Source File: plain_train_net.py    From detectron2 with Apache License 2.0 6 votes vote down vote up
def main(args):
    cfg = setup(args)

    model = build_model(cfg)
    logger.info("Model:\n{}".format(model))
    if args.eval_only:
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        return do_test(cfg, model)

    distributed = comm.get_world_size() > 1
    if distributed:
        model = DistributedDataParallel(
            model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
        )

    do_train(cfg, model)
    return do_test(cfg, model) 
Example #6
Source File: base.py    From rlpyt with MIT License 6 votes vote down vote up
def data_parallel(self):
        """Wraps the model with PyTorch's DistributedDataParallel.  The
        intention is for rlpyt to create a separate Python process to drive
        each GPU (or CPU-group for CPU-only, MPI-like configuration). Agents
        with additional model components (beyond ``self.model``) which will
        have gradients computed through them should extend this method to wrap
        those, as well.

        Typically called in the runner during startup.
        """
        if self.device.type == "cpu":
            self.model = DDPC(self.model)
            logger.log("Initialized DistributedDataParallelCPU agent model.")
        else:
            self.model = DDP(self.model,
                device_ids=[self.device.index], output_device=self.device.index)
            logger.log("Initialized DistributedDataParallel agent model on "
                f"device {self.device}.") 
Example #7
Source File: trainer.py    From seq2seq.pytorch with MIT License 6 votes vote down vote up
def __init__(self, *kargs, **kwargs):
        super(NestedTrainer, self).__init__(*kargs, **kwargs)
        self.model_with_loss = AddLossModule(self.model, self.criterion)
        if self.distributed:
            self.model_with_loss = DistributedDataParallel(
                self.model_with_loss,
                device_ids=[self.local_rank],
                output_device=self.local_rank)
        else:
            if isinstance(self.device_ids, tuple):
                self.model_with_loss = DataParallel(self.model_with_loss,
                                                    self.device_ids,
                                                    dim=0 if self.batch_first else 1)
        _, target_tok = self.save_info['tokenizers'].values()
        target_words = target_tok.common_words(8188)
        self.contrast_batch = batch_nested_sequences(target_words) 
Example #8
Source File: checkpoint.py    From SSD with MIT License 6 votes vote down vote up
def save(self, name, **kwargs):
        if not self.save_dir:
            return

        if not self.save_to_disk:
            return

        data = {}
        if isinstance(self.model, DistributedDataParallel):
            data['model'] = self.model.module.state_dict()
        else:
            data['model'] = self.model.state_dict()
        if self.optimizer is not None:
            data["optimizer"] = self.optimizer.state_dict()
        if self.scheduler is not None:
            data["scheduler"] = self.scheduler.state_dict()
        data.update(kwargs)

        save_file = os.path.join(self.save_dir, "{}.pth".format(name))
        self.logger.info("Saving checkpoint to {}".format(save_file))
        torch.save(data, save_file)

        self.tag_last_checkpoint(save_file) 
Example #9
Source File: checkpoint.py    From SSD with MIT License 6 votes vote down vote up
def load(self, f=None, use_latest=True):
        if self.has_checkpoint() and use_latest:
            # override argument with existing checkpoint
            f = self.get_checkpoint_file()
        if not f:
            # no checkpoint could be found
            self.logger.info("No checkpoint found.")
            return {}

        self.logger.info("Loading checkpoint from {}".format(f))
        checkpoint = self._load_file(f)
        model = self.model
        if isinstance(model, DistributedDataParallel):
            model = self.model.module

        model.load_state_dict(checkpoint.pop("model"))
        if "optimizer" in checkpoint and self.optimizer:
            self.logger.info("Loading optimizer from {}".format(f))
            self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
        if "scheduler" in checkpoint and self.scheduler:
            self.logger.info("Loading scheduler from {}".format(f))
            self.scheduler.load_state_dict(checkpoint.pop("scheduler"))

        # return any further checkpoint data
        return checkpoint 
Example #10
Source File: main.py    From examples with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup() 
Example #11
Source File: main.py    From examples with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def demo_model_parallel(rank, world_size):
    print(f"Running DDP with model parallel example on rank {rank}.")
    setup(rank, world_size)

    # setup mp_model and devices for this process
    dev0 = rank * 2
    dev1 = rank * 2 + 1
    mp_model = ToyMpModel(dev0, dev1)
    ddp_mp_model = DDP(mp_model)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    # outputs will be on dev1
    outputs = ddp_mp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(dev1)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup() 
Example #12
Source File: model.py    From pytorch-project-template with Apache License 2.0 6 votes vote down vote up
def __init__(self, hp, net_arch, loss_f, rank=0, world_size=1):
        self.hp = hp
        self.device = self.hp.model.device
        self.net = net_arch.to(self.device)
        self.rank = rank
        self.world_size = world_size
        if self.device != "cpu" and self.world_size != 0:
            self.net = DDP(self.net, device_ids=[self.rank])
        self.input = None
        self.GT = None
        self.step = 0
        self.epoch = -1

        # init optimizer
        optimizer_mode = self.hp.train.optimizer.mode
        if optimizer_mode == "adam":
            self.optimizer = torch.optim.Adam(
                self.net.parameters(), **(self.hp.train.optimizer[optimizer_mode])
            )
        else:
            raise Exception("%s optimizer not supported" % optimizer_mode)

        # init loss
        self.loss_f = loss_f
        self.log = DotDict() 
Example #13
Source File: example.py    From examples with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def demo_basic(local_world_size, local_rank):

    # setup devices for this process. For local_world_size = 2, num_gpus = 8,
    # rank 1 uses GPUs [0, 1, 2, 3] and
    # rank 2 uses GPUs [4, 5, 6, 7].
    n = torch.cuda.device_count() // local_world_size
    device_ids = list(range(local_rank * n, (local_rank + 1) * n))

    print(
        f"[{os.getpid()}] rank = {dist.get_rank()}, "
        + f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids}"
    )

    model = ToyModel().cuda(device_ids[0])
    ddp_model = DDP(model, device_ids)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_ids[0])
    loss_fn(outputs, labels).backward()
    optimizer.step() 
Example #14
Source File: io.py    From torchpack with MIT License 6 votes vote down vote up
def save_checkpoint(model,
                    epoch,
                    num_iters,
                    out_dir,
                    filename_tmpl='epoch_{}.pth',
                    optimizer=None,
                    is_best=False):
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)
    if isinstance(model, (DataParallel, DistributedDataParallel)):
        model = model.module
    filename = os.path.join(out_dir, filename_tmpl.format(epoch))
    checkpoint = {
        'epoch': epoch,
        'num_iters': num_iters,
        'state_dict': model_weights_to_cpu(model.state_dict())
    }
    if optimizer is not None:
        checkpoint['optimizer'] = optimizer.state_dict()
    torch.save(checkpoint, filename)
    latest_link = os.path.join(out_dir, 'latest.pth')
    make_link(filename, latest_link)
    if is_best:
        best_link = os.path.join(out_dir, 'best.pth')
        make_link(filename, best_link) 
Example #15
Source File: benchmark.py    From detectron2 with Apache License 2.0 5 votes vote down vote up
def benchmark_train(args):
    cfg = setup(args)
    model = build_model(cfg)
    logger.info("Model:\n{}".format(model))
    if comm.get_world_size() > 1:
        model = DistributedDataParallel(
            model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
        )
    optimizer = build_optimizer(cfg, model)
    checkpointer = DetectionCheckpointer(model, optimizer=optimizer)
    checkpointer.load(cfg.MODEL.WEIGHTS)

    cfg.defrost()
    cfg.DATALOADER.NUM_WORKERS = 0
    data_loader = build_detection_train_loader(cfg)
    dummy_data = list(itertools.islice(data_loader, 100))

    def f():
        data = DatasetFromList(dummy_data, copy=False)
        while True:
            yield from data

    max_iter = 400
    trainer = SimpleTrainer(model, f(), optimizer)
    trainer.register_hooks(
        [hooks.IterationTimer(), hooks.PeriodicWriter([CommonMetricPrinter(max_iter)])]
    )
    trainer.train(1, max_iter) 
Example #16
Source File: gradnorm_logger.py    From catalyst with Apache License 2.0 5 votes vote down vote up
def grad_norm(*, model: Model, prefix: str, norm_type: int) -> Dict:
        """Computes gradient norms for a given model.

        Args:
            model (Model): model which gradients to be saved.
            prefix (str): prefix for keys in resulting dictionary.
            norm_type (int): norm type of gradient norm.

        Returns:
            Dict: dictionary in which gradient norms are stored.
        """
        if isinstance(model, (DataParallel, DistributedDataParallel)):
            model = model.module

        total_norm = 0.0
        grad_norm = {}

        for tag, value in model.named_parameters():
            tag = tag.replace(".", "/")
            metrics_tag = f"{prefix}/{tag}"
            param_norm = value.grad.data.norm(norm_type).item()
            total_norm += param_norm ** norm_type
            grad_norm[metrics_tag] = param_norm

        total_norm = total_norm ** (1.0 / norm_type)
        metrics_tag = f"{prefix}/total"
        grad_norm[metrics_tag] = total_norm

        return grad_norm 
Example #17
Source File: runner.py    From torchpack with MIT License 5 votes vote down vote up
def __init__(self,
                 model,
                 optimizer,
                 batch_processor,
                 work_dir=None,
                 log_level=logging.INFO):
        self.model = model
        self.optimizer = self.set_optimizer(optimizer)
        assert callable(batch_processor)
        self.batch_processor = batch_processor

        self.rank, self.world_size = get_dist_info()

        if isinstance(work_dir, str):
            self.work_dir = os.path.abspath(work_dir)
            if not os.path.isdir(self.work_dir):
                os.makedirs(self.work_dir)
        elif work_dir is None:
            self.work_dir = work_dir
        else:
            raise TypeError('"work_dir" must be a str or None')

        self.logger = self.init_logger(work_dir, log_level)

        if isinstance(self.model, (DataParallel, DistributedDataParallel)):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self.log_buffer = LogBuffer()
        self.hooks = []
        self.max_epoch = 0
        self.max_iter = 0
        self.epoch = 0
        self.num_iters = 0
        self.num_epoch_iters = 0
        self.mode = None 
Example #18
Source File: io.py    From torchpack with MIT License 5 votes vote down vote up
def load_checkpoint(model,
                    filename,
                    map_location=None,
                    strict=False,
                    logger=None):
    # load checkpoint from modelzoo or file or url
    if filename.startswith('modelzoo://'):
        model_name = filename[11:]
        checkpoint = model_zoo.load_url(model_urls[model_name])
    elif filename.startswith(('http://', 'https://')):
        checkpoint = model_zoo.load_url(filename)
    else:
        if not os.path.isfile(filename):
            raise IOError('{} is not a checkpoint file'.format(filename))
        checkpoint = torch.load(filename, map_location=map_location)
    # get state_dict from checkpoint
    if isinstance(checkpoint, OrderedDict):
        state_dict = checkpoint
    elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        raise RuntimeError(
            'No state_dict found in checkpoint file {}'.format(filename))
    # strip prefix of state_dict
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
    # load state_dict
    if isinstance(model, (DataParallel, DistributedDataParallel)):
        load_state_dict(model.module, state_dict, strict, logger)
    else:
        load_state_dict(model, state_dict, strict, logger)
    return checkpoint 
Example #19
Source File: base_model.py    From mmsr with Apache License 2.0 5 votes vote down vote up
def save_network(self, network, network_label, iter_label):
        save_filename = '{}_{}.pth'.format(iter_label, network_label)
        save_path = os.path.join(self.opt['path']['models'], save_filename)
        if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
            network = network.module
        state_dict = network.state_dict()
        for key, param in state_dict.items():
            state_dict[key] = param.cpu()
        torch.save(state_dict, save_path) 
Example #20
Source File: base_model.py    From mmsr with Apache License 2.0 5 votes vote down vote up
def get_network_description(self, network):
        """Get the string and total parameters of the network"""
        if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
            network = network.module
        return str(network), sum(map(lambda x: x.numel(), network.parameters())) 
Example #21
Source File: SRGAN_model.py    From mmsr with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
                                                                    DistributedDataParallel):
                net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
                                                 self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
                                                     self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info('Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                    logger.info(s) 
Example #22
Source File: base_model.py    From mmsr with Apache License 2.0 5 votes vote down vote up
def load_network(self, load_path, network, strict=True):
        if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
            network = network.module
        load_net = torch.load(load_path)
        load_net_clean = OrderedDict()  # remove unnecessary 'module.'
        for k, v in load_net.items():
            if k.startswith('module.'):
                load_net_clean[k[7:]] = v
            else:
                load_net_clean[k] = v
        network.load_state_dict(load_net_clean, strict=strict) 
Example #23
Source File: actions.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def get_DDP_modules(self, call_chain):
        modules = []
        for ind in range(1, len(call_chain)):
            m_id = call_chain[ind][0].unique_instance_id
            module = self.ddp_module_dict[m_id]
            if isinstance(module, DDP):
                modules.append(module)

        return modules 
Example #24
Source File: base_model.py    From EDVR with Apache License 2.0 5 votes vote down vote up
def save_network(self, network, network_label, iter_label):
        save_filename = '{}_{}.pth'.format(iter_label, network_label)
        save_path = os.path.join(self.opt['path']['models'], save_filename)
        if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
            network = network.module
        state_dict = network.state_dict()
        for key, param in state_dict.items():
            state_dict[key] = param.cpu()
        torch.save(state_dict, save_path) 
Example #25
Source File: ddpg_agent.py    From rlpyt with MIT License 5 votes vote down vote up
def data_parallel(self):
        super().data_parallel()  # Takes care of self.model.
        if self.device.type == "cpu":
            self.q_model = DDPC(self.q_model)
        else:
            self.q_model = DDP(self.q_model) 
Example #26
Source File: base_model.py    From EDVR with Apache License 2.0 5 votes vote down vote up
def load_network(self, load_path, network, strict=True):
        if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
            network = network.module
        load_net = torch.load(load_path)
        load_net_clean = OrderedDict()  # remove unnecessary 'module.'
        for k, v in load_net.items():
            if k.startswith('module.'):
                load_net_clean[k[7:]] = v
            else:
                load_net_clean[k] = v
        network.load_state_dict(load_net_clean, strict=strict) 
Example #27
Source File: td3_agent.py    From rlpyt with MIT License 5 votes vote down vote up
def data_parallel(self):
        super().data_parallel()
        if self.device.type == "cpu":
            self.q2_model = DDPC(self.q2_model)
        else:
            self.q2_model = DDP(self.q2_model) 
Example #28
Source File: SR_model.py    From EDVR with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
Example #29
Source File: lightning.py    From pytorch-lightning with Apache License 2.0 5 votes vote down vote up
def configure_ddp(
            self,
            model: 'LightningModule',
            device_ids: List[int]
    ) -> DistributedDataParallel:
        r"""
        Override to init DDP in your own way or with your own wrapper.
        The only requirements are that:

        1. On a validation batch the call goes to ``model.validation_step``.
        2. On a training batch the call goes to ``model.training_step``.
        3. On a testing batch, the call goes to ``model.test_step``.+

        Args:
            model: the :class:`LightningModule` currently being optimized.
            device_ids: the list of GPU ids.

        Return:
            DDP wrapped model

        Examples:
            .. code-block:: python

                # default implementation used in Trainer
                def configure_ddp(self, model, device_ids):
                    # Lightning DDP simply routes to test_step, val_step, etc...
                    model = LightningDistributedDataParallel(
                        model,
                        device_ids=device_ids,
                        find_unused_parameters=True
                    )
                    return model

        """
        model = LightningDistributedDataParallel(
            model,
            device_ids=device_ids,
            find_unused_parameters=True
        )
        return model 
Example #30
Source File: checkpoint.py    From fvcore with Apache License 2.0 5 votes vote down vote up
def __init__(
        self,
        model: nn.Module,
        save_dir: str = "",
        *,
        save_to_disk: bool = True,
        **checkpointables: object,
    ) -> None:
        """
        Args:
            model (nn.Module): model.
            save_dir (str): a directory to save and find checkpoints.
            save_to_disk (bool): if True, save checkpoint to disk, otherwise
                disable saving for this checkpointer.
            checkpointables (object): any checkpointable objects, i.e., objects
                that have the `state_dict()` and `load_state_dict()` method. For
                example, it can be used like
                `Checkpointer(model, "dir", optimizer=optimizer)`.
        """
        if isinstance(model, (DistributedDataParallel, DataParallel)):
            model = model.module
        self.model = model
        self.checkpointables = copy.copy(checkpointables)  # pyre-ignore
        self.logger = logging.getLogger(__name__)  # pyre-ignore
        self.save_dir = save_dir
        self.save_to_disk = save_to_disk