Python tensorflow.read_file() Examples
The following are 30
code examples of tensorflow.read_file().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: camvid.py From MachineLearning with Apache License 2.0 | 6 votes |
def CamVid_reader_seq(filename_queue, seq_length): image_seq_filenames = tf.split(axis=0, num_or_size_splits=seq_length, value=filename_queue[0]) label_seq_filenames = tf.split(axis=0, num_or_size_splits=seq_length, value=filename_queue[1]) image_seq = [] label_seq = [] for im ,la in zip(image_seq_filenames, label_seq_filenames): imageValue = tf.read_file(tf.squeeze(im)) labelValue = tf.read_file(tf.squeeze(la)) image_bytes = tf.image.decode_png(imageValue) label_bytes = tf.image.decode_png(labelValue) image = tf.cast(tf.reshape(image_bytes, (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH)), tf.float32) label = tf.cast(tf.reshape(label_bytes, (IMAGE_HEIGHT, IMAGE_WIDTH, 1)), tf.int64) image_seq.append(image) label_seq.append(label) return image_seq, label_seq
Example #2
Source File: image_reader_cuda.py From Siamese-RPN-tensorflow with MIT License | 6 votes |
def read_from_disk(self,queue): index_t=queue[0]#tf.random_shuffle(self.input_list)[0] index_min=tf.reshape(tf.where(tf.less_equal(self.node,index_t)),[-1]) node_min=self.node[index_min[-1]] node_max=self.node[index_min[-1]+1] interval_list=list(range(30,100)) interval=tf.random_shuffle(interval_list)[0] index_d=[tf.cond(tf.greater(index_t-interval,node_min),lambda:index_t-interval,lambda:index_t+interval),tf.cond(tf.less(index_t+interval,node_max),lambda:index_t+interval,lambda:index_t-interval)] index_d=tf.random_shuffle(index_d) index_d=index_d[0] constant_t=tf.read_file(self.img_list[index_t]) template=tf.image.decode_jpeg(constant_t, channels=3) template=template[:,:,::-1] constant_d=tf.read_file(self.img_list[index_d]) detection=tf.image.decode_jpeg(constant_d, channels=3) detection=detection[:,:,::-1] template_label=self.label_list[index_t] detection_label=self.label_list[index_d] template_p,template_label_p,_,_=self.crop_resize(template,template_label,1) detection_p,detection_label_p,offset,ratio=self.crop_resize(detection,detection_label,2) return template_p,template_label_p,detection_p,detection_label_p,offset,ratio,detection,detection_label,index_t,index_d
Example #3
Source File: inputs.py From TF-SegNet with MIT License | 6 votes |
def dataset_reader(filename_queue): #prev name: CamVid_reader image_filename = filename_queue[0] #tensor of type string label_filename = filename_queue[1] #tensor of type string #get png encoded image imageValue = tf.read_file(image_filename) labelValue = tf.read_file(label_filename) #decodes a png image into a uint8 or uint16 tensor #returns a tensor of type dtype with shape [height, width, depth] image_bytes = tf.image.decode_png(imageValue) label_bytes = tf.image.decode_png(labelValue) #Labels are png, not jpeg image = tf.reshape(image_bytes, (FLAGS.image_h, FLAGS.image_w, FLAGS.image_c)) label = tf.reshape(label_bytes, (FLAGS.image_h, FLAGS.image_w, 1)) return image, label
Example #4
Source File: dataset_input.py From LiTS---Liver-Tumor-Segmentation-Challenge with MIT License | 6 votes |
def _parse_function(image, mask): image_string = tf.read_file(image) mask_string = tf.read_file(mask) if image_type == 'jpg': image_decoded = tf.image.decode_jpeg(image_string, 0) mask_decoded = tf.image.decode_jpeg(mask_string, 1) elif image_type == 'png': image_decoded = tf.image.decode_png(image_string, 0) mask_decoded = tf.image.decode_png(mask_string, 1) elif image_type == 'bmp': image_decoded = tf.image.decode_bmp(image_string, 0) mask_decoded = tf.image.decode_bmp(mask_string, 1) else: raise TypeError('==> Error: Only support jpg, png and bmp.') # already in 0~1 image_decoded = tf.image.convert_image_dtype(image_decoded, tf.float32) mask_decoded = tf.image.convert_image_dtype(mask_decoded, tf.float32) return image_decoded, mask_decoded
Example #5
Source File: image_reader.py From LIP_JPPNet with MIT License | 6 votes |
def read_images_from_disk(input_queue, input_size, random_scale, random_mirror): # optional pre-processing arguments """Read one image and its corresponding mask with optional pre-processing. Args: input_queue: tf queue with paths to the image and its mask. input_size: a tuple with (height, width) values. If not given, return images of original size. random_scale: whether to randomly scale the images prior to random crop. random_mirror: whether to randomly mirror the images prior to random crop. Returns: Two tensors: the decoded image and its mask. """ img_contents = tf.read_file(input_queue[0]) img = tf.image.decode_jpeg(img_contents, channels=3) img_r, img_g, img_b = tf.split(value=img, num_or_size_splits=3, axis=2) img = tf.cast(tf.concat([img_b, img_g, img_r], 2), dtype=tf.float32) # Extract mean. img -= IMG_MEAN return img
Example #6
Source File: data.py From generative-compression with MIT License | 6 votes |
def load_inference(filenames, labels, batch_size, resize=(32,32)): # Single image estimation over multiple stochastic forward passes def _preprocess_inference(image_path, label, resize=(32,32)): # Preprocess individual images during inference image_path = tf.squeeze(image_path) image = tf.image.decode_png(tf.read_file(image_path)) image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.per_image_standardization(image) image = tf.image.resize_images(image, size=resize) return image, label dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) dataset = dataset.map(_preprocess_inference) dataset = dataset.batch(batch_size) return dataset
Example #7
Source File: style_transfer_modeler.py From lambda-deep-learning-demo with Apache License 2.0 | 6 votes |
def compute_style_feature(self): style_image = tf.read_file(self.config.style_image_path) style_image = \ tf.image.decode_jpeg(style_image, channels=self.config.image_depth, dct_method="INTEGER_ACCURATE") style_image = tf.to_float(style_image) style_image = vgg_preprocessing._mean_image_subtraction(style_image) style_image = tf.expand_dims(style_image, 0) (logits, features), self.feature_net_init_flag = self.feature_net( style_image, self.config.data_format, is_training=False, init_flag=self.feature_net_init_flag, ckpt_path=self.config.feature_net_path) self.style_features_target_op = {} for style_layer in self.style_layers: layer = features[style_layer] self.style_features_target_op[style_layer] = \ self.compute_gram(layer, self.config.data_format) return self.style_features_target_op
Example #8
Source File: object_detection_mscoco_inputter.py From lambda-deep-learning-demo with Apache License 2.0 | 6 votes |
def parse_fn(self, image_id, file_name, classes, boxes): """Parse a single input sample """ image = tf.read_file(file_name) image = tf.image.decode_png(image, channels=3) image = tf.to_float(image) scale = [0, 0] translation = [0, 0] if self.augmenter: is_training = (self.config.mode == "train") image, classes, boxes, scale, translation = self.augmenter.augment( image, classes, boxes, self.config.resolution, is_training=is_training, speed_mode=False) return ([image_id], image, classes, boxes, scale, translation, [file_name])
Example #9
Source File: style_transfer_csv_inputter.py From lambda-deep-learning-demo with Apache License 2.0 | 6 votes |
def parse_fn(self, image_path): """Parse a single input sample """ image = tf.read_file(image_path) image = tf.image.decode_jpeg(image, channels=self.config.image_depth, dct_method="INTEGER_ACCURATE") if self.config.mode == "infer": image = tf.to_float(image) image = vgg_preprocessing._mean_image_subtraction(image) else: if self.augmenter: is_training = (self.config.mode == "train") image = self.augmenter.augment( image, self.config.image_height, self.config.image_width, self.config.resize_side_min, self.config.resize_side_max, is_training=is_training, speed_mode=self.config.augmenter_speed_mode) return (image,)
Example #10
Source File: image_classification_csv_inputter.py From lambda-deep-learning-demo with Apache License 2.0 | 6 votes |
def parse_fn(self, image_path, label): """Parse a single input sample """ image = tf.read_file(image_path) image = tf.image.decode_jpeg(image, channels=self.config.image_depth, dct_method="INTEGER_ACCURATE") if self.augmenter: is_training = (self.config.mode == "train") image = self.augmenter.augment( image, self.config.image_height, self.config.image_width, is_training=is_training, speed_mode=self.config.augmenter_speed_mode) label = tf.one_hot(label, depth=self.config.num_classes) return (image, label)
Example #11
Source File: image_segmentation_csv_inputter.py From lambda-deep-learning-demo with Apache License 2.0 | 6 votes |
def parse_fn(self, image_path, label_path): """Parse a single input sample """ image = tf.read_file(image_path) image = tf.image.decode_png(image, channels=self.config.image_depth) if self.config.mode == "infer": image = tf.to_float(image) image = vgg_preprocessing._mean_image_subtraction(image) label = image[0] return image, label else: label = tf.read_file(label_path) label = tf.image.decode_png(label, channels=1) label = tf.cast(label, dtype=tf.int64) if self.augmenter: is_training = (self.config.mode == "train") return self.augmenter.augment(image, label, self.config.output_height, self.config.output_width, self.config.resize_side_min, self.config.resize_side_max, is_training=is_training, speed_mode=self.config.augmenter_speed_mode)
Example #12
Source File: data.py From generative-compression with MIT License | 6 votes |
def load_inference(filenames, labels, batch_size, resize=(32,32)): # Single image estimation over multiple stochastic forward passes def _preprocess_inference(image_path, label, resize=(32,32)): # Preprocess individual images during inference image_path = tf.squeeze(image_path) image = tf.image.decode_png(tf.read_file(image_path)) image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.per_image_standardization(image) image = tf.image.resize_images(image, size=resize) return image, label dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) dataset = dataset.map(_preprocess_inference) dataset = dataset.batch(batch_size) return dataset
Example #13
Source File: predict_spatial.py From sign-language-gesture-recognition with MIT License | 6 votes |
def read_tensor_from_image_file(frames, input_height=299, input_width=299, input_mean=0, input_std=255): input_name = "file_reader" frames = [(tf.read_file(frame, input_name), frame) for frame in frames] decoded_frames = [] for frame in frames: file_name = frame[1] file_reader = frame[0] if file_name.endswith(".png"): image_reader = tf.image.decode_png(file_reader, channels=3, name="png_reader") elif file_name.endswith(".gif"): image_reader = tf.squeeze(tf.image.decode_gif(file_reader, name="gif_reader")) elif file_name.endswith(".bmp"): image_reader = tf.image.decode_bmp(file_reader, name="bmp_reader") else: image_reader = tf.image.decode_jpeg(file_reader, channels=3, name="jpeg_reader") decoded_frames.append(image_reader) float_caster = [tf.cast(image_reader, tf.float32) for image_reader in decoded_frames] float_caster = tf.stack(float_caster) resized = tf.image.resize_bilinear(float_caster, [input_height, input_width]) normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) sess = tf.Session() result = sess.run(normalized) return result
Example #14
Source File: data_loader.py From tensorflow_multigpu_imagenet with MIT License | 6 votes |
def preprocess(self, filename): # Read examples from files in the filename queue. file_content = tf.read_file(filename) # Read JPEG or PNG or GIF image from file reshaped_image = tf.to_float(tf.image.decode_jpeg(file_content, channels=self.raw_size[2])) # Resize image to 256*256 reshaped_image = tf.image.resize_images(reshaped_image, (self.raw_size[0], self.raw_size[1])) img_info = filename if self.is_training: reshaped_image = self._train_preprocess(reshaped_image) else: reshaped_image = self._test_preprocess(reshaped_image) # Subtract off the mean and divide by the variance of the pixels. reshaped_image = tf.image.per_image_standardization(reshaped_image) # Set the shapes of tensors. reshaped_image.set_shape(self.processed_size) return reshaped_image
Example #15
Source File: kitti_segnet.py From MachineLearning with Apache License 2.0 | 6 votes |
def CamVid_reader_seq(filename_queue, seq_length): image_seq_filenames = tf.split(axis=0, num_or_size_splits=seq_length, value=filename_queue[0]) label_seq_filenames = tf.split(axis=0, num_or_size_splits=seq_length, value=filename_queue[1]) image_seq = [] label_seq = [] for im ,la in zip(image_seq_filenames, label_seq_filenames): imageValue = tf.read_file(tf.squeeze(im)) labelValue = tf.read_file(tf.squeeze(la)) image_bytes = tf.image.decode_png(imageValue) label_bytes = tf.image.decode_png(labelValue) image = tf.cast(tf.reshape(image_bytes, (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH)), tf.float32) label = tf.cast(tf.reshape(label_bytes, (IMAGE_HEIGHT, IMAGE_WIDTH, 1)), tf.int64) image_seq.append(image) label_seq.append(label) return image_seq, label_seq
Example #16
Source File: data_reader.py From Learning2AdaptForStereo with Apache License 2.0 | 6 votes |
def read_image_from_disc(image_path,shape=None,dtype=tf.uint8): """ Create a queue to hoold the paths of files to be loaded, then create meta op to read and decode image Args: image_path: metaop with path of the image to be loaded shape: optional shape for the image Returns: meta_op with image_data """ image_raw = tf.read_file(image_path) if dtype==tf.uint8: image = tf.image.decode_image(image_raw) else: image = tf.image.decode_png(image_raw,dtype=dtype) if shape is None: image.set_shape([None,None,3]) else: image.set_shape(shape) return tf.cast(image, dtype=tf.float32)
Example #17
Source File: dataset.py From PSPNet-Keras-tensorflow with MIT License | 5 votes |
def load_image(self, image_path, is_jpeg): # Read the file file_data = tf.read_file(image_path) # Decode the image data img = tf.cond( is_jpeg, lambda: tf.image.decode_jpeg(file_data, channels=self.data_spec.channels), lambda: tf.image.decode_png(file_data, channels=self.data_spec.channels)) if self.data_spec.expects_bgr: # Convert from RGB channel ordering to BGR # This matches, for instance, how OpenCV orders the channels. img = tf.reverse(img, [False, False, True]) return img
Example #18
Source File: MapillaryLike_instance.py From MOTSFusion with MIT License | 5 votes |
def load_annotation(self, img, img_filename, annotation_filename): annotation_filename_without_id = tf.string_split([annotation_filename], ':').values[0] ann_data = tf.read_file(annotation_filename_without_id) ann = tf.image.decode_png(ann_data, dtype=tf.uint16, channels=1) ann.set_shape(img.get_shape().as_list()[:-1] + [1]) ann = self.postproc_annotation(annotation_filename, ann) return ann
Example #19
Source File: Dataset.py From MOTSFusion with MIT License | 5 votes |
def load_annotation(self, img, img_filename, annotation_filename): ann_data = tf.read_file(annotation_filename) ann = tf.image.decode_image(ann_data, channels=1) ann.set_shape(img.get_shape().as_list()[:-1] + [1]) ann = self.postproc_annotation(annotation_filename, ann) return ann
Example #20
Source File: Dataset.py From MOTSFusion with MIT License | 5 votes |
def load_image(self, img_filename): img_data = tf.read_file(img_filename) img = tf.image.decode_image(img_data, channels=3) img = tf.image.convert_image_dtype(img, tf.float32) img.set_shape((None, None, 3)) return img
Example #21
Source File: train.py From DeepBlending with Apache License 2.0 | 5 votes |
def read_jpg(self, image_path): image = tf.image.decode_jpeg(tf.read_file(tf.string_join([self.data_folder, image_path], "/"))) image = tf.image.convert_image_dtype(image, tf.float32) return image ############################################ # Helper functions for our training losses # ############################################
Example #22
Source File: utils.py From R3Det_Tensorflow with MIT License | 5 votes |
def build_dataset(self, filenames, labels, is_training): """Build input dataset.""" batch_drop_remainder = False if 'condconv' in self.model_name and not is_training: # CondConv layers can only be called with known batch dimension. Thus, we # must drop all remaining examples that do not make up one full batch. # To ensure all examples are evaluated, use a batch size that evenly # divides the number of files. batch_drop_remainder = True num_files = len(filenames) if num_files % self.batch_size != 0: tf.logging.warn('Remaining examples in last batch are not being ' 'evaluated.') filenames = tf.constant(filenames) labels = tf.constant(labels) dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) def _parse_function(filename, label): image_string = tf.read_file(filename) preprocess_fn = self.get_preprocess_fn() image_decoded = preprocess_fn( image_string, is_training, image_size=self.image_size) image = tf.cast(image_decoded, tf.float32) return image, label dataset = dataset.map(_parse_function) dataset = dataset.batch(self.batch_size, drop_remainder=batch_drop_remainder) iterator = dataset.make_one_shot_iterator() images, labels = iterator.get_next() return images, labels
Example #23
Source File: inference_wrapper.py From SiamFC-tf with MIT License | 5 votes |
def build_inputs(self): # filename = tf.placeholder(tf.string, [], name='filename') # image_file = tf.read_file(filename) # image = tf.image.decode_jpeg(image_file, channels=3, dct_method="INTEGER_ACCURATE") image = tf.placeholder(tf.float32, shape=(None, None, 3), name="input") # image = tf.to_float(image) self.image = image self.target_bbox_feed = tf.placeholder(dtype=tf.float32, shape=[4], name='target_bbox_feed') # center's y, x, height, width
Example #24
Source File: image_utils.py From cv-tricks.com with MIT License | 5 votes |
def create_tensorflow_image_loader(session, expand_dims=True, options=None, run_metadata=None): """Tensorflow image loader Results seem to deviate quite a bit from yahoo image loader due to different jpeg encoders/decoders and different image resize implementations between PIL, skimage and tensorflow Only supports jpeg images. Relevant tensorflow issues: * https://github.com/tensorflow/tensorflow/issues/6720 * https://github.com/tensorflow/tensorflow/issues/12753 """ import tensorflow as tf def load_image(image_path): image = tf.read_file(image_path) image = __tf_jpeg_process(image) if expand_dims: image_batch = tf.expand_dims(image, axis=0) return session.run(image_batch, options=options, run_metadata=run_metadata) return session.run(image, options=options, run_metadata=run_metadata) return load_image
Example #25
Source File: utils.py From StarGAN-Tensorflow with MIT License | 5 votes |
def image_processing(self, filename, label, fix_label): x = tf.read_file(filename) x_decode = tf.image.decode_jpeg(x, channels=self.channels) img = tf.image.resize_images(x_decode, [self.load_size, self.load_size]) img = tf.cast(img, tf.float32) / 127.5 - 1 if self.augment_flag : augment_size = self.load_size + (30 if self.load_size == 256 else 15) p = random.random() if p > 0.5 : img = augmentation(img, augment_size) return img, label, fix_label
Example #26
Source File: data_loader.py From DeepMatchVO with MIT License | 5 votes |
def load_test_batch(self, image_sequence_names): """load a batch of test images for inference""" def _parse_test_img(img_path): with tf.device('/cpu:0'): img_buffer = tf.read_file(img_path) image_decoded = tf.image.decode_jpeg(img_buffer) return image_decoded image_dataset = tf.data.Dataset.from_tensor_slices(image_sequence_names).map( _parse_test_img).batch(self.batch_size).prefetch(self.batch_size*4) iterator = image_dataset.make_initializable_iterator() return iterator
Example #27
Source File: utils.py From UNIT-Tensorflow with MIT License | 5 votes |
def image_processing(self, filename): x = tf.read_file(filename) x_decode = tf.image.decode_jpeg(x, channels=self.channels) img = tf.image.resize_images(x_decode, [self.load_size, self.load_size]) img = tf.cast(img, tf.float32) / 127.5 - 1 if self.augment_flag : augment_size = self.load_size + (30 if self.load_size == 256 else 15) p = random.random() if p > 0.5: img = augmentation(img, augment_size) return img
Example #28
Source File: slim_classifier.py From CVTron with Apache License 2.0 | 5 votes |
def classify(self, img_file, model_name, model_path): labels_to_names = None if dataset_utils.has_labels(model_path, 'labels.txt'): labels_to_names = dataset_utils.read_label_file(model_path, 'labels.txt') else: tf.logging.error('No label map') return None keys = list(labels_to_names.keys()) with tf.Graph().as_default(): image_preprocessing_fn = preprocessing_factory.get_preprocessing( model_name, is_training=False) network_fn = nets_factory.get_network_fn( model_name, num_classes=len(keys), is_training=False) image_string = tf.read_file(img_file) image = tf.image.decode_jpeg(image_string, channels=3) processed_image = image_preprocessing_fn(image, network_fn.default_image_size, network_fn.default_image_size) image_expanded = tf.expand_dims(processed_image, axis=0) logits, _ = network_fn(image_expanded) probabilites = tf.nn.softmax(logits) predictions = tf.argmax(logits, 1) latest_checkpoint = tf.train.latest_checkpoint(model_path) init_fn = slim.assign_from_checkpoint_fn(latest_checkpoint, slim.get_model_variables(scope_map[model_name])) session_config = tf.ConfigProto() session_config.gpu_options.allow_growth = True with tf.Session(config=session_config) as sess: init_fn(sess) probs, pred = sess.run([probabilites, predictions]) result =[] for i in range(len(probs[0])): result.append({'type': labels_to_names[keys[i]], 'prob': str(probs[0][i])}) sorted_result = sorted(result, key=lambda k: float(k['prob']), reverse=True) return sorted_result
Example #29
Source File: data.py From generative-compression with MIT License | 5 votes |
def load_cGAN_dataset(image_paths, semantic_map_paths, batch_size, test=False, augment=False, downsample=False, training_dataset='cityscapes'): """ Load image dataset with semantic label maps for conditional GAN """ def _parser(image_path, semantic_map_path): def _aspect_preserving_width_resize(image, width=512): # If training on ADE20k height_i = tf.shape(image)[0] new_height = height_i - tf.floormod(height_i, 16) return tf.image.resize_image_with_crop_or_pad(image, new_height, width) def _image_decoder(path): im = tf.image.decode_png(tf.read_file(image_path), channels=3) im = tf.image.convert_image_dtype(im, dtype=tf.float32) return 2 * im - 1 # [0,1] -> [-1,1] (tanh range) image, semantic_map = _image_decoder(image_path), _image_decoder(semantic_map_path) print('Training on', training_dataset) if training_dataset is 'ADE20k': image = _aspect_preserving_width_resize(image) semantic_map = _aspect_preserving_width_resize(semantic_map) # im.set_shape([512,1024,3]) # downscaled cityscapes return image, semantic_map dataset = tf.data.Dataset.from_tensor_slices(image_paths, semantic_map_paths) dataset = dataset.map(_parser) dataset = dataset.shuffle(buffer_size=8) dataset = dataset.batch(batch_size) if test: dataset = dataset.repeat() return dataset
Example #30
Source File: data.py From generative-compression with MIT License | 5 votes |
def load_cGAN_dataset(image_paths, semantic_map_paths, batch_size, test=False, augment=False, downsample=False, training_dataset='cityscapes'): """ Load image dataset with semantic label maps for conditional GAN """ def _parser(image_path, semantic_map_path): def _aspect_preserving_width_resize(image, width=512): # If training on ADE20k height_i = tf.shape(image)[0] new_height = height_i - tf.floormod(height_i, 16) return tf.image.resize_image_with_crop_or_pad(image, new_height, width) def _image_decoder(path): im = tf.image.decode_png(tf.read_file(image_path), channels=3) im = tf.image.convert_image_dtype(im, dtype=tf.float32) return 2 * im - 1 # [0,1] -> [-1,1] (tanh range) image, semantic_map = _image_decoder(image_path), _image_decoder(semantic_map_path) print('Training on', training_dataset) if training_dataset is 'ADE20k': image = _aspect_preserving_width_resize(image) semantic_map = _aspect_preserving_width_resize(semantic_map) # im.set_shape([512,1024,3]) # downscaled cityscapes return image, semantic_map dataset = tf.data.Dataset.from_tensor_slices(image_paths, semantic_map_paths) dataset = dataset.map(_parser) dataset = dataset.shuffle(buffer_size=8) dataset = dataset.batch(batch_size) if test: dataset = dataset.repeat() return dataset