Python PIL.Image.open() Examples

The following are 30 code examples of PIL.Image.open(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module PIL.Image , or try the search function .
Example #1
Source File: screenshot.py    From AboveTustin with MIT License 9 votes vote down vote up
def screenshot(self, name):
        '''
        screenshot()
        Takes a screenshot of the browser
        '''
        if do_crop:
            print('cropping screenshot')
            #  Grab screenshot rather than saving
            im = self.browser.get_screenshot_as_png()
            im = Image.open(BytesIO(im))

            #  Crop to specifications
            im = im.crop((crop_x, crop_y, crop_width, crop_height))
            im.save(name)
        else:
            self.browser.save_screenshot(name)
        print("success saving screenshot: %s" % name)
        return name 
Example #2
Source File: datasets.py    From pruning_yolov3 with GNU General Public License v3.0 8 votes vote down vote up
def convert_images2bmp():
    # cv2.imread() jpg at 230 img/s, *.bmp at 400 img/s
    for path in ['../coco/images/val2014/', '../coco/images/train2014/']:
        folder = os.sep + Path(path).name
        output = path.replace(folder, folder + 'bmp')
        if os.path.exists(output):
            shutil.rmtree(output)  # delete output folder
        os.makedirs(output)  # make new output folder

        for f in tqdm(glob.glob('%s*.jpg' % path)):
            save_name = f.replace('.jpg', '.bmp').replace(folder, folder + 'bmp')
            cv2.imwrite(save_name, cv2.imread(f))

    for label_path in ['../coco/trainvalno5k.txt', '../coco/5k.txt']:
        with open(label_path, 'r') as file:
            lines = file.read()
        lines = lines.replace('2014/', '2014bmp/').replace('.jpg', '.bmp').replace(
            '/Users/glennjocher/PycharmProjects/', '../')
        with open(label_path.replace('5k', '5k_bmp'), 'w') as file:
            file.write(lines) 
Example #3
Source File: download_images.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 7 votes vote down vote up
def download_image(image_id, url, x1, y1, x2, y2, output_dir):
    """Downloads one image, crops it, resizes it and saves it locally."""
    output_filename = os.path.join(output_dir, image_id + '.png')
    if os.path.exists(output_filename):
        # Don't download image if it's already there
        return True
    try:
        # Download image
        url_file = urlopen(url)
        if url_file.getcode() != 200:
            return False
        image_buffer = url_file.read()
        # Crop, resize and save image
        image = Image.open(BytesIO(image_buffer)).convert('RGB')
        w = image.size[0]
        h = image.size[1]
        image = image.crop((int(x1 * w), int(y1 * h), int(x2 * w),
                            int(y2 * h)))
        image = image.resize((299, 299), resample=Image.ANTIALIAS)
        image.save(output_filename)
    except IOError:
        return False
    return True 
Example #4
Source File: run_attacks_and_defenses.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def _load_dataset_clipping(self, dataset_dir, epsilon):
    """Helper method which loads dataset and determines clipping range.

    Args:
      dataset_dir: location of the dataset.
      epsilon: maximum allowed size of adversarial perturbation.
    """
    self.dataset_max_clip = {}
    self.dataset_min_clip = {}
    self._dataset_image_count = 0
    for fname in os.listdir(dataset_dir):
      if not fname.endswith('.png'):
        continue
      image_id = fname[:-4]
      image = np.array(
          Image.open(os.path.join(dataset_dir, fname)).convert('RGB'))
      image = image.astype('int32')
      self._dataset_image_count += 1
      self.dataset_max_clip[image_id] = np.clip(image + epsilon,
                                                0,
                                                255).astype('uint8')
      self.dataset_min_clip[image_id] = np.clip(image - epsilon,
                                                0,
                                                255).astype('uint8') 
Example #5
Source File: window.py    From LPHK with GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, master=None):
        tk.Frame.__init__(self, master)
        self.master = master
        self.init_window()
        
        self.about_image = ImageTk.PhotoImage(Image.open(PATH + "/resources/LPHK-banner.png"))
        self.info_image = ImageTk.PhotoImage(Image.open(PATH + "/resources/info.png"))
        self.warning_image = ImageTk.PhotoImage(Image.open(PATH + "/resources/warning.png"))
        self.error_image = ImageTk.PhotoImage(Image.open(PATH + "/resources/error.png"))
        self.alert_image = ImageTk.PhotoImage(Image.open(PATH + "/resources/alert.png"))
        self.scare_image = ImageTk.PhotoImage(Image.open(PATH + "/resources/scare.png"))
        self.grid_drawn = False
        self.grid_rects = [[None for y in range(9)] for x in range(9)]
        self.button_mode = "edit"
        self.last_clicked = None
        self.outline_box = None 
Example #6
Source File: weather-icons.py    From unicorn-hat-hd with MIT License 6 votes vote down vote up
def loop():

    print('Looping through all images in folder {}\n'
          'CRL+C to skip image'.format(folder_path))

    try:

        for img_file in os.listdir(folder_path):

            if img_file.endswith(icon_extension):

                print('Drawing image: {}'.format(folder_path + img_file))

                img = Image.open(folder_path + img_file)

                draw_animation(img)

            else:

                print('Not using this file, might be not an image: {}'.format(img_file))

    except KeyboardInterrupt:
        unicorn.off()

    unicorn.off() 
Example #7
Source File: setup_selectarea.py    From PiPark with GNU General Public License v2.0 6 votes vote down vote up
def output_coords():
    # Open the file to output the co-ordinates to
    f1 = open('./setup_data.py', 'w+')

    # Print the dictionary data to the file
    print >>f1, 'boxes = ['
    
    for i in range(Boxes.length()):
        c = Boxes.get(i).get_output(SelWindow.bgcoords)
        
        if c != None:
            o = (i)
            print >>f1, c, ','

    print >>f1, ']'
    print 'INFO: Box data saved in file boxdata.py.'
    tkMessageBox.showinfo("Pi Setup", "Box data saved in file.")

# -----------------------------------------------------------------------------
# Main Program
# ----------------------------------------------------------------------------- 
Example #8
Source File: preparation.py    From cvpr2018-hnd with MIT License 6 votes vote down vote up
def is_image_file(id, dataset, dtype, filename):
    filename_lower = filename.lower()
    if any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS):
        if dtype == 'novel':
            try:
                default_loader(filename)
                return True
            except OSError:
                print('{filename} failed to load'.format(filename=filename))
                with open('taxonomy/{dataset}/corrupted_{dtype}_{id:d}.txt' \
                          .format(dataset=dataset, dtype=dtype, id=id), 'a') as f:
                    f.write(filename + '\n')
                return False
        else:
            return True
    else:
        return False 
Example #9
Source File: weather-icons.py    From unicorn-hat-hd with MIT License 6 votes vote down vote up
def weather_icons():
    try:

        if argv[1] == 'loop':

            loop()

        elif argv[1] in os.listdir(folder_path):

            print('Drawing Image: {}'.format(argv[1]))

            img = Image.open(folder_path + argv[1])

            draw_animation(img)
            unicorn.off()

        else:
            help()

    except IndexError:
        help() 
Example #10
Source File: validate_submission_lib.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def _prepare_sample_data(self, submission_type):
    """Prepares sample data for the submission.

    Args:
      submission_type: type of the submission.
    """
    # write images
    images = np.random.randint(0, 256,
                               size=[BATCH_SIZE, 299, 299, 3], dtype=np.uint8)
    for i in range(BATCH_SIZE):
      Image.fromarray(images[i, :, :, :]).save(
          os.path.join(self._sample_input_dir, IMAGE_NAME_PATTERN.format(i)))
    # write target class for targeted attacks
    if submission_type == 'targeted_attack':
      target_classes = np.random.randint(1, 1001, size=[BATCH_SIZE])
      target_class_filename = os.path.join(self._sample_input_dir,
                                           'target_class.csv')
      with open(target_class_filename, 'w') as f:
        for i in range(BATCH_SIZE):
          f.write((IMAGE_NAME_PATTERN + ',{1}\n').format(i, target_classes[i])) 
Example #11
Source File: run_attacks_and_defenses.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, filename):
    """Initializes instance of DatasetMetadata."""
    self._true_labels = {}
    self._target_classes = {}
    with open(filename) as f:
      reader = csv.reader(f)
      header_row = next(reader)
      try:
        row_idx_image_id = header_row.index('ImageId')
        row_idx_true_label = header_row.index('TrueLabel')
        row_idx_target_class = header_row.index('TargetClass')
      except ValueError:
        raise IOError('Invalid format of dataset metadata.')
      for row in reader:
        if len(row) < len(header_row):
          # skip partial or empty lines
          continue
        try:
          image_id = row[row_idx_image_id]
          self._true_labels[image_id] = int(row[row_idx_true_label])
          self._target_classes[image_id] = int(row[row_idx_target_class])
        except (IndexError, ValueError):
          raise IOError('Invalid format of dataset metadata') 
Example #12
Source File: validate_submission_lib.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def _prepare_sample_data(self, submission_type):
    """Prepares sample data for the submission.

    Args:
      submission_type: type of the submission.
    """
    # write images
    images = np.random.randint(0, 256,
                               size=[BATCH_SIZE, 299, 299, 3], dtype=np.uint8)
    for i in range(BATCH_SIZE):
      Image.fromarray(images[i, :, :, :]).save(
          os.path.join(self._sample_input_dir, IMAGE_NAME_PATTERN.format(i)))
    # write target class for targeted attacks
    if submission_type == 'targeted_attack':
      target_classes = np.random.randint(1, 1001, size=[BATCH_SIZE])
      target_class_filename = os.path.join(self._sample_input_dir,
                                           'target_class.csv')
      with open(target_class_filename, 'w') as f:
        for i in range(BATCH_SIZE):
          f.write((IMAGE_NAME_PATTERN + ',{1}\n').format(i, target_classes[i])) 
Example #13
Source File: test_imagenet_attacks.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def load_images(input_dir, metadata_file_path, batch_shape):
    """Retrieve numpy arrays of images and labels, read from a directory."""
    num_images = batch_shape[0]
    with open(metadata_file_path) as input_file:
        reader = csv.reader(input_file)
        header_row = next(reader)
        rows = list(reader)

    row_idx_image_id = header_row.index('ImageId')
    row_idx_true_label = header_row.index('TrueLabel')
    images = np.zeros(batch_shape)
    labels = np.zeros(num_images, dtype=np.int32)
    for idx in xrange(num_images):
        row = rows[idx]
        filepath = os.path.join(input_dir, row[row_idx_image_id] + '.png')

        with tf.gfile.Open(filepath, 'rb') as f:
            image = np.array(
                Image.open(f).convert('RGB')).astype(np.float) / 255.0
        images[idx, :, :, :] = image
        labels[idx] = int(row[row_idx_true_label])
    return images, labels 
Example #14
Source File: data.py    From VSE-C with MIT License 6 votes vote down vote up
def __getitem__(self, index):
        """This function returns a tuple that is further passed to collate_fn
        """
        vocab = self.vocab
        root = self.root
        ann_id = self.ids[index]
        img_id = ann_id[0]
        caption = self.dataset[img_id]['sentences'][ann_id[1]]['raw']
        path = self.dataset[img_id]['filename']

        image = Image.open(os.path.join(root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        # Convert caption (string) to word ids.
        tokens = nltk.tokenize.word_tokenize(
            str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target, index, img_id 
Example #15
Source File: data.py    From VSE-C with MIT License 6 votes vote down vote up
def __init__(self, data_path, data_split, vocab, cap_suffix='caps'):
        self.vocab = vocab
        loc = data_path + '/'

        # Captions
        self.captions = []
        with open(loc+'%s_%s.txt' % (data_split, cap_suffix), 'rb') as f:
            for line in f:
                tmp = line.strip()
                if type(tmp) == bytes:
                    tmp = bytes.decode(tmp)
                self.captions.append(tmp)

        # Image features
        self.images = np.load(loc+'%s_ims.npy' % data_split)
        self.length = len(self.captions)
        # rkiros data has redundancy in images, we divide by 5, 10crop doesn't
        if self.images.shape[0] != self.length:
            self.im_div = 5
        else:
            self.im_div = 1
        # the development set for coco is large and so validation would be slow
        if data_split == 'dev':
            self.length = 5000 
Example #16
Source File: image_segmentaion.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def get_data(img_path):
    """get the (1, 3, h, w) np.array data for the supplied image
                Args:
                    img_path (string): the input image path

                Returns:
                    np.array: image data in a (1, 3, h, w) shape

    """
    mean = np.array([123.68, 116.779, 103.939])  # (R,G,B)
    img = Image.open(img_path)
    img = np.array(img, dtype=np.float32)
    reshaped_mean = mean.reshape(1, 1, 3)
    img = img - reshaped_mean
    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2)
    img = np.expand_dims(img, axis=0)
    return img 
Example #17
Source File: data.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def __init__(self, root_dir, flist_name,
                 rgb_mean = (117, 117, 117),
                 cut_off_size = None,
                 data_name = "data",
                 label_name = "softmax_label"):
        super(FileIter, self).__init__()
        self.root_dir = root_dir
        self.flist_name = os.path.join(self.root_dir, flist_name)
        self.mean = np.array(rgb_mean)  # (R, G, B)
        self.cut_off_size = cut_off_size
        self.data_name = data_name
        self.label_name = label_name

        self.num_data = len(open(self.flist_name, 'r').readlines())
        self.f = open(self.flist_name, 'r')
        self.data, self.label = self._read()
        self.cursor = -1 
Example #18
Source File: data.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def get_caltech101_data():
    url = "https://s3.us-east-2.amazonaws.com/mxnet-public/101_ObjectCategories.tar.gz"
    dataset_name = "101_ObjectCategories"
    data_folder = "data"
    if not os.path.isdir(data_folder):
        os.makedirs(data_folder)
    tar_path = mx.gluon.utils.download(url, path=data_folder)
    if (not os.path.isdir(os.path.join(data_folder, "101_ObjectCategories")) or
        not os.path.isdir(os.path.join(data_folder, "101_ObjectCategories_test"))):
        tar = tarfile.open(tar_path, "r:gz")
        tar.extractall(data_folder)
        tar.close()
        print('Data extracted')
    training_path = os.path.join(data_folder, dataset_name)
    testing_path = os.path.join(data_folder, "{}_test".format(dataset_name))
    return training_path, testing_path 
Example #19
Source File: super_resolution.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def resolve(ctx):
    from PIL import Image
    if isinstance(ctx, list):
        ctx = [ctx[0]]
    net.load_parameters('superres.params', ctx=ctx)
    img = Image.open(opt.resolve_img).convert('YCbCr')
    y, cb, cr = img.split()
    data = mx.nd.expand_dims(mx.nd.expand_dims(mx.nd.array(y), axis=0), axis=0)
    out_img_y = mx.nd.reshape(net(data), shape=(-3, -2)).asnumpy()
    out_img_y = out_img_y.clip(0, 255)
    out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')

    out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
    out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
    out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')

    out_img.save('resolved.png') 
Example #20
Source File: ImageAnim.py    From BiblioPixelAnimations with MIT License 5 votes vote down vote up
def _loadGIFSequence(self, imagePath):
        img = Image.open(imagePath)
        if any(self._offset):
            ox, oy = self._offset
        elif self.scale_to:
            ox, oy = 0, 0
        else:
            ox = max(0, (self.layout.width - img.size[0]) // 2)
            oy = max(0, (self.layout.height - img.size[1]) // 2)

        return [self._getBufferFromImage(frame, ox, oy)
                for frame in ImageSequence.Iterator(img)] 
Example #21
Source File: data.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def reset(self):
        self.cursor = -1
        self.f.close()
        self.f = open(self.flist_name, 'r') 
Example #22
Source File: data.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def _read_img(self, img_name, label_name):
        img = Image.open(os.path.join(self.root_dir, img_name))
        label = Image.open(os.path.join(self.root_dir, label_name))
        assert img.size == label.size
        img = np.array(img, dtype=np.float32)  # (h, w, c)
        label = np.array(label)  # (h, w)
        if self.cut_off_size is not None:
            max_hw = max(img.shape[0], img.shape[1])
            min_hw = min(img.shape[0], img.shape[1])
            if min_hw > self.cut_off_size:
                rand_start_max = int(np.random.uniform(0, max_hw - self.cut_off_size - 1))
                rand_start_min = int(np.random.uniform(0, min_hw - self.cut_off_size - 1))
                if img.shape[0] == max_hw :
                    img = img[rand_start_max : rand_start_max + self.cut_off_size, rand_start_min : rand_start_min + self.cut_off_size]
                    label = label[rand_start_max : rand_start_max + self.cut_off_size, rand_start_min : rand_start_min + self.cut_off_size]
                else :
                    img = img[rand_start_min : rand_start_min + self.cut_off_size, rand_start_max : rand_start_max + self.cut_off_size]
                    label = label[rand_start_min : rand_start_min + self.cut_off_size, rand_start_max : rand_start_max + self.cut_off_size]
            elif max_hw > self.cut_off_size:
                rand_start = int(np.random.uniform(0, max_hw - min_hw - 1))
                if img.shape[0] == max_hw :
                    img = img[rand_start : rand_start + min_hw, :]
                    label = label[rand_start : rand_start + min_hw, :]
                else :
                    img = img[:, rand_start : rand_start + min_hw]
                    label = label[:, rand_start : rand_start + min_hw]
        reshaped_mean = self.mean.reshape(1, 1, 3)
        img = img - reshaped_mean
        img = np.swapaxes(img, 0, 2)
        img = np.swapaxes(img, 1, 2)  # (c, h, w)
        img = np.expand_dims(img, axis=0)  # (1, c, h, w)
        label = np.array(label)  # (h, w)
        label = np.expand_dims(label, axis=0)  # (1, h, w)
        return (img, label) 
Example #23
Source File: image_search.py    From ultra_secret_scripts with GNU General Public License v3.0 5 votes vote down vote up
def load_image_from_file(image_filename):
    img = Image.open(image_filename)
    img = np.array(img)
    img = img[:, :, ::-1].copy()
    return img 
Example #24
Source File: super_resolution.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def get_test_image():
    """Download and process the test image"""
    # Load test image
    input_image_dim = 224
    img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg'
    download(img_url, 'super_res_input.jpg')
    img = Image.open('super_res_input.jpg').resize((input_image_dim, input_image_dim))
    img_ycbcr = img.convert("YCbCr")
    img_y, img_cb, img_cr = img_ycbcr.split()
    input_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]
    return input_image, img_cb, img_cr 
Example #25
Source File: unittest_utils_test.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def test_encoded_image_corresponds_to_numpy_array(self):
    image, encoded = unittest_utils.create_random_image('PNG', (20, 10, 3))
    pil_image = PILImage.open(StringIO.StringIO(encoded))
    self.assertAllEqual(image, np.array(pil_image)) 
Example #26
Source File: img.py    From vergeml with MIT License 5 votes vote down vote up
def open_image(path):
    """Open image at path.

    PIL lazily opens the image, which can lead to a 'too many open files' error.
    This workaround reads the file into memory immediately."""
    img1 = Image.open(path)
    img2 = img1.copy()
    img1.close()
    return img2 
Example #27
Source File: preparation.py    From cvpr2018-hnd with MIT License 5 votes vote down vote up
def find_classes(id, num_workers, dataset, dtype):
    dir = 'datasets/{dataset}/{dtype}'.format(dataset=dataset, dtype=dtype)
    classes_path = 'taxonomy/{dataset}/classes_{dtype}_{id:d}.txt'.format(dataset=dataset, dtype=dtype, id=id)
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    num_classes = len(classes)
    with open(classes_path, 'w') as f:
        for cname in classes[id*num_classes//num_workers:(id+1)*num_classes//num_workers]:
            num = len(os.listdir(os.path.join(dir, cname)))
            f.write('{cname}\t{num}\n'.format(cname=cname, num=num))
    return classes 
Example #28
Source File: preparation.py    From cvpr2018-hnd with MIT License 5 votes vote down vote up
def make_dataset(id, num_workers, dataset, dtype, classes, bias, max_num_images):
    dir = 'datasets/{dataset}/{dtype}'.format(dataset=dataset, dtype=dtype)
    if dtype == 'train':
        train_path = 'taxonomy/{dataset}/images_{dtype}_{id:d}.txt'.format(dataset=dataset, dtype='train', id=id)
        val_path   = 'taxonomy/{dataset}/images_{dtype}_{id:d}.txt'.format(dataset=dataset, dtype='val',   id=id)
        fs = [open(train_path, 'w'), open(val_path, 'w')]
    else:
        images_path = 'taxonomy/{dataset}/images_{dtype}_{id:d}.txt'.format(dataset=dataset, dtype=dtype, id=id)
        f = open(images_path, 'w')
    num_classes = len(classes)
    classes_id = list(enumerate(classes))
    for c, cname in classes_id[id*num_classes//num_workers:(id+1)*num_classes//num_workers]:
        d = os.path.join(dir, cname)
        num_images = 0
        stop_flag = False
        if dtype == 'train': f = fs[1]
        for fname in sorted(os.listdir(d)):
            path = os.path.join(d, fname)
            if is_image_file(id, dataset, dtype, path):
                num_images += 1
                f.write('{path}\t{c:d}\n'.format(path=path, c=c+bias))
                if max_num_images >= 0 and num_images >= max_num_images:
                    if dtype == 'train': f = fs[0]
                    else:
                        stop_flag = True
                        break
            if stop_flag:
                break
    if dtype == 'train':
        fs[0].close()
        fs[1].close()
    else:
        f.close() 
Example #29
Source File: preparation.py    From cvpr2018-hnd with MIT License 5 votes vote down vote up
def merge_text(num_workers, dataset, dtype, ttype):
    path = 'taxonomy/{dataset}/{ttype}_{dtype}'.format(dataset=dataset, ttype=ttype, dtype=dtype)
    with open('{path}.txt'.format(path=path), 'w') as fo:
        for id in range(num_workers):
            path_id = '{path}_{id:d}.txt'.format(path=path, id=id)
            if os.path.isfile(path_id):
                with open(path_id, 'r') as fi:
                    fo.write(fi.read())
    for id in range(num_workers):
        path_id = '{path}_{id:d}.txt'.format(path=path, id=id)
        if os.path.isfile(path_id):
            os.remove(path_id) 
Example #30
Source File: super_resolution.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def get_dataset(prefetch=False):
    image_path = os.path.join(dataset_path, "BSDS300/images")

    if not os.path.exists(image_path):
        os.makedirs(dataset_path)
        file_name = download(dataset_url)
        with tarfile.open(file_name) as tar:
            for item in tar:
                tar.extract(item, dataset_path)
        os.remove(file_name)

    crop_size = 256
    crop_size -= crop_size % upscale_factor
    input_crop_size = crop_size // upscale_factor

    input_transform = [CenterCropAug((crop_size, crop_size)), ResizeAug(input_crop_size)]
    target_transform = [CenterCropAug((crop_size, crop_size))]

    iters = (ImagePairIter(os.path.join(image_path, "train"),
                           (input_crop_size, input_crop_size),
                           (crop_size, crop_size),
                           batch_size, color_flag, input_transform, target_transform),
             ImagePairIter(os.path.join(image_path, "test"),
                           (input_crop_size, input_crop_size),
                           (crop_size, crop_size),
                           test_batch_size, color_flag,
                           input_transform, target_transform))

    return [PrefetchingIter(i) for i in iters] if prefetch else iters