Python torch.utils.data() Examples
The following are 30
code examples of torch.utils.data().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.utils
, or try the search function
.
Example #1
Source File: train_pointlk.py From pointnet-registration-framework with MIT License | 6 votes |
def eval_1(self, model, testloader, device): model.eval() vloss = 0.0 gloss = 0.0 count = 0 with torch.no_grad(): for i, data in enumerate(testloader): loss, loss_g = self.compute_loss(model, data, device) vloss1 = loss.item() vloss += vloss1 gloss1 = loss_g.item() gloss += gloss1 count += 1 ave_vloss = float(vloss)/count ave_gloss = float(gloss)/count return ave_vloss, ave_gloss
Example #2
Source File: model.py From ganomaly with MIT License | 6 votes |
def set_input(self, input:torch.Tensor): """ Set input and ground truth Args: input (FloatTensor): Input data for batch i. """ with torch.no_grad(): self.input.resize_(input[0].size()).copy_(input[0]) self.gt.resize_(input[1].size()).copy_(input[1]) self.label.resize_(input[1].size()) # Copy the first batch as the fixed input. if self.total_steps == self.opt.batchsize: self.fixed_input.resize_(input[0].size()).copy_(input[0]) ##
Example #3
Source File: data.py From yolo2-pytorch with GNU Lesser General Public License v3.0 | 6 votes |
def __init__(self, data, transform=lambda data: data, one_hot=None, shuffle=False, dir=None): """ Load the cached data (.pkl) into memory. :author 申瑞珉 (Ruimin Shen) :param data: A list contains the data samples (dict). :param transform: A function transforms (usually performs a sequence of data augmentation operations) the labels in a dict. :param one_hot: If a int value (total number of classes) is given, the class label (key "cls") will be generated in a one-hot format. :param shuffle: Shuffle the loaded dataset. :param dir: The directory to store the exception data. """ self.data = data if shuffle: random.shuffle(self.data) self.transform = transform self.one_hot = None if one_hot is None else sklearn.preprocessing.OneHotEncoder(one_hot, dtype=np.float32) self.dir = dir
Example #4
Source File: data.py From yolo2-pytorch with GNU Lesser General Public License v3.0 | 6 votes |
def __init__(self, resize, sizes, maintain=1, transform_image=lambda image: image, transform_tensor=None, dir=None): """ Unify multiple data samples (e.g., resize images into the same size, and padding bounding box labels into the same number) to form a batch. :author 申瑞珉 (Ruimin Shen) :param resize: A function to resize the image and labels. :param sizes: The image sizes to be randomly choosed. :param maintain: How many times a size to be maintained. :param transform_image: A function to transform the resized image. :param transform_tensor: A function to standardize a image into a tensor. :param dir: The directory to store the exception data. """ self.resize = resize self.sizes = sizes assert maintain > 0 self.maintain = maintain self._maintain = maintain self.transform_image = transform_image self.transform_tensor = transform_tensor self.dir = dir
Example #5
Source File: train_classifier.py From pointnet-registration-framework with MIT License | 6 votes |
def eval_1(self, model, testloader, device): model.eval() vloss = 0.0 pred = 0.0 count = 0 with torch.no_grad(): for i, data in enumerate(testloader): target, output, loss = self.compute_loss(model, data, device) loss1 = loss.item() vloss += loss1 count += output.size(0) _, pred1 = output.max(dim=1) ag = (pred1 == target) am = ag.sum() pred += am.item() ave_loss = float(vloss)/count accuracy = float(pred)/count return ave_loss, accuracy
Example #6
Source File: train_classifier.py From pointnet-registration-framework with MIT License | 6 votes |
def train_1(self, model, trainloader, optimizer, device): model.train() vloss = 0.0 pred = 0.0 count = 0 for i, data in enumerate(trainloader): target, output, loss = self.compute_loss(model, data, device) # forward + backward + optimize optimizer.zero_grad() loss.backward() optimizer.step() loss1 = loss.item() vloss += loss1 count += output.size(0) _, pred1 = output.max(dim=1) ag = (pred1 == target) am = ag.sum() pred += am.item() running_loss = float(vloss)/count accuracy = float(pred)/count return running_loss, accuracy
Example #7
Source File: build.py From R2CNN.pytorch with MIT License | 6 votes |
def make_batch_data_sampler( dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0 ): if aspect_grouping: if not isinstance(aspect_grouping, (list, tuple)): aspect_grouping = [aspect_grouping] aspect_ratios = _compute_aspect_ratios(dataset) group_ids = _quantize(aspect_ratios, aspect_grouping) batch_sampler = samplers.GroupedBatchSampler( sampler, group_ids, images_per_batch, drop_uneven=False ) else: batch_sampler = torch.utils.data.sampler.BatchSampler( sampler, images_per_batch, drop_last=False ) if num_iters is not None: batch_sampler = samplers.IterationBasedBatchSampler( batch_sampler, num_iters, start_iter ) return batch_sampler
Example #8
Source File: data.py From yolo2-pytorch with GNU Lesser General Public License v3.0 | 6 votes |
def __call__(self, batch): height, width = self.next_size() dim = max(len(data['cls']) for data in batch) _batch = [] for data in batch: try: data = self.resize(data, height, width) data['image'] = self.transform_image(data['image']) data = padding_labels(data, dim) if self.transform_tensor is not None: data['tensor'] = self.transform_tensor(data['image']) _batch.append(data) except: if self.dir is not None: os.makedirs(self.dir, exist_ok=True) name = self.__module__ + '.' + type(self).__name__ with open(os.path.join(self.dir, name + '.pkl'), 'wb') as f: pickle.dump(data, f) raise return torch.utils.data.dataloader.default_collate(_batch)
Example #9
Source File: DataSet.py From ext_portrait_segmentation with MIT License | 6 votes |
def __getitem__(self, idx): ''' :param idx: Index of the image file :return: returns the image and corresponding label file. ''' image_name = self.imList[idx] label_name = self.labelList[idx] image = cv2.imread(image_name) label = cv2.imread(label_name, 0) label_bool = 255 * ((label > 200).astype(np.uint8)) if self.transform: [image, label] = self.transform(image, label_bool) if self.edge: np_label = 255 * label.data.numpy().astype(np.uint8) kernel = np.ones((self.kernel_size , self.kernel_size ), np.uint8) erosion = cv2.erode(np_label, kernel, iterations=1) dilation = cv2.dilate(np_label, kernel, iterations=1) boundary = dilation - erosion edgemap = 255 * torch.ones_like(label) edgemap[torch.from_numpy(boundary) > 0] = label[torch.from_numpy(boundary) > 0] return (image, label, edgemap) else: return (image, label)
Example #10
Source File: eval.py From yolo2-pytorch with GNU Lesser General Public License v3.0 | 6 votes |
def get_loader(self): paths = [os.path.join(self.cache_dir, phase + '.pkl') for phase in self.config.get('eval', 'phase').split()] dataset = utils.data.Dataset(utils.data.load_pickles(paths)) logging.info('num_examples=%d' % len(dataset)) size = tuple(map(int, self.config.get('image', 'size').split())) try: workers = self.config.getint('data', 'workers') except configparser.NoOptionError: workers = multiprocessing.cpu_count() collate_fn = utils.data.Collate( transform.parse_transform(self.config, self.config.get('transform', 'resize_eval')), [size], transform_image=transform.get_transform(self.config, self.config.get('transform', 'image_test').split()), transform_tensor=transform.get_transform(self.config, self.config.get('transform', 'tensor').split()), ) return torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size, num_workers=workers, collate_fn=collate_fn)
Example #11
Source File: receptive_field_analyzer.py From yolo2-pytorch with GNU Lesser General Public License v3.0 | 6 votes |
def __init__(self, args, config): self.args = args self.config = config self.model_dir = utils.get_model_dir(config) self.category = utils.get_category(config) self.anchors = torch.from_numpy(utils.get_anchors(config)).contiguous() self.dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config), self.anchors, len(self.category)) self.dnn.eval() logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.dnn.state_dict().values()))) if torch.cuda.is_available(): self.dnn.cuda() self.height, self.width = tuple(map(int, config.get('image', 'size').split())) output = self.dnn(torch.autograd.Variable(utils.ensure_device(torch.zeros(1, 3, self.height, self.width)), volatile=True)) _, _, self.rows, self.cols = output.size() self.i, self.j = self.rows // 2, self.cols // 2 self.output = output[:, :, self.i, self.j] dataset = Dataset(self.height, self.width) try: workers = self.config.getint('data', 'workers') except configparser.NoOptionError: workers = multiprocessing.cpu_count() self.loader = torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size, num_workers=workers)
Example #12
Source File: receptive_field_analyzer.py From yolo2-pytorch with GNU Lesser General Public License v3.0 | 6 votes |
def __call__(self): changed = np.zeros([self.height, self.width], np.bool) for yx in tqdm.tqdm(self.loader): batch_size = yx.size(0) tensor = torch.zeros(batch_size, 3, self.height, self.width) for i, _yx in enumerate(torch.unbind(yx)): y, x = torch.unbind(_yx) tensor[i, :, y, x] = 1 tensor = utils.ensure_device(tensor) output = self.dnn(torch.autograd.Variable(tensor, volatile=True)) output = output[:, :, self.i, self.j] cmp = output == self.output cmp = torch.prod(cmp, -1).data for _yx, c in zip(torch.unbind(yx), torch.unbind(cmp)): y, x = torch.unbind(_yx) changed[y, x] = c return changed
Example #13
Source File: test_pointlk.py From pointnet-registration-framework with MIT License | 6 votes |
def eval_1(self, model, testloader, device): model.eval() with open(self.filename, 'w') as fout: self.eval_1__header(fout) with torch.no_grad(): for i, data in enumerate(testloader): p0, p1, igt = data res = self.do_estimate(p0, p1, model, device) # --> [1, 4, 4] ig_gt = igt.cpu().contiguous().view(-1, 4, 4) # --> [1, 4, 4] g_hat = res.cpu().contiguous().view(-1, 4, 4) # --> [1, 4, 4] dg = g_hat.bmm(ig_gt) # if correct, dg == identity matrix. dx = ptlk.se3.log(dg) # --> [1, 6] (if corerct, dx == zero vector) dn = dx.norm(p=2, dim=1) # --> [1] dm = dn.mean() self.eval_1__write(fout, ig_gt, g_hat) LOGGER.info('test, %d/%d, %f', i, len(testloader), dm)
Example #14
Source File: test_pointlk.py From pointnet-registration-framework with MIT License | 6 votes |
def run(args, testset, action): if not torch.cuda.is_available(): args.device = 'cpu' args.device = torch.device(args.device) LOGGER.debug('Testing (PID=%d), %s', os.getpid(), args) model = action.create_model() if args.pretrained: assert os.path.isfile(args.pretrained) model.load_state_dict(torch.load(args.pretrained, map_location='cpu')) model.to(args.device) # dataloader testloader = torch.utils.data.DataLoader( testset, batch_size=1, shuffle=False, num_workers=args.workers) # testing LOGGER.debug('tests, begin') action.eval_1(model, testloader, args.device) LOGGER.debug('tests, end')
Example #15
Source File: prefix_dataset.py From hyrnn with Apache License 2.0 | 6 votes |
def __init__(self, root, num=10, split="train", download=False): assert num in {10, 30, 50} assert split in {"train", "test", "valid"} self.num = num self.split = split self.root = root if download: self.download() else: self._check_integrity() name = {"train": "train", "test": "test", "valid": "dev"}[split] self.data = pickle.load(open(os.path.join(root, self._suffix, name), "rb")) self.id2word = pickle.load( open(os.path.join(root, self._suffix, "id_to_word"), "rb") ) self.word2id = pickle.load( open(os.path.join(root, self._suffix, "word_to_id"), "rb") )
Example #16
Source File: train.py From ICDAR-2019-SROIE with MIT License | 6 votes |
def trainBatch(net, criterion, optimizer): data = train_iter.next() cpu_images, cpu_texts = data batch_size = cpu_images.size(0) utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts) utils.loadData(text, t) utils.loadData(length, l) preds = crnn(image) preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) cost = criterion(preds, text, preds_size, length) / batch_size crnn.zero_grad() cost.backward() optimizer.step() return cost
Example #17
Source File: build.py From Res2Net-maskrcnn with MIT License | 6 votes |
def make_batch_data_sampler( dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0 ): if aspect_grouping: if not isinstance(aspect_grouping, (list, tuple)): aspect_grouping = [aspect_grouping] aspect_ratios = _compute_aspect_ratios(dataset) group_ids = _quantize(aspect_ratios, aspect_grouping) batch_sampler = samplers.GroupedBatchSampler( sampler, group_ids, images_per_batch, drop_uneven=False ) else: batch_sampler = torch.utils.data.sampler.BatchSampler( sampler, images_per_batch, drop_last=False ) if num_iters is not None: batch_sampler = samplers.IterationBasedBatchSampler( batch_sampler, num_iters, start_iter ) return batch_sampler
Example #18
Source File: train.py From yolo2-pytorch with GNU Lesser General Public License v3.0 | 6 votes |
def get_loader(self): paths = [os.path.join(self.cache_dir, phase + '.pkl') for phase in self.config.get('train', 'phase').split()] dataset = utils.data.Dataset( utils.data.load_pickles(paths), transform=transform.augmentation.get_transform(self.config, self.config.get('transform', 'augmentation').split()), one_hot=None if self.config.getboolean('train', 'cross_entropy') else len(self.category), shuffle=self.config.getboolean('data', 'shuffle'), dir=os.path.join(self.model_dir, 'exception'), ) logging.info('num_examples=%d' % len(dataset)) try: workers = self.config.getint('data', 'workers') if torch.cuda.is_available(): workers = workers * torch.cuda.device_count() except configparser.NoOptionError: workers = multiprocessing.cpu_count() collate_fn = utils.data.Collate( transform.parse_transform(self.config, self.config.get('transform', 'resize_train')), utils.train.load_sizes(self.config), maintain=self.config.getint('data', 'maintain'), transform_image=transform.get_transform(self.config, self.config.get('transform', 'image_train').split()), transform_tensor=transform.get_transform(self.config, self.config.get('transform', 'tensor').split()), dir=os.path.join(self.model_dir, 'exception'), ) return torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size * torch.cuda.device_count() if torch.cuda.is_available() else self.args.batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn, pin_memory=torch.cuda.is_available())
Example #19
Source File: train.py From yolo2-pytorch with GNU Lesser General Public License v3.0 | 6 votes |
def iterate(self, data): for key in data: t = data[key] if torch.is_tensor(t): data[key] = utils.ensure_device(t) tensor = torch.autograd.Variable(data['tensor']) pred = pybenchmark.profile('inference')(model._inference)(self.inference, tensor) height, width = data['image'].size()[1:3] rows, cols = pred['feature'].size()[-2:] loss, debug = pybenchmark.profile('loss')(model.loss)(self.anchors, norm_data(data, height, width, rows, cols), pred, self.config.getfloat('model', 'threshold')) loss_hparam = {key: loss[key] * self.config.getfloat('hparam', key) for key in loss} loss_total = sum(loss_hparam.values()) self.optimizer.zero_grad() loss_total.backward() try: clip = self.config.getfloat('train', 'clip') nn.utils.clip_grad_norm(self.inference.parameters(), clip) except configparser.NoOptionError: pass self.optimizer.step() return dict( height=height, width=width, rows=rows, cols=cols, data=data, pred=pred, debug=debug, loss_total=loss_total, loss=loss, loss_hparam=loss_hparam, )
Example #20
Source File: train.py From yolo2-pytorch with GNU Lesser General Public License v3.0 | 5 votes |
def norm_data(data, height, width, rows, cols, keys='yx_min, yx_max'): _data = {key: data[key] for key in data} scale = utils.ensure_device(torch.from_numpy(np.reshape(np.array([rows / height, cols / width], dtype=np.float32), [1, 1, 2]))) for key in keys.split(', '): _data[key] = _data[key] * scale return _data
Example #21
Source File: dataset.py From decaNLP with BSD 3-Clause "New" or "Revised" License | 5 votes |
def __init__(self, path, format, fields, skip_header=False, subsample=False, **kwargs): """Create a TabularDataset given a path, file format, and field list. Arguments: path (str): Path to the data file. format (str): The format of the data file. One of "CSV", "TSV", or "JSON" (case-insensitive). fields (list(tuple(str, Field)) or dict[str: tuple(str, Field)]: For CSV and TSV formats, list of tuples of (name, field). The list should be in the same order as the columns in the CSV or TSV file, while tuples of (name, None) represent columns that will be ignored. For JSON format, dictionary whose keys are the JSON keys and whose values are tuples of (name, field). This allows the user to rename columns from their JSON key names and also enables selecting a subset of columns to load (since JSON keys not present in the input dictionary are ignored). skip_header (bool): Whether to skip the first line of the input file. """ make_example = { 'json': Example.fromJSON, 'dict': Example.fromdict, 'tsv': Example.fromTSV, 'csv': Example.fromCSV}[format.lower()] examples = [] with io.open(os.path.expanduser(path), encoding="utf8") as f: if skip_header: next(f) for line in f: examples.append(make_example(line, fields)) if make_example in (Example.fromdict, Example.fromJSON): fields, field_dict = [], fields for field in field_dict.values(): if isinstance(field, list): fields.extend(field) else: fields.append(field) super(TabularDataset, self).__init__(examples, fields, **kwargs)
Example #22
Source File: dataset.py From decaNLP with BSD 3-Clause "New" or "Revised" License | 5 votes |
def download(cls, root, check=None): """Download and unzip an online archive (.zip, .gz, or .tgz). Arguments: root (str): Folder to download data to. check (str or None): Folder whose existence indicates that the dataset has already been downloaded, or None to check the existence of root/{cls.name}. Returns: dataset_path (str): Path to extracted dataset. """ path = os.path.join(root, cls.name) check = path if check is None else check if not os.path.isdir(check): for url in cls.urls: if isinstance(url, tuple): url, filename = url else: filename = os.path.basename(url) zpath = os.path.join(path, filename) if not os.path.isfile(zpath): if not os.path.exists(os.path.dirname(zpath)): os.makedirs(os.path.dirname(zpath)) print('downloading {}'.format(filename)) download_from_url(url, zpath) ext = os.path.splitext(filename)[-1] if ext == '.zip': with zipfile.ZipFile(zpath, 'r') as zfile: print('extracting') zfile.extractall(path) elif ext in ['.gz', '.tgz']: with tarfile.open(zpath, 'r:gz') as tar: dirs = [member for member in tar.getmembers()] tar.extractall(path=path, members=dirs) elif ext in ['.bz2', '.tar']: with tarfile.open(zpath) as tar: dirs = [member for member in tar.getmembers()] tar.extractall(path=path, members=dirs) return os.path.join(path, cls.dirname)
Example #23
Source File: train.py From yolo2-pytorch with GNU Lesser General Public License v3.0 | 5 votes |
def copy_image(self, **kwargs): step, height, width, rows, cols, data, pred, debug = (kwargs[key] for key in 'step, height, width, rows, cols, data, pred, debug'.split(', ')) data = {key: data[key].clone().cpu().numpy() for key in 'image, yx_min, yx_max, cls'.split(', ')} pred = {key: pred[key].data.clone().cpu().numpy() for key in 'yx_min, yx_max, iou, logits'.split(', ') if key in pred} matching = (debug['positive'].float() - debug['negative'].float() + 1) / 2 matching = matching.data.clone().cpu().numpy() return dict( step=step, height=height, width=width, rows=rows, cols=cols, data=data, pred=pred, matching=matching, )
Example #24
Source File: dataset.py From decaNLP with BSD 3-Clause "New" or "Revised" License | 5 votes |
def splits(cls, path=None, root='.data', train=None, validation=None, test=None, **kwargs): """Create Dataset objects for multiple splits of a dataset. Arguments: path (str): Common prefix of the splits' file paths, or None to use the result of cls.download(root). root (str): Root dataset storage directory. Default is '.data'. train (str): Suffix to add to path for the train set, or None for no train set. Default is None. validation (str): Suffix to add to path for the validation set, or None for no validation set. Default is None. test (str): Suffix to add to path for the test set, or None for no test set. Default is None. Remaining keyword arguments: Passed to the constructor of the Dataset (sub)class being used. Returns: split_datasets (tuple(Dataset)): Datasets for train, validation, and test splits in that order, if provided. """ if path is None: path = cls.download(root) train_data = None if train is None else cls( os.path.join(path, train), **kwargs) val_data = None if validation is None else cls( os.path.join(path, validation), **kwargs) test_data = None if test is None else cls( os.path.join(path, test), **kwargs) return tuple(d for d in (train_data, val_data, test_data) if d is not None)
Example #25
Source File: build.py From Res2Net-maskrcnn with MIT License | 5 votes |
def make_data_sampler(dataset, shuffle, distributed): if distributed: return samplers.DistributedSampler(dataset, shuffle=shuffle) if shuffle: sampler = torch.utils.data.sampler.RandomSampler(dataset) else: sampler = torch.utils.data.sampler.SequentialSampler(dataset) return sampler
Example #26
Source File: DataSet.py From ext_portrait_segmentation with MIT License | 5 votes |
def __init__(self, imList, labelList, Enc=True, transform=None, edge=False): ''' :param imList: image list (Note that these lists have been processed and pickled using the loadData.py) :param labelList: label list (Note that these lists have been processed and pickled using the loadData.py) :param transform: Type of transformation. SEe CVTransforms.py for supported transformations ''' self.imList = imList self.labelList = labelList self.transform = transform print("This num of data is " +str(len(imList))) self.edge = edge if Enc : self.kernel_size = 5 else: self.kernel_size = 15
Example #27
Source File: train.py From PoseWarper with Apache License 2.0 | 5 votes |
def parse_args(): parser = argparse.ArgumentParser(description='Train keypoints network') # general parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) parser.add_argument('opts', help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER) # philly parser.add_argument('--modelDir', help='model directory', type=str, default='') parser.add_argument('--logDir', help='log directory', type=str, default='') parser.add_argument('--dataDir', help='data directory', type=str, default='') parser.add_argument('--prevModelDir', help='prev Model directory', type=str, default='') args = parser.parse_args() return args
Example #28
Source File: test.py From PoseWarper with Apache License 2.0 | 5 votes |
def parse_args(): parser = argparse.ArgumentParser(description='Train keypoints network') # general parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) parser.add_argument('opts', help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER) # philly parser.add_argument('--modelDir', help='model directory', type=str, default='') parser.add_argument('--logDir', help='log directory', type=str, default='') parser.add_argument('--dataDir', help='data directory', type=str, default='') parser.add_argument('--prevModelDir', help='prev Model directory', type=str, default='') args = parser.parse_args() return args
Example #29
Source File: movielens.py From pytorch-fm with MIT License | 5 votes |
def __init__(self, dataset_path, sep=',', engine='c', header='infer'): data = pd.read_csv(dataset_path, sep=sep, engine=engine, header=header).to_numpy()[:, :3] self.items = data[:, :2].astype(np.int) - 1 # -1 because ID begins from 1 self.targets = self.__preprocess_target(data[:, 2]).astype(np.float32) self.field_dims = np.max(self.items, axis=0) + 1 self.user_field_idx = np.array((0, ), dtype=np.long) self.item_field_idx = np.array((1,), dtype=np.long)
Example #30
Source File: horovod_distributed.py From pytorch-distributed with MIT License | 5 votes |
def train(train_loader, model, criterion, optimizer, epoch, args): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i, (images, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)