Python tensorflow.python.ops.array_ops.slice() Examples

The following are 30 code examples of tensorflow.python.ops.array_ops.slice(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.ops.array_ops , or try the search function .
Example #1
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def tensors_to_item(self, keys_to_tensors):
    indices = keys_to_tensors[self._indices_key]
    values = keys_to_tensors[self._values_key]
    if self._shape_key:
      shape = keys_to_tensors[self._shape_key]
      if isinstance(shape, sparse_tensor.SparseTensor):
        shape = sparse_ops.sparse_tensor_to_dense(shape)
    elif self._shape:
      shape = self._shape
      shape = indices.dense_shape
    indices_shape = array_ops.shape(indices.indices)
    rank = indices_shape[1]
    ids = math_ops.to_int64(indices.values)
    indices_columns_to_preserve = array_ops.slice(
        indices.indices, [0, 0], array_ops.stack([-1, rank - 1]))
    new_indices = array_ops.concat(
        [indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])], 1)

    tensor = sparse_tensor.SparseTensor(new_indices, values.values, shape)
    if self._densify:
      tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
    return tensor 
Example #2
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _gini(self, class_counts):
    """Calculate the Gini impurity.

    If c(i) denotes the i-th class count and c = sum_i c(i) then
      score = 1 - sum_i ( c(i) / c )^2

      class_counts: A 2-D tensor of per-class counts, usually a slice or
        gather from variables.node_sums.

      A 1-D tensor of the Gini impurities for each row in the input.
    smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1])
    sums = math_ops.reduce_sum(smoothed, 1)
    sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1)

    return 1.0 - sum_squares / (sums * sums) 
Example #3
Source File:    From Multiview2Novelview with MIT License 6 votes vote down vote up
def _attention(self, query, attn_states):
    conv2d = nn_ops.conv2d
    reduce_sum = math_ops.reduce_sum
    softmax = nn_ops.softmax
    tanh = math_ops.tanh

    with vs.variable_scope("attention"):
      k = vs.get_variable(
          "attn_w", [1, 1, self._attn_size, self._attn_vec_size])
      v = vs.get_variable("attn_v", [self._attn_vec_size])
      hidden = array_ops.reshape(attn_states,
                                 [-1, self._attn_length, 1, self._attn_size])
      hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
      if self._linear3 is None:
        self._linear3 = _Linear(query, self._attn_vec_size, True)
      y = self._linear3(query)
      y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size])
      s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
      a = softmax(s)
      d = reduce_sum(
          array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
      new_attns = array_ops.reshape(d, [-1, self._attn_size])
      new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1])
      return new_attns, new_attn_states 
Example #4
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _variance(self, sums, squares):
    """Calculate the variance for each row of the input tensors.

    Variance is V = E[x^2] - (E[x])^2.

      sums: A tensor containing output sums, usually a slice from
        variables.node_sums.  Should contain the number of examples seen
        in index 0 so we can calculate expected value.
      squares: Same as sums, but sums of squares.

      A 1-D tensor of the variances for each row in the input.
    total_count = array_ops.slice(sums, [0, 0], [-1, 1])
    e_x = sums / total_count
    e_x2 = squares / total_count

    return math_ops.reduce_sum(e_x2 - math_ops.square(e_x), 1) 
Example #5
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _weighted_gini(self, class_counts):
    """Our split score is the Gini impurity times the number of examples.

    If c(i) denotes the i-th class count and c = sum_i c(i) then
      score = c * (1 - sum_i ( c(i) / c )^2 )
            = c - sum_i c(i)^2 / c
      class_counts: A 2-D tensor of per-class counts, usually a slice or
        gather from variables.node_sums.

      A 1-D tensor of the Gini impurities for each row in the input.
    smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1])
    sums = math_ops.reduce_sum(smoothed, 1)
    sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1)

    return sums - sum_squares / sums 
Example #6
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def call(self, inputs, state):
    """Run this multi-layer cell on inputs, starting from state."""
    cur_state_pos = 0
    cur_inp = inputs
    new_states = []
    for i, cell in enumerate(self._cells):
      with vs.variable_scope("cell_%d" % i):
        if self._state_is_tuple:
          if not nest.is_sequence(state):
            raise ValueError(
                "Expected state to be a tuple of length %d, but received: %s" %
                (len(self.state_size), state))
          cur_state = state[i]
          cur_state = array_ops.slice(state, [0, cur_state_pos],
                                      [-1, cell.state_size])
          cur_state_pos += cell.state_size
        cur_inp, new_state = cell(cur_inp, cur_state)

    new_states = (tuple(new_states) if self._state_is_tuple else
                  array_ops.concat(new_states, 1))

    return cur_inp, new_states 
Example #7
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def average_impurity(self):
    """Constructs a TF graph for evaluating the average leaf impurity of a tree.

    If in regression mode, this is the leaf variance. If in classification mode,
    this is the gini impurity.

      The last op in the graph.
    children = array_ops.squeeze(array_ops.slice(
        self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
    is_leaf = math_ops.equal(constants.LEAF_NODE, children)
    leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
    counts = array_ops.gather(self.variables.node_sums, leaves)
    gini = self._weighted_gini(counts)
    # Guard against step 1, when there often are no leaves yet.
    def impurity():
      return gini
    # Since average impurity can be used for loss, when there's no data just
    # return a big number so that loss always decreases.
    def big():
      return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
    return control_flow_ops.cond(math_ops.greater(
        array_ops.shape(leaves)[0], 0), impurity, big) 
Example #8
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def _PadGrad(op, grad):
  """Gradient for Pad."""
  # Pad introduces values around the original tensor, so the gradient function
  # slices the original shape out of the gradient."""
  x = op.inputs[0]
  a = op.inputs[1]  # [Rank(x), 2]
  # Takes a slice of a. The 1st column. [Rank(x), 1].
  pad_before = array_ops.slice(a, [0, 0],
                               array_ops.stack([array_ops.rank(x), 1]))
  # Make it a 1-D tensor.
  begin = array_ops.reshape(pad_before, [-1])
  sizes = array_ops.shape(x)
  return array_ops.slice(grad, begin, sizes), None

# ReverseSequence is just a permutation.  The gradient permutes back. 
Example #9
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def _MatrixSetDiagGrad(op, grad):
  """Gradient for MatrixSetDiag."""
  input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
  diag_shape = op.inputs[1].get_shape()
  batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
  matrix_shape = input_shape[-2:]
  if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
    diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
    with ops.colocate_with(grad):
      grad_shape = array_ops.shape(grad)
      grad_rank = array_ops.rank(grad)
      batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
      matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
      min_dim = math_ops.reduce_min(matrix_shape)
      diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
  grad_input = array_ops.matrix_set_diag(
      grad, array_ops.zeros(
          diag_shape, dtype=grad.dtype))
  grad_diag = array_ops.matrix_diag_part(grad)
  return (grad_input, grad_diag) 
Example #10
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _split_logits(self, logits):
    """Splits logits for heads.

      logits: the logits tensor.

      A list of logits for the individual heads.
    all_logits = []
    begin = 0
    for head in self._heads:
      current_logits_size = head.logits_dimension
      current_logits = array_ops.slice(logits, [0, begin],
                                       [-1, current_logits_size])
      begin += current_logits_size
    return all_logits 
Example #11
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def _flatten_outer_dims(logits):
  """Flattens logits' outer dimensions and keep its last dimension."""
  rank = array_ops.rank(logits)
  last_dim_size = array_ops.slice(
      array_ops.shape(logits), [math_ops.subtract(rank, 1)], [1])
  output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0))

  # Set output shape if known.
  shape = logits.get_shape()
  if shape is not None and shape.dims is not None:
    shape = shape.as_list()
    product = 1
    product_valid = True
    for d in shape[:-1]:
      if d is None:
        product_valid = False
        product *= d
    if product_valid:
      output_shape = [product, shape[-1]]

  return output 
Example #12
Source File:    From Multiview2Novelview with MIT License 6 votes vote down vote up
def _make_tf_features(self, input_feat):
    """Make the frequency features.
      input_feat: input Tensor, 2D, batch x num_units.
      A list of frequency features, with each element containing:
      - A 2D, batch x output_dim, Tensor representing the time-frequency feature
        for that frequency index. Here output_dim is feature_size.
      ValueError: if input_size cannot be inferred from static shape inference.
    input_size = input_feat.get_shape().with_rank(2)[-1].value
    if input_size is None:
      raise ValueError("Cannot infer input_size from static shape inference.")
    num_feats = int((input_size - self._feature_size) / (
        self._frequency_skip)) + 1
    freq_inputs = []
    for f in range(num_feats):
      cur_input = array_ops.slice(input_feat, [0, f*self._frequency_skip],
                                  [-1, self._feature_size])
    return freq_inputs 
Example #13
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def _make_tf_features(self, input_feat):
    """Make the frequency features.

      input_feat: input Tensor, 2D, batch x num_units.

      A list of frequency features, with each element containing:
      - A 2D, batch x output_dim, Tensor representing the time-frequency feature
        for that frequency index. Here output_dim is feature_size.
      ValueError: if input_size cannot be inferred from static shape inference.
    input_size = input_feat.get_shape().with_rank(2)[-1].value
    if input_size is None:
      raise ValueError("Cannot infer input_size from static shape inference.")
    num_feats = int((input_size - self._feature_size) / (
        self._frequency_skip)) + 1
    freq_inputs = []
    for f in range(num_feats):
      cur_input = array_ops.slice(input_feat, [0, f*self._frequency_skip],
                                  [-1, self._feature_size])
    return freq_inputs 
Example #14
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def _attention(self, query, attn_states):
    conv2d = nn_ops.conv2d
    reduce_sum = math_ops.reduce_sum
    softmax = nn_ops.softmax
    tanh = math_ops.tanh

    with vs.variable_scope("attention"):
      k = vs.get_variable(
          "attn_w", [1, 1, self._attn_size, self._attn_vec_size])
      v = vs.get_variable("attn_v", [self._attn_vec_size])
      hidden = array_ops.reshape(attn_states,
                                 [-1, self._attn_length, 1, self._attn_size])
      hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
      y = _linear(query, self._attn_vec_size, True)
      y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size])
      s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
      a = softmax(s)
      d = reduce_sum(
          array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
      new_attns = array_ops.reshape(d, [-1, self._attn_size])
      new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1])
      return new_attns, new_attn_states 
Example #15
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def _get_input_for_group(self, inputs, group_id, group_size):
    """Slices inputs into groups to prepare for processing by cell's groups

      inputs: cell input or it's previous state,
              a Tensor, 2D, [batch x num_units]
      group_id: group id, a Scalar, for which to prepare input
      group_size: size of the group

      subset of inputs corresponding to group "group_id",
      a Tensor, 2D, [batch x num_units/number_of_groups]
    return array_ops.slice(input_=inputs,
                           begin=[0, group_id * group_size],
                           size=[self._batch_size, group_size],
                           name=("GLSTM_group%d_input_generation" % group_id)) 
Example #16
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def soft_inference_graph(self, data):
    with ops.device(self.device_assigner.get_device(self.layer_num)):
      path_probability, path = (

      output = array_ops.slice(
          self.training_ops.unpack_path(path, path_probability),
          [0, self.params.num_nodes - self.params.num_leaves - 1],
          [-1, self.params.num_leaves])

      return output 
Example #17
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def inference_graph(self, data):
    with ops.device(self.device_assigner):
      routing_probabilities = gen_training_ops.k_feature_routing_function(

      output = array_ops.slice(
          [0, self.params.num_nodes - self.params.num_leaves - 1],
          [-1, self.params.num_leaves])

      return output 
Example #18
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def inference_graph(self, data):
    with ops.device(self.device_assigner.get_device(self.layer_num)):
      routing_probabilities = self.training_ops.k_feature_routing_function(

      output = array_ops.slice(
          [0, self.params.num_nodes - self.params.num_leaves - 1],
          [-1, self.params.num_leaves])

      return output 
Example #19
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def soft_inference_graph(self, data):
    with ops.device(self.device_assigner):
      path_probability, path = (

      output = array_ops.slice(
          gen_training_ops.unpack_path(path, path_probability),
          [0, self.params.num_nodes - self.params.num_leaves - 1],
          [-1, self.params.num_leaves])

      return output 
Example #20
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def _gini(self, class_counts):
    """Calculate the Gini impurity.

    If c(i) denotes the i-th class count and c = sum_i c(i) then
      score = 1 - sum_i ( c(i) / c )^2

      class_counts: A 2-D tensor of per-class counts, usually a slice or
        gather from variables.node_sums.

      A 1-D tensor of the Gini impurities for each row in the input.
    smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1])
    sums = math_ops.reduce_sum(smoothed, 1)
    sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1)

    return 1.0 - sum_squares / (sums * sums) 
Example #21
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def _weighted_gini(self, class_counts):
    """Our split score is the Gini impurity times the number of examples.

    If c(i) denotes the i-th class count and c = sum_i c(i) then
      score = c * (1 - sum_i ( c(i) / c )^2 )
            = c - sum_i c(i)^2 / c
      class_counts: A 2-D tensor of per-class counts, usually a slice or
        gather from variables.node_sums.

      A 1-D tensor of the Gini impurities for each row in the input.
    smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1])
    sums = math_ops.reduce_sum(smoothed, 1)
    sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1)

    return sums - sum_squares / sums 
Example #22
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def _variance(self, sums, squares):
    """Calculate the variance for each row of the input tensors.

    Variance is V = E[x^2] - (E[x])^2.

      sums: A tensor containing output sums, usually a slice from
        variables.node_sums.  Should contain the number of examples seen
        in index 0 so we can calculate expected value.
      squares: Same as sums, but sums of squares.

      A 1-D tensor of the variances for each row in the input.
    total_count = array_ops.slice(sums, [0, 0], [-1, 1])
    e_x = sums / total_count
    e_x2 = squares / total_count

    return math_ops.reduce_sum(e_x2 - math_ops.square(e_x), 1) 
Example #23
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def average_impurity(self):
    """Constructs a TF graph for evaluating the average leaf impurity of a tree.

    If in regression mode, this is the leaf variance. If in classification mode,
    this is the gini impurity.

      The last op in the graph.
    children = array_ops.squeeze(array_ops.slice(
        self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
    is_leaf = math_ops.equal(constants.LEAF_NODE, children)
    leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
    counts = array_ops.gather(self.variables.node_sums, leaves)
    gini = self._weighted_gini(counts)
    # Guard against step 1, when there often are no leaves yet.
    def impurity():
      return gini
    # Since average impurity can be used for loss, when there's no data just
    # return a big number so that loss always decreases.
    def big():
      return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
    return control_flow_ops.cond(math_ops.greater(
        array_ops.shape(leaves)[0], 0), impurity, big) 
Example #24
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def tensors_to_item(self, keys_to_tensors):
    indices = keys_to_tensors[self._indices_key]
    values = keys_to_tensors[self._values_key]
    if self._shape_key:
      shape = keys_to_tensors[self._shape_key]
      if isinstance(shape, sparse_tensor.SparseTensor):
        shape = sparse_ops.sparse_tensor_to_dense(shape)
    elif self._shape:
      shape = self._shape
      shape = indices.dense_shape
    indices_shape = array_ops.shape(indices.indices)
    rank = indices_shape[1]
    ids = math_ops.to_int64(indices.values)
    indices_columns_to_preserve = array_ops.slice(
        indices.indices, [0, 0], array_ops.stack([-1, rank - 1]))
    new_indices = array_ops.concat(
        [indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])], 1)

    tensor = sparse_tensor.SparseTensor(new_indices, values.values, shape)
    if self._densify:
      tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
    return tensor 
Example #25
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def __call__(self, inputs, state, scope=None):
    """Run this multi-layer cell on inputs, starting from state."""
    with vs.variable_scope(scope or "multi_rnn_cell"):
      cur_state_pos = 0
      cur_inp = inputs
      new_states = []
      for i, cell in enumerate(self._cells):
        with vs.variable_scope("cell_%d" % i):
          if self._state_is_tuple:
            if not nest.is_sequence(state):
              raise ValueError(
                  "Expected state to be a tuple of length %d, but received: %s"
                  % (len(self.state_size), state))
            cur_state = state[i]
            cur_state = array_ops.slice(
                state, [0, cur_state_pos], [-1, cell.state_size])
            cur_state_pos += cell.state_size
          cur_inp, new_state = cell(cur_inp, cur_state)
    new_states = (tuple(new_states) if self._state_is_tuple else
                  array_ops.concat(new_states, 1))
    return cur_inp, new_states 
Example #26
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _attention(self, query, attn_states):
    conv2d = nn_ops.conv2d
    reduce_sum = math_ops.reduce_sum
    softmax = nn_ops.softmax
    tanh = math_ops.tanh

    with vs.variable_scope("attention"):
      k = vs.get_variable(
          "attn_w", [1, 1, self._attn_size, self._attn_vec_size])
      v = vs.get_variable("attn_v", [self._attn_vec_size])
      hidden = array_ops.reshape(attn_states,
                                 [-1, self._attn_length, 1, self._attn_size])
      hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
      y = _linear(query, self._attn_vec_size, True)
      y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size])
      s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
      a = softmax(s)
      d = reduce_sum(
          array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
      new_attns = array_ops.reshape(d, [-1, self._attn_size])
      new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1])
      return new_attns, new_attn_states 
Example #27
Source File:    From lambda-packs with MIT License 6 votes vote down vote up
def _split_logits(self, logits):
    """Splits logits for heads.

      logits: the logits tensor.

      A list of logits for the individual heads.
    all_logits = []
    begin = 0
    for head in self._heads:
      current_logits_size = head.logits_dimension
      current_logits = array_ops.slice(logits, [0, begin],
                                       [-1, current_logits_size])
      begin += current_logits_size
    return all_logits 
Example #28
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _PadGrad(op, grad):
  """Gradient for Pad."""
  # Pad introduces values around the original tensor, so the gradient function
  # slices the original shape out of the gradient."""
  x = op.inputs[0]
  a = op.inputs[1]  # [Rank(x), 2]
  # Takes a slice of a. The 1st column. [Rank(x), 1].
  pad_before = array_ops.slice(a, [0, 0],
                               array_ops.stack([array_ops.rank(x), 1]))
  # Make it a 1-D tensor.
  begin = array_ops.reshape(pad_before, [-1])
  sizes = array_ops.shape(x)
  return array_ops.slice(grad, begin, sizes), None

# ReverseSequence is just a permutation.  The gradient permutes back. 
Example #29
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _flatten_outer_dims(logits):
  """Flattens logits' outer dimensions and keep its last dimension."""
  rank = array_ops.rank(logits)
  last_dim_size = array_ops.slice(
      array_ops.shape(logits), [math_ops.subtract(rank, 1)], [1])
  output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0))

  # Set output shape if known.
  shape = logits.get_shape()
  if shape is not None and shape.dims is not None:
    shape = shape.as_list()
    product = 1
    product_valid = True
    for d in shape[:-1]:
      if d is None:
        product_valid = False
        product *= d
    if product_valid:
      output_shape = [product, shape[-1]]

  return output 
Example #30
Source File:    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _MatrixSetDiagGrad(op, grad):
  input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
  diag_shape = op.inputs[1].get_shape()
  batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
  matrix_shape = input_shape[-2:]
  if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
    diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
    with ops.colocate_with(grad):
      grad_shape = array_ops.shape(grad)
      grad_rank = array_ops.rank(grad)
      batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
      matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
      min_dim = math_ops.reduce_min(matrix_shape)
      diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
  grad_input = array_ops.matrix_set_diag(
      grad, array_ops.zeros(
          diag_shape, dtype=grad.dtype))
  grad_diag = array_ops.matrix_diag_part(grad)
  return (grad_input, grad_diag)