Python torch.distributed.is_initialized() Examples

The following are 30 code examples of torch.distributed.is_initialized(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.distributed , or try the search function .
Example #1
Source File: distributed.py    From virtex with MIT License 6 votes vote down vote up
def average_across_processes(t: Union[torch.Tensor, Dict[str, torch.Tensor]]):
    r"""
    Averages a tensor, or a dict of tensors across all processes in a process
    group. Objects in all processes will finally have same mean value.

    .. note::

        Nested dicts of tensors are not supported.

    Parameters
    ----------
    t: torch.Tensor or Dict[str, torch.Tensor]
        A tensor or dict of tensors to average across processes.
    """
    if dist.is_initialized():
        if isinstance(t, torch.Tensor):
            dist.all_reduce(t, op=dist.ReduceOp.SUM)
            t /= get_world_size()
        elif isinstance(t, dict):
            for k in t:
                dist.all_reduce(t[k], op=dist.ReduceOp.SUM)
                t[k] /= dist.get_world_size() 
Example #2
Source File: distributed_communicator.py    From CrypTen with MIT License 6 votes vote down vote up
def scatter(self, scatter_list, src, size=None, device=None):
        """Scatters a list of tensors to all parties."""
        assert dist.is_initialized(), "initialize the communicator first"
        if src != self.get_rank():
            if size is None:
                size = scatter_list[self.get_rank()].size()
            if device is None:
                try:
                    device = scatter_list[self.get_rank()].device
                except Exception:
                    pass
            tensor = torch.empty(size=size, dtype=torch.long, device=device)
            dist.scatter(tensor, [], src, group=self.main_group)
        else:
            scatter_list = [s.data for s in scatter_list]
            tensor = scatter_list[self.get_rank()]
            dist.scatter(tensor, scatter_list, src, group=self.main_group)
        return tensor 
Example #3
Source File: distributed_communicator.py    From CrypTen with MIT License 6 votes vote down vote up
def reduce(self, input, dst, op=ReduceOp.SUM, batched=False):
        """Reduces the input data across all parties."""
        assert dist.is_initialized(), "initialize the communicator first"

        if batched:
            assert isinstance(input, list), "batched reduce input must be a list"
            reqs = []
            result = [x.clone().data for x in input]
            for tensor in result:
                reqs.append(
                    dist.reduce(
                        tensor, dst, op=op, group=self.main_group, async_op=True
                    )
                )
            for req in reqs:
                req.wait()
        else:
            assert torch.is_tensor(
                input.data
            ), "unbatched input for reduce must be a torch tensor"
            result = input.clone()
            dist.reduce(result.data, dst, op=op, group=self.main_group)

        return result if dst == self.get_rank() else None 
Example #4
Source File: distributed_communicator.py    From CrypTen with MIT License 6 votes vote down vote up
def all_reduce(self, input, op=ReduceOp.SUM, batched=False):
        """Reduces the input data across all parties; all get the final result."""
        assert dist.is_initialized(), "initialize the communicator first"

        if batched:
            assert isinstance(input, list), "batched reduce input must be a list"
            reqs = []
            result = [x.clone() for x in input]
            for tensor in result:
                reqs.append(
                    dist.all_reduce(
                        tensor.data, op=op, group=self.main_group, async_op=True
                    )
                )
            for req in reqs:
                req.wait()
        else:
            assert torch.is_tensor(
                input.data
            ), "unbatched input for reduce must be a torch tensor"
            result = input.clone()
            dist.all_reduce(result.data, op=op, group=self.main_group)
        return result 
Example #5
Source File: distributed_communicator.py    From CrypTen with MIT License 5 votes vote down vote up
def all_gather(self, tensor):
        """Gathers tensors from all parties in a list."""
        assert dist.is_initialized(), "initialize the communicator first"
        result = []
        for _ in range(self.get_world_size()):
            result.append(torch.empty(size=tensor.size(), dtype=torch.long))
        dist.all_gather(result, tensor, group=self.main_group)
        return result 
Example #6
Source File: comm.py    From ACNet with MIT License 5 votes vote down vote up
def get_world_size():
    if not dist.is_available():
        return 1
    if not dist.is_initialized():
        return 1
    return dist.get_world_size() 
Example #7
Source File: comm.py    From ACNet with MIT License 5 votes vote down vote up
def get_rank():
    if not dist.is_available():
        return 0
    if not dist.is_initialized():
        return 0
    return dist.get_rank() 
Example #8
Source File: util.py    From allennlp with Apache License 2.0 5 votes vote down vote up
def is_distributed() -> bool:
    """
    Checks if the distributed process group is available and has been initialized
    """
    return dist.is_available() and dist.is_initialized() 
Example #9
Source File: distributed_communicator.py    From CrypTen with MIT License 5 votes vote down vote up
def get_rank(self):
        """Returns the rank of the current process."""
        assert dist.is_initialized(), "initialize the communicator first"
        return dist.get_rank() 
Example #10
Source File: distributed_communicator.py    From CrypTen with MIT License 5 votes vote down vote up
def gather(self, tensor, dst):
        """Gathers a list of tensors in a single party."""
        assert dist.is_initialized(), "initialize the communicator first"
        if self.get_rank() == dst:
            result = []
            for _ in range(self.get_world_size()):
                result.append(torch.empty(size=tensor.size(), dtype=torch.long))
            dist.gather(tensor, result, dst, group=self.main_group)
            return result
        dist.gather(tensor, [], dst, group=self.main_group)
        return [None] 
Example #11
Source File: torch_utils.py    From ACNet with MIT License 5 votes vote down vote up
def get_world_size():
    if not dist.is_available():
        return 1
    if not dist.is_initialized():
        return 1
    return dist.get_world_size() 
Example #12
Source File: distributed_communicator.py    From CrypTen with MIT License 5 votes vote down vote up
def barrier(self):
        """Synchronizes all processes.

        This collective blocks processes until the whole group enters this
        function.
        """
        assert dist.is_initialized(), "initialize the communicator first"
        dist.barrier(group=self.main_group) 
Example #13
Source File: distributed_communicator.py    From CrypTen with MIT License 5 votes vote down vote up
def get_world_size(self):
        """Returns the size of the world."""
        assert dist.is_initialized(), "initialize the communicator first"
        return self.world_size 
Example #14
Source File: multiprocess_test_case.py    From CrypTen with MIT License 5 votes vote down vote up
def get_random_linear(in_channels, out_channels):
    linear = torch.nn.Linear(in_channels, out_channels)
    if dist.is_initialized():
        # Broadcast this tensor to the world so that the generated random tensor
        # is in sync in all distributed processes. See T45688819 for more
        # information.
        comm.get().broadcast(linear.weight, 0)
        comm.get().broadcast(linear.bias, 0)

    return linear 
Example #15
Source File: distributed.py    From ocr-pytorch with MIT License 5 votes vote down vote up
def get_rank():
    if not dist.is_available():
        return 0

    if not dist.is_initialized():
        return 0

    return dist.get_rank() 
Example #16
Source File: distributed_communicator.py    From CrypTen with MIT License 5 votes vote down vote up
def isend(self, tensor, dst):
        """Sends the specified tensor to the destination dst."""
        assert dist.is_initialized(), "initialize the communicator first"
        return dist.isend(tensor.data, dst, group=self.main_group) 
Example #17
Source File: distributed_communicator.py    From CrypTen with MIT License 5 votes vote down vote up
def get_distributed_backend(self):
        """Returns name of torch.distributed backend used."""
        assert dist.is_initialized(), "initialize the communicator first"
        return dist.get_backend() 
Example #18
Source File: torch_utils.py    From ACNet with MIT License 5 votes vote down vote up
def synchronize():
    """
    Helper function to synchronize (barrier) among all processes when
    using distributed training
    """
    if not dist.is_available():
        return
    if not dist.is_initialized():
        return
    world_size = dist.get_world_size()
    if world_size == 1:
        return
    dist.barrier() 
Example #19
Source File: torch_utils.py    From ACNet with MIT License 5 votes vote down vote up
def get_rank():
    if not dist.is_available():
        return 0
    if not dist.is_initialized():
        return 0
    return dist.get_rank() 
Example #20
Source File: torch_utils.py    From Centripetal-SGD with Apache License 2.0 5 votes vote down vote up
def synchronize():
    """
    Helper function to synchronize (barrier) among all processes when
    using distributed training
    """
    if not dist.is_available():
        return
    if not dist.is_initialized():
        return
    world_size = dist.get_world_size()
    if world_size == 1:
        return
    dist.barrier() 
Example #21
Source File: dist_utils.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def get_dist_info():
    if TORCH_VERSION < '1.0':
        initialized = dist._initialized
    else:
        if dist.is_available():
            initialized = dist.is_initialized()
        else:
            initialized = False
    if initialized:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1
    return rank, world_size 
Example #22
Source File: __init__.py    From pytorch_image_classification with MIT License 5 votes vote down vote up
def apply_data_parallel_wrapper(config: yacs.config.CfgNode,
                                model: nn.Module) -> nn.Module:
    local_rank = config.train.dist.local_rank
    if dist.is_available() and dist.is_initialized():
        if config.train.dist.use_sync_bn:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = nn.parallel.DistributedDataParallel(model,
                                                    device_ids=[local_rank],
                                                    output_device=local_rank)
    else:
        if config.device == 'cuda':
            model = nn.DataParallel(model)
    return model 
Example #23
Source File: dist.py    From pytorch_image_classification with MIT License 5 votes vote down vote up
def get_rank() -> int:
    if not (dist.is_available() and dist.is_initialized()):
        return 0
    else:
        return dist.get_rank() 
Example #24
Source File: comm.py    From remote_sensing_object_detection_2019 with MIT License 5 votes vote down vote up
def synchronize():
    """
    Helper function to synchronize (barrier) among all processes when
    using distributed training
    """
    if not dist.is_available():
        return
    if not dist.is_initialized():
        return
    world_size = dist.get_world_size()
    if world_size == 1:
        return
    dist.barrier() 
Example #25
Source File: comm.py    From remote_sensing_object_detection_2019 with MIT License 5 votes vote down vote up
def get_rank():
    if not dist.is_available():
        return 0
    if not dist.is_initialized():
        return 0
    return dist.get_rank() 
Example #26
Source File: comm.py    From remote_sensing_object_detection_2019 with MIT License 5 votes vote down vote up
def get_world_size():
    if not dist.is_available():
        return 1
    if not dist.is_initialized():
        return 1
    return dist.get_world_size() 
Example #27
Source File: distributed.py    From virtex with MIT License 5 votes vote down vote up
def get_rank() -> int:
    r"""Return rank of current process in the process group."""
    return dist.get_rank() if dist.is_initialized() else 0 
Example #28
Source File: distributed.py    From virtex with MIT License 5 votes vote down vote up
def get_world_size() -> int:
    r"""Return number of processes in the process group, each uses 1 GPU."""
    return dist.get_world_size() if dist.is_initialized() else 1 
Example #29
Source File: distributed.py    From virtex with MIT License 5 votes vote down vote up
def synchronize() -> None:
    r"""Synchronize (barrier) all processes in a process group."""
    if dist.is_initialized():
        dist.barrier() 
Example #30
Source File: gfl_head.py    From mmdetection with Apache License 2.0 5 votes vote down vote up
def reduce_mean(tensor):
    if not (dist.is_available() and dist.is_initialized()):
        return tensor
    tensor = tensor.clone()
    dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
    return tensor