Python torch.backends() Examples
The following are 2
code examples of torch.backends().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: main.py From SMIT with MIT License | 5 votes |
def main(config): from torch.backends import cudnn # For fast training cudnn.benchmark = True data_loader = get_loader( config.mode_data, config.image_size, config.batch_size, config.dataset_fake, config.mode, num_workers=config.num_workers, all_attr=config.ALL_ATTR, c_dim=config.c_dim) from misc.scores import set_score if set_score(config): return if config.mode == 'train': from train import Train Train(config, data_loader) from test import Test test = Test(config, data_loader) test(dataset=config.dataset_real) elif config.mode == 'test': from test import Test test = Test(config, data_loader) if config.DEMO_PATH: test.DEMO(config.DEMO_PATH) else: test(dataset=config.dataset_real)
Example #2
Source File: demo.py From aster.pytorch with MIT License | 4 votes |
def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) cudnn.benchmark = True torch.backends.cudnn.deterministic = True args.cuda = args.cuda and torch.cuda.is_available() if args.cuda: print('using cuda.') torch.set_default_tensor_type('torch.cuda.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') # Create data loaders if args.height is None or args.width is None: args.height, args.width = (32, 100) dataset_info = DataInfo(args.voc_type) # Create model model = ModelBuilder(arch=args.arch, rec_num_classes=dataset_info.rec_num_classes, sDim=args.decoder_sdim, attDim=args.attDim, max_len_labels=args.max_len, eos=dataset_info.char2id[dataset_info.EOS], STN_ON=args.STN_ON) # Load from checkpoint if args.resume: checkpoint = load_checkpoint(args.resume) model.load_state_dict(checkpoint['state_dict']) if args.cuda: device = torch.device("cuda") model = model.to(device) model = nn.DataParallel(model) # Evaluation model.eval() img = image_process(args.image_path) with torch.no_grad(): img = img.to(device) input_dict = {} input_dict['images'] = img.unsqueeze(0) # TODO: testing should be more clean. # to be compatible with the lmdb-based testing, need to construct some meaningless variables. rec_targets = torch.IntTensor(1, args.max_len).fill_(1) rec_targets[:,args.max_len-1] = dataset_info.char2id[dataset_info.EOS] input_dict['rec_targets'] = rec_targets input_dict['rec_lengths'] = [args.max_len] output_dict = model(input_dict) pred_rec = output_dict['output']['pred_rec'] pred_str, _ = get_str_list(pred_rec, input_dict['rec_targets'], dataset=dataset_info) print('Recognition result: {0}'.format(pred_str[0]))