Python torch.utils.data.SubsetRandomSampler() Examples

The following are 2 code examples of torch.utils.data.SubsetRandomSampler(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.utils.data , or try the search function .
Example #1
Source File: main.py    From aster.pytorch with MIT License 6 votes vote down vote up
def get_dataloader(synthetic_dataset, real_dataset, height, width, batch_size, workers,
                   is_train, keep_ratio):
  num_synthetic_dataset = len(synthetic_dataset)
  num_real_dataset = len(real_dataset)

  synthetic_indices = list(np.random.permutation(num_synthetic_dataset))
  synthetic_indices = synthetic_indices[num_real_dataset:]
  real_indices = list(np.random.permutation(num_real_dataset) + num_synthetic_dataset)
  concated_indices = synthetic_indices + real_indices
  assert len(concated_indices) == num_synthetic_dataset

  sampler = SubsetRandomSampler(concated_indices)
  concated_dataset = ConcatDataset([synthetic_dataset, real_dataset])
  print('total image: ', len(concated_dataset))

  data_loader = DataLoader(concated_dataset, batch_size=batch_size, num_workers=workers,
    shuffle=False, pin_memory=True, drop_last=True, sampler=sampler,
    collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio))
  return concated_dataset, data_loader 
Example #2
Source File: cgnn.py    From cgnn with Apache License 2.0 4 votes vote down vote up
def main(device, model_param, optimizer_param, scheduler_param, dataset_param, dataloader_param,
         num_epochs, seed, load_model):
    print("Seed:", seed)
    print()
    torch.manual_seed(seed)

    dataloader_param["collate_fn"] = graph_collate

    # Create dataset
    dataset = GraphDataset(dataset_param["dataset_path"], dataset_param["target_name"])

    # split the dataset into training, validation, and test sets.
    split_file_path = dataset_param["split_file"]
    if split_file_path is not None and os.path.isfile(split_file_path):
        with open(split_file_path) as f:
            split = json.load(f)
    else:
        print("No split file. Default split: 256 (train), 32 (val), 32 (test)")
        split = {"train": range(256), "val": range(256, 288), "test": range(288, 320)}
    print(" ".join(["{}: {}".format(k, len(x)) for k, x in split.items()]))

    # Create a CGNN model
    model = create_model(device, model_param, optimizer_param, scheduler_param)
    if load_model:
        print("Loading weights from model.pth")
        model.load()
    #print("Model:", model.device)

    # Train
    train_sampler = SubsetRandomSampler(split["train"])
    val_sampler = SubsetRandomSampler(split["val"])
    train_dl = DataLoader(dataset, sampler=train_sampler, **dataloader_param)
    val_dl = DataLoader(dataset, sampler=val_sampler, **dataloader_param)
    model.train(train_dl, val_dl, num_epochs)
    if num_epochs > 0:
        model.save()

    # Test
    test_set = Subset(dataset, split["test"])
    test_dl = DataLoader(test_set, **dataloader_param)
    outputs, targets = model.evaluate(test_dl)
    names = [dataset.graph_names[i] for i in split["test"]]
    df_predictions = pd.DataFrame({"name": names, "prediction": outputs, "target": targets})
    df_predictions.to_csv("test_predictions.csv", index=False)

    print("\nEND")