Python torch.utils.data.dataset.Dataset() Examples

The following are 13 code examples of torch.utils.data.dataset.Dataset(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.utils.data.dataset , or try the search function .
Example #1
Source File: utils.py    From hidden-networks with Apache License 2.0 6 votes vote down vote up
def one_batch_dataset(dataset, batch_size):
    print("==> Grabbing a single batch")

    perm = torch.randperm(len(dataset))

    one_batch = [dataset[idx.item()] for idx in perm[:batch_size]]

    class _OneBatchWrapper(Dataset):
        def __init__(self):
            self.batch = one_batch

        def __getitem__(self, index):
            return self.batch[index]

        def __len__(self):
            return len(self.batch)

    return _OneBatchWrapper() 
Example #2
Source File: video_keyframe_dataset.py    From detectron2 with Apache License 2.0 6 votes vote down vote up
def __init__(
        self,
        video_list: List[str],
        frame_selector: Optional[FrameSelector] = None,
        transform: Optional[FrameTransform] = None,
    ):
        """
        Dataset constructor

        Args:
            video_list (List[str]): list of paths to video files
            frame_selector (Callable: KeyFrameList -> KeyFrameList):
                selects keyframes to process, keyframes are given by
                packet timestamps in timebase counts. If None, all keyframes
                are selected (default: None)
            transform (Callable: torch.Tensor -> torch.Tensor):
                transforms a batch of RGB images (tensors of size [B, H, W, 3]),
                returns a tensor of the same size. If None, no transform is
                applied (default: None)

        """
        self.video_list = video_list
        self.frame_selector = frame_selector
        self.transform = transform 
Example #3
Source File: datasets.py    From multimodal-vae-public with MIT License 6 votes vote down vote up
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root             = os.path.expanduser(root)
        self.transform        = transform
        self.target_transform = target_transform
        self.train            = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file)) 
Example #4
Source File: data_loader.py    From Depth-Completion with MIT License 6 votes vote down vote up
def __init__(self, dataset_name, data_path, train=True):
        self.dataset_name = dataset_name
        if self.dataset_name not in {'matterport'}:
            raise Exception(f'Dataset name not found: {self.dataset_name}')
        self.data_root = data_path
        self.len = 0
        self.train = train
        self.scene_name = []
        self.color_name = []
        self.depth_name = []
        self.normal_name = []
        self.render_name = []
        self.boundary_name = []
        self.depth_boundary_name = []
        
        if self.dataset_name == 'matterport':
            self._load_data_name_matterport(train=self.train) 
Example #5
Source File: filterable.py    From NSCL-PyTorch-Release with MIT License 5 votes vote down vote up
def __init__(self, owner_dataset, indices=None, filter_name=None, filter_func=None):
        """
        Args:
            owner_dataset (Dataset): the original dataset.
            indices (List[int]): a list of indices that was filterred out.
            filter_name (str): human-friendly name for the filter.
            filter_func (Callable): just for tracking.
        """

        super().__init__()
        self.owner_dataset = owner_dataset
        self.indices = indices
        self._filter_name = filter_name
        self._filter_func = filter_func 
Example #6
Source File: wider_face.py    From tiny-faces-pytorch with MIT License 5 votes vote down vote up
def __init__(self, path, templates, img_transforms=None, dataset_root="", split="train",
                 train=True, input_size=(500, 500), heatmap_size=(63, 63),
                 pos_thresh=0.7, neg_thresh=0.3, pos_fraction=0.5, debug=False):
        super().__init__()

        self.data = []
        self.split = split

        self.load(path)

        print("Dataset loaded")
        print("{0} samples in the {1} dataset".format(len(self.data),
                                                      self.split))
        # self.data = data

        # canonical object templates obtained via clustering
        # NOTE we directly use the values from Peiyun's repository stored in "templates.json"
        self.templates = templates

        self.transforms = img_transforms
        self.dataset_root = Path(dataset_root)
        self.input_size = input_size
        self.heatmap_size = heatmap_size
        self.pos_thresh = pos_thresh
        self.neg_thresh = neg_thresh
        self.pos_fraction = pos_fraction

        # receptive field computed using a combination of values from Matconvnet
        # plus derived equations.
        self.rf = {
            'size': [859, 859],
            'stride': [8, 8],
            'offset': [-1, -1]
        }

        self.processor = DataProcessor(input_size, heatmap_size,
                                       pos_thresh, neg_thresh,
                                       templates, rf=self.rf)
        self.debug = debug 
Example #7
Source File: load_neuroimaging_data.py    From FastSurfer with Apache License 2.0 5 votes vote down vote up
def __len__(self):
        return self.count


##
# Dataset loading (for training)
##

# Operator to load hdf5-file for training 
Example #8
Source File: dataloader.py    From pytorch-meta with MIT License 5 votes vote down vote up
def collate_task(self, task):
        if isinstance(task, TorchDataset):
            return self.collate_fn([task[idx] for idx in range(len(task))])
        elif isinstance(task, OrderedDict):
            return OrderedDict([(key, self.collate_task(subtask))
                for (key, subtask) in task.items()])
        else:
            raise NotImplementedError() 
Example #9
Source File: observe_speaker.py    From Self-Supervised-Speech-Pretraining-and-Representation-Learning with MIT License 5 votes vote down vote up
def main():

    # Load the train-clean-100 set
    tables = pd.read_csv(os.path.join(root, 'train-clean-100' + '.csv'))

    # Compute speaker dictionary
    print('[Dataset] - Computing speaker class...')
    O = tables['file_path'].tolist()
    speakers = get_all_speakers(O)
    speaker2idx = compute_speaker2idx(speakers)
    class_num = len(speaker2idx)
    print('[Dataset] - Possible speaker classes: ', class_num)
    

    train = tables.sample(frac=0.9, random_state=20190929) # random state is a seed value
    test = tables.drop(train.index)
    table = train.sort_values(by=['length'], ascending=False)

    X = table['file_path'].tolist()
    X_lens = table['length'].tolist()

    # Crop seqs that are too long
    if drop and max_timestep > 0:
        table = table[table.length < max_timestep]
    if drop and max_label_len > 0:
        table = table[table.label.str.count('_')+1 < max_label_len]

    # computer utterance per speaker
    num_utt = []
    for speaker in speakers:
        if speaker in speaker2idx:
            num_utt.append(speakers[speaker])
    print('Average utterance per speaker: ', np.mean(num_utt))

    # TODO: furthur analysis 
Example #10
Source File: dataloader.py    From Self-Supervised-Speech-Pretraining-and-Representation-Learning with MIT License 5 votes vote down vote up
def __init__(self, run_mam, file_path, sets, bucket_size, max_timestep=0, drop=False, mam_config=None):
        super(AcousticDataset, self).__init__(file_path, sets, bucket_size, max_timestep, drop)

        self.run_mam = run_mam
        self.mam_config = mam_config
        self.sample_step = mam_config['max_input_length'] if 'max_input_length' in mam_config else 0
        if self.sample_step > 0: print('[Dataset] - Sampling random segments for training, sample length:', self.sample_step)
        X = self.table['file_path'].tolist()
        X_lens = self.table['length'].tolist()

        # Use bucketing to allow different batch size at run time
        self.X = []
        batch_x, batch_len = [], []

        for x, x_len in zip(X, X_lens):
            batch_x.append(x)
            batch_len.append(x_len)
            
            # Fill in batch_x until batch is full
            if len(batch_x) == bucket_size:
                # Half the batch size if seq too long
                if (bucket_size >= 2) and (max(batch_len) > HALF_BATCHSIZE_TIME):
                    self.X.append(batch_x[:bucket_size//2])
                    self.X.append(batch_x[bucket_size//2:])
                else:
                    self.X.append(batch_x)
                batch_x, batch_len = [], []
        
        # Gather the last batch
        if len(batch_x) > 0:
            self.X.append(batch_x) 
Example #11
Source File: wider_face.py    From tiny-faces-pytorch with MIT License 4 votes vote down vote up
def __getitem__(self, index):
        datum = self.data[index]

        image_root = self.dataset_root / "WIDER_{0}".format(self.split)
        image_path = image_root / "images" / datum['img_path']
        image = Image.open(image_path).convert('RGB')

        if self.split == 'train':
            bboxes = datum['bboxes']

            if self.debug:
                if bboxes.shape[0] == 0:
                    print(image_path)
                print("Dataset index: \t", index)
                print("image path:\t", image_path)

            img, class_map, reg_map, bboxes = self.process_inputs(image,
                                                                  bboxes)

            # convert everything to tensors
            if self.transforms is not None:
                # if img is a byte or uint8 array, it will convert from 0-255 to 0-1
                # this converts from (HxWxC) to (CxHxW) as well
                img = self.transforms(img)

            class_map = torch.from_numpy(class_map)
            reg_map = torch.from_numpy(reg_map)

            return img, class_map, reg_map

        elif self.split == 'val':
            # NOTE Return only the image and the image path.
            # Use the eval_tools to get the final results.
            if self.transforms is not None:
                # Only convert to tensor since we do normalization after rescaling
                img = transforms.functional.to_tensor(image)

            return img, datum['img_path']

        elif self.split == 'test':
            filename = datum['img_path']

            if self.transforms is not None:
                img = self.transforms(image)

            return img, filename 
Example #12
Source File: dataloader.py    From Self-Supervised-Speech-Pretraining-and-Representation-Learning with MIT License 4 votes vote down vote up
def __init__(self, run_mam, file_path, phone_path, sets, bucket_size, max_timestep=0, drop=False, train_proportion=1.0, mam_config=None):
        super(Mel_Phone_Dataset, self).__init__(file_path, sets, bucket_size, max_timestep, drop)

        self.run_mam = run_mam
        self.mam_config = mam_config
        self.phone_path = phone_path
        self.class_num = len(pickle.load(open(os.path.join(phone_path, 'phone2idx.pkl'), 'rb')))
        print('[Dataset] - Possible phone classes: ', self.class_num)

        unaligned = pickle.load(open(os.path.join(phone_path, 'unaligned.pkl'), 'rb'))
        X = self.table['file_path'].tolist()
        X_lens = self.table['length'].tolist()
        if train_proportion < 1.0:
            print('[Dataset] - Truncating dataset size from ', len(X), end='')
            chose_proportion = int(len(X)*train_proportion)
            sample_index = sorted(random.sample(range(len(X)), chose_proportion), reverse=True)
            X = np.asarray(X)[sample_index]
            X_lens = np.asarray(X_lens)[sample_index]
            print(' to ', len(X))
            if len(X) < 200: # is a batch is too small, manually duplicate epoch size to increase dataloader speed.
                for _ in range(4): 
                    X = np.concatenate((X, X), axis=0)
                    X_lens = np.concatenate((X_lens, X_lens), axis=0)
        elif train_proportion > 1.0:
            raise ValueError('Invalid range for `train_proportion`, (0.0, 1.0] is the appropriate range!)')

        # Use bucketing to allow different batch sizes at run time
        self.X = []
        batch_x, batch_len = [], []

        for x, x_len in zip(X, X_lens):
            if x not in unaligned:
                batch_x.append(x)
                batch_len.append(x_len)
                
                # Fill in batch_x until batch is full
                if len(batch_x) == bucket_size:
                    # Half the batch size if seq too long
                    if (bucket_size >= 2) and (max(batch_len) > HALF_BATCHSIZE_TIME):
                        self.X.append(batch_x[:bucket_size//2])
                        self.X.append(batch_x[bucket_size//2:])
                    else:
                        self.X.append(batch_x)
                    batch_x, batch_len = [], []
        
        # Gather the last batch
        if len(batch_x) > 0:
            if x not in unaligned:
                self.X.append(batch_x) 
Example #13
Source File: dataloader.py    From Self-Supervised-Speech-Pretraining-and-Representation-Learning with MIT License 4 votes vote down vote up
def __init__(self, run_mam, file_path, phone_path, sets, bucket_size, max_timestep=0, drop=False, mam_config=None, split='train', seed=1337):
        super(CPC_Phone_Dataset, self).__init__(file_path, sets, bucket_size, max_timestep, drop)

        assert('train-clean-100' in sets and len(sets) == 1) # `sets` must be ['train-clean-100']
        random.seed(seed)
        self.run_mam = run_mam
        self.mam_config = mam_config
        self.phone_path = phone_path
        phone_file = open(os.path.join(phone_path, 'converted_aligned_phones.txt')).readlines()
        
        self.Y = {}
        # phone_set = []
        for line in phone_file:
            line = line.strip('\n').split(' ')
            self.Y[line[0]] = [int(p) for p in line[1:]]
            # for p in line[1:]: 
                # if p not in phone_set: phone_set.append(p)
        self.class_num = 41 # len(phone_set) # uncomment the above lines if you want to recompute
        
        if split == 'train' or split == 'dev':
            usage_list = open(os.path.join(phone_path, 'train_split.txt')).readlines()
            random.shuffle(usage_list)
            percent = int(len(usage_list)*0.9)
            usage_list = usage_list[:percent] if split == 'train' else usage_list[percent:]
        elif split == 'test':
            usage_list = open(os.path.join(phone_path, 'test_split.txt')).readlines()
        else:
            raise ValueError('Invalid \'split\' argument for dataset: CPC_Phone_Dataset!')
        usage_list = [line.strip('\n') for line in usage_list]
        print('[Dataset] - Possible phone classes: ' + str(self.class_num) + ', number of data: ' + str(len(usage_list)))

        X = self.table['file_path'].tolist()
        X_lens = self.table['length'].tolist()

        # Use bucketing to allow different batch sizes at run time
        self.X = []
        batch_x, batch_len = [], []

        for x, x_len in zip(X, X_lens):
            if self.parse_x_name(x) in usage_list:
                batch_x.append(x)
                batch_len.append(x_len)
                
                # Fill in batch_x until batch is full
                if len(batch_x) == bucket_size:
                    # Half the batch size if seq too long
                    if (bucket_size >= 2) and (max(batch_len) > HALF_BATCHSIZE_TIME):
                        self.X.append(batch_x[:bucket_size//2])
                        self.X.append(batch_x[bucket_size//2:])
                    else:
                        self.X.append(batch_x)
                    batch_x, batch_len = [], []
        
        # Gather the last batch
        if len(batch_x) > 0:
            if self.parse_x_name(x) in usage_list:
                self.X.append(batch_x)