Python tensorflow.compat.v1.uint8() Examples

The following are 30 code examples of tensorflow.compat.v1.uint8(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.compat.v1 , or try the search function .
Example #1
Source File: visualization_utils.py    From models with Apache License 2.0 6 votes vote down vote up
def draw_heatmaps_on_image(image, heatmaps):
  """Draws heatmaps on an image.

  The heatmaps are handled channel by channel and different colors are used to
  paint different heatmap channels.

  Args:
    image: a PIL.Image object.
    heatmaps: a numpy array with shape [image_height, image_width, channel].
      Note that the image_height and image_width should match the size of input
      image.
  """
  draw = ImageDraw.Draw(image)
  channel = heatmaps.shape[2]
  for c in range(channel):
    heatmap = heatmaps[:, :, c] * 255
    heatmap = heatmap.astype('uint8')
    bitmap = Image.fromarray(heatmap, 'L')
    bitmap.convert('1')
    draw.bitmap(
        xy=[(0, 0)],
        bitmap=bitmap,
        fill=STANDARD_COLORS[c]) 
Example #2
Source File: gym_env.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def __init__(self, batch_size, *args, **kwargs):
    self._store_rollouts = kwargs.pop("store_rollouts", True)

    super(T2TEnv, self).__init__(*args, **kwargs)

    self.batch_size = batch_size
    self._rollouts_by_epoch_and_split = collections.OrderedDict()
    self.current_epoch = None
    self._should_preprocess_on_reset = True
    with tf.Graph().as_default() as tf_graph:
      self._tf_graph = _Noncopyable(tf_graph)
      self._decoded_image_p = _Noncopyable(
          tf.placeholder(dtype=tf.uint8, shape=(None, None, None))
      )
      self._encoded_image_t = _Noncopyable(
          tf.image.encode_png(self._decoded_image_p.obj)
      )
      self._encoded_image_p = _Noncopyable(tf.placeholder(tf.string))
      self._decoded_image_t = _Noncopyable(
          tf.image.decode_png(self._encoded_image_p.obj)
      )
      self._session = _Noncopyable(tf.Session()) 
Example #3
Source File: preprocessing.py    From benchmarks with Apache License 2.0 6 votes vote down vote up
def preprocess(self, image_buffer, bbox, batch_position):
    """Preprocessing image_buffer as a function of its batch position."""
    if self.train:
      image = train_image(image_buffer, self.height, self.width, bbox,
                          batch_position, self.resize_method, self.distortions,
                          None, summary_verbosity=self.summary_verbosity,
                          distort_color_in_yiq=self.distort_color_in_yiq,
                          fuse_decode_and_crop=self.fuse_decode_and_crop)
    else:
      image = tf.image.decode_jpeg(
          image_buffer, channels=3, dct_method='INTEGER_FAST')
      image = eval_image(image, self.height, self.width, batch_position,
                         self.resize_method,
                         summary_verbosity=self.summary_verbosity)
    # Note: image is now float32 [height,width,3] with range [0, 255]

    # image = tf.cast(image, tf.uint8) # HACK TESTING

    if self.match_mlperf:
      mlperf.logger.log(key=mlperf.tags.INPUT_MEAN_SUBTRACTION,
                        value=_CHANNEL_MEANS)
      normalized = image - _CHANNEL_MEANS
    else:
      normalized = normalized_image(image)
    return tf.cast(normalized, self.dtype) 
Example #4
Source File: metrics.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def image_summary(predictions, targets, hparams):
  """Reshapes predictions and passes it to tensorboard.

  Args:
    predictions : The predicted image (logits).
    targets : The ground truth.
    hparams: model hparams.

  Returns:
    summary_proto: containing the summary images.
    weights: A Tensor of zeros of the same shape as predictions.
  """
  del hparams
  results = tf.cast(tf.argmax(predictions, axis=-1), tf.uint8)
  gold = tf.cast(targets, tf.uint8)
  summary1 = tf.summary.image("prediction", results, max_outputs=2)
  summary2 = tf.summary.image("data", gold, max_outputs=2)
  summary = tf.summary.merge([summary1, summary2])
  return summary, tf.zeros_like(predictions) 
Example #5
Source File: common_video.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def _encode_gif(images, fps):
  """Encodes numpy images into gif string.

  Args:
    images: A 4-D `uint8` `np.array` (or a list of 3-D images) of shape
      `[time, height, width, channels]` where `channels` is 1 or 3.
    fps: frames per second of the animation

  Returns:
    The encoded gif string.

  Raises:
    IOError: If the ffmpeg command returns an error.
  """
  writer = WholeVideoWriter(fps)
  writer.write_multi(images)
  return writer.finish() 
Example #6
Source File: common_layers.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def tpu_safe_image_summary(image):
  if is_xla_compiled():
    # We only support float32 images at the moment due to casting complications.
    if image.dtype != tf.float32:
      image = to_float(image)
  else:
    image = tf.cast(image, tf.uint8)
  return image


# This has been (shamefully) copied from
# GitHub tensorflow/models/blob/master/research/slim/nets/cyclegan.py
#
# tensorflow/models cannot be pip installed, and even if it were we don't want
# to depend on all the models in it.
#
# Therefore copying and forgoing any more bugfixes into it is the most
# expedient way to use this function. 
Example #7
Source File: image_utils.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def image_to_tf_summary_value(image, tag):
  """Converts a NumPy image to a tf.Summary.Value object.

  Args:
    image: 3-D NumPy array.
    tag: name for tf.Summary.Value for display in tensorboard.
  Returns:
    image_summary: A tf.Summary.Value object.
  """
  curr_image = np.asarray(image, dtype=np.uint8)
  height, width, n_channels = curr_image.shape
  # If monochrome image, then reshape to [height, width]
  if n_channels == 1:
    curr_image = np.reshape(curr_image, [height, width])
  s = io.BytesIO()
  matplotlib_pyplot().imsave(s, curr_image, format="png")
  img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
                             height=height, width=width,
                             colorspace=n_channels)
  return tf.Summary.Value(tag=tag, image=img_sum) 
Example #8
Source File: download_and_convert_mnist.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def _extract_images(filename, num_images):
  """Extract the images into a numpy array.

  Args:
    filename: The path to an MNIST images file.
    num_images: The number of images in the file.

  Returns:
    A numpy array of shape [number_of_images, height, width, channels].
  """
  print('Extracting images from: ', filename)
  with gzip.open(filename) as bytestream:
    bytestream.read(16)
    buf = bytestream.read(
        _IMAGE_SIZE * _IMAGE_SIZE * num_images * _NUM_CHANNELS)
    data = np.frombuffer(buf, dtype=np.uint8)
    data = data.reshape(num_images, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
  return data 
Example #9
Source File: download_and_convert_mnist.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def _extract_labels(filename, num_labels):
  """Extract the labels into a vector of int64 label IDs.

  Args:
    filename: The path to an MNIST labels file.
    num_labels: The number of labels in the file.

  Returns:
    A numpy array of shape [number_of_labels]
  """
  print('Extracting labels from: ', filename)
  with gzip.open(filename) as bytestream:
    bytestream.read(8)
    buf = bytestream.read(1 * num_labels)
    labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
  return labels 
Example #10
Source File: tf_example_decoder.py    From Object_Detection_Tracking with Apache License 2.0 6 votes vote down vote up
def _decode_masks(self, parsed_tensors):
    """Decode a set of PNG masks to the tf.float32 tensors."""
    def _decode_png_mask(png_bytes):
      mask = tf.squeeze(
          tf.io.decode_png(png_bytes, channels=1, dtype=tf.uint8), axis=-1)
      mask = tf.cast(mask, dtype=tf.float32)
      mask.set_shape([None, None])
      return mask

    height = parsed_tensors['image/height']
    width = parsed_tensors['image/width']
    masks = parsed_tensors['image/object/mask']
    return tf.cond(
        tf.greater(tf.size(masks), 0),
        lambda: tf.map_fn(_decode_png_mask, masks, dtype=tf.float32),
        lambda: tf.zeros([0, height, width], dtype=tf.float32)) 
Example #11
Source File: robust_model.py    From interval-bound-propagation with Apache License 2.0 6 votes vote down vote up
def parse(data_dict):
  """Parse dataset from _data_gen into the same format as sst_binary."""
  sentiment = data_dict['label']
  sentence = data_dict['sentence']
  dense_chars = tf.decode_raw(sentence, tf.uint8)
  dense_chars.set_shape((None,))
  chars = tfp.math.dense_to_sparse(dense_chars)
  if six.PY3:
    safe_chr = lambda c: '?' if c >= 128 else chr(c)
  else:
    safe_chr = chr
  to_char = np.vectorize(safe_chr)
  chars = tf.SparseTensor(indices=chars.indices,
                          values=tf.py_func(to_char, [chars.values], tf.string),
                          dense_shape=chars.dense_shape)
  return {'sentiment': sentiment,
          'sentence': chars} 
Example #12
Source File: tensorspec_utils_test.py    From tensor2robot with Apache License 2.0 6 votes vote down vote up
def test_pad_image_tensor_to_spec_shape(self):
    varlen_spec = utils.ExtendedTensorSpec(
        shape=(3, 2, 2, 1),
        dtype=tf.uint8,
        name='varlen',
        data_format='png',
        varlen_default_value=3.0)
    test_data = [[
        [[[1]] * 2] * 2,
        [[[2]] * 2] * 2,
    ]]
    prepadded_tensor = tf.convert_to_tensor(test_data, dtype=varlen_spec.dtype)
    tensor = utils.pad_or_clip_tensor_to_spec_shape(prepadded_tensor,
                                                    varlen_spec)
    with self.session() as sess:
      np_tensor = sess.run(tensor)
      self.assertAllEqual(
          np_tensor,
          np.array([[
              [[[1]] * 2] * 2,
              [[[2]] * 2] * 2,
              [[[3]] * 2] * 2,
          ]])) 
Example #13
Source File: tfrecord_image_generator.py    From benchmarks with Apache License 2.0 6 votes vote down vote up
def _process_image(coder, name):
  """Process a single image file.

  If name is "train", a black image is returned. Otherwise, a white image is
  returned.

  Args:
    coder: instance of ImageCoder to provide TensorFlow image coding utils.
    name: string, unique identifier specifying the data set.
  Returns:
    image_buffer: bytes, JPEG encoding of RGB image.
    height: integer, image height in pixels.
    width: integer, image width in pixels.
  """
  # Read the image file.
  value = 0 if name == 'train' else 255
  height = random.randint(30, 299)
  width = random.randint(30, 299)
  image = np.full((height, width, 3), value, np.uint8)

  jpeg_data = coder.encode_jpeg(image)

  return jpeg_data, height, width 
Example #14
Source File: exporter.py    From models with Apache License 2.0 6 votes vote down vote up
def _tf_example_input_placeholder(input_shape=None):
  """Returns input that accepts a batch of strings with tf examples.

  Args:
    input_shape: the shape to resize the output decoded images to (optional).

  Returns:
    a tuple of input placeholder and the output decoded images.
  """
  batch_tf_example_placeholder = tf.placeholder(
      tf.string, shape=[None], name='tf_example')
  def decode(tf_example_string_tensor):
    tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
        tf_example_string_tensor)
    image_tensor = tensor_dict[fields.InputDataFields.image]
    if input_shape is not None:
      image_tensor = tf.image.resize(image_tensor, input_shape[1:3])
    return image_tensor
  return (batch_tf_example_placeholder,
          shape_utils.static_or_dynamic_map_fn(
              decode,
              elems=batch_tf_example_placeholder,
              dtype=tf.uint8,
              parallel_iterations=32,
              back_prop=False)) 
Example #15
Source File: detection_inference_tf1_test.py    From models with Apache License 2.0 6 votes vote down vote up
def create_mock_graph():
  g = tf.Graph()
  with g.as_default():
    in_image_tensor = tf.placeholder(
        tf.uint8, shape=[1, None, None, 3], name='image_tensor')
    tf.constant([2.0], name='num_detections')
    tf.constant(
        [[[0, 0.8, 0.7, 1], [0.1, 0.2, 0.8, 0.9], [0.2, 0.3, 0.4, 0.5]]],
        name='detection_boxes')
    tf.constant([[0.1, 0.2, 0.3]], name='detection_scores')
    tf.identity(
        tf.constant([[1.0, 2.0, 3.0]]) *
        tf.reduce_sum(tf.cast(in_image_tensor, dtype=tf.float32)),
        name='detection_classes')
    graph_def = g.as_graph_def()

  with tf.gfile.Open(get_mock_graph_path(), 'w') as fl:
    fl.write(graph_def.SerializeToString()) 
Example #16
Source File: detection_inference_tf1_test.py    From models with Apache License 2.0 6 votes vote down vote up
def create_mock_tfrecord():
  pil_image = Image.fromarray(np.array([[[123, 0, 0]]], dtype=np.uint8), 'RGB')
  image_output_stream = six.BytesIO()
  pil_image.save(image_output_stream, format='png')
  encoded_image = image_output_stream.getvalue()

  feature_map = {
      'test_field':
          dataset_util.float_list_feature([1, 2, 3, 4]),
      standard_fields.TfExampleFields.image_encoded:
          dataset_util.bytes_feature(encoded_image),
  }

  tf_example = tf.train.Example(features=tf.train.Features(feature=feature_map))
  with tf.python_io.TFRecordWriter(get_mock_tfrecord_path()) as writer:
    writer.write(tf_example.SerializeToString())
  return encoded_image 
Example #17
Source File: download_and_convert_mnist.py    From models with Apache License 2.0 6 votes vote down vote up
def _extract_labels(filename, num_labels):
  """Extract the labels into a vector of int64 label IDs.

  Args:
    filename: The path to an MNIST labels file.
    num_labels: The number of labels in the file.

  Returns:
    A numpy array of shape [number_of_labels]
  """
  print('Extracting labels from: ', filename)
  with gzip.open(filename) as bytestream:
    bytestream.read(8)
    buf = bytestream.read(1 * num_labels)
    labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
  return labels 
Example #18
Source File: download_and_convert_mnist.py    From models with Apache License 2.0 6 votes vote down vote up
def _extract_images(filename, num_images):
  """Extract the images into a numpy array.

  Args:
    filename: The path to an MNIST images file.
    num_images: The number of images in the file.

  Returns:
    A numpy array of shape [number_of_images, height, width, channels].
  """
  print('Extracting images from: ', filename)
  with gzip.open(filename) as bytestream:
    bytestream.read(16)
    buf = bytestream.read(
        _IMAGE_SIZE * _IMAGE_SIZE * num_images * _NUM_CHANNELS)
    data = np.frombuffer(buf, dtype=np.uint8)
    data = data.reshape(num_images, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
  return data 
Example #19
Source File: char_utils.py    From language with Apache License 2.0 6 votes vote down vote up
def word_to_char_ids(word, word_length):
  """Convert a string to a padded vector of character ids.

  If the true length of the word is less than `word_length`, padding is added.
  If the true length of the word is greater than `word_length`, additional bytes
  are ignored.

  Args:
    word: <string> []
    word_length: Number of bytes to include per word.

  Returns:
    char_ids: <int32> [word_length]
  """
  char_ids = tf.to_int32(tf.decode_raw(word, tf.uint8))[:word_length - 2]
  padding = tf.fill([word_length - tf.shape(char_ids)[0] - 2], PAD_CHAR)
  char_ids = tf.concat([[BOW_CHAR], char_ids, [EOW_CHAR], padding], 0)
  char_ids.set_shape([word_length])
  return char_ids 
Example #20
Source File: inference.py    From PINTO_model_zoo with MIT License 6 votes vote down vote up
def serve_images(self, image_arrays):
    """Serve a list of image arrays.

    Args:
      image_arrays: A list of image content with each image has shape [height,
        width, 3] and uint8 type.

    Returns:
      A list of detections.
    """
    if not self.sess:
      self.build()
    predictions = self.sess.run(
        self.signitures['prediction'],
        feed_dict={self.signitures['image_arrays']: image_arrays})
    return predictions 
Example #21
Source File: distri.py    From nni with MIT License 6 votes vote down vote up
def neglogp(self, x):
        """
        return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
        Note: we can't use sparse_softmax_cross_entropy_with_logits because
              the implementation does not allow second-order derivatives...
        """
        if x.dtype in {tf.uint8, tf.int32, tf.int64}:
            # one-hot encoding
            x_shape_list = x.shape.as_list()
            logits_shape_list = self.logits.get_shape().as_list()[:-1]
            for xs, ls in zip(x_shape_list, logits_shape_list):
                if xs is not None and ls is not None:
                    assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls)

            x = tf.one_hot(x, self.logits.get_shape().as_list()[-1])
        else:
            # already encoded
            assert x.shape.as_list() == self.logits.shape.as_list()

        return tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=self.logits,
            labels=x) 
Example #22
Source File: t2r_models.py    From tensor2robot with Apache License 2.0 6 votes vote down vote up
def _transform_in_feature_specification(
      self, tensor_spec_struct
  ):
    """The specification for the input features for the preprocess_fn.

    Here we will transform the feature spec to represent the requirements
    for preprocessing.

    Args:
      tensor_spec_struct: A flat spec structure {str: TensorSpec}.

    Returns:
      tensor_spec_struct: The transformed flat spec structure {str:
      TensorSpec}.
    """
    self.update_spec(
        tensor_spec_struct,
        'state/image',
        shape=(512, 640, 3),
        dtype=tf.uint8,
        data_format='jpeg')
    return tensor_spec_struct 
Example #23
Source File: evaluate.py    From compression with Apache License 2.0 5 votes vote down vote up
def eval_trained_model(config_name,
                       ckpt_dir,
                       out_dir,
                       tfds_arguments: helpers.TFDSArguments,
                       max_images=None):
  """Evaluate a trained model."""
  config = configs.get_config(config_name)
  hific = model.HiFiC(config, helpers.ModelMode.EVALUATION)

  # Automatically uses the validation split.
  dataset = hific.build_input(
      batch_size=1, crop_size=None, tfds_arguments=tfds_arguments)
  iterator = tf.data.make_one_shot_iterator(dataset)
  get_next_image = iterator.get_next()

  output_image, bpp = hific.build_model(**get_next_image)
  input_image = get_next_image['input_image']

  input_image = tf.cast(tf.round(input_image[0, ...]), tf.uint8)
  output_image = tf.cast(tf.round(output_image[0, ...]), tf.uint8)

  os.makedirs(out_dir, exist_ok=True)

  with tf.Session() as sess:
    hific.restore_trained_model(sess, ckpt_dir)
    for i in itertools.count(0):
      if max_images and i == max_images:
        break
      try:
        inp_np, otp_np, bpp_np = sess.run([input_image, output_image, bpp])
        print(f'Image {i}: {bpp_np:.3} bpp, saving in {out_dir}...')
        Image.fromarray(inp_np).save(
            os.path.join(out_dir, f'img_{i:010d}inp.png'))
        Image.fromarray(otp_np).save(
            os.path.join(out_dir, f'img_{i:010d}otp_{bpp_np:.3f}.png'))
      except tf.errors.OutOfRangeError:
        print('No more inputs')
        break
  print('Done!') 
Example #24
Source File: bls2017.py    From compression with Apache License 2.0 5 votes vote down vote up
def quantize_image(image):
  image = tf.round(image * 255)
  image = tf.saturate_cast(image, tf.uint8)
  return image 
Example #25
Source File: vocabularies.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def encode_tf(self, s):
    """Encode a tf.Scalar string to a tf.Tensor.

    Args:
      s: a tf.Scalar with dtype tf.string
    Returns:
      a 1d tf.Tensor with dtype tf.int32
    """
    tf_ids = tf.io.decode_raw(s, tf.uint8) + self._num_special_tokens
    return tf.dtypes.cast(tf_ids, tf.int32) 
Example #26
Source File: glow_ops.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def postprocess(x, n_bits_x=8):
  """Converts x from [-0.5, 0.5], to [0, 255].

  Args:
    x: 3-D or 4-D Tensor normalized between [-0.5, 0.5]
    n_bits_x: Number of bits representing each pixel of the output.
              Defaults to 8, to default to 256 possible values.
  Returns:
    x: 3-D or 4-D Tensor representing images or videos.
  """
  x = tf.where(tf.is_finite(x), x, tf.ones_like(x))
  x = tf.clip_by_value(x, -0.5, 0.5)
  x += 0.5
  x = x * 2**n_bits_x
  return tf.cast(tf.clip_by_value(x, 0, 255), dtype=tf.uint8) 
Example #27
Source File: mnist_dataset.py    From mesh with Apache License 2.0 5 votes vote down vote up
def dataset(directory, images_file, labels_file):
  """Download and parse MNIST dataset."""

  images_file = download(directory, images_file)
  labels_file = download(directory, labels_file)

  check_image_file_header(images_file)
  check_labels_file_header(labels_file)

  def decode_image(image):
    # Normalize from [0, 255] to [0.0, 1.0]
    image = tf.decode_raw(image, tf.uint8)
    image = tf.cast(image, tf.float32)
    image = tf.reshape(image, [784])
    return image / 255.0

  def decode_label(label):
    label = tf.decode_raw(label, tf.uint8)  # tf.string -> [tf.uint8]
    label = tf.reshape(label, [])  # label is a scalar
    return tf.to_int32(label)

  images = tf.data.FixedLengthRecordDataset(
      images_file, 28 * 28, header_bytes=16).map(decode_image)
  labels = tf.data.FixedLengthRecordDataset(
      labels_file, 1, header_bytes=8).map(decode_label)
  return tf.data.Dataset.zip((images, labels)) 
Example #28
Source File: tfrecord_image_generator.py    From benchmarks with Apache License 2.0 5 votes vote down vote up
def __init__(self):
    # Create a single Session to run all image coding calls.
    self._sess = tf.Session()

    # Initializes function that converts PNG to JPEG data.
    self._image = tf.placeholder(dtype=tf.uint8)
    self._encode_jpeg = tf.image.encode_jpeg(
        self._image, format='rgb', quality=100) 
Example #29
Source File: sv2p.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def visualize_predictions(self, real_frames, gen_frames, actions=None):

    def concat_on_y_axis(x):
      x = tf.unstack(x, axis=1)
      x = tf.concat(x, axis=1)
      return x
    frames_gd = common_video.swap_time_and_batch_axes(real_frames)
    frames_pd = common_video.swap_time_and_batch_axes(gen_frames)
    if actions is not None:
      actions = common_video.swap_time_and_batch_axes(actions)

    if self.is_per_pixel_softmax:
      frames_pd_shape = common_layers.shape_list(frames_pd)
      frames_pd = tf.reshape(frames_pd, [-1, 256])
      frames_pd = tf.to_float(tf.argmax(frames_pd, axis=-1))
      frames_pd = tf.reshape(frames_pd, frames_pd_shape[:-1] + [3])

    frames_gd = concat_on_y_axis(frames_gd)
    frames_pd = concat_on_y_axis(frames_pd)
    if actions is not None:
      actions = tf.clip_by_value(actions, 0, 1)
      summary("action_vid", tf.cast(actions * 255, tf.uint8))
      actions = concat_on_y_axis(actions)
      side_by_side_video = tf.concat([frames_gd, frames_pd, actions], axis=2)
    else:
      side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2)
    tf.summary.image("full_video", side_by_side_video) 
Example #30
Source File: image_utils.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def encode_images_as_png(images):
  """Yield images encoded as pngs."""
  if tf.executing_eagerly():
    for image in images:
      yield tf.image.encode_png(image).numpy()
  else:
    (height, width, channels) = images[0].shape
    with tf.Graph().as_default():
      image_t = tf.placeholder(dtype=tf.uint8, shape=(height, width, channels))
      encoded_image_t = tf.image.encode_png(image_t)
      with tf.Session() as sess:
        for image in images:
          enc_string = sess.run(encoded_image_t, feed_dict={image_t: image})
          yield enc_string