Python tensorflow.int64() Examples

The following are 30 code examples of tensorflow.int64(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: common_layers.py    From fine-lm with MIT License 6 votes vote down vote up
def ones_matrix_band_part(rows, cols, num_lower, num_upper, out_shape=None):
  """Matrix band part of ones."""
  if all([isinstance(el, int) for el in [rows, cols, num_lower, num_upper]]):
    # Needed info is constant, so we construct in numpy
    if num_lower < 0:
      num_lower = rows - 1
    if num_upper < 0:
      num_upper = cols - 1
    lower_mask = np.tri(cols, rows, num_lower).T
    upper_mask = np.tri(rows, cols, num_upper)
    band = np.ones((rows, cols)) * lower_mask * upper_mask
    if out_shape:
      band = band.reshape(out_shape)
    band = tf.constant(band, tf.float32)
  else:
    band = tf.matrix_band_part(
        tf.ones([rows, cols]), tf.cast(num_lower, tf.int64),
        tf.cast(num_upper, tf.int64))
    if out_shape:
      band = tf.reshape(band, out_shape)

  return band 
Example #2
Source File: ops_test.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def test_indices_to_dense_vector_int(self):
    size = 500
    num_indices = 25
    rand_indices = np.random.permutation(np.arange(size))[0:num_indices]

    expected_output = np.zeros(size, dtype=np.int64)
    expected_output[rand_indices] = 1

    tf_rand_indices = tf.constant(rand_indices)
    indicator = ops.indices_to_dense_vector(
        tf_rand_indices, size, 1, dtype=tf.int64)

    with self.test_session() as sess:
      output = sess.run(indicator)
      self.assertAllEqual(output, expected_output)
      self.assertEqual(output.dtype, expected_output.dtype) 
Example #3
Source File: multi_problem.py    From fine-lm with MIT License 6 votes vote down vote up
def add_task_id(self, task, example):
    """Convert example to code switching mode by adding a task id."""
    if hasattr(task, "class_labels"):
      # TODO(urvashik): handle the case where num_labels > 9
      example["targets"] = tf.cast(discretization.int_to_bit(
          example["targets"], 1, base=10) + 50, tf.int64)
      example["targets"] = tf.squeeze(example["targets"], axis=[-1])

    if task.has_inputs:
      inputs = example.pop("inputs")
      concat_list = [inputs, [task.task_id], example["targets"]]
    else:
      concat_list = [[task.task_id], example["targets"]]

    example["targets"] = tf.concat(concat_list, 0)
    return example 
Example #4
Source File: bulk_component.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def build_cross_entropy_loss(logits, gold):
  """Constructs a cross entropy from logits and one-hot encoded gold labels.

  Supports skipping rows where the gold label is the magic -1 value.

  Args:
    logits: float Tensor of scores.
    gold: int Tensor of one-hot labels.

  Returns:
    cost, correct, total: the total cost, the total number of correctly
        predicted labels, and the total number of valid labels.
  """
  valid = tf.reshape(tf.where(tf.greater(gold, -1)), [-1])
  gold = tf.gather(gold, valid)
  logits = tf.gather(logits, valid)
  correct = tf.reduce_sum(tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
  total = tf.size(gold)
  cost = tf.reduce_sum(
      tf.contrib.nn.deprecated_flipped_sparse_softmax_cross_entropy_with_logits(
          logits, tf.cast(gold, tf.int64))) / tf.cast(total, tf.float32)
  return cost, correct, total 
Example #5
Source File: problem.py    From fine-lm with MIT License 6 votes vote down vote up
def serving_input_fn(self, hparams):
    """Input fn for serving export, starting from serialized example."""
    mode = tf.estimator.ModeKeys.PREDICT
    serialized_example = tf.placeholder(
        dtype=tf.string, shape=[None], name="serialized_example")
    dataset = tf.data.Dataset.from_tensor_slices(serialized_example)
    dataset = dataset.map(self.decode_example)
    dataset = dataset.map(lambda ex: self.preprocess_example(ex, mode, hparams))
    dataset = dataset.map(self.maybe_reverse_and_copy)
    dataset = dataset.map(data_reader.cast_ints_to_int32)
    dataset = dataset.padded_batch(
        tf.shape(serialized_example, out_type=tf.int64)[0],
        dataset.output_shapes)
    dataset = dataset.map(standardize_shapes)
    features = tf.contrib.data.get_single_element(dataset)

    if self.has_inputs:
      features.pop("targets", None)

    return tf.estimator.export.ServingInputReceiver(
        features=features, receiver_tensors=serialized_example) 
Example #6
Source File: problem.py    From fine-lm with MIT License 6 votes vote down vote up
def decode_example(self, serialized_example):
    """Return a dict of Tensors from a serialized tensorflow.Example."""
    data_fields, data_items_to_decoders = self.example_reading_spec()
    # Necessary to rejoin examples in the correct order with the Cloud ML Engine
    # batch prediction API.
    data_fields["batch_prediction_key"] = tf.FixedLenFeature([1], tf.int64, 0)
    if data_items_to_decoders is None:
      data_items_to_decoders = {
          field: tf.contrib.slim.tfexample_decoder.Tensor(field)
          for field in data_fields
      }

    decoder = tf.contrib.slim.tfexample_decoder.TFExampleDecoder(
        data_fields, data_items_to_decoders)

    decode_items = list(sorted(data_items_to_decoders))
    decoded = decoder.decode(serialized_example, items=decode_items)
    return dict(zip(decode_items, decoded)) 
Example #7
Source File: probclass.py    From imgcomp-cvpr with GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, pc: _Network3D, config, centers, sess, freqs_resolution=1e9):
        """
        :param sess: Must be set at the latest before using get_pr or get_freqs
        """
        self.pc_class = pc.__class__
        self.config = config
        self.input_ctx_shape = self.pc_class.get_context_shape(config)
        self.input_ctx = tf.placeholder(tf.int64, self.input_ctx_shape)  # symbols!
        input_ctx_batched = tf.expand_dims(self.input_ctx, 0)  # add batch dimension, 1DHW
        input_ctx_batched = tf.expand_dims(input_ctx_batched, -1)  # add T dimension for 3d conv, now 1CHW1
        # Here, in contrast to pc.bitcost(...), q does not need to be padded, as it is part of some context.
        # Logits will be a 1111L vector, i.e., prediction of the next pixel
        q = tf.gather(centers, input_ctx_batched)
        logits = pc.logits(q, is_training=False)
        self.pr = tf.nn.softmax(logits)
        self.freqs = tf.squeeze(tf.cast(self.pr * freqs_resolution, tf.int64))
        self.sess = sess

        self._get_freqs = None 
Example #8
Source File: skip_thoughts_model_test.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def build_inputs(self):
    if self.mode == "encode":
      # Encode mode doesn't read from disk, so defer to parent.
      return super(SkipThoughtsModel, self).build_inputs()
    else:
      # Replace disk I/O with random Tensors.
      self.encode_ids = tf.random_uniform(
          [self.config.batch_size, 15],
          minval=0,
          maxval=self.config.vocab_size,
          dtype=tf.int64)
      self.decode_pre_ids = tf.random_uniform(
          [self.config.batch_size, 15],
          minval=0,
          maxval=self.config.vocab_size,
          dtype=tf.int64)
      self.decode_post_ids = tf.random_uniform(
          [self.config.batch_size, 15],
          minval=0,
          maxval=self.config.vocab_size,
          dtype=tf.int64)
      self.encode_mask = tf.ones_like(self.encode_ids)
      self.decode_pre_mask = tf.ones_like(self.decode_pre_ids)
      self.decode_post_mask = tf.ones_like(self.decode_post_ids) 
Example #9
Source File: show_and_tell_model_test.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def build_inputs(self):
    if self.mode == "inference":
      # Inference mode doesn't read from disk, so defer to parent.
      return super(ShowAndTellModel, self).build_inputs()
    else:
      # Replace disk I/O with random Tensors.
      self.images = tf.random_uniform(
          shape=[self.config.batch_size, self.config.image_height,
                 self.config.image_width, 3],
          minval=-1,
          maxval=1)
      self.input_seqs = tf.random_uniform(
          [self.config.batch_size, 15],
          minval=0,
          maxval=self.config.vocab_size,
          dtype=tf.int64)
      self.target_seqs = tf.random_uniform(
          [self.config.batch_size, 15],
          minval=0,
          maxval=self.config.vocab_size,
          dtype=tf.int64)
      self.input_mask = tf.ones_like(self.input_seqs) 
Example #10
Source File: ops_test.py    From object_detector_app with MIT License 6 votes vote down vote up
def test_indices_to_dense_vector_int(self):
    size = 500
    num_indices = 25
    rand_indices = np.random.permutation(np.arange(size))[0:num_indices]

    expected_output = np.zeros(size, dtype=np.int64)
    expected_output[rand_indices] = 1

    tf_rand_indices = tf.constant(rand_indices)
    indicator = ops.indices_to_dense_vector(
        tf_rand_indices, size, 1, dtype=tf.int64)

    with self.test_session() as sess:
      output = sess.run(indicator)
      self.assertAllEqual(output, expected_output)
      self.assertEqual(output.dtype, expected_output.dtype) 
Example #11
Source File: inputs.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def _read_single_sequence_example(file_list, tokens_shape=None):
  """Reads and parses SequenceExamples from TFRecord-encoded file_list."""
  tf.logging.info('Constructing TFRecordReader from files: %s', file_list)
  file_queue = tf.train.string_input_producer(file_list)
  reader = tf.TFRecordReader()
  seq_key, serialized_record = reader.read(file_queue)
  ctx, sequence = tf.parse_single_sequence_example(
      serialized_record,
      sequence_features={
          data_utils.SequenceWrapper.F_TOKEN_ID:
              tf.FixedLenSequenceFeature(tokens_shape or [], dtype=tf.int64),
          data_utils.SequenceWrapper.F_LABEL:
              tf.FixedLenSequenceFeature([], dtype=tf.int64),
          data_utils.SequenceWrapper.F_WEIGHT:
              tf.FixedLenSequenceFeature([], dtype=tf.float32),
      })
  return seq_key, ctx, sequence 
Example #12
Source File: gym_problems.py    From fine-lm with MIT License 6 votes vote down vote up
def extra_reading_spec(self):
    """Additional data fields to store on disk and their decoders."""

    # TODO(piotrmilos): shouldn't done be included here?
    data_fields = {
        "frame_number": tf.FixedLenFeature([1], tf.int64),
        "action": tf.FixedLenFeature([1], tf.int64),
        "reward": tf.FixedLenFeature([1], tf.int64)
    }
    decoders = {
        "frame_number":
            tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="frame_number"),
        "action":
            tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="action"),
        "reward":
            tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="reward"),
    }
    return data_fields, decoders 
Example #13
Source File: cifar10.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def loss(logits, labels):
  """Add L2Loss to all the trainable variables.

  Add summary for "Loss" and "Loss/avg".
  Args:
    logits: Logits from inference().
    labels: Labels from distorted_inputs or inputs(). 1-D tensor
            of shape [batch_size]

  Returns:
    Loss tensor of type float.
  """
  # Calculate the average cross entropy loss across the batch.
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels, logits=logits, name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)

  # The total loss is defined as the cross entropy loss plus all of the weight
  # decay terms (L2 loss).
  return tf.add_n(tf.get_collection('losses'), name='total_loss') 
Example #14
Source File: variables.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def global_step(device=''):
  """Returns the global step variable.

  Args:
    device: Optional device to place the variable. It can be an string or a
      function that is called to get the device for the variable.

  Returns:
    the tensor representing the global step variable.
  """
  global_step_ref = tf.get_collection(tf.GraphKeys.GLOBAL_STEP)
  if global_step_ref:
    return global_step_ref[0]
  else:
    collections = [
        VARIABLES_TO_RESTORE,
        tf.GraphKeys.GLOBAL_VARIABLES,
        tf.GraphKeys.GLOBAL_STEP,
    ]
    # Get the device for the variable.
    with tf.device(variable_device(device, 'global_step')):
      return tf.get_variable('global_step', shape=[], dtype=tf.int64,
                             initializer=tf.zeros_initializer(),
                             trainable=False, collections=collections) 
Example #15
Source File: text_problems.py    From fine-lm with MIT License 5 votes vote down vote up
def example_reading_spec(self):
    data_fields, data_items_to_decoders = (super(QuestionAndContext2TextProblem,
                                                 self)
                                           .example_reading_spec())
    data_fields["context"] = tf.VarLenFeature(tf.int64)
    return (data_fields, data_items_to_decoders) 
Example #16
Source File: video_generated.py    From fine-lm with MIT License 5 votes vote down vote up
def extra_reading_spec(self):
    """Additional data fields to store on disk and their decoders."""
    data_fields = {
        "frame_number": tf.FixedLenFeature([1], tf.int64),
    }
    decoders = {
        "frame_number": tf.contrib.slim.tfexample_decoder.Tensor(
            tensor_key="frame_number"),
    }
    return data_fields, decoders 
Example #17
Source File: tf_metrics.py    From tudouNLP with MIT License 5 votes vote down vote up
def precision(labels, predictions, num_classes, pos_indices=None,
              weights=None, average='micro'):
    """Multi-class precision metric for Tensorflow
    Parameters
    ----------
    labels : Tensor of tf.int32 or tf.int64
        The true labels
    predictions : Tensor of tf.int32 or tf.int64
        The predictions, same shape as labels
    num_classes : int
        The number of classes
    pos_indices : list of int, optional
        The indices of the positive classes, default is all
    weights : Tensor of tf.int32, optional
        Mask, must be of compatible shape with labels
    average : str, optional
        'micro': counts the total number of true positives, false
            positives, and false negatives for the classes in
            `pos_indices` and infer the metric from it.
        'macro': will compute the metric separately for each class in
            `pos_indices` and average. Will not account for class
            imbalance.
        'weighted': will compute the metric separately for each class in
            `pos_indices` and perform a weighted average by the total
            number of true labels for each class.
    Returns
    -------
    tuple of (scalar float Tensor, update_op)
    """
    cm, op = _streaming_confusion_matrix(
        labels, predictions, num_classes, weights)
    pr, _, _ = metrics_from_confusion_matrix(
        cm, pos_indices, average=average)
    op, _, _ = metrics_from_confusion_matrix(
        op, pos_indices, average=average)
    return (pr, op) 
Example #18
Source File: exp12_user_dataset_low_API_2.py    From LearningTensorflow with MIT License 5 votes vote down vote up
def read_tfrecord(tf_filename, size):
    queue = tf.train.string_input_producer([tf_filename])
    reader = tf.TFRecordReader()
    __, serialized_example = reader.read(queue)
    feature = {
          'image_raw': tf.FixedLenFeature([size[0]*size[1]*size[2]], tf.float32),
          'height': tf.FixedLenFeature([], tf.int64),
          'width': tf.FixedLenFeature([], tf.int64)
    }
    features = tf.parse_single_example(serialized_example, features=feature)
    image = features['image_raw']
    image = tf.reshape(image, size)
    return image
# end change
############################################################################## 
Example #19
Source File: exp14_user_dataset_high_API_2.py    From LearningTensorflow with MIT License 5 votes vote down vote up
def _parse_function(example_proto):
  feature = {
        'image_raw': tf.FixedLenFeature([180*180*1], tf.float32),
        'height': tf.FixedLenFeature([], tf.int64),
        'width': tf.FixedLenFeature([], tf.int64)
  }
  parsed_features = tf.parse_single_example(example_proto, feature)
  image = parsed_features['image_raw']
  image = tf.reshape(image, [180,180,1])
  return image 
Example #20
Source File: exp11_user_dataset_low_API_1.py    From LearningTensorflow with MIT License 5 votes vote down vote up
def read_tfrecord(tf_filename, size):
    queue = tf.train.string_input_producer([tf_filename])
    reader = tf.TFRecordReader()
    __, serialized_example = reader.read(queue)
    feature = {
          'image_raw': tf.FixedLenFeature([], tf.string),
          'height': tf.FixedLenFeature([], tf.int64),
          'width': tf.FixedLenFeature([], tf.int64)
    }
    features = tf.parse_single_example(serialized_example, features=feature)
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image = tf.reshape(image, size)
    return image 
Example #21
Source File: tf_metrics.py    From tudouNLP with MIT License 5 votes vote down vote up
def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None,
          average='micro', beta=1):
    """Multi-class fbeta metric for Tensorflow
    Parameters
    ----------
    labels : Tensor of tf.int32 or tf.int64
        The true labels
    predictions : Tensor of tf.int32 or tf.int64
        The predictions, same shape as labels
    num_classes : int
        The number of classes
    pos_indices : list of int, optional
        The indices of the positive classes, default is all
    weights : Tensor of tf.int32, optional
        Mask, must be of compatible shape with labels
    average : str, optional
        'micro': counts the total number of true positives, false
            positives, and false negatives for the classes in
            `pos_indices` and infer the metric from it.
        'macro': will compute the metric separately for each class in
            `pos_indices` and average. Will not account for class
            imbalance.
        'weighted': will compute the metric separately for each class in
            `pos_indices` and perform a weighted average by the total
            number of true labels for each class.
    beta : int, optional
        Weight of precision in harmonic mean
    Returns
    -------
    tuple of (scalar float Tensor, update_op)
    """
    cm, op = _streaming_confusion_matrix(
        labels, predictions, num_classes, weights)
    _, _, fbeta = metrics_from_confusion_matrix(
        cm, pos_indices, average=average, beta=beta)
    _, _, op = metrics_from_confusion_matrix(
        op, pos_indices, average=average, beta=beta)
    return (fbeta, op) 
Example #22
Source File: Metrics.py    From NTM-One-Shot-TF with MIT License 5 votes vote down vote up
def accuracy_instance(predictions, targets, n=[1, 2, 3, 4, 5, 10], nb_classes=5, nb_samples_per_class=10, batch_size=1):
    targets = tf.cast(targets, predictions.dtype)

    accuracy = tf.constant(value=0, shape=(batch_size, nb_samples_per_class), dtype=tf.float32)
    indices = tf.constant(value=0, shape=(batch_size, nb_classes+1), dtype=tf.float32)

    def step_((accuracy, indices), (p, t)):
        """with tf.variable_scope("Metric_step_var", reuse=True):
            accuracy = tf.get_variable(name="accuracy", shape=(batch_size, nb_samples_per_class),
                                       initializer=tf.constant_initializer(0), dtype=tf.float32)
            indices = tf.get_variable(name="indices", shape=(batch_size, nb_classes + 1),
                                      initializer=tf.constant_initializer(0), dtype=tf.float32)"""

        p = tf.cast(p, tf.int32)
        t = tf.cast(t, tf.int32)
        ##Accuracy Update
        batch_range = tf.cast(tf.range(0, batch_size), dtype=tf.int32)
        gather = tf.cast(tf.gather_nd(indices,tf.stack([tf.range(0,p.get_shape().as_list()[0]), t], axis=1)), tf.int32)
        index = tf.cast(tf.stack([batch_range, gather], axis=1), dtype=tf.int64)
        val = tf.cast(tf.equal(p, t), tf.float32)
        delta = tf.SparseTensor(indices=index, values=val, dense_shape=tf.cast(accuracy.get_shape().as_list(), tf.int64))
        accuracy = accuracy + tf.sparse_tensor_to_dense(delta)
        ##Index Update
        index = tf.cast(tf.stack([batch_range, t], axis=1), dtype=tf.int64)
        val = tf.constant(1.0, shape=[batch_size])
        delta = tf.SparseTensor(indices=index, values=val, dense_shape=tf.cast(indices.get_shape().as_list(), dtype=tf.int64))
        indices = indices + tf.sparse_tensor_to_dense(delta)
        return [accuracy, indices] 
Example #23
Source File: text_problems.py    From fine-lm with MIT License 5 votes vote down vote up
def example_reading_spec(self):
    data_fields = {"targets": tf.VarLenFeature(tf.int64)}
    if self.has_inputs:
      data_fields["inputs"] = tf.VarLenFeature(tf.int64)

    if self.packed_length:
      if self.has_inputs:
        data_fields["inputs_segmentation"] = tf.VarLenFeature(tf.int64)
        data_fields["inputs_position"] = tf.VarLenFeature(tf.int64)
      data_fields["targets_segmentation"] = tf.VarLenFeature(tf.int64)
      data_fields["targets_position"] = tf.VarLenFeature(tf.int64)

    data_items_to_decoders = None
    return (data_fields, data_items_to_decoders) 
Example #24
Source File: mnist_eager.py    From dockerfiles with Apache License 2.0 5 votes vote down vote up
def compute_accuracy(predictions, labels):
  return tf.reduce_sum(
      tf.cast(
          tf.equal(
              tf.argmax(predictions, axis=1,
                        output_type=tf.int64),
              tf.argmax(labels, axis=1,
                        output_type=tf.int64)),
          dtype=tf.float32)) / float(predictions.shape[0].value) 
Example #25
Source File: image_utils.py    From fine-lm with MIT License 5 votes vote down vote up
def example_reading_spec(self):
    label_key = "image/class/label"
    data_fields, data_items_to_decoders = (
        super(Image2TextProblem, self).example_reading_spec())
    data_fields[label_key] = tf.VarLenFeature(tf.int64)
    data_items_to_decoders[
        "targets"] = tf.contrib.slim.tfexample_decoder.Tensor(label_key)
    return data_fields, data_items_to_decoders 
Example #26
Source File: image_utils.py    From fine-lm with MIT License 5 votes vote down vote up
def example_reading_spec(self):
    label_key = "image/class/label"
    data_fields, data_items_to_decoders = (
        super(Image2ClassProblem, self).example_reading_spec())
    data_fields[label_key] = tf.FixedLenFeature((1,), tf.int64)

    data_items_to_decoders[
        "targets"] = tf.contrib.slim.tfexample_decoder.Tensor(label_key)
    return data_fields, data_items_to_decoders 
Example #27
Source File: problem_hparams.py    From fine-lm with MIT License 5 votes vote down vote up
def example_reading_spec(self):
    data_fields = {
        "inputs": tf.VarLenFeature(tf.int64),
        "audio/sample_count": tf.FixedLenFeature((), tf.int64),
        "audio/sample_width": tf.FixedLenFeature((), tf.int64),
        "targets": tf.VarLenFeature(tf.int64),
    }
    return data_fields, None 
Example #28
Source File: babi_qa.py    From fine-lm with MIT License 5 votes vote down vote up
def example_reading_spec(self):
    data_fields, data_items_to_decoders = (
        super(BabiQa, self).example_reading_spec())
    data_fields['targets'] = tf.FixedLenFeature([1], tf.int64)
    return (data_fields, data_items_to_decoders) 
Example #29
Source File: video_utils.py    From fine-lm with MIT License 5 votes vote down vote up
def example_reading_spec(self):
    label_key = "image/class/label"
    data_fields, data_items_to_decoders = (
        super(Video2ClassProblem, self).example_reading_spec())
    data_fields[label_key] = tf.FixedLenFeature((1,), tf.int64)
    data_items_to_decoders[
        "targets"] = tf.contrib.slim.tfexample_decoder.Tensor(label_key)
    return data_fields, data_items_to_decoders 
Example #30
Source File: common_layers.py    From fine-lm with MIT License 5 votes vote down vote up
def embedding(x,
              vocab_size,
              dense_size,
              name=None,
              reuse=None,
              multiplier=1.0,
              symbol_dropout_rate=0.0,
              embedding_var=None,
              dtype=tf.float32):
  """Embed x of type int64 into dense vectors, reducing to max 4 dimensions."""
  with tf.variable_scope(
      name, default_name="embedding", values=[x], reuse=reuse, dtype=dtype):
    if embedding_var is None:
      embedding_var = tf.get_variable("kernel", [vocab_size, dense_size])
    # On the backwards pass, we want to convert the gradient from
    # an indexed-slices to a regular tensor before sending it back to the
    # parameter server. This avoids excess computation on the parameter server.
    if not tf.contrib.eager.in_eager_mode():
      embedding_var = convert_gradient_to_tensor(embedding_var)
    x = dropout_no_scaling(x, 1.0 - symbol_dropout_rate)
    emb_x = gather(embedding_var, x, dtype)
    if multiplier != 1.0:
      emb_x *= multiplier
    static_shape = emb_x.shape.as_list()
    if len(static_shape) < 5:
      return emb_x
    assert len(static_shape) == 5
    # If we had an extra channel dimension, assume it's 1, i.e. shape[3] == 1.
    return tf.squeeze(emb_x, 3)