Python torch.utils.data.dataset.ConcatDataset() Examples

The following are 7 code examples of torch.utils.data.dataset.ConcatDataset(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.utils.data.dataset , or try the search function .
Example #1
Source File: charades_ego_plus_charades.py    From PyVideoResearch with GNU General Public License v3.0 6 votes vote down vote up
def get(cls, args, splits=('train', 'val', 'val_video')):
        newargs1 = copy.deepcopy(args)
        newargs2 = copy.deepcopy(args)
        vars(newargs1).update({
            'train_file': args.train_file.split(';')[0],
            'val_file': args.val_file.split(';')[0],
            'data': args.data.split(';')[0]})
        vars(newargs2).update({
            'train_file': args.train_file.split(';')[1],
            'val_file': args.val_file.split(';')[1],
            'data': args.data.split(';')[1]})

        if 'train' in splits or 'val' in splits:
            train_datasetego, val_datasetego, _ = CharadesEgoMeta.get(newargs1, splits=splits)
        else:
            train_datasetego, val_datasetego = None, None
        train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoPlusCharades, cls).get(newargs2, splits=splits)

        if 'train' in splits:
            train_dataset.target_transform = transforms.Lambda(lambda x: -x)
            train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 3)  # magic number to balance
        if 'val' in splits:
            val_dataset.target_transform = transforms.Lambda(lambda x: -x)
            val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 3)
        return train_dataset, val_dataset, valvideo_dataset 
Example #2
Source File: charades_ego_video_plus_charades.py    From PyVideoResearch with GNU General Public License v3.0 6 votes vote down vote up
def get(cls, args, splits=('train', 'val', 'val_video')):
        newargs1 = copy.deepcopy(args)
        newargs2 = copy.deepcopy(args)
        vars(newargs1).update({
            'train_file': args.train_file.split(';')[0],
            'val_file': args.val_file.split(';')[0],
            'data': args.data.split(';')[0]})
        vars(newargs2).update({
            'train_file': args.train_file.split(';')[1],
            'val_file': args.val_file.split(';')[1],
            'data': args.data.split(';')[1]})

        if 'train' in splits or 'val' in splits:
            train_datasetego, val_datasetego, _ = CharadesEgoVideoMeta.get(newargs1, splits=splits)
        else:
            train_datasetego, val_datasetego = None, None
        train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoVideoPlusCharades, cls).get(newargs2, splits=splits)

        if 'train' in splits:
            train_dataset.target_transform = transforms.Lambda(lambda x: -x)
            train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 3)  # magic number to balance
        if 'val' in splits:
            val_dataset.target_transform = transforms.Lambda(lambda x: -x)
            val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 3)
        return train_dataset, val_dataset, valvideo_dataset 
Example #3
Source File: charades_ego_plus_charades3.py    From PyVideoResearch with GNU General Public License v3.0 6 votes vote down vote up
def get(cls, args, splits=('train', 'val', 'val_video')):
        newargs1 = copy.deepcopy(args)
        newargs2 = copy.deepcopy(args)
        vars(newargs1).update({
            'train_file': args.train_file.split(';')[0],
            'val_file': args.val_file.split(';')[0],
            'data': args.data.split(';')[0]})
        vars(newargs2).update({
            'train_file': args.train_file.split(';')[1],
            'val_file': args.val_file.split(';')[1],
            'data': args.data.split(';')[1]})

        if 'train' in splits or 'val' in splits:
            train_datasetego, val_datasetego, _ = CharadesEgoMeta.get(newargs1, splits=splits)
        else:
            train_datasetego, val_datasetego = None, None
        train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoPlusCharades3, cls).get(newargs2, splits=splits)

        if 'train' in splits:
            train_dataset.target_transform = transforms.Lambda(lambda x: -x)
            train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 1)  # magic number to balance
        if 'val' in splits:
            val_dataset.target_transform = transforms.Lambda(lambda x: -x)
            val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 1)
        return train_dataset, val_dataset, valvideo_dataset 
Example #4
Source File: charades_ego_plus_charades2.py    From PyVideoResearch with GNU General Public License v3.0 6 votes vote down vote up
def get(cls, args, splits=('train', 'val', 'val_video')):
        newargs1 = copy.deepcopy(args)
        newargs2 = copy.deepcopy(args)
        vars(newargs1).update({
            'train_file': args.train_file.split(';')[0],
            'val_file': args.val_file.split(';')[0],
            'data': args.data.split(';')[0]})
        vars(newargs2).update({
            'train_file': args.train_file.split(';')[1],
            'val_file': args.val_file.split(';')[1],
            'data': args.data.split(';')[1]})

        if 'train' in splits or 'val' in splits:
            train_datasetego, val_datasetego, _ = CharadesEgoMeta.get(newargs1, splits=splits)
        else:
            train_datasetego, val_datasetego = None, None
        train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoPlusCharades2, cls).get(newargs2, splits=splits)

        if 'train' in splits:
            train_dataset.target_transform = transforms.Lambda(lambda x: -x)
            train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 6)  # magic number to balance
        if 'val' in splits:
            val_dataset.target_transform = transforms.Lambda(lambda x: -x)
            val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 6)
        return train_dataset, val_dataset, valvideo_dataset 
Example #5
Source File: dataset.py    From Jacinle with MIT License 5 votes vote down vote up
def __add__(self, other):
        from torch.utils.data.dataset import ConcatDataset
        return ConcatDataset([self, other]) 
Example #6
Source File: charadesegoplusrgb.py    From actor-observer with GNU General Public License v3.0 5 votes vote down vote up
def get(cls, args):
        train_datasetego, val_datasetego, _ = charadesego.CharadesEgo.get(args)
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        newargs = copy.deepcopy(args)
        vars(newargs).update({
            'train_file': args.original_charades_train,
            'val_file': args.original_charades_test,
            'data': args.original_charades_data})
        train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoPlusRGB, cls).get(newargs)
        train_dataset.transform.transforms.append(transforms.Lambda(lambda x: [x, x, x]))
        val_dataset.transform.transforms.append(transforms.Lambda(lambda x: [x, x, x]))
        valvideo_dataset.transform.transforms.append(transforms.Lambda(lambda x: [x, x, x]))
        train_dataset.target_transform = transforms.Lambda(lambda x: -x)
        val_dataset.target_transform = transforms.Lambda(lambda x: -x)

        valvideoego_dataset = CharadesMeta(
            args.data, 'val_video',
            args.egocentric_test_data,
            args.cache,
            args.cache_buster,
            transform=transforms.Compose([
                transforms.Resize(int(256. / 224 * args.inputsize)),
                transforms.CenterCrop(args.inputsize),
                transforms.ToTensor(),
                normalize,
            ]))

        train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 6)
        val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 6)
        return train_dataset, val_dataset, valvideo_dataset, valvideoego_dataset 
Example #7
Source File: dataloaders.py    From ignite with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
def get_train_val_loaders(
    root_path: str,
    train_transforms: Callable,
    val_transforms: Callable,
    batch_size: int = 16,
    num_workers: int = 8,
    val_batch_size: Optional[int] = None,
    with_sbd: Optional[str] = None,
    limit_train_num_samples: Optional[int] = None,
    limit_val_num_samples: Optional[int] = None,
) -> Tuple[DataLoader, DataLoader, DataLoader]:

    train_ds = get_train_dataset(root_path)
    val_ds = get_val_dataset(root_path)

    if with_sbd is not None:
        sbd_train_ds = get_train_noval_sbdataset(with_sbd)
        train_ds = ConcatDataset([train_ds, sbd_train_ds])

    if limit_train_num_samples is not None:
        np.random.seed(limit_train_num_samples)
        train_indices = np.random.permutation(len(train_ds))[:limit_train_num_samples]
        train_ds = Subset(train_ds, train_indices)

    if limit_val_num_samples is not None:
        np.random.seed(limit_val_num_samples)
        val_indices = np.random.permutation(len(val_ds))[:limit_val_num_samples]
        val_ds = Subset(val_ds, val_indices)

    # random samples for evaluation on training dataset
    if len(val_ds) < len(train_ds):
        np.random.seed(len(val_ds))
        train_eval_indices = np.random.permutation(len(train_ds))[: len(val_ds)]
        train_eval_ds = Subset(train_ds, train_eval_indices)
    else:
        train_eval_ds = train_ds

    train_ds = TransformedDataset(train_ds, transform_fn=train_transforms)
    val_ds = TransformedDataset(val_ds, transform_fn=val_transforms)
    train_eval_ds = TransformedDataset(train_eval_ds, transform_fn=val_transforms)

    train_loader = idist.auto_dataloader(
        train_ds, shuffle=True, batch_size=batch_size, num_workers=num_workers, drop_last=True,
    )

    val_batch_size = batch_size * 4 if val_batch_size is None else val_batch_size
    val_loader = idist.auto_dataloader(
        val_ds, shuffle=False, batch_size=val_batch_size, num_workers=num_workers, drop_last=False,
    )

    train_eval_loader = idist.auto_dataloader(
        train_eval_ds, shuffle=False, batch_size=val_batch_size, num_workers=num_workers, drop_last=False,
    )

    return train_loader, val_loader, train_eval_loader