Python tensorpack.dataflow.BatchData() Examples

The following are 11 code examples of tensorpack.dataflow.BatchData(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorpack.dataflow , or try the search function .
Example #1
Source File: data.py    From sequential-imagenet-dataloader with MIT License 6 votes vote down vote up
def __init__(self, mode, batch_size=256, shuffle=False, num_workers=25, cache=50000,
            collate_fn=default_collate,  drop_last=False, cuda=False):
        # enumerate standard imagenet augmentors
        imagenet_augmentors = fbresnet_augmentor(mode == 'train')

        # load the lmdb if we can find it
        lmdb_loc = os.path.join(os.environ['IMAGENET'],'ILSVRC-%s.lmdb'%mode)
        ds = td.LMDBData(lmdb_loc, shuffle=False)
        ds = td.LocallyShuffleData(ds, cache)
        ds = td.PrefetchData(ds, 5000, 1)
        ds = td.LMDBDataPoint(ds)
        ds = td.MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0)
        ds = td.AugmentImageComponent(ds, imagenet_augmentors)
        ds = td.PrefetchDataZMQ(ds, num_workers)
        self.ds = td.BatchData(ds, batch_size)
        self.ds.reset_state()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.cuda = cuda
        #self.drop_last = drop_last 
Example #2
Source File: utils_tp.py    From imgclsmob with MIT License 5 votes vote down vote up
def get_imagenet_dataflow(datadir,
                          is_train,
                          batch_size,
                          augmentors,
                          parallel=None):
    """
    See explanations in the tutorial:
    http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
    """
    assert datadir is not None
    assert isinstance(augmentors, list)
    if parallel is None:
        parallel = min(40, multiprocessing.cpu_count() // 2)  # assuming hyperthreading
    if is_train:
        ds = dataset.ILSVRC12(datadir, "train", shuffle=True)
        ds = AugmentImageComponent(ds, augmentors, copy=False)
        if parallel < 16:
            logging.warning("DataFlow may become the bottleneck when too few processes are used.")
        ds = PrefetchDataZMQ(ds, parallel)
        ds = BatchData(ds, batch_size, remainder=False)
    else:
        ds = dataset.ILSVRC12Files(datadir, "val", shuffle=False)
        aug = imgaug.AugmentorList(augmentors)

        def mapf(dp):
            fname, cls = dp
            im = cv2.imread(fname, cv2.IMREAD_COLOR)
            im = np.flip(im, axis=2)
            # print("fname={}".format(fname))
            im = aug.augment(im)
            return im, cls
        ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
        # ds = MapData(ds, mapf)
        ds = BatchData(ds, batch_size, remainder=True)
        ds = PrefetchDataZMQ(ds, 1)
        # ds = PrefetchData(ds, 1)
    return ds 
Example #3
Source File: dataietr.py    From face_landmark with Apache License 2.0 5 votes vote down vote up
def build_iter(self):

        ds = DataFromGenerator(self.generator)
        ds = BatchData(ds, self.batch_size)
        ds = MultiProcessPrefetchData(ds, self.prefetch_size, self.process_num)
        ds.reset_state()
        ds = ds.get_data()
        return ds 
Example #4
Source File: misc.py    From petridishnn with MIT License 5 votes vote down vote up
def preprocess_data_flow(ds, options, is_train, do_multiprocess=False):
    ds_size = ds.size()
    while options.batch_size > ds_size:
        options.batch_size //= 2
    ds = BatchData(ds, max(1, options.batch_size // options.nr_gpu),
        remainder=not is_train)
    if do_multiprocess:
        ds = PrefetchData(ds, 5, 5)
    return ds 
Example #5
Source File: critic.py    From petridishnn with MIT License 5 votes vote down vote up
def critic_dataflow_factory(ctrl, data, is_train):
    """
    Generate a critic dataflow
    """
    if ctrl.critic_type == CriticTypes.CONV:
        ds = ConvCriticDataFlow(data, shuffle=is_train, max_depth=ctrl.controller_max_depth)
        ds = BatchData(ds, ctrl.controller_batch_size, remainder=not is_train, use_list=False)
    elif ctrl.critic_type == CriticTypes.LSTM:
        ds = LSTMCriticDataFlow(data, shuffle=is_train)
        ds = BatchData(ds, ctrl.controller_batch_size, remainder=not is_train, use_list=True)
    return ds 
Example #6
Source File: data_loader.py    From ADL with MIT License 5 votes vote down vote up
def get_data(split, option):
    is_training = split == 'train'
    parallel = multiprocessing.cpu_count() // 2
    ds = get_data_flow(split, is_training, option)
    augmentors = fbresnet_augmentor(is_training, option)
    ds = AugmentImageCoordinates(ds, augmentors, coords_index=2, copy=False)
    if is_training:
        ds = PrefetchDataZMQ(ds, parallel)
    ds = BatchData(ds, option.batch_size, remainder=not is_training)
    return ds 
Example #7
Source File: dataietr.py    From faceboxes-tensorflow with Apache License 2.0 5 votes vote down vote up
def build_iter(self):


        ds = DataFromGenerator(self.generator)

        ds = BatchData(ds, self.num_gpu *  self.batch_size)

        ds = MultiProcessPrefetchData(ds, self.prefetch_size, self.process_num)
        ds.reset_state()
        ds = ds.get_data()
        return ds 
Example #8
Source File: resnet50_for_embedding.py    From will-people-like-your-image with GNU Lesser General Public License v3.0 5 votes vote down vote up
def get_data(lmdb_path, txt_path):

        if txt_path:
            ds = arod_dataflow_from_txt.Triplets(lmdb_path, txt_path, IMAGE_HEIGHT, IMAGE_WIDTH)
        else:
            ds = arod_provider.Triplets(lmdb_path, IMAGE_HEIGHT, IMAGE_WIDTH)

        ds.reset_state()
        cpu = min(10, multiprocessing.cpu_count())
        ds = PrefetchDataZMQ(ds, cpu)
        ds = BatchData(ds, BATCH_SIZE)
        return ds 
Example #9
Source File: dataietr.py    From PINTO_model_zoo with MIT License 5 votes vote down vote up
def build_iter(self):

        ds = DataFromGenerator(self.generator)
        ds = BatchData(ds, self.batch_size)
        ds = MultiProcessPrefetchData(ds, self.prefetch_size, self.process_num)
        ds.reset_state()
        ds = ds.get_data()
        return ds 
Example #10
Source File: mnist-keras-v2.py    From tensorpack with Apache License 2.0 5 votes vote down vote up
def get_data():
    def f(dp):
        im = dp[0][:, :, None]
        onehot = np.eye(10)[dp[1]]
        return [im, onehot]

    train = BatchData(MapData(dataset.Mnist('train'), f), 128)
    test = BatchData(MapData(dataset.Mnist('test'), f), 256)
    return train, test 
Example #11
Source File: embedding_data.py    From tensorpack with Apache License 2.0 5 votes vote down vote up
def get_test_data(batch=128):
    ds = dataset.Mnist('test')
    ds = BatchData(ds, batch)
    return ds