Python torch.multiprocessing.spawn() Examples
The following are 30
code examples of torch.multiprocessing.spawn().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.multiprocessing
, or try the search function
.
Example #1
Source File: local_timer_example.py From elastic with BSD 3-Clause "New" or "Revised" License | 8 votes |
def test_torch_mp_example(self): # in practice set the max_interval to a larger value (e.g. 60 seconds) mp_queue = mp.get_context("spawn").Queue() server = timer.LocalTimerServer(mp_queue, max_interval=0.01) server.start() world_size = 8 # all processes should complete successfully # since start_process does NOT take context as parameter argument yet # this method WILL FAIL (hence the test is disabled) torch_mp.spawn( fn=_happy_function, args=(mp_queue,), nprocs=world_size, join=True ) with self.assertRaises(Exception): # torch.multiprocessing.spawn kills all sub-procs # if one of them gets killed torch_mp.spawn( fn=_stuck_function, args=(mp_queue,), nprocs=world_size, join=True ) server.stop()
Example #2
Source File: imagenet_torch_loader.py From pytorch_quantization with MIT License | 6 votes |
def main(): if cfg.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') if cfg.dist_url == "env://" and cfg.world_size == -1: cfg.world_size = int(os.environ["WORLD_SIZE"]) cfg.distributed = cfg.world_size > 1 or cfg.multiprocessing_distributed ngpus_per_node = torch.cuda.device_count() if cfg.multiprocessing_distributed: # Since we have ngpus_per_node processes per node, the total world_size # needs to be adjusted accordingly cfg.world_size = ngpus_per_node * cfg.world_size # Use torch.multiprocessing.spawn to launch distributed processes: the # main_worker process function mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, cfg)) else: # Simply call main_worker function main_worker(cfg.gpu, ngpus_per_node, cfg)
Example #3
Source File: distributed_slurm_main.py From pytorch-distributed with MIT License | 6 votes |
def main(): args = parser.parse_args() if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True # torch.backends.cudnn.enabled = False warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') args.local_rank = int(os.environ["SLURM_PROCID"]) args.world_size = int(os.environ["SLURM_NPROCS"]) ngpus_per_node = torch.cuda.device_count() job_id = os.environ["SLURM_JOBID"] args.dist_url = "file://{}.{}".format(os.path.realpath(args.dist_file), job_id) mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
Example #4
Source File: extractive_summarization_cnndm_distributed_train.py From nlp-recipes with MIT License | 6 votes |
def main(): print("NCCL_IB_DISABLE: {}".format(os.getenv("NCCL_IB_DISABLE"))) args = parser.parse_args() print("quick_run is {}".format(args.quick_run)) print("output_dir is {}".format(args.output_dir)) print("data_dir is {}".format(args.data_dir)) print("cache_dir is {}".format(args.cache_dir)) # shutil.rmtree(args.output_dir) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.cache_dir, exist_ok=True) ngpus_per_node = torch.cuda.device_count() processor = ExtSumProcessor(model_name=args.model_name) summarizer = ExtractiveSummarizer( processor, args.model_name, args.encoder, args.max_pos_length, args.cache_dir ) mp.spawn( main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, summarizer, args) )
Example #5
Source File: test_server.py From DetNAS with MIT License | 6 votes |
def _inference(self, cand): # bn_statistic parent_conn, child_conn = mp.Pipe() args = dict({"local_rank": 0, "distributed": False}) mp.spawn( bn_statistic, nprocs=self.ngpus_per_node, args=(self.ngpus_per_node, cfg, args, cand, child_conn)) salt = parent_conn.recv() # fitness parent_conn, child_conn = mp.Pipe() args = dict({"local_rank": 0, "distributed": False}) mp.spawn( fitness, nprocs=self.ngpus_per_node, args=(self.ngpus_per_node, cfg, args, cand, salt, child_conn)) if os.path.isfile(os.path.join(cfg.OUTPUT_DIR, salt+".pth")): os.remove(os.path.join(cfg.OUTPUT_DIR, salt+".pth")) return parent_conn.recv()
Example #6
Source File: train_distributed.py From helen with MIT License | 6 votes |
def setup(rank, device_ids, args): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' # initialize the process group dist.init_process_group("gloo", rank=rank, world_size=len(device_ids)) train_file, test_file, batch_size, epochs, gpu_mode, num_workers, retrain_model, \ retrain_model_path, gru_layers, hidden_size, learning_rate, weight_decay, model_dir, stats_dir, total_callers, \ train_mode = args # issue with semaphore lock: https://github.com/pytorch/pytorch/issues/2517 # mp.set_start_method('spawn') # Explicitly setting seed to make sure that models created in two processes # start from same random weights and biases. https://github.com/pytorch/pytorch/issues/2517 torch.manual_seed(42) train(train_file, test_file, batch_size, epochs, gpu_mode, num_workers, retrain_model, retrain_model_path, gru_layers, hidden_size, learning_rate, weight_decay, model_dir, stats_dir, train_mode, total_callers, rank, device_ids[rank]) cleanup()
Example #7
Source File: predict_gpu.py From helen with MIT License | 6 votes |
def predict_gpu(file_chunks, output_filepath, model_path, batch_size, total_callers, devices, num_workers): """ Create a prediction table/dictionary of an images set using a trained model. :param file_chunks: Path to chunked files :param batch_size: Batch size used for prediction :param model_path: Path to a trained model :param output_filepath: Path to output directory :param total_callers: Number of callers to spawn :param devices: List of available CUDA devices :param num_workers: Number of workers to be used by the dataloader :return: Prediction dictionary """ # create the arguments to send for prediction args = (output_filepath, model_path, batch_size, num_workers) # spawn the processes to call the prediction method mp.spawn(setup, args=(total_callers, args, file_chunks, devices), nprocs=total_callers, join=True)
Example #8
Source File: make_sem_seg_labels.py From irn with MIT License | 6 votes |
def run(args): model = getattr(importlib.import_module(args.irn_network), 'EdgeDisplacement')() model.load_state_dict(torch.load(args.irn_weights_name), strict=False) model.eval() n_gpus = torch.cuda.device_count() dataset = voc12.dataloader.VOC12ClassificationDatasetMSF(args.infer_list, voc12_root=args.voc12_root, scales=(1.0,)) dataset = torchutils.split_dataset(dataset, n_gpus) print("[", end='') multiprocessing.spawn(_work, nprocs=n_gpus, args=(model, dataset, args), join=True) print("]") torch.cuda.empty_cache()
Example #9
Source File: make_cam.py From irn with MIT License | 6 votes |
def run(args): model = getattr(importlib.import_module(args.cam_network), 'CAM')() model.load_state_dict(torch.load(args.cam_weights_name + '.pth'), strict=True) model.eval() n_gpus = torch.cuda.device_count() dataset = voc12.dataloader.VOC12ClassificationDatasetMSF(args.train_list, voc12_root=args.voc12_root, scales=args.cam_scales) dataset = torchutils.split_dataset(dataset, n_gpus) print('[ ', end='') multiprocessing.spawn(_work, nprocs=n_gpus, args=(model, dataset, args), join=True) print(']') torch.cuda.empty_cache()
Example #10
Source File: cam_to_ir_label.py From irn with MIT License | 5 votes |
def run(args): dataset = voc12.dataloader.VOC12ImageDataset(args.train_list, voc12_root=args.voc12_root, img_normal=None, to_torch=False) dataset = torchutils.split_dataset(dataset, args.num_workers) print('[ ', end='') multiprocessing.spawn(_work, nprocs=args.num_workers, args=(dataset, args), join=True) print(']')
Example #11
Source File: imagenet.py From pytorch-dp with Apache License 2.0 | 5 votes |
def main(): args = parser.parse_args() if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn( "You have chosen to seed training. " "This will turn on the CUDNN deterministic setting, " "which can slow down your training considerably! " "You may see unexpected behavior when restarting " "from checkpoints." ) if args.gpu is not None: warnings.warn( "You have chosen a specific GPU. This will completely " "disable data parallelism." ) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 or args.multiprocessing_distributed ngpus_per_node = torch.cuda.device_count() if args.multiprocessing_distributed: # Since we have ngpus_per_node processes per node, the total world_size # needs to be adjusted accordingly args.world_size = ngpus_per_node * args.world_size # Use torch.multiprocessing.spawn to launch distributed processes: the # main_worker process function mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: # Simply call main_worker function main_worker(args.gpu, ngpus_per_node, args)
Example #12
Source File: multiprocessing_distributed.py From pytorch-distributed with MIT License | 5 votes |
def main(): args = parser.parse_args() mp.spawn(main_worker, nprocs=4, args=(4, args))
Example #13
Source File: train.py From semseg with MIT License | 5 votes |
def main(): args = get_parser() check(args) os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.train_gpu) if args.manual_seed is not None: random.seed(args.manual_seed) np.random.seed(args.manual_seed) torch.manual_seed(args.manual_seed) torch.cuda.manual_seed(args.manual_seed) torch.cuda.manual_seed_all(args.manual_seed) cudnn.benchmark = False cudnn.deterministic = True if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 or args.multiprocessing_distributed args.ngpus_per_node = len(args.train_gpu) if len(args.train_gpu) == 1: args.sync_bn = False args.distributed = False args.multiprocessing_distributed = False if args.multiprocessing_distributed: port = find_free_port() args.dist_url = f"tcp://127.0.0.1:{port}" args.world_size = args.ngpus_per_node * args.world_size mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args.ngpus_per_node, args)) else: main_worker(args.train_gpu, args.ngpus_per_node, args)
Example #14
Source File: native.py From ignite with BSD 3-Clause "New" or "Revised" License | 5 votes |
def spawn( fn: Callable, args: Tuple, kwargs_dict: Optional[Mapping] = None, nproc_per_node: int = 1, nnodes: int = 1, node_rank: int = 0, master_addr: str = "127.0.0.1", master_port: int = 2222, backend: str = "nccl", **kwargs ): world_size = nnodes * nproc_per_node spawn_kwargs = { "join": kwargs.get("join", True), "daemon": kwargs.get("daemon", False), } # start_method in pytorch >= 1.5 if LooseVersion(torch.__version__) >= LooseVersion("1.5.0"): spawn_kwargs["start_method"] = kwargs.get("start_method", "spawn") mp.spawn( _NativeDistModel._dist_worker_task_fn, nprocs=nproc_per_node, args=( backend, fn, args, kwargs_dict, world_size, nproc_per_node, node_rank, master_addr, master_port, kwargs, ), **spawn_kwargs, )
Example #15
Source File: launch.py From detectron2 with Apache License 2.0 | 5 votes |
def launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=()): """ Args: main_func: a function that will be called by `main_func(*args)` num_machines (int): the total number of machines machine_rank (int): the rank of this machine (one per machine) dist_url (str): url to connect to for distributed training, including protocol e.g. "tcp://127.0.0.1:8686". Can be set to auto to automatically select a free port on localhost args (tuple): arguments passed to main_func """ world_size = num_machines * num_gpus_per_machine if world_size > 1: # https://github.com/pytorch/pytorch/pull/14391 # TODO prctl in spawned processes if dist_url == "auto": assert num_machines == 1, "dist_url=auto cannot work with distributed training." port = _find_free_port() dist_url = f"tcp://127.0.0.1:{port}" mp.spawn( _distributed_worker, nprocs=num_gpus_per_machine, args=(main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args), daemon=False, ) else: main_func(*args)
Example #16
Source File: main.py From PyTorch with MIT License | 5 votes |
def main(): args = parser.parse_args() if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 or args.multiprocessing_distributed ngpus_per_node = torch.cuda.device_count() if args.multiprocessing_distributed: # Since we have ngpus_per_node processes per node, the total world_size # needs to be adjusted accordingly args.world_size = ngpus_per_node * args.world_size # Use torch.multiprocessing.spawn to launch distributed processes: the # main_worker process function mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: # Simply call main_worker function main_worker(args.gpu, ngpus_per_node, args)
Example #17
Source File: main.py From PyTorch with MIT License | 5 votes |
def main(): mp.spawn( run_worker, args=(args.world_size, ), nprocs=args.world_size, join=True )
Example #18
Source File: main.py From TF2 with Apache License 2.0 | 5 votes |
def main(): args = parser.parse_args() if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 or args.multiprocessing_distributed ngpus_per_node = torch.cuda.device_count() if args.multiprocessing_distributed: # Since we have ngpus_per_node processes per node, the total world_size # needs to be adjusted accordingly args.world_size = ngpus_per_node * args.world_size # Use torch.multiprocessing.spawn to launch distributed processes: the # main_worker process function mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: # Simply call main_worker function main_worker(args.gpu, ngpus_per_node, args)
Example #19
Source File: main.py From online-normalization with BSD 3-Clause "New" or "Revised" License | 5 votes |
def main(): if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 or args.multiprocessing_distributed if args.distributed: raise NotImplementedError('multiprocessing with ON not implemented') ngpus_per_node = torch.cuda.device_count() if args.multiprocessing_distributed: # Since we have ngpus_per_node processes per node, the total world_size # needs to be adjusted accordingly args.world_size = ngpus_per_node * args.world_size # Use torch.multiprocessing.spawn to launch distributed processes: the # main_worker process function mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: # Simply call main_worker function main_worker(args.gpu, ngpus_per_node, args)
Example #20
Source File: local_timer_example.py From elastic with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_example_start_method_spawn(self): self._run_example_with(start_method="spawn")
Example #21
Source File: launch.py From detectron2 with Apache License 2.0 | 5 votes |
def launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=()): """ Launch multi-gpu or distributed training. This function must be called on all machines involved in the training. It will spawn child processes (defined by ``num_gpus_per_machine`) on each machine. Args: main_func: a function that will be called by `main_func(*args)` num_gpus_per_machine (int): number of GPUs per machine num_machines (int): the total number of machines machine_rank (int): the rank of this machine dist_url (str): url to connect to for distributed jobs, including protocol e.g. "tcp://127.0.0.1:8686". Can be set to "auto" to automatically select a free port on localhost args (tuple): arguments passed to main_func """ world_size = num_machines * num_gpus_per_machine if world_size > 1: # https://github.com/pytorch/pytorch/pull/14391 # TODO prctl in spawned processes if dist_url == "auto": assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs." port = _find_free_port() dist_url = f"tcp://127.0.0.1:{port}" if num_machines > 1 and dist_url.startswith("file://"): logger = logging.getLogger(__name__) logger.warning( "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://" ) mp.spawn( _distributed_worker, nprocs=num_gpus_per_machine, args=(main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args), daemon=False, ) else: main_func(*args)
Example #22
Source File: main.py From GroupNorm-reproduce with Apache License 2.0 | 5 votes |
def main(): args = parser.parse_args() if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 or args.multiprocessing_distributed ngpus_per_node = torch.cuda.device_count() if args.multiprocessing_distributed: # Since we have ngpus_per_node processes per node, the total world_size # needs to be adjusted accordingly args.world_size = ngpus_per_node * args.world_size # Use torch.multiprocessing.spawn to launch distributed processes: the # main_worker process function mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: # Simply call main_worker function main_worker(args.gpu, ngpus_per_node, args)
Example #23
Source File: train_net.py From DSGN with MIT License | 5 votes |
def main(): args = get_parser() if args.debug: args.savemodel = './outputs/debug/' args.btrain = 1 args.workers = 0 global cfg exp = Experimenter(args.savemodel, cfg_path=args.cfg) cfg = exp.config reset_seed(args.seed) cfg.debug = args.debug cfg.warmup = getattr(cfg, 'warmup', True) if not args.debug else False ### distributed training ### if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() print('ngpus_per_node: {}'.format(ngpus_per_node)) args.ngpus_per_node = ngpus_per_node args.distributed = ngpus_per_node > 0 and (args.world_size > 1 or args.multiprocessing_distributed) args.multiprocessing_distributed = args.distributed if args.distributed and args.multiprocessing_distributed: # Since we have ngpus_per_node processes per node, the total world_size # needs to be adjusted accordingly args.world_size = ngpus_per_node * args.world_size # Use torch.multiprocessing.spawn to launch distributed processes: the # main_worker process function mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args, cfg, exp)) else: # Simply call main_worker function main_worker(0, ngpus_per_node, args, cfg, exp)
Example #24
Source File: conftest.py From pytorch-lightning with Apache License 2.0 | 5 votes |
def pytest_pyfunc_call(pyfuncitem): if pyfuncitem.get_closest_marker("spawn"): testfunction = pyfuncitem.obj funcargs = pyfuncitem.funcargs testargs = tuple([funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames]) mp.spawn(wraps, (testfunction, testargs)) return True
Example #25
Source File: conftest.py From pytorch-lightning with Apache License 2.0 | 5 votes |
def pytest_configure(config): config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn")
Example #26
Source File: test_converters.py From pytorch-lightning with Apache License 2.0 | 5 votes |
def test_numpy_metric_ddp(): tutils.reset_seed() tutils.set_random_master_port() world_size = 2 mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size) # dist.destroy_process_group()
Example #27
Source File: test_converters.py From pytorch-lightning with Apache License 2.0 | 5 votes |
def test_tensor_metric_ddp(): tutils.reset_seed() tutils.set_random_master_port() world_size = 2 mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size) # dist.destroy_process_group()
Example #28
Source File: test_converters.py From pytorch-lightning with Apache License 2.0 | 5 votes |
def test_sync_reduce_ddp(): """Make sure sync-reduce works with DDP""" tutils.reset_seed() tutils.set_random_master_port() worldsize = 2 mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize) # dist.destroy_process_group()
Example #29
Source File: abstractive_summarization_bertsum_cnndm_distributed_train.py From nlp-recipes with MIT License | 5 votes |
def main(): args = parser.parse_args() print("NCCL_IB_DISABLE: {}".format(os.getenv("NCCL_IB_DISABLE"))) print("quick_run is {}".format(args.quick_run)) print("output_dir is {}".format(args.output_dir)) print("data_dir is {}".format(args.data_dir)) print("cache_dir is {}".format(args.cache_dir)) TOP_N = -1 if args.quick_run.lower() == "false": TOP_N = 10 train_dataset, test_dataset = CNNDMSummarizationDataset( top_n=TOP_N, local_cache_path=args.data_dir, prepare_extractive=False ) ngpus_per_node = torch.cuda.device_count() processor = BertSumAbsProcessor( cache_dir=args.cache_dir, max_src_len=args.max_pos_length ) summarizer = BertSumAbs( processor, cache_dir=args.cache_dir, max_pos_length=args.max_pos_length ) mp.spawn( main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, summarizer, train_dataset, test_dataset, args), )
Example #30
Source File: imagenet.py From Compact-Global-Descriptor with BSD 2-Clause "Simplified" License | 5 votes |
def main(): # Use CUDA os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id use_cuda = torch.cuda.is_available() gpus = list(range(len(args.gpu_id.split(',')))) # Random seed if args.manualSeed is None: args.manualSeed = random.randint(1, 10000) random.seed(args.manualSeed) torch.manual_seed(args.manualSeed) if use_cuda: torch.cuda.manual_seed_all(args.manualSeed) start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 or args.multiprocessing_distributed if not os.path.isdir(args.checkpoint): mkdir_p(args.checkpoint) ngpus_per_node = torch.cuda.device_count() if args.multiprocessing_distributed: # Since we have ngpus_per_node processes per node, the total world_size # needs to be adjusted accordingly args.world_size = ngpus_per_node * args.world_size # Use torch.multiprocessing.spawn to launch distributed processes: the # main_worker process function mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: # Simply call main_worker function main_worker(args.gpu_id, ngpus_per_node, args)