Python apex.amp.init() Examples

The following are 8 code examples of apex.amp.init(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module apex.amp , or try the search function .
Example #1
Source File: trainer.py    From pytorch-asr with GNU General Public License v3.0 5 votes vote down vote up
def init_distributed(use_cuda, backend="nccl", init="slurm", local_rank=-1):
    #try:
    #    mp.set_start_method('spawn')  # spawn, forkserver, and fork
    #except RuntimeError:
    #    pass

    try:
        if local_rank == -1:
            if init == "slurm":
                rank = int(os.environ['SLURM_PROCID'])
                world_size = int(os.environ['SLURM_NTASKS'])
                local_rank = int(os.environ['SLURM_LOCALID'])
                #maser_node = os.environ['SLURM_TOPOLOGY_ADDR']
                #maser_port = '23456'
            elif init == "ompi":
                rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
                world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
                local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])

            if use_cuda:
                device = local_rank % torch.cuda.device_count()
                torch.cuda.set_device(device)
                print(f"set cuda device to cuda:{device}")

            master_node = os.environ["MASTER_ADDR"]
            master_port = os.environ["MASTER_PORT"]
            init_method = f"tcp://{master_node}:{master_port}"
            #init_method = "env://"
            dist.init_process_group(backend=backend, init_method=init_method, world_size=world_size, rank=rank)
            print(f"initialized as {rank}/{world_size} via {init_method}")
        else:
            if use_cuda:
                torch.cuda.set_device(local_rank)
                print(f"set cuda device to cuda:{local_rank}")
            dist.init_process_group(backend=backend, init_method="env://")
            print(f"initialized as {dist.get_rank()}/{dist.get_world_size()} via env://")
    except Exception as e:
        print(f"initialized as single process") 
Example #2
Source File: trainer.py    From pytorch-asr with GNU General Public License v3.0 5 votes vote down vote up
def get_amp_handle(args):
    if not args.use_cuda:
        args.fp16 = False
    if args.fp16:
        from apex import amp
        amp_handle = amp.init(enabled=True, enable_caching=True, verbose=False)
        return amp_handle
    else:
        return None 
Example #3
Source File: test_rnn.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def setUp(self):
        self.handle = amp.init(enabled=True)
        common_init(self) 
Example #4
Source File: test_rnn.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def setUp(self):
        self.handle = amp.init(enabled=True)
        common_init(self) 
Example #5
Source File: test_basic_casts.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def setUp(self):
        self.handle = amp.init(enabled=True)
        common_init(self) 
Example #6
Source File: test_basic_casts.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def setUp(self):
        self.handle = amp.init(enabled=True)
        common_init(self) 
Example #7
Source File: test_basic_casts.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_bce_is_float_with_allow_banned(self):
        self.handle._deactivate()
        self.handle = amp.init(enabled=True, allow_banned=True)
        assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
        self.bce_common(assertion) 
Example #8
Source File: test_promotion.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def setUp(self):
        self.handle = amp.init(enabled=True)
        common_init(self)