Python tensorpack.dataflow.BatchData() Examples
The following are 11
code examples of tensorpack.dataflow.BatchData().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorpack.dataflow
, or try the search function
.
Example #1
Source File: data.py From sequential-imagenet-dataloader with MIT License | 6 votes |
def __init__(self, mode, batch_size=256, shuffle=False, num_workers=25, cache=50000, collate_fn=default_collate, drop_last=False, cuda=False): # enumerate standard imagenet augmentors imagenet_augmentors = fbresnet_augmentor(mode == 'train') # load the lmdb if we can find it lmdb_loc = os.path.join(os.environ['IMAGENET'],'ILSVRC-%s.lmdb'%mode) ds = td.LMDBData(lmdb_loc, shuffle=False) ds = td.LocallyShuffleData(ds, cache) ds = td.PrefetchData(ds, 5000, 1) ds = td.LMDBDataPoint(ds) ds = td.MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0) ds = td.AugmentImageComponent(ds, imagenet_augmentors) ds = td.PrefetchDataZMQ(ds, num_workers) self.ds = td.BatchData(ds, batch_size) self.ds.reset_state() self.batch_size = batch_size self.num_workers = num_workers self.cuda = cuda #self.drop_last = drop_last
Example #2
Source File: utils_tp.py From imgclsmob with MIT License | 5 votes |
def get_imagenet_dataflow(datadir, is_train, batch_size, augmentors, parallel=None): """ See explanations in the tutorial: http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html """ assert datadir is not None assert isinstance(augmentors, list) if parallel is None: parallel = min(40, multiprocessing.cpu_count() // 2) # assuming hyperthreading if is_train: ds = dataset.ILSVRC12(datadir, "train", shuffle=True) ds = AugmentImageComponent(ds, augmentors, copy=False) if parallel < 16: logging.warning("DataFlow may become the bottleneck when too few processes are used.") ds = PrefetchDataZMQ(ds, parallel) ds = BatchData(ds, batch_size, remainder=False) else: ds = dataset.ILSVRC12Files(datadir, "val", shuffle=False) aug = imgaug.AugmentorList(augmentors) def mapf(dp): fname, cls = dp im = cv2.imread(fname, cv2.IMREAD_COLOR) im = np.flip(im, axis=2) # print("fname={}".format(fname)) im = aug.augment(im) return im, cls ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True) # ds = MapData(ds, mapf) ds = BatchData(ds, batch_size, remainder=True) ds = PrefetchDataZMQ(ds, 1) # ds = PrefetchData(ds, 1) return ds
Example #3
Source File: dataietr.py From face_landmark with Apache License 2.0 | 5 votes |
def build_iter(self): ds = DataFromGenerator(self.generator) ds = BatchData(ds, self.batch_size) ds = MultiProcessPrefetchData(ds, self.prefetch_size, self.process_num) ds.reset_state() ds = ds.get_data() return ds
Example #4
Source File: misc.py From petridishnn with MIT License | 5 votes |
def preprocess_data_flow(ds, options, is_train, do_multiprocess=False): ds_size = ds.size() while options.batch_size > ds_size: options.batch_size //= 2 ds = BatchData(ds, max(1, options.batch_size // options.nr_gpu), remainder=not is_train) if do_multiprocess: ds = PrefetchData(ds, 5, 5) return ds
Example #5
Source File: critic.py From petridishnn with MIT License | 5 votes |
def critic_dataflow_factory(ctrl, data, is_train): """ Generate a critic dataflow """ if ctrl.critic_type == CriticTypes.CONV: ds = ConvCriticDataFlow(data, shuffle=is_train, max_depth=ctrl.controller_max_depth) ds = BatchData(ds, ctrl.controller_batch_size, remainder=not is_train, use_list=False) elif ctrl.critic_type == CriticTypes.LSTM: ds = LSTMCriticDataFlow(data, shuffle=is_train) ds = BatchData(ds, ctrl.controller_batch_size, remainder=not is_train, use_list=True) return ds
Example #6
Source File: data_loader.py From ADL with MIT License | 5 votes |
def get_data(split, option): is_training = split == 'train' parallel = multiprocessing.cpu_count() // 2 ds = get_data_flow(split, is_training, option) augmentors = fbresnet_augmentor(is_training, option) ds = AugmentImageCoordinates(ds, augmentors, coords_index=2, copy=False) if is_training: ds = PrefetchDataZMQ(ds, parallel) ds = BatchData(ds, option.batch_size, remainder=not is_training) return ds
Example #7
Source File: dataietr.py From faceboxes-tensorflow with Apache License 2.0 | 5 votes |
def build_iter(self): ds = DataFromGenerator(self.generator) ds = BatchData(ds, self.num_gpu * self.batch_size) ds = MultiProcessPrefetchData(ds, self.prefetch_size, self.process_num) ds.reset_state() ds = ds.get_data() return ds
Example #8
Source File: resnet50_for_embedding.py From will-people-like-your-image with GNU Lesser General Public License v3.0 | 5 votes |
def get_data(lmdb_path, txt_path): if txt_path: ds = arod_dataflow_from_txt.Triplets(lmdb_path, txt_path, IMAGE_HEIGHT, IMAGE_WIDTH) else: ds = arod_provider.Triplets(lmdb_path, IMAGE_HEIGHT, IMAGE_WIDTH) ds.reset_state() cpu = min(10, multiprocessing.cpu_count()) ds = PrefetchDataZMQ(ds, cpu) ds = BatchData(ds, BATCH_SIZE) return ds
Example #9
Source File: dataietr.py From PINTO_model_zoo with MIT License | 5 votes |
def build_iter(self): ds = DataFromGenerator(self.generator) ds = BatchData(ds, self.batch_size) ds = MultiProcessPrefetchData(ds, self.prefetch_size, self.process_num) ds.reset_state() ds = ds.get_data() return ds
Example #10
Source File: mnist-keras-v2.py From tensorpack with Apache License 2.0 | 5 votes |
def get_data(): def f(dp): im = dp[0][:, :, None] onehot = np.eye(10)[dp[1]] return [im, onehot] train = BatchData(MapData(dataset.Mnist('train'), f), 128) test = BatchData(MapData(dataset.Mnist('test'), f), 256) return train, test
Example #11
Source File: embedding_data.py From tensorpack with Apache License 2.0 | 5 votes |
def get_test_data(batch=128): ds = dataset.Mnist('test') ds = BatchData(ds, batch) return ds