Python torch.utils.data.dataloader.DataLoader() Examples

The following are 30 code examples of torch.utils.data.dataloader.DataLoader(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.utils.data.dataloader , or try the search function .
Example #1
Source File: inference_utils.py    From allRank with Apache License 2.0 6 votes vote down vote up
def __rank_slates(dataloader: DataLoader, model: LTRModel) -> Tuple[torch.Tensor, torch.Tensor]:
    reranked_X = []
    reranked_y = []
    model.eval()
    device = get_torch_device()
    with torch.no_grad():
        for xb, yb, _ in dataloader:
            X = xb.type(torch.float32).to(device=device)
            y_true = yb.to(device=device)

            input_indices = torch.ones_like(y_true).type(torch.long)
            mask = (y_true == losses.PADDED_Y_VALUE)
            scores = model.score(X, mask, input_indices)

            scores[mask] = float('-inf')

            _, indices = scores.sort(descending=True, dim=-1)
            indices_X = torch.unsqueeze(indices, -1).repeat_interleave(X.shape[-1], -1)
            reranked_X.append(torch.gather(X, dim=1, index=indices_X).cpu())
            reranked_y.append(torch.gather(y_true, dim=1, index=indices).cpu())

    combined_X = torch.cat(reranked_X)
    combined_y = torch.cat(reranked_y)
    return combined_X, combined_y 
Example #2
Source File: test_dataloaders.py    From pytorch-lightning with Apache License 2.0 6 votes vote down vote up
def test_warning_with_iterable_dataset_and_len(tmpdir):
    """ Tests that a warning messages is shown when an IterableDataset defines `__len__`. """
    model = EvalModelTemplate()
    original_dataset = model.train_dataloader().dataset

    class IterableWithLen(IterableDataset):

        def __iter__(self):
            return iter(original_dataset)

        def __len__(self):
            return len(original_dataset)

    dataloader = DataLoader(IterableWithLen(), batch_size=16)
    assert _has_len(dataloader)
    assert _has_iterable_dataset(dataloader)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_steps=3,
    )
    with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
        trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
    with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
        trainer.test(model, test_dataloaders=[dataloader]) 
Example #3
Source File: deepSVDD_trainer.py    From Deep-SVDD-PyTorch with MIT License 6 votes vote down vote up
def init_center_c(self, train_loader: DataLoader, net: BaseNet, eps=0.1):
        """Initialize hypersphere center c as the mean from an initial forward pass on the data."""
        n_samples = 0
        c = torch.zeros(net.rep_dim, device=self.device)

        net.eval()
        with torch.no_grad():
            for data in train_loader:
                # get the inputs of the batch
                inputs, _, _ = data
                inputs = inputs.to(self.device)
                outputs = net(inputs)
                n_samples += outputs.shape[0]
                c += torch.sum(outputs, dim=0)

        c /= n_samples

        # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights.
        c[(abs(c) < eps) & (c < 0)] = -eps
        c[(abs(c) < eps) & (c > 0)] = eps

        return c 
Example #4
Source File: dataloader.py    From EverybodyDanceNow_reproduce_pytorch with MIT License 6 votes vote down vote up
def copy(loader):
        """
        Init a sDataloader from an existing Dataloader
        :param loader: an instance of Dataloader
        :type loader: DataLoader
        :return: a new instance of sDataloader
        :rtype: sDataLoader
        """
        if not isinstance(loader, DataLoader):
            logger.warning('loader should be an instance of Dataloader, but got {}'.format(type(loader)))
            return loader

        new_loader = sDataLoader(loader.dataset)
        for k, v in loader.__dict__.items():
            setattr(new_loader, k, v)
        return new_loader 
Example #5
Source File: DeepSAD_trainer.py    From Deep-SAD-PyTorch with MIT License 6 votes vote down vote up
def init_center_c(self, train_loader: DataLoader, net: BaseNet, eps=0.1):
        """Initialize hypersphere center c as the mean from an initial forward pass on the data."""
        n_samples = 0
        c = torch.zeros(net.rep_dim, device=self.device)

        net.eval()
        with torch.no_grad():
            for data in train_loader:
                # get the inputs of the batch
                inputs, _, _, _ = data
                inputs = inputs.to(self.device)
                outputs = net(inputs)
                n_samples += outputs.shape[0]
                c += torch.sum(outputs, dim=0)

        c /= n_samples

        # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights.
        c[(abs(c) < eps) & (c < 0)] = -eps
        c[(abs(c) < eps) & (c > 0)] = eps

        return c 
Example #6
Source File: eval.py    From kaggle_carvana_segmentation with MIT License 6 votes vote down vote up
def predict(self, skip_folds=None):
        for fold, (train_index, val_index) in enumerate(self.folds):
            prefix = ('fold' + str(fold) + "_") if self.test else ""
            if skip_folds is not None:
                if fold in skip_folds:
                    continue
            self.prev_name = None
            ds_cls = ValDataset if not self.test else SequentialDataset
            val_dataset = ds_cls(self.ds, val_index, stage='test', config=self.config)
            val_dl = PytorchDataLoader(val_dataset, batch_size=self.config.predict_batch_size, num_workers=self.num_workers, drop_last=False)
            weights_path = os.path.join(self.config.models_dir, 'albu')
            model = read_model(weights_path, self.folder, fold)
            pbar = val_dl if self.config.dbg else tqdm.tqdm(val_dl, total=len(val_dl))
            for data in pbar:
                self.show_mask = 'mask' in data and self.show_mask
                if 'mask' not in data:
                    self.need_dice = False

                predicted = self.predict_samples(model, data)
                self.process_data(predicted, model, data, prefix=prefix)

                if not self.config.dbg and self.need_dice:
                    pbar.set_postfix(dice="{:.5f}".format(np.mean(self.dice)))
            if self.config.use_crop:
                self.on_image_constructed(prefix=prefix) 
Example #7
Source File: imagenetloader.py    From DTC with MIT License 6 votes vote down vote up
def ImageNetLoader882(batch_size, num_workers, split='train', shuffle=False, path='data_shallow14/datasets/ImageNet/'):
    img_split = 'images/'+split
    classes_118, class_to_idx_118 = find_classes_from_file(path+'imagenet_rand118/imagenet_118.txt')
    samples_118 = make_dataset(path+img_split, classes_118, class_to_idx_118)
    classes_1000, _ = find_classes_from_folder(path+img_split)
    classes_882 = list(set(classes_1000) - set(classes_118))
    class_to_idx_882 = {classes_882[i]: i for i in range(len(classes_882))}
    samples_882 = make_dataset(path+img_split, classes_882, class_to_idx_882)
    if split=='train':
        transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    else:
        transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    dataset = ImageFolder(transform=transform, samples=samples_882)
    dataloader_882 = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) 
    return dataloader_882 
Example #8
Source File: imagenetloader.py    From DTC with MIT License 6 votes vote down vote up
def ImageNetLoader82from882(batch_size, num_workers, num_val_cls=30, path='data_shallow14/datasets/ImageNet/'):
    classes_118, class_to_idx_118 = find_classes_from_file(path+'imagenet_rand118/imagenet_118.txt')
    samples_118 = make_dataset(path+'images/train', classes_118, class_to_idx_118)
    classes_1000, _ = find_classes_from_folder(path+'images/train')
    classes_882 = list(set(classes_1000) - set(classes_118))
    classes_val = classes_882[800:800+num_val_cls]
    class_to_idx_val = {classes_val[i]: i for i in range(len(classes_val))}
    samples_val = make_dataset(path+'images/train', classes_val, class_to_idx_val)
    transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    dataset_val = ImageFolder(transform=transform, samples=samples_val)
    dataloader_val= DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 
    return dataloader_val 
Example #9
Source File: data.py    From Pytorch-NCE with MIT License 6 votes vote down vote up
def get_dataloader(self, filename, bs=1):
        full_path = os.path.join(self.base_path, filename)
        if self.concat:
            dataset = ContLMDataset(full_path, vocab=self.vocab, bptt=self.bptt)
        else:
            dataset = LMDataset(full_path, vocab=self.vocab, bptt=self.bptt)
        return DataLoader(
            dataset=dataset,
            batch_size=bs,
            shuffle=self.shuffle,
            pin_memory=self.pin_memory,
            collate_fn=pad_collate_fn,
            # num_workers=1,
            # waiting for a new torch version to support
            # drop_last=True,
        ) 
Example #10
Source File: test_dataloaders.py    From pytorch-lightning with Apache License 2.0 5 votes vote down vote up
def test_val_dataloader_not_implemented_error_failed(tmpdir):
    """Test not_implemented_error train data loader (e.g. IterableDataset)"""
    model = EvalModelTemplate()
    model.val_dataloader = model.val_dataloader__not_implemented_error

    trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5)

    with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
        trainer.fit(model) 
Example #11
Source File: test_dataloaders.py    From pytorch-lightning with Apache License 2.0 5 votes vote down vote up
def test_test_inf_dataloader_error(tmpdir):
    """Test inf train data loader (e.g. IterableDataset)"""
    model = EvalModelTemplate()
    model.test_dataloader = model.test_dataloader__infinite

    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)

    with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
        trainer.test(model) 
Example #12
Source File: test_dataloaders.py    From pytorch-lightning with Apache License 2.0 5 votes vote down vote up
def test_test_dataloader_not_implemented_error_failed(tmpdir):
    """Test not_implemented_error train data loader (e.g. IterableDataset)"""
    model = EvalModelTemplate()
    model.test_dataloader = model.test_dataloader__not_implemented_error

    trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5)

    with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
        trainer.test(model) 
Example #13
Source File: imagenetloader.py    From DTC with MIT License 5 votes vote down vote up
def ImageNetLoader30(batch_size, num_workers=2, path='data_shallow14/datasets/ImageNet/', subset='A', aug=None, shuffle=False, subfolder='train'):
    # dataloader of 30 classes
    classes_30, class_to_idx_30 = find_classes_from_file(path+'imagenet_rand118/imagenet_30_{}.txt'.format(subset))
    samples_30 = make_dataset(path+'images/{}'.format(subfolder), classes_30, class_to_idx_30)
    if aug == None:
        transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    elif aug=='once':
        transform = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    elif aug=='twice':
        transform = TransformTwice(transforms.Compose([
            transforms.Resize(256),
            transforms.RandomCrop(224),
            RandomTranslateWithReflect(4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
        ]))
    dataset = ImageFolder(transform=transform, samples=samples_30)
    dataloader_30 = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) 
    return dataloader_30 
Example #14
Source File: imagenetloader.py    From DTC with MIT License 5 votes vote down vote up
def ImageNetLoader800from882(batch_size, num_workers, split='train', path='data_shallow14/datasets/ImageNet/'):
    # this dataloader split the 882 classes into train + val = 882
    classes_118, class_to_idx_118 = find_classes_from_file(path+'imagenet_rand118/imagenet_118.txt')
    samples_118 = make_dataset(path+'images/train', classes_118, class_to_idx_118)
    classes_1000, _ = find_classes_from_folder(path+'images/train')
    classes_882 = list(set(classes_1000) - set(classes_118))
    classes_train = classes_882[:800]
    class_to_idx_train = {classes_train[i]: i for i in range(len(classes_train))}
    samples_800 = make_dataset(path+'images/'+split, classes_train, class_to_idx_train)
    if split=='train':
        transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    dataset_800= ImageFolder(transform=transform, samples=samples_800)
    dataloader_800= DataLoader(dataset_800, batch_size=batch_size, shuffle=split=='train', num_workers=num_workers, pin_memory=True) 
    return dataloader_800 
Example #15
Source File: gqa.py    From lxmert with MIT License 5 votes vote down vote up
def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
    dset = GQADataset(splits)
    tset = GQATorchDataset(dset)
    evaluator = GQAEvaluator(dset)
    data_loader = DataLoader(
        tset, batch_size=bs,
        shuffle=shuffle, num_workers=args.num_workers,
        drop_last=drop_last, pin_memory=True
    )

    return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) 
Example #16
Source File: nlvr2.py    From lxmert with MIT License 5 votes vote down vote up
def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
    dset = NLVR2Dataset(splits)
    tset = NLVR2TorchDataset(dset)
    evaluator = NLVR2Evaluator(dset)
    data_loader = DataLoader(
        tset, batch_size=bs,
        shuffle=shuffle, num_workers=args.num_workers,
        drop_last=drop_last, pin_memory=True
    )

    return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) 
Example #17
Source File: vqa.py    From lxmert with MIT License 5 votes vote down vote up
def get_data_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
    dset = VQADataset(splits)
    tset = VQATorchDataset(dset)
    evaluator = VQAEvaluator(dset)
    data_loader = DataLoader(
        tset, batch_size=bs,
        shuffle=shuffle, num_workers=args.num_workers,
        drop_last=drop_last, pin_memory=True
    )

    return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) 
Example #18
Source File: rebalance_dataset_ensemble.py    From swagaf with MIT License 5 votes vote down vote up
def __init__(self, instances, inds, train=True, recompute_assignments=False):
        self.instances = instances
        self.inds = inds
        self.train = train
        self.recompute_assignments = recompute_assignments

        self.dataloader = DataLoader(dataset=self, batch_size=128 if not recompute_assignments else 16,
                                     shuffle=self.train, num_workers=0,
                                     collate_fn=self.collate, drop_last=self.train) 
Example #19
Source File: load_data.py    From swagaf with MIT License 5 votes vote down vote up
def __init__(self, fold, mode):
        self.mode = mode
        self.fold = fold
        self.instances, self.vocab = load_lm_data(fold=self.fold, mode=self.mode)
        self.dataloader = DataLoader(dataset=self, batch_size=32,
                                     shuffle=self.mode == 'train', num_workers=0,
                                     collate_fn=self.collate, drop_last=self.mode == 'train')
        self.indexer = ELMoTokenCharactersIndexer() 
Example #20
Source File: inference_utils.py    From allRank with Apache License 2.0 5 votes vote down vote up
def __create_data_loader(ds: LibSVMDataset, config: Config) -> DataLoader:
    return DataLoader(ds, batch_size=config.data.batch_size, num_workers=config.data.num_workers, shuffle=False) 
Example #21
Source File: test_dataloaders.py    From pytorch-lightning with Apache License 2.0 5 votes vote down vote up
def test_train_dataloader_not_implemented_error_failed(tmpdir):
    """Test not_implemented_error train data loader (e.g. IterableDataset)"""
    model = EvalModelTemplate()
    model.train_dataloader = model.train_dataloader__not_implemented_error

    trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5)

    with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
        trainer.fit(model) 
Example #22
Source File: test_dataloaders.py    From pytorch-lightning with Apache License 2.0 5 votes vote down vote up
def test_val_inf_dataloader_error(tmpdir):
    """Test inf train data loader (e.g. IterableDataset)"""
    model = EvalModelTemplate()
    model.val_dataloader = model.val_dataloader__infinite

    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)

    with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
        trainer.fit(model) 
Example #23
Source File: test_dataloaders.py    From pytorch-lightning with Apache License 2.0 5 votes vote down vote up
def test_train_inf_dataloader_error(tmpdir):
    """Test inf train data loader (e.g. IterableDataset)"""
    model = EvalModelTemplate()
    model.train_dataloader = model.train_dataloader__infinite

    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)

    with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
        trainer.fit(model) 
Example #24
Source File: eval.py    From dsb2018_topcoders with MIT License 5 votes vote down vote up
def predict(self, fold, val_indexes):
        prefix = ('fold' + str(fold) + "_") if self.test else ""
        val_dataset = SequentialDataset(self.ds, val_indexes, stage='test', config=self.config, transforms=self.val_transforms)
        val_dl = PytorchDataLoader(val_dataset, batch_size=self.config.predict_batch_size, num_workers=self.num_workers, drop_last=False)
        model = read_model(self.folder, fold)
        pbar = tqdm.tqdm(val_dl, total=len(val_dl))
        for data in pbar:
            samples = data['image']
            # predicted = predict(model, samples, flips=self.flips)
            predicted = predict8tta(model, samples, self.config.sigmoid)
            self.process_batch(predicted, model, data, prefix=prefix)
        self.post_predict_action(prefix=prefix) 
Example #25
Source File: test-torch-dataloader.py    From Jacinle with MIT License 5 votes vote down vote up
def test_torch_dataloader(self):
        ds = _FakeDataset()
        dl = DataLoader(ds, num_workers=2)
        res = list(dl)
        self.assertEqual(as_float(res[0]), as_float(res[1])) 
Example #26
Source File: __init__.py    From EDSR-PyTorch with MIT License 5 votes vote down vote up
def __init__(self, args):
        self.loader_train = None
        if not args.test_only:
            datasets = []
            for d in args.data_train:
                module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
                m = import_module('data.' + module_name.lower())
                datasets.append(getattr(m, module_name)(args, name=d))

            self.loader_train = dataloader.DataLoader(
                MyConcatDataset(datasets),
                batch_size=args.batch_size,
                shuffle=True,
                pin_memory=not args.cpu,
                num_workers=args.n_threads,
            )

        self.loader_test = []
        for d in args.data_test:
            if d in ['Set5', 'Set14', 'B100', 'Urban100']:
                m = import_module('data.benchmark')
                testset = getattr(m, 'Benchmark')(args, train=False, name=d)
            else:
                module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
                m = import_module('data.' + module_name.lower())
                testset = getattr(m, module_name)(args, train=False, name=d)

            self.loader_test.append(
                dataloader.DataLoader(
                    testset,
                    batch_size=1,
                    shuffle=False,
                    pin_memory=not args.cpu,
                    num_workers=args.n_threads,
                )
            ) 
Example #27
Source File: train.py    From kaggle_carvana_segmentation with MIT License 5 votes vote down vote up
def train(ds, folds, config, num_workers=0, transforms=None, skip_folds=None):
    os.makedirs(os.path.join('..', 'weights'), exist_ok=True)
    os.makedirs(os.path.join('..', 'logs'), exist_ok=True)

    for fold, (train_idx, val_idx) in enumerate(folds):
        if skip_folds and fold in skip_folds:
            continue

        tr = TrainDataset(ds, train_idx, config, transform=transforms)
        val = ValDataset(ds, val_idx, config, transform=None)
        train_loader = PytorchDataLoader(tr,
                                         batch_size=config.batch_size,
                                         shuffle=True,
                                         drop_last=True,
                                         num_workers=num_workers,
                                         pin_memory=True)
        val_loader = PytorchDataLoader(val,
                                       batch_size=config.batch_size,
                                       shuffle=False,
                                       drop_last=False,
                                       num_workers=num_workers,
                                       pin_memory=True)
        trainer = PytorchTrain(fold=fold,
                               config=config,
                               metrics=[('soft dice', dice_loss),
                                        ('hard dice', dice_clamp),
                                        ('bce', nn.modules.loss.BCELoss())])
        trainer.fit(train_loader, val_loader)
        trainer.writer.close() 
Example #28
Source File: __init__.py    From reid_baseline_with_syncbn with MIT License 5 votes vote down vote up
def make_dataloader(cfg, num_gpus=1):
    train_trm = get_trm(cfg, is_train=True)
    val_trm = get_trm(cfg, is_train=False)

    num_workers = cfg.DATALOADER.NUM_WORKERS * num_gpus
    dataset = init_dataset(cfg)

    num_classes = dataset.num_train_pids
    train_set = ImageDataset(dataset.train, cfg, train_trm)
    if cfg.DATALOADER.SAMPLER == 'softmax':
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH * num_gpus, shuffle=True,
            num_workers=num_workers,
            collate_fn=train_collate_fn
        )
    else:
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH * num_gpus,
            sampler=RandomIdentitySampler(dataset.train,
                cfg.SOLVER.IMS_PER_BATCH * num_gpus,
                cfg.DATALOADER.NUM_INSTANCE * num_gpus),
            num_workers=num_workers, collate_fn=train_collate_fn
        )

    val_set = ImageDataset(dataset.query + dataset.gallery, cfg, val_trm)
    val_loader = DataLoader(
        val_set, batch_size=cfg.TEST.IMS_PER_BATCH * num_gpus, shuffle=False,
        num_workers=num_workers,
        collate_fn=val_collate_fn
    )
    return train_loader, val_loader, len(dataset.query), num_classes 
Example #29
Source File: __init__.py    From margipose with Apache License 2.0 5 votes vote down vote up
def make_dataloader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                    num_workers=0, pin_memory=False, drop_last=False):
    return DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
        batch_sampler=batch_sampler, collate_fn=collate, num_workers=num_workers,
        pin_memory=pin_memory, drop_last=drop_last, worker_init_fn=worker_init
    ) 
Example #30
Source File: train.py    From PJ_NLP with Apache License 2.0 5 votes vote down vote up
def train():
    data = np.load(conf.emb_dict_path)
    emb_mat = t.from_numpy(data['vec'])
    word2id = data['word2id'].item()
    del data
    vocab_size = len(word2id)
    print('vocab size : {}'.format(vocab_size))

    dataset = ZhiHuData(conf.train_data)
    data_loader = DataLoader(dataset=dataset, batch_size=conf.batch_size)

    Model = name_model[model_name]
    model = Model(vocab_size, emb_mat).cuda()

    # 打印参数
    get_params_num(model)
    optimizer = model.get_optimizer(conf.lr1, conf.lr2)
    best_score = 0
    step = 0
    for epoch in range(conf.epochs):
        print('epoch:===>', epoch)
        for i, batch in tqdm.tqdm(enumerate(data_loader)):
            title, content, label = batch
            title, content, label = Variable(title.cuda()), Variable(content.cuda()), Variable(label.cuda())
            optimizer.zero_grad()
            output = model(title, content)
            loss = model.loss_fn(output, label.float())
            loss.backward()
            optimizer.step()
            step += 1
            writer.add_scalar('train loss', loss, step)

        scores, prec_, recall_ = val(model)

        if best_score < scores:
            best_score = scores
            t.save(model, conf.model_all_path.format(model_name))
            # t.save(model.state_dict(), conf.model_dict_path.format(model_name))
    # 可视化
    writer.add_graph(model, (title, content))
    writer.close()