Python torch.utils.data.Dataset() Examples

The following are 30 code examples of torch.utils.data.Dataset(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.utils.data , or try the search function .
Example #1
Source File: data.py    From VSE-C with MIT License 7 votes vote down vote up
def get_test_loader(split_name, data_name, vocab, crop_size, batch_size,
                    workers, opt, cap_suffix='caps'):
    dpath = os.path.join(opt.data_path, data_name)
    if opt.data_name.endswith('_precomp'):
        if not opt.use_external_captions:
            test_loader = get_precomp_loader(dpath, split_name, vocab, opt,
                                             batch_size, False, workers, cap_suffix)
        else:
            test_loader = get_precomp_train_caption_loader(dpath, split_name, vocab, opt,
                                                           batch_size, False, workers, cap_suffix)
    else:
        # Build Dataset Loader
        roots, ids = get_paths(dpath, data_name, opt.use_restval)

        transform = get_transform(data_name, split_name, opt)
        test_loader = get_loader_single(opt.data_name, split_name,
                                        roots[split_name]['img'],
                                        roots[split_name]['cap'],
                                        vocab, transform, ids=ids[split_name],
                                        batch_size=batch_size, shuffle=False,
                                        num_workers=workers,
                                        collate_fn=collate_fn)

    return test_loader 
Example #2
Source File: utils_NMT.py    From ConvLab with MIT License 7 votes vote down vote up
def get_seq(pairs,lang,batch_size,type,max_len):   
    x_seq = []
    y_seq = []
    ptr_seq = []
    for pair in pairs:
        x_seq.append(pair[0])
        y_seq.append(pair[1])
        ptr_seq.append(pair[2])
        if(type):
            lang.index_words(pair[0])
            lang.index_words(pair[1])
    
    dataset = Dataset(x_seq, y_seq,ptr_seq,lang.word2index, lang.word2index,max_len)
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              shuffle=type,
                                              collate_fn=collate_fn)
    return data_loader 
Example #3
Source File: dataset.py    From seamseg with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, root_dir, split_name, transform):
        super(ISSDataset, self).__init__()
        self.root_dir = root_dir
        self.split_name = split_name
        self.transform = transform

        # Folders
        self._img_dir = path.join(root_dir, ISSDataset._IMG_DIR)
        self._msk_dir = path.join(root_dir, ISSDataset._MSK_DIR)
        self._lst_dir = path.join(root_dir, ISSDataset._LST_DIR)
        for d in self._img_dir, self._msk_dir, self._lst_dir:
            if not path.isdir(d):
                raise IOError("Dataset sub-folder {} does not exist".format(d))

        # Load meta-data and split
        self._meta, self._images = self._load_split() 
Example #4
Source File: data_loader.py    From LipReading with MIT License 6 votes vote down vote up
def __init__(self, audio_conf, manifest_filepath, labels, normalize=False, augment=False):
        """
        Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
        a comma. Each new line is a different sample. Example below:

        /path/to/audio.wav,/path/to/audio.txt
        ...

        :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
        :param manifest_filepath: Path to manifest csv as describe above
        :param labels: String containing all the possible characters to map to
        :param normalize: Apply standard mean and deviation normalization to audio tensor
        :param augment(default False):  Apply random tempo and gain perturbations
        """
        with open(manifest_filepath) as f:
            ids = f.readlines()
        ids = [x.strip().split(',') for x in ids]
        self.ids = ids
        self.size = len(ids)
        self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
        super(SpectrogramDataset, self).__init__(audio_conf, normalize, augment) 
Example #5
Source File: utils_babi.py    From ConvLab with MIT License 6 votes vote down vote up
def get_seq(pairs,lang,batch_size,type,max_len):   
    x_seq = []
    y_seq = []
    ptr_seq = []
    gate_seq = []
    for pair in pairs:
        x_seq.append(pair[0])
        y_seq.append(pair[1])
        ptr_seq.append(pair[2])
        gate_seq.append(pair[3])
        if(type):
            lang.index_words(pair[0])
            lang.index_words(pair[1])
    
    dataset = Dataset(x_seq, y_seq,ptr_seq,gate_seq,lang.word2index, lang.word2index,max_len)
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              shuffle=type,
                                              collate_fn=collate_fn)
    return data_loader 
Example #6
Source File: sunrgbd_detection_dataset_hd.py    From H3DNet with MIT License 6 votes vote down vote up
def __init__(self, data_path=None, split_set='train', num_points=20000,
        use_color=False, use_height=False, use_v1=False,
        augment=False, scan_idx_list=None):

        assert(num_points<=50000)
        self.use_v1 = use_v1 
        if use_v1:
            self.data_path = os.path.join(data_path, 'sunrgbd_pc_bbox_votes_50k_v1_' + split_set)
            # self.data_path = os.path.join('/scratch/cluster/yanght/Dataset/sunrgbd/sunrgbd_pc_bbox_votes_50k_v1_' + split_set)
        else:
            AssertionError("v2 data is not prepared")

        self.raw_data_path = os.path.join(ROOT_DIR, 'sunrgbd/sunrgbd_trainval')
        self.scan_names = sorted(list(set([os.path.basename(x)[0:6] \
            for x in os.listdir(self.data_path)])))

        if scan_idx_list is not None:
            self.scan_names = [self.scan_names[i] for i in scan_idx_list]
        self.num_points = num_points
        self.augment = augment
        self.use_color = use_color
        self.use_height = use_height 
Example #7
Source File: base.py    From nsf with MIT License 6 votes vote down vote up
def __init__(self, num_epochs=None, *args, **kwargs):
        """Constructor.

        Args:
            dataset: A `Dataset` object to be loaded.
            batch_size: int, the size of each batch.
            shuffle: bool, whether to shuffle the dataset after each epoch.
            drop_last: bool, whether to drop last batch if its size is less than
                `batch_size`.
            num_epochs: int or None, number of epochs to iterate over the dataset.
                If None, defaults to infinity.
        """
        super().__init__(
            *args, **kwargs
        )
        self.finite_iterable = super().__iter__()
        self.counter = 0
        self.num_epochs = float('inf') if num_epochs is None else num_epochs 
Example #8
Source File: MyMNIST.py    From sgd-influence with MIT License 6 votes vote down vote up
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, seed=0):
        super(MNIST, self).__init__(root)
        self.seed = seed
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file
        self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) 
Example #9
Source File: datasets.py    From bgd with MIT License 6 votes vote down vote up
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file)) 
Example #10
Source File: dataset.py    From prediction-flow with MIT License 6 votes vote down vote up
def __getitem__(self, idx):
        record = OrderedDict()

        for feat in chain(
                self.features.number_features,
                self.features.category_features):
            record[feat.name] = self.X_map[feat.name][idx]

        for feat in self.features.sequence_features:
            seq = self.X_map[feat.name][idx]
            record[feat.name] = Dataset.__pad_sequence(feat, seq)
            record[f"__{feat.name}_length"] = np.int64(seq.shape[0])

        if self.y is not None:
            record['label'] = self.y[idx]
        return record 
Example #11
Source File: base.py    From LEDNet with MIT License 6 votes vote down vote up
def transform(self, fn, lazy=True):
        """Returns a new dataset with each sample transformed by the
        transformer function `fn`.

        Parameters
        ----------
        fn : callable
            A transformer function that takes a sample as input and
            returns the transformed sample.
        lazy : bool, default True
            If False, transforms all samples at once. Otherwise,
            transforms each sample on demand. Note that if `fn`
            is stochastic, you must set lazy to True or you will
            get the same result on all epochs.

        Returns
        -------
        Dataset
            The transformed dataset.
        """
        trans = _LazyTransformDataset(self, fn)
        if lazy:
            return trans
        return SimpleDataset([i for i in trans]) 
Example #12
Source File: data_silo.py    From FARM with Apache License 2.0 6 votes vote down vote up
def _dataset_from_chunk(cls, chunk, processor):
        """
        Creating a dataset for a chunk (= subset) of dicts. In multiprocessing:
          * we read in all dicts from a file
          * split all dicts into chunks
          * feed *one chunk* to *one process*
          => the *one chunk*  gets converted to *one dataset* (that's what we do here)
          * all datasets get collected and concatenated
        :param chunk: Instead of only having a list of dicts here we also supply an index (ascending int) for each.
            => [(0, dict), (1, dict) ...]
        :type chunk: list of tuples
        :param processor: FARM Processor (e.g. TextClassificationProcessor)
        :return: PyTorch Dataset
        """
        dicts = [d[1] for d in chunk]
        indices = [x[0] for x in chunk]
        dataset = processor.dataset_from_dicts(dicts=dicts, indices=indices)
        return dataset 
Example #13
Source File: dataloader.py    From FARM with Apache License 2.0 6 votes vote down vote up
def covert_dataset_to_dataloader(dataset, sampler, batch_size):
    """
    Wraps a PyTorch Dataset with a DataLoader.

    :param dataset: Dataset to be wrapped.
    :type dataset: Dataset
    :param sampler: PyTorch sampler used to pick samples in a batch.
    :type sampler: Sampler
    :param batch_size: Number of samples in the batch.
    :return: A DataLoader that wraps the input Dataset.
    """
    sampler_initialized = sampler(dataset)
    data_loader = DataLoader(
        dataset, sampler=sampler_initialized, batch_size=batch_size
    )
    return data_loader 
Example #14
Source File: classifier.py    From metal with Apache License 2.0 6 votes vote down vote up
def resume_training(self, train_data, model_path, valid_data=None):
        """This model resume training of a classifier by reloading the appropriate state_dicts for each model

        Args:
           train_data: a tuple of Tensors (X,Y), a Dataset, or a DataLoader of
                X (data) and Y (labels) for the train split
            model_path: the path to the saved checpoint for resuming training
            valid_data: a tuple of Tensors (X,Y), a Dataset, or a DataLoader of
                X (data) and Y (labels) for the dev split
        """
        restore_state = self.checkpointer.restore(model_path)
        loss_fn = self._get_loss_fn()
        self.train()
        self._train_model(
            train_data=train_data,
            loss_fn=loss_fn,
            valid_data=valid_data,
            restore_state=restore_state,
        ) 
Example #15
Source File: classifier.py    From metal with Apache License 2.0 6 votes vote down vote up
def _create_data_loader(self, data, **kwargs):
        """Converts input data into a DataLoader"""
        if data is None:
            return None

        # Set DataLoader config
        # NOTE: Not applicable if data is already a DataLoader
        config = {
            **self.config["train_config"]["data_loader_config"],
            **kwargs,
            "pin_memory": self.config["device"] != "cpu",
        }
        # Return data as DataLoader
        if isinstance(data, DataLoader):
            return data
        elif isinstance(data, Dataset):
            return DataLoader(data, **config)
        elif isinstance(data, (tuple, list)):
            return DataLoader(self._create_dataset(*data), **config)
        else:
            raise ValueError("Input data type not recognized.") 
Example #16
Source File: datasets.py    From signaltrain with GNU General Public License v3.0 5 votes vote down vote up
def get_single_chunk(self):
        """
        Grabs audio and knobs either from files or from preloaded buffer(s)
        """
        if self.preload:  # This will typically be the case
            i = np.random.randint(0,high=len(self.x))  # pick a random line from preloaded audio
            in_audio, targ_audio, knobs_wc = self.x[i], self.y[i], self.knobs[i]  # note these might be, e.g. 10 seconds long
        else:
            in_audio, targ_audio, knobs_wc = self.read_one_new_file_pair() # read x, y, knobs

        # Grab a random chunk from within total audio nfile
        assert len(in_audio) > self.chunk_size, f"Error: len(in_audio)={len(in_audio)}, must be > self.chunk_size={self.chunk_size}"
        ibgn = np.random.randint(0, len(in_audio) - self.chunk_size)
        x_item = in_audio[ibgn:ibgn+self.chunk_size]
        y_item = targ_audio[ibgn:ibgn+self.chunk_size]

        if self.rerun_effect:  # re-run the effect on this chunk , and replace target audio
            y_item, x_item = self.effect.go_wc(x_item, knobs_wc)   # Apply the audio effect

        y_item = y_item[-self.y_size:]   # Format for expected output size

        # normalize knobs for nn usage
        kr = self.effect.knob_ranges   # kr is abbribation for 'knob ranges'
        knobs_nn = (knobs_wc - kr[:,0])/(kr[:,1]-kr[:,0]) - 0.5

        if self.augment:
            x_item, y_item = do_augment(x_item, y_item)

        return x_item.astype(self.dtype, copy=False), y_item.astype(self.dtype, copy=False), knobs_nn.astype(self.dtype, copy=False)

    # required part of torch.Dataset class.  This is how DataLoader gets a new piece of data 
Example #17
Source File: eval_hooks.py    From DenseMatchingBenchmark with MIT License 5 votes vote down vote up
def __init__(self, cfg, dataset, interval=1):
        self.cfg = cfg.copy()
        assert isinstance(dataset, Dataset), \
            "dataset must be a Dataset object, not {}".format(type(dataset))
        self.dataset = dataset
        self.interval = interval 
Example #18
Source File: eval_hooks.py    From DenseMatchingBenchmark with MIT License 5 votes vote down vote up
def __init__(self, cfg, dataset, interval=1):
        self.cfg = cfg.copy()
        assert isinstance(dataset, Dataset), \
            "dataset must be a Dataset object, not {}".format(type(dataset))
        self.dataset = dataset
        self.interval = interval 
Example #19
Source File: cocostuff.py    From SPNet with MIT License 5 votes vote down vote up
def __repr__(self):
        fmt_str = "Dataset " + self.__class__.__name__ + "\n"
        fmt_str += "    Number of datapoints: {}\n".format(self.__len__())
        fmt_str += "    Split: {}\n".format(self.split)
        fmt_str += "    Root Location: {}\n".format(self.root)
        return fmt_str 
Example #20
Source File: dataset.py    From scVI with MIT License 5 votes vote down vote up
def __getitem__(self, idx):
        """Implements @abstractcmethod in ``torch.utils.data.dataset.Dataset`` ."""
        return idx 
Example #21
Source File: utils_kvr.py    From ConvLab with MIT License 5 votes vote down vote up
def get_seq(pairs,lang,batch_size,type,max_len):   
    x_seq = []
    y_seq = []
    ptr_seq = []
    gate_seq = []
    entity = []
    entity_cal = []
    entity_nav = []
    entity_wet = []
    for pair in pairs:
        x_seq.append(pair[0])
        y_seq.append(pair[1])
        ptr_seq.append(pair[2])
        gate_seq.append(pair[3])
        entity.append(pair[4])
        entity_cal.append(pair[5])
        entity_nav.append(pair[6])
        entity_wet.append(pair[7])
        if(type):
            lang.index_words(pair[0])
            lang.index_words(pair[1])
    
    dataset = Dataset(x_seq, y_seq,ptr_seq,gate_seq,lang.word2index, lang.word2index,max_len,entity,entity_cal,entity_nav,entity_wet)
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              shuffle=type,
                                              collate_fn=collate_fn)
    return data_loader 
Example #22
Source File: inc_ext.py    From iccv2019-inc with MIT License 5 votes vote down vote up
def _load_meta(self):
        path = os.path.join(self.root, self.base_folder, self.meta['filename'])
        if not check_integrity(path, self.meta['md5']):
            raise RuntimeError('Dataset metadata file not found or corrupted.' +
                               ' You can use download=True to download it')
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.classes = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 
Example #23
Source File: vis_hooks.py    From DenseMatchingBenchmark with MIT License 5 votes vote down vote up
def __init__(self, dataset, cfg, interval=1):
        self.cfg = cfg.copy()
        if isinstance(dataset, Dataset):
            self.dataset = dataset
        else:
            raise TypeError("dataset must be a Dataset object, not {}".format(type(dataset)))
        self.interval = interval 
Example #24
Source File: data_utils.py    From cloudml-samples with Apache License 2.0 5 votes vote down vote up
def download_data():
    """Download the data from Google Cloud Storage"""
    # Load the Dataset from the public GCS bucket
    bucket = storage.Client().bucket('cloud-samples-data')
    # Path to the data inside the public bucket
    blob = bucket.blob('ml-engine/sonar/sonar.all-data')
    # Download the data
    blob.download_to_filename('sonar.all-data') 
Example #25
Source File: utils.py    From audio with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def __init__(self, dataset: Dataset, location: str = ".cached") -> None:
        self.dataset = dataset
        self.location = location

        self._id = id(self)
        self._cache: List = [None] * len(dataset) 
Example #26
Source File: classifier.py    From metal with Apache License 2.0 5 votes vote down vote up
def _create_dataset(self, *data):
        """Converts input data to the appropriate Dataset"""
        # Make sure data is a tuple of dense tensors
        data = [self._to_torch(x, dtype=torch.FloatTensor) for x in data]
        return TensorDataset(*data) 
Example #27
Source File: convert_lmdb.py    From torch-toolbox with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def generate_lmdb_dataset(
        data_set: Dataset,
        save_dir: str,
        name: str,
        num_workers=0,
        max_size_rate=1.0,
        write_frequency=5000):
    data_loader = DataLoader(
        data_set,
        num_workers=num_workers,
        collate_fn=lambda x: x)
    num_samples = len(data_set)
    check_dir(save_dir)
    lmdb_path = os.path.join(save_dir, '{}.lmdb'.format(name))
    db = lmdb.open(lmdb_path, subdir=False,
                   map_size=int(1099511627776 * max_size_rate),
                   readonly=False, meminit=True, map_async=True)
    txn = db.begin(write=True)
    for idx, data in enumerate(tqdm(data_loader)):
        txn.put(get_key(idx), dumps_pyarrow(data[0]))
        if idx % write_frequency == 0 and idx > 0:
            txn.commit()
            txn = db.begin(write=True)
    txn.put(b'__len__', dumps_pyarrow(num_samples))
    try:
        classes = data_set.classes
        class_to_idx = data_set.class_to_idx
        txn.put(b'classes', dumps_pyarrow(classes))
        txn.put(b'class_to_idx', dumps_pyarrow(class_to_idx))
    except AttributeError:
        pass

    txn.commit()
    db.sync()
    db.close() 
Example #28
Source File: classifier.py    From metal with Apache License 2.0 5 votes vote down vote up
def _get_predictions(self, data, break_ties="random", return_probs=False, **kwargs):
        """Computes predictions in batch, given a labeled dataset

        Args:
            data: a Pytorch DataLoader, Dataset, or tuple with Tensors (X,Y):
                X: The input for the predict method
                Y: An [n] or [n, 1] torch.Tensor or np.ndarray of target labels
                    in {1,...,k}
            break_ties: How to break ties when making predictions
            return_probs: Return the predicted probabilities as well

        Returns:
            Y_p: A Tensor of predictions
            Y: A Tensor of labels
            [Optionally: Y_s: An [n, k] np.ndarray of predicted probabilities]
        """
        data_loader = self._create_data_loader(data)
        Y_p = []
        Y = []
        Y_s = []

        # Do batch evaluation by default, getting the predictions and labels
        for batch_num, data in enumerate(data_loader):
            Xb, Yb = data
            Y.append(self._to_numpy(Yb))

            # Optionally move to device
            if self.config["device"] != "cpu":
                Xb = place_on_gpu(Xb)

            # Append predictions and labels from DataLoader
            Y_pb, Y_sb = self.predict(
                Xb, break_ties=break_ties, return_probs=True, **kwargs
            )
            Y_p.append(self._to_numpy(Y_pb))
            Y_s.append(self._to_numpy(Y_sb))
        Y_p, Y, Y_s = map(self._stack_batches, [Y_p, Y, Y_s])
        if return_probs:
            return Y_p, Y, Y_s
        else:
            return Y_p, Y 
Example #29
Source File: datasets.py    From signaltrain with GNU General Public License v3.0 5 votes vote down vote up
def process_audio(self):  # TODO: not used yet following torchaudio
        """ Render raw audio as pytorch-friendly file. TODO: not done yet.
        """
        if os.path.exists(self.processed_dir):
            return

        # get a list of available audio files.  Note that knob settings are included to the target filenames
        input_filenames = sorted(glob.glob(self.path+'/'+'input_*'))
        self.target_filenames = sorted(glob.glob(self.path+'/'+'target_*'))
        assert len(input_filenames) == len(target_filenames)   # TODO: One can image a scheme with multiple targets per input
        print("Dataset: Found",self.__len__(),"raw audio i-o pairs in path",self.path) 
Example #30
Source File: vis_hooks.py    From DenseMatchingBenchmark with MIT License 5 votes vote down vote up
def __init__(self, dataset, cfg, interval=1):
        self.cfg = cfg.copy()
        if isinstance(dataset, Dataset):
            self.dataset = dataset
        else:
            raise TypeError("dataset must be a Dataset object, not {}".format(type(dataset)))
        self.interval = interval