Python apex.amp.init() Examples
The following are 8
code examples of apex.amp.init().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
apex.amp
, or try the search function
.
Example #1
Source File: trainer.py From pytorch-asr with GNU General Public License v3.0 | 5 votes |
def init_distributed(use_cuda, backend="nccl", init="slurm", local_rank=-1): #try: # mp.set_start_method('spawn') # spawn, forkserver, and fork #except RuntimeError: # pass try: if local_rank == -1: if init == "slurm": rank = int(os.environ['SLURM_PROCID']) world_size = int(os.environ['SLURM_NTASKS']) local_rank = int(os.environ['SLURM_LOCALID']) #maser_node = os.environ['SLURM_TOPOLOGY_ADDR'] #maser_port = '23456' elif init == "ompi": rank = int(os.environ['OMPI_COMM_WORLD_RANK']) world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) if use_cuda: device = local_rank % torch.cuda.device_count() torch.cuda.set_device(device) print(f"set cuda device to cuda:{device}") master_node = os.environ["MASTER_ADDR"] master_port = os.environ["MASTER_PORT"] init_method = f"tcp://{master_node}:{master_port}" #init_method = "env://" dist.init_process_group(backend=backend, init_method=init_method, world_size=world_size, rank=rank) print(f"initialized as {rank}/{world_size} via {init_method}") else: if use_cuda: torch.cuda.set_device(local_rank) print(f"set cuda device to cuda:{local_rank}") dist.init_process_group(backend=backend, init_method="env://") print(f"initialized as {dist.get_rank()}/{dist.get_world_size()} via env://") except Exception as e: print(f"initialized as single process")
Example #2
Source File: trainer.py From pytorch-asr with GNU General Public License v3.0 | 5 votes |
def get_amp_handle(args): if not args.use_cuda: args.fp16 = False if args.fp16: from apex import amp amp_handle = amp.init(enabled=True, enable_caching=True, verbose=False) return amp_handle else: return None
Example #3
Source File: test_rnn.py From apex with BSD 3-Clause "New" or "Revised" License | 5 votes |
def setUp(self): self.handle = amp.init(enabled=True) common_init(self)
Example #4
Source File: test_rnn.py From apex with BSD 3-Clause "New" or "Revised" License | 5 votes |
def setUp(self): self.handle = amp.init(enabled=True) common_init(self)
Example #5
Source File: test_basic_casts.py From apex with BSD 3-Clause "New" or "Revised" License | 5 votes |
def setUp(self): self.handle = amp.init(enabled=True) common_init(self)
Example #6
Source File: test_basic_casts.py From apex with BSD 3-Clause "New" or "Revised" License | 5 votes |
def setUp(self): self.handle = amp.init(enabled=True) common_init(self)
Example #7
Source File: test_basic_casts.py From apex with BSD 3-Clause "New" or "Revised" License | 5 votes |
def test_bce_is_float_with_allow_banned(self): self.handle._deactivate() self.handle = amp.init(enabled=True, allow_banned=True) assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT) self.bce_common(assertion)
Example #8
Source File: test_promotion.py From apex with BSD 3-Clause "New" or "Revised" License | 5 votes |
def setUp(self): self.handle = amp.init(enabled=True) common_init(self)