Python tensorflow.uint8() Examples
The following are 30
code examples of tensorflow.uint8().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: dataset.py From DNA-GAN with MIT License | 7 votes |
def parse_fn(self, serialized_example): features={ 'image/id_name': tf.FixedLenFeature([], tf.string), 'image/height' : tf.FixedLenFeature([], tf.int64), 'image/width' : tf.FixedLenFeature([], tf.int64), 'image/encoded': tf.FixedLenFeature([], tf.string), } for name in self.feature_list: features[name] = tf.FixedLenFeature([], tf.int64) example = tf.parse_single_example(serialized_example, features=features) image = tf.decode_raw(example['image/encoded'], tf.uint8) raw_height = tf.cast(example['image/height'], tf.int32) raw_width = tf.cast(example['image/width'], tf.int32) image = tf.reshape(image, [raw_height, raw_width, 3]) image = tf.image.resize_images(image, size=[self.height, self.width]) # from IPython import embed; embed(); exit() feature_val_list = [tf.cast(example[name], tf.float32) for name in self.feature_list] return image, feature_val_list
Example #2
Source File: categorical_calibration_layer.py From lattice with Apache License 2.0 | 6 votes |
def call(self, inputs): """Standard Keras call() method.""" if inputs.dtype not in [tf.uint8, tf.int32, tf.int64]: inputs = tf.cast(inputs, dtype=tf.int32) if self.default_input_value is not None: default_input_value_tensor = tf.constant( int(self.default_input_value), dtype=inputs.dtype, name=DEFAULT_INPUT_VALUE_NAME) replacement = tf.zeros_like(inputs) + (self.num_buckets - 1) inputs = tf.where( tf.equal(inputs, default_input_value_tensor), replacement, inputs) # We can't use tf.gather_nd(self.kernel, inputs) as it doesn't support # constraints (constraint functions are not supported for IndexedSlices). # Instead we use matrix multiplication by one-hot encoding of the index. if self.units == 1: # This can be slightly faster as it uses matmul. return tf.matmul( tf.one_hot(tf.squeeze(inputs, axis=[-1]), depth=self.num_buckets), self.kernel) return tf.reduce_sum( tf.one_hot(inputs, axis=1, depth=self.num_buckets) * self.kernel, axis=1)
Example #3
Source File: train.py From imgcomp-cvpr with GNU General Public License v3.0 | 6 votes |
def get_mse_per_img(inp, otp, cast_to_int): """ :param inp: NCHW :param otp: NCHW :param cast_to_int: if True, both inp and otp are casted to int32 before the error is calculated, to ensure real world errors (image pixels are always quantized). But the error is always casted back to float32 before a mean per image is calculated and returned :return: float32 tensor of shape (N,) """ with tf.name_scope('mse_{}'.format('int' if cast_to_int else 'float')): if cast_to_int: # Values are expected to be in 0...255, i.e., uint8, but tf.square does not support uint8's inp, otp = tf.cast(inp, tf.int32), tf.cast(otp, tf.int32) squared_error = tf.square(otp - inp) squared_error_float = tf.to_float(squared_error) mse_per_image = tf.reduce_mean(squared_error_float, axis=[1, 2, 3]) return mse_per_image
Example #4
Source File: exporter.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def _encoded_image_string_tensor_input_placeholder(): """Returns input that accepts a batch of PNG or JPEG strings. Returns: a tuple of input placeholder and the output decoded images. """ batch_image_str_placeholder = tf.placeholder( dtype=tf.string, shape=[None], name='encoded_image_string_tensor') def decode(encoded_image_string_tensor): image_tensor = tf.image.decode_image(encoded_image_string_tensor, channels=3) image_tensor.set_shape((None, None, 3)) return image_tensor return (batch_image_str_placeholder, tf.map_fn( decode, elems=batch_image_str_placeholder, dtype=tf.uint8, parallel_iterations=32, back_prop=False))
Example #5
Source File: exporter.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def _tf_example_input_placeholder(): """Returns input that accepts a batch of strings with tf examples. Returns: a tuple of input placeholder and the output decoded images. """ batch_tf_example_placeholder = tf.placeholder( tf.string, shape=[None], name='tf_example') def decode(tf_example_string_tensor): tensor_dict = tf_example_decoder.TfExampleDecoder().decode( tf_example_string_tensor) image_tensor = tensor_dict[fields.InputDataFields.image] return image_tensor return (batch_tf_example_placeholder, shape_utils.static_or_dynamic_map_fn( decode, elems=batch_tf_example_placeholder, dtype=tf.uint8, parallel_iterations=32, back_prop=False))
Example #6
Source File: detection_inference_test.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def create_mock_graph(): g = tf.Graph() with g.as_default(): in_image_tensor = tf.placeholder( tf.uint8, shape=[1, None, None, 3], name='image_tensor') tf.constant([2.0], name='num_detections') tf.constant( [[[0, 0.8, 0.7, 1], [0.1, 0.2, 0.8, 0.9], [0.2, 0.3, 0.4, 0.5]]], name='detection_boxes') tf.constant([[0.1, 0.2, 0.3]], name='detection_scores') tf.identity( tf.constant([[1.0, 2.0, 3.0]]) * tf.reduce_sum(tf.cast(in_image_tensor, dtype=tf.float32)), name='detection_classes') graph_def = g.as_graph_def() with tf.gfile.Open(get_mock_graph_path(), 'w') as fl: fl.write(graph_def.SerializeToString())
Example #7
Source File: visualization_utils.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def draw_keypoints_on_image_array(image, keypoints, color='red', radius=2, use_normalized_coordinates=True): """Draws keypoints on an image (numpy array). Args: image: a numpy array with shape [height, width, 3]. keypoints: a numpy array with shape [num_keypoints, 2]. color: color to draw the keypoints with. Default is red. radius: keypoint radius. Default value is 2. use_normalized_coordinates: if True (default), treat keypoint values as relative to the image. Otherwise treat them as absolute. """ image_pil = Image.fromarray(np.uint8(image)).convert('RGB') draw_keypoints_on_image(image_pil, keypoints, color, radius, use_normalized_coordinates) np.copyto(image, np.array(image_pil))
Example #8
Source File: detection_inference_test.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def create_mock_tfrecord(): pil_image = Image.fromarray(np.array([[[123, 0, 0]]], dtype=np.uint8), 'RGB') image_output_stream = StringIO.StringIO() pil_image.save(image_output_stream, format='png') encoded_image = image_output_stream.getvalue() feature_map = { 'test_field': dataset_util.float_list_feature([1, 2, 3, 4]), standard_fields.TfExampleFields.image_encoded: dataset_util.bytes_feature(encoded_image), } tf_example = tf.train.Example(features=tf.train.Features(feature=feature_map)) with tf.python_io.TFRecordWriter(get_mock_tfrecord_path()) as writer: writer.write(tf_example.SerializeToString())
Example #9
Source File: common_layers.py From fine-lm with MIT License | 6 votes |
def summarize_video(video, prefix, max_outputs=1): """Summarize the video using image summaries starting with prefix.""" video_shape = shape_list(video) if len(video_shape) != 5: raise ValueError("Assuming videos given as tensors in the format " "[batch, time, height, width, channels] but got one " "of shape: %s" % str(video_shape)) if tf.contrib.eager.in_eager_mode(): return if video.get_shape().as_list()[1] is None: tf.summary.image( "%s_last_frame" % prefix, tf.cast(video[:, -1, :, :, :], tf.uint8), max_outputs=max_outputs) else: for k in range(video_shape[1]): tf.summary.image( "%s_frame_%d" % (prefix, k), tf.cast(video[:, k, :, :, :], tf.uint8), max_outputs=max_outputs)
Example #10
Source File: visualization_utils_test.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def create_colorful_test_image(self): """This function creates an image that can be used to test vis functions. It makes an image composed of four colored rectangles. Returns: colorful test numpy array image. """ ch255 = np.full([100, 200, 1], 255, dtype=np.uint8) ch128 = np.full([100, 200, 1], 128, dtype=np.uint8) ch0 = np.full([100, 200, 1], 0, dtype=np.uint8) imr = np.concatenate((ch255, ch128, ch128), axis=2) img = np.concatenate((ch255, ch255, ch0), axis=2) imb = np.concatenate((ch255, ch0, ch255), axis=2) imw = np.concatenate((ch128, ch128, ch128), axis=2) imu = np.concatenate((imr, img), axis=1) imd = np.concatenate((imb, imw), axis=1) image = np.concatenate((imu, imd), axis=0) return image
Example #11
Source File: metrics.py From fine-lm with MIT License | 6 votes |
def image_summary(predictions, targets, hparams): """Reshapes predictions and passes it to tensorboard. Args: predictions : The predicted image (logits). targets : The ground truth. hparams: model hparams. Returns: summary_proto: containing the summary images. weights: A Tensor of zeros of the same shape as predictions. """ del hparams results = tf.cast(tf.argmax(predictions, axis=-1), tf.uint8) gold = tf.cast(targets, tf.uint8) summary1 = tf.summary.image("prediction", results, max_outputs=2) summary2 = tf.summary.image("data", gold, max_outputs=2) summary = tf.summary.merge([summary1, summary2]) return summary, tf.zeros_like(predictions)
Example #12
Source File: freeze_model.py From deep_sort with GNU General Public License v3.0 | 6 votes |
def main(): args = parse_args() with tf.Session(graph=tf.Graph()) as session: input_var = tf.placeholder( tf.uint8, (None, 128, 64, 3), name="images") image_var = tf.map_fn( lambda x: _preprocess(x), tf.cast(input_var, tf.float32), back_prop=False) factory_fn = _network_factory() features, _ = factory_fn(image_var, reuse=None) features = tf.identity(features, name="features") saver = tf.train.Saver(slim.get_variables_to_restore()) saver.restore(session, args.checkpoint_in) output_graph_def = tf.graph_util.convert_variables_to_constants( session, tf.get_default_graph().as_graph_def(), [features.name.split(":")[0]]) with tf.gfile.GFile(args.graphdef_out, "wb") as file_handle: file_handle.write(output_graph_def.SerializeToString())
Example #13
Source File: vfn_train.py From view-finding-network with GNU General Public License v3.0 | 6 votes |
def read_and_decode_aug(filename_queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, # Defaults are not specified since both keys are required. features={ 'image_raw': tf.FixedLenFeature([], tf.string), }) image = tf.decode_raw(features['image_raw'], tf.uint8) image = tf.image.random_flip_left_right(tf.reshape(image, [227, 227, 6])) # Convert from [0, 255] -> [-0.5, 0.5] floats. image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 image = tf.image.random_brightness(image, 0.01) image = tf.image.random_contrast(image, 0.95, 1.05) return tf.split(image, 2, 2) # 3rd dimension two parts
Example #14
Source File: vfn_train.py From view-finding-network with GNU General Public License v3.0 | 6 votes |
def read_and_decode(filename_queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, # Defaults are not specified since both keys are required. features={ 'image_raw': tf.FixedLenFeature([], tf.string), }) image = tf.decode_raw(features['image_raw'], tf.uint8) image = tf.reshape(image, [227, 227, 6]) # Convert from [0, 255] -> [-0.5, 0.5] floats. image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 return tf.split(image, 2, 2) # 3rd dimension two parts
Example #15
Source File: visualization_utils_test.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def test_draw_bounding_boxes_on_image_tensors_with_additional_channels(self): """Tests the case where input image tensor has more than 3 channels.""" category_index = {1: {'id': 1, 'name': 'dog'}} image_np = self.create_test_image_with_five_channels() images_np = np.stack((image_np, image_np), axis=0) with tf.Graph().as_default(): images_tensor = tf.constant(value=images_np, dtype=tf.uint8) boxes = tf.constant(0, dtype=tf.float32, shape=[2, 0, 4]) classes = tf.constant(0, dtype=tf.int64, shape=[2, 0]) scores = tf.constant(0, dtype=tf.float32, shape=[2, 0]) images_with_boxes = ( visualization_utils.draw_bounding_boxes_on_image_tensors( images_tensor, boxes, classes, scores, category_index, min_score_thresh=0.2)) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) final_images_np = sess.run(images_with_boxes) self.assertEqual((2, 100, 200, 3), final_images_np.shape)
Example #16
Source File: visualization_utils.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def draw_keypoints_on_image_array(image, keypoints, color='red', radius=2, use_normalized_coordinates=True): """Draws keypoints on an image (numpy array). Args: image: a numpy array with shape [height, width, 3]. keypoints: a numpy array with shape [num_keypoints, 2]. color: color to draw the keypoints with. Default is red. radius: keypoint radius. Default value is 2. use_normalized_coordinates: if True (default), treat keypoint values as relative to the image. Otherwise treat them as absolute. """ image_pil = Image.fromarray(np.uint8(image)).convert('RGB') draw_keypoints_on_image(image_pil, keypoints, color, radius, use_normalized_coordinates) np.copyto(image, np.array(image_pil))
Example #17
Source File: nclt.py From hierarchical_loc with BSD 3-Clause "New" or "Revised" License | 6 votes |
def __init__(self, fin, scale=1.0, fmask=None): self.fin = fin # read in distort with open(fin, 'r') as f: header = f.readline().rstrip() chunks = re.sub(r'[^0-9,]', '', header).split(',') self.mapu = np.zeros((int(chunks[1]), int(chunks[0])), dtype=np.float32) self.mapv = np.zeros((int(chunks[1]), int(chunks[0])), dtype=np.float32) for line in f.readlines(): chunks = line.rstrip().split(' ') self.mapu[int(chunks[0]), int(chunks[1])] = float(chunks[3]) self.mapv[int(chunks[0]), int(chunks[1])] = float(chunks[2]) # generate a mask self.mask = np.ones(self.mapu.shape, dtype=np.uint8) self.mask = cv2.remap(self.mask, self.mapu, self.mapv, cv2.INTER_LINEAR) kernel = np.ones((30, 30), np.uint8) self.mask = cv2.erode(self.mask, kernel, iterations=1) # crop black regions out h, w = self.mask.shape self.x_lim = [f(np.where(self.mask[int(h/2), :])[0]) for f in [np.min, np.max]] self.y_lim = [f(np.where(self.mask[:, int(w/2)])[0]) for f in [np.min, np.max]]
Example #18
Source File: download_and_convert_mnist.py From DOTA_models with Apache License 2.0 | 6 votes |
def _extract_labels(filename, num_labels): """Extract the labels into a vector of int64 label IDs. Args: filename: The path to an MNIST labels file. num_labels: The number of labels in the file. Returns: A numpy array of shape [number_of_labels] """ print('Extracting labels from: ', filename) with gzip.open(filename) as bytestream: bytestream.read(8) buf = bytestream.read(1 * num_labels) labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64) return labels
Example #19
Source File: download_and_convert_mnist.py From DOTA_models with Apache License 2.0 | 6 votes |
def _extract_images(filename, num_images): """Extract the images into a numpy array. Args: filename: The path to an MNIST images file. num_images: The number of images in the file. Returns: A numpy array of shape [number_of_images, height, width, channels]. """ print('Extracting images from: ', filename) with gzip.open(filename) as bytestream: bytestream.read(16) buf = bytestream.read( _IMAGE_SIZE * _IMAGE_SIZE * num_images * _NUM_CHANNELS) data = np.frombuffer(buf, dtype=np.uint8) data = data.reshape(num_images, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS) return data
Example #20
Source File: tensor.py From spleeter with MIT License | 6 votes |
def from_float32_to_uint8( tensor, tensor_key='tensor', min_key='min', max_key='max'): """ :param tensor: :param tensor_key: :param min_key: :param max_key: :returns: """ tensor_min = tf.reduce_min(tensor) tensor_max = tf.reduce_max(tensor) return { tensor_key: tf.cast( (tensor - tensor_min) / (tensor_max - tensor_min + 1e-16) * 255.9999, dtype=tf.uint8), min_key: tensor_min, max_key: tensor_max }
Example #21
Source File: dataset.py From disentangling_conditional_gans with MIT License | 6 votes |
def __init__(self, resolution=1024, num_channels=3, dtype='uint8', dynamic_range=[0,255], label_size=0, label_dtype='float32'): self.resolution = resolution self.resolution_log2 = int(np.log2(resolution)) self.shape = [num_channels, resolution, resolution] self.dtype = dtype self.dynamic_range = dynamic_range self.label_size = label_size self.label_dtype = label_dtype self._tf_minibatch_var = None self._tf_lod_var = None self._tf_minibatch_np = None self._tf_labels_np = None assert self.resolution == 2 ** self.resolution_log2 with tf.name_scope('Dataset'): self._tf_minibatch_var = tf.Variable(np.int32(0), name='minibatch_var') self._tf_lod_var = tf.Variable(np.int32(0), name='lod_var')
Example #22
Source File: visualization_utils.py From vehicle_counting_tensorflow with MIT License | 5 votes |
def add_cdf_image_summary(values, name): """Adds a tf.summary.image for a CDF plot of the values. Normalizes `values` such that they sum to 1, plots the cumulative distribution function and creates a tf image summary. Args: values: a 1-D float32 tensor containing the values. name: name for the image summary. """ def cdf_plot(values): """Numpy function to plot CDF.""" normalized_values = values / np.sum(values) sorted_values = np.sort(normalized_values) cumulative_values = np.cumsum(sorted_values) fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32) / cumulative_values.size) fig = plt.figure(frameon=False) ax = fig.add_subplot('111') ax.plot(fraction_of_examples, cumulative_values) ax.set_ylabel('cumulative normalized values') ax.set_xlabel('fraction of examples') fig.canvas.draw() width, height = fig.get_size_inches() * fig.get_dpi() image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape( 1, height, width, 3) return image cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8) tf.summary.image(name, cdf_plot)
Example #23
Source File: visualization_utils.py From vehicle_counting_tensorflow with MIT License | 5 votes |
def draw_bounding_box_on_image_array(current_frame_number, image, ymin, xmin, ymax, xmax, color='red', thickness=4, display_str_list=(), use_normalized_coordinates=True): """Adds a bounding box to an image (numpy array). Args: image: a numpy array with shape [height, width, 3]. ymin: ymin of bounding box in normalized coordinates (same below). xmin: xmin of bounding box. ymax: ymax of bounding box. xmax: xmax of bounding box. color: color to draw bounding box. Default is red. thickness: line thickness. Default value is 4. display_str_list: list of strings to display in box (each to be shown on its own line). use_normalized_coordinates: If True (default), treat coordinates ymin, xmin, ymax, xmax as relative to the image. Otherwise treat coordinates as absolute. """ image_pil = Image.fromarray(np.uint8(image)).convert('RGB') is_vehicle_detected, csv_line, update_csv = draw_bounding_box_on_image(current_frame_number,image_pil, ymin, xmin, ymax, xmax, color, thickness, display_str_list, use_normalized_coordinates) np.copyto(image, np.array(image_pil)) return is_vehicle_detected, csv_line, update_csv
Example #24
Source File: visualization_utils.py From vehicle_counting_tensorflow with MIT License | 5 votes |
def encode_image_array_as_png_str(image): """Encodes a numpy array into a PNG string. Args: image: a numpy array with shape [height, width, 3]. Returns: PNG encoded image string. """ image_pil = Image.fromarray(np.uint8(image)) output = six.BytesIO() image_pil.save(output, format='PNG') png_string = output.getvalue() output.close() return png_string
Example #25
Source File: eval_util.py From vehicle_counting_tensorflow with MIT License | 5 votes |
def _resize_detection_masks(args): detection_boxes, detection_masks, image_shape = args detection_masks_reframed = ops.reframe_box_masks_to_image_masks( detection_masks, detection_boxes, image_shape[0], image_shape[1]) return tf.cast(tf.greater(detection_masks_reframed, 0.5), tf.uint8)
Example #26
Source File: eval_util.py From vehicle_counting_tensorflow with MIT License | 5 votes |
def _resize_groundtruth_masks(args): mask, image_shape = args mask = tf.expand_dims(mask, 3) mask = tf.image.resize_images( mask, image_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=True) return tf.cast(tf.squeeze(mask, 3), tf.uint8)
Example #27
Source File: visualization_utils.py From vehicle_counting_tensorflow with MIT License | 5 votes |
def save_image_array_as_png(image, output_path): """Saves an image (represented as a numpy array) to PNG. Args: image: a numpy array with shape [height, width, 3]. output_path: path to which image should be written. """ image_pil = Image.fromarray(np.uint8(image)).convert('RGB') with tf.gfile.Open(output_path, 'w') as fid: image_pil.save(fid, 'PNG')
Example #28
Source File: train.py From tensorflow-data with MIT License | 5 votes |
def parser(record): keys_to_features = { "image_raw": tf.FixedLenFeature([], tf.string), "label": tf.FixedLenFeature([], tf.int64) } parsed = tf.parse_single_example(record, keys_to_features) image = tf.decode_raw(parsed["image_raw"], tf.uint8) image = tf.cast(image, tf.float32) #image = tf.reshape(image, shape=[224, 224, 3]) label = tf.cast(parsed["label"], tf.int32) return {'image': image}, label
Example #29
Source File: visualization_utils.py From vehicle_counting_tensorflow with MIT License | 5 votes |
def add_cdf_image_summary(values, name): """Adds a tf.summary.image for a CDF plot of the values. Normalizes `values` such that they sum to 1, plots the cumulative distribution function and creates a tf image summary. Args: values: a 1-D float32 tensor containing the values. name: name for the image summary. """ def cdf_plot(values): """Numpy function to plot CDF.""" normalized_values = values / np.sum(values) sorted_values = np.sort(normalized_values) cumulative_values = np.cumsum(sorted_values) fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32) / cumulative_values.size) fig = plt.figure(frameon=False) ax = fig.add_subplot('111') ax.plot(fraction_of_examples, cumulative_values) ax.set_ylabel('cumulative normalized values') ax.set_xlabel('fraction of examples') fig.canvas.draw() width, height = fig.get_size_inches() * fig.get_dpi() image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape( 1, int(height), int(width), 3) return image cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8) tf.summary.image(name, cdf_plot)
Example #30
Source File: inputs_test.py From vehicle_counting_tensorflow with MIT License | 5 votes |
def test_returns_resized_masks(self): tensor_dict = { fields.InputDataFields.image: tf.constant(np.random.rand(4, 4, 3).astype(np.float32)), fields.InputDataFields.groundtruth_instance_masks: tf.constant(np.random.rand(2, 4, 4).astype(np.float32)), fields.InputDataFields.groundtruth_classes: tf.constant(np.array([3, 1], np.int32)), fields.InputDataFields.original_image_spatial_shape: tf.constant(np.array([4, 4], np.int32)) } def fake_image_resizer_fn(image, masks=None): resized_image = tf.image.resize_images(image, [8, 8]) results = [resized_image] if masks is not None: resized_masks = tf.transpose( tf.image.resize_images(tf.transpose(masks, [1, 2, 0]), [8, 8]), [2, 0, 1]) results.append(resized_masks) results.append(tf.shape(resized_image)) return results num_classes = 3 input_transformation_fn = functools.partial( inputs.transform_input_data, model_preprocess_fn=_fake_model_preprocessor_fn, image_resizer_fn=fake_image_resizer_fn, num_classes=num_classes, retain_original_image=True) with self.test_session() as sess: transformed_inputs = sess.run( input_transformation_fn(tensor_dict=tensor_dict)) self.assertAllEqual(transformed_inputs[ fields.InputDataFields.original_image].dtype, tf.uint8) self.assertAllEqual(transformed_inputs[ fields.InputDataFields.original_image_spatial_shape], [4, 4]) self.assertAllEqual(transformed_inputs[ fields.InputDataFields.original_image].shape, [8, 8, 3]) self.assertAllEqual(transformed_inputs[ fields.InputDataFields.groundtruth_instance_masks].shape, [2, 8, 8])