Python tensorpack.imgaug.AugmentorList() Examples
The following are 5
code examples of tensorpack.imgaug.AugmentorList().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorpack.imgaug
, or try the search function
.
Example #1
Source File: imagenet_utils.py From ghostnet with Apache License 2.0 | 5 votes |
def get_imagenet_dataflow( datadir, name, batch_size, augmentors, meta_dir=None, parallel=None): """ See explanations in the tutorial: http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html """ assert name in ['train', 'val', 'test'] assert datadir is not None assert isinstance(augmentors, list) isTrain = name == 'train' #parallel = 1 if parallel is None: parallel = min(40, multiprocessing.cpu_count() // 2) # assuming hyperthreading if isTrain: ds = dataset.ILSVRC12(datadir, name, meta_dir=meta_dir, shuffle=True) ds = AugmentImageComponent(ds, augmentors, copy=False) if parallel < 16: logger.warn("DataFlow may become the bottleneck when too few processes are used.") ds = PrefetchDataZMQ(ds, parallel) ds = BatchData(ds, batch_size, remainder=False) else: ds = dataset.ILSVRC12Files(datadir, name, meta_dir= meta_dir, shuffle=False) aug = imgaug.AugmentorList(augmentors) def mapf(dp): fname, cls = dp im = cv2.imread(fname, cv2.IMREAD_COLOR) im = aug.augment(im) return im, cls ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True) ds = BatchData(ds, batch_size, remainder=True) ds = PrefetchDataZMQ(ds, 1) return ds
Example #2
Source File: imagenet_utils.py From benchmarks with The Unlicense | 5 votes |
def get_val_dataflow( datadir, batch_size, augmentors=None, parallel=None, num_splits=None, split_index=None): if augmentors is None: augmentors = fbresnet_augmentor(False) assert datadir is not None assert isinstance(augmentors, list) if parallel is None: parallel = min(40, mp.cpu_count()) if num_splits is None: ds = dataset.ILSVRC12Files(datadir, 'val', shuffle=False) else: # shard validation data assert split_index < num_splits files = dataset.ILSVRC12Files(datadir, 'val', shuffle=False) files.reset_state() files = list(files.get_data()) logger.info("Number of validation data = {}".format(len(files))) split_size = len(files) // num_splits start, end = split_size * split_index, split_size * (split_index + 1) end = min(end, len(files)) logger.info("Local validation split = {} - {}".format(start, end)) files = files[start: end] ds = DataFromList(files, shuffle=False) aug = imgaug.AugmentorList(augmentors) def mapf(dp): fname, cls = dp im = cv2.imread(fname, cv2.IMREAD_COLOR) im = aug.augment(im) return im, cls ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=min(2000, ds.size()), strict=True) ds = BatchData(ds, batch_size, remainder=True) # do not fork() under MPI return ds
Example #3
Source File: imagenet_utils.py From webvision-2.0-benchmarks with Apache License 2.0 | 5 votes |
def get_imagenet_dataflow( datadir, name, batch_size, augmentors, parallel=None): """ See explanations in the tutorial: http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html """ assert name in ['train', 'val', 'test'] assert datadir is not None assert isinstance(augmentors, list) isTrain = name == 'train' meta_dir = os.path.join(datadir, "meta") if parallel is None: parallel = min(40, multiprocessing.cpu_count()) if isTrain: ds = Imagenet5k(datadir, name, meta_dir=meta_dir, shuffle=True) ds = AugmentImageComponent(ds, augmentors, copy=False) if parallel < 16: logger.warn("DataFlow may become the bottleneck when too few processes are used.") ds = PrefetchDataZMQ(ds, parallel) ds = BatchData(ds, batch_size, remainder=False) else: ds = Imagenet5kFiles(datadir, name, meta_dir=meta_dir, shuffle=False) aug = imgaug.AugmentorList(augmentors) def mapf(dp): fname, cls = dp im = cv2.imread(fname, cv2.IMREAD_COLOR) im = aug.augment(im) return im, cls ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True) ds = BatchData(ds, batch_size, remainder=True) ds = PrefetchDataZMQ(ds, 1) return ds
Example #4
Source File: imagenet_utils.py From LQ-Nets with MIT License | 5 votes |
def get_imagenet_dataflow( datadir, name, batch_size, augmentors, parallel=None): """ See explanations in the tutorial: http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html """ assert name in ['train', 'val', 'test'] assert datadir is not None assert isinstance(augmentors, list) isTrain = name == 'train' if parallel is None: parallel = min(30, multiprocessing.cpu_count()) if isTrain: ds = dataset.ILSVRC12(datadir, name, shuffle=True) ds = AugmentImageComponent(ds, augmentors, copy=False) ds = PrefetchDataZMQ(ds, parallel) ds = BatchData(ds, batch_size, remainder=False) else: ds = dataset.ILSVRC12Files(datadir, name, shuffle=False) aug = imgaug.AugmentorList(augmentors) def mapf(dp): fname, cls = dp im = cv2.imread(fname, cv2.IMREAD_COLOR) im = aug.augment(im) return im, cls ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True) ds = BatchData(ds, batch_size, remainder=True) ds = PrefetchDataZMQ(ds, 1) return ds
Example #5
Source File: data_helper.py From tf-mobilenet-v2 with MIT License | 5 votes |
def get_imagenet_dataflow( datadir, name, batch_size, augmentors, parallel=None): """ See explanations in the tutorial: http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html """ assert name in ['train', 'val', 'test'] assert datadir is not None assert isinstance(augmentors, list) isTrain = name == 'train' if parallel is None: parallel = min(40, multiprocessing.cpu_count() // 6) if isTrain: ds = dataset.ILSVRC12(datadir, name, shuffle=True) ds = AugmentImageComponent(ds, augmentors, copy=False) if parallel < 16: logger.warning("DataFlow may become the bottleneck when too few processes are used.") ds = PrefetchData(ds, 1000, parallel) ds = BatchData(ds, batch_size, remainder=False) else: ds = dataset.ILSVRC12Files(datadir, name, shuffle=False) aug = imgaug.AugmentorList(augmentors) def mapf(dp): fname, cls = dp im = np.zeros((256, 256, 3), dtype=np.uint8) for _ in range(30): try: im = cv2.imread(fname, cv2.IMREAD_COLOR) im = aug.augment(im) break except Exception as e: logger.warning(str(e), 'file=', fname) time.sleep(1) return im, cls ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True) ds = BatchData(ds, batch_size, remainder=True) ds = PrefetchData(ds, 100, 1) return ds