Python tensorflow.python.ops.array_ops.slice() Examples
The following are 30
code examples of tensorflow.python.ops.array_ops.slice().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.ops.array_ops
, or try the search function
.
Example #1
Source File: tfexample_decoder.py From lambda-packs with MIT License | 6 votes |
def tensors_to_item(self, keys_to_tensors): indices = keys_to_tensors[self._indices_key] values = keys_to_tensors[self._values_key] if self._shape_key: shape = keys_to_tensors[self._shape_key] if isinstance(shape, sparse_tensor.SparseTensor): shape = sparse_ops.sparse_tensor_to_dense(shape) elif self._shape: shape = self._shape else: shape = indices.dense_shape indices_shape = array_ops.shape(indices.indices) rank = indices_shape[1] ids = math_ops.to_int64(indices.values) indices_columns_to_preserve = array_ops.slice( indices.indices, [0, 0], array_ops.stack([-1, rank - 1])) new_indices = array_ops.concat( [indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])], 1) tensor = sparse_tensor.SparseTensor(new_indices, values.values, shape) if self._densify: tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value) return tensor
Example #2
Source File: tensor_forest.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _gini(self, class_counts): """Calculate the Gini impurity. If c(i) denotes the i-th class count and c = sum_i c(i) then score = 1 - sum_i ( c(i) / c )^2 Args: class_counts: A 2-D tensor of per-class counts, usually a slice or gather from variables.node_sums. Returns: A 1-D tensor of the Gini impurities for each row in the input. """ smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1]) sums = math_ops.reduce_sum(smoothed, 1) sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1) return 1.0 - sum_squares / (sums * sums)
Example #3
Source File: rnn_cell.py From Multiview2Novelview with MIT License | 6 votes |
def _attention(self, query, attn_states): conv2d = nn_ops.conv2d reduce_sum = math_ops.reduce_sum softmax = nn_ops.softmax tanh = math_ops.tanh with vs.variable_scope("attention"): k = vs.get_variable( "attn_w", [1, 1, self._attn_size, self._attn_vec_size]) v = vs.get_variable("attn_v", [self._attn_vec_size]) hidden = array_ops.reshape(attn_states, [-1, self._attn_length, 1, self._attn_size]) hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") if self._linear3 is None: self._linear3 = _Linear(query, self._attn_vec_size, True) y = self._linear3(query) y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size]) s = reduce_sum(v * tanh(hidden_features + y), [2, 3]) a = softmax(s) d = reduce_sum( array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2]) new_attns = array_ops.reshape(d, [-1, self._attn_size]) new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1]) return new_attns, new_attn_states
Example #4
Source File: tensor_forest.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _variance(self, sums, squares): """Calculate the variance for each row of the input tensors. Variance is V = E[x^2] - (E[x])^2. Args: sums: A tensor containing output sums, usually a slice from variables.node_sums. Should contain the number of examples seen in index 0 so we can calculate expected value. squares: Same as sums, but sums of squares. Returns: A 1-D tensor of the variances for each row in the input. """ total_count = array_ops.slice(sums, [0, 0], [-1, 1]) e_x = sums / total_count e_x2 = squares / total_count return math_ops.reduce_sum(e_x2 - math_ops.square(e_x), 1)
Example #5
Source File: tensor_forest.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _weighted_gini(self, class_counts): """Our split score is the Gini impurity times the number of examples. If c(i) denotes the i-th class count and c = sum_i c(i) then score = c * (1 - sum_i ( c(i) / c )^2 ) = c - sum_i c(i)^2 / c Args: class_counts: A 2-D tensor of per-class counts, usually a slice or gather from variables.node_sums. Returns: A 1-D tensor of the Gini impurities for each row in the input. """ smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1]) sums = math_ops.reduce_sum(smoothed, 1) sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1) return sums - sum_squares / sums
Example #6
Source File: rnn_cell_impl.py From lambda-packs with MIT License | 6 votes |
def call(self, inputs, state): """Run this multi-layer cell on inputs, starting from state.""" cur_state_pos = 0 cur_inp = inputs new_states = [] for i, cell in enumerate(self._cells): with vs.variable_scope("cell_%d" % i): if self._state_is_tuple: if not nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) cur_state = state[i] else: cur_state = array_ops.slice(state, [0, cur_state_pos], [-1, cell.state_size]) cur_state_pos += cell.state_size cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) new_states = (tuple(new_states) if self._state_is_tuple else array_ops.concat(new_states, 1)) return cur_inp, new_states
Example #7
Source File: tensor_forest.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def average_impurity(self): """Constructs a TF graph for evaluating the average leaf impurity of a tree. If in regression mode, this is the leaf variance. If in classification mode, this is the gini impurity. Returns: The last op in the graph. """ children = array_ops.squeeze(array_ops.slice( self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) is_leaf = math_ops.equal(constants.LEAF_NODE, children) leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf), squeeze_dims=[1])) counts = array_ops.gather(self.variables.node_sums, leaves) gini = self._weighted_gini(counts) # Guard against step 1, when there often are no leaves yet. def impurity(): return gini # Since average impurity can be used for loss, when there's no data just # return a big number so that loss always decreases. def big(): return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000. return control_flow_ops.cond(math_ops.greater( array_ops.shape(leaves)[0], 0), impurity, big)
Example #8
Source File: array_grad.py From lambda-packs with MIT License | 6 votes |
def _PadGrad(op, grad): """Gradient for Pad.""" # Pad introduces values around the original tensor, so the gradient function # slices the original shape out of the gradient.""" x = op.inputs[0] a = op.inputs[1] # [Rank(x), 2] # Takes a slice of a. The 1st column. [Rank(x), 1]. pad_before = array_ops.slice(a, [0, 0], array_ops.stack([array_ops.rank(x), 1])) # Make it a 1-D tensor. begin = array_ops.reshape(pad_before, [-1]) sizes = array_ops.shape(x) return array_ops.slice(grad, begin, sizes), None # ReverseSequence is just a permutation. The gradient permutes back.
Example #9
Source File: array_grad.py From lambda-packs with MIT License | 6 votes |
def _MatrixSetDiagGrad(op, grad): """Gradient for MatrixSetDiag.""" input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape()) diag_shape = op.inputs[1].get_shape() batch_shape = input_shape[:-2].merge_with(diag_shape[:-1]) matrix_shape = input_shape[-2:] if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined(): diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())] else: with ops.colocate_with(grad): grad_shape = array_ops.shape(grad) grad_rank = array_ops.rank(grad) batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2]) matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2]) min_dim = math_ops.reduce_min(matrix_shape) diag_shape = array_ops.concat([batch_shape, [min_dim]], 0) grad_input = array_ops.matrix_set_diag( grad, array_ops.zeros( diag_shape, dtype=grad.dtype)) grad_diag = array_ops.matrix_diag_part(grad) return (grad_input, grad_diag)
Example #10
Source File: head.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _split_logits(self, logits): """Splits logits for heads. Args: logits: the logits tensor. Returns: A list of logits for the individual heads. """ all_logits = [] begin = 0 for head in self._heads: current_logits_size = head.logits_dimension current_logits = array_ops.slice(logits, [0, begin], [-1, current_logits_size]) all_logits.append(current_logits) begin += current_logits_size return all_logits
Example #11
Source File: nn_ops.py From lambda-packs with MIT License | 6 votes |
def _flatten_outer_dims(logits): """Flattens logits' outer dimensions and keep its last dimension.""" rank = array_ops.rank(logits) last_dim_size = array_ops.slice( array_ops.shape(logits), [math_ops.subtract(rank, 1)], [1]) output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0)) # Set output shape if known. shape = logits.get_shape() if shape is not None and shape.dims is not None: shape = shape.as_list() product = 1 product_valid = True for d in shape[:-1]: if d is None: product_valid = False break else: product *= d if product_valid: output_shape = [product, shape[-1]] output.set_shape(output_shape) return output
Example #12
Source File: rnn_cell.py From Multiview2Novelview with MIT License | 6 votes |
def _make_tf_features(self, input_feat): """Make the frequency features. Args: input_feat: input Tensor, 2D, batch x num_units. Returns: A list of frequency features, with each element containing: - A 2D, batch x output_dim, Tensor representing the time-frequency feature for that frequency index. Here output_dim is feature_size. Raises: ValueError: if input_size cannot be inferred from static shape inference. """ input_size = input_feat.get_shape().with_rank(2)[-1].value if input_size is None: raise ValueError("Cannot infer input_size from static shape inference.") num_feats = int((input_size - self._feature_size) / ( self._frequency_skip)) + 1 freq_inputs = [] for f in range(num_feats): cur_input = array_ops.slice(input_feat, [0, f*self._frequency_skip], [-1, self._feature_size]) freq_inputs.append(cur_input) return freq_inputs
Example #13
Source File: rnn_cell.py From lambda-packs with MIT License | 6 votes |
def _make_tf_features(self, input_feat): """Make the frequency features. Args: input_feat: input Tensor, 2D, batch x num_units. Returns: A list of frequency features, with each element containing: - A 2D, batch x output_dim, Tensor representing the time-frequency feature for that frequency index. Here output_dim is feature_size. Raises: ValueError: if input_size cannot be inferred from static shape inference. """ input_size = input_feat.get_shape().with_rank(2)[-1].value if input_size is None: raise ValueError("Cannot infer input_size from static shape inference.") num_feats = int((input_size - self._feature_size) / ( self._frequency_skip)) + 1 freq_inputs = [] for f in range(num_feats): cur_input = array_ops.slice(input_feat, [0, f*self._frequency_skip], [-1, self._feature_size]) freq_inputs.append(cur_input) return freq_inputs
Example #14
Source File: rnn_cell.py From lambda-packs with MIT License | 6 votes |
def _attention(self, query, attn_states): conv2d = nn_ops.conv2d reduce_sum = math_ops.reduce_sum softmax = nn_ops.softmax tanh = math_ops.tanh with vs.variable_scope("attention"): k = vs.get_variable( "attn_w", [1, 1, self._attn_size, self._attn_vec_size]) v = vs.get_variable("attn_v", [self._attn_vec_size]) hidden = array_ops.reshape(attn_states, [-1, self._attn_length, 1, self._attn_size]) hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") y = _linear(query, self._attn_vec_size, True) y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size]) s = reduce_sum(v * tanh(hidden_features + y), [2, 3]) a = softmax(s) d = reduce_sum( array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2]) new_attns = array_ops.reshape(d, [-1, self._attn_size]) new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1]) return new_attns, new_attn_states
Example #15
Source File: rnn_cell.py From lambda-packs with MIT License | 6 votes |
def _get_input_for_group(self, inputs, group_id, group_size): """Slices inputs into groups to prepare for processing by cell's groups Args: inputs: cell input or it's previous state, a Tensor, 2D, [batch x num_units] group_id: group id, a Scalar, for which to prepare input group_size: size of the group Returns: subset of inputs corresponding to group "group_id", a Tensor, 2D, [batch x num_units/number_of_groups] """ return array_ops.slice(input_=inputs, begin=[0, group_id * group_size], size=[self._batch_size, group_size], name=("GLSTM_group%d_input_generation" % group_id))
Example #16
Source File: decisions_to_data.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def soft_inference_graph(self, data): with ops.device(self.device_assigner.get_device(self.layer_num)): path_probability, path = ( self.training_ops.stochastic_hard_routing_function( data, self.tree_parameters, self.tree_thresholds, tree_depth=self.params.hybrid_tree_depth, random_seed=self.params.base_random_seed)) output = array_ops.slice( self.training_ops.unpack_path(path, path_probability), [0, self.params.num_nodes - self.params.num_leaves - 1], [-1, self.params.num_leaves]) return output
Example #17
Source File: decisions_to_data.py From lambda-packs with MIT License | 6 votes |
def inference_graph(self, data): with ops.device(self.device_assigner): routing_probabilities = gen_training_ops.k_feature_routing_function( data, self.tree_parameters, self.tree_thresholds, max_nodes=self.params.num_nodes, num_features_per_node=self.params.num_features_per_node, layer_num=0, random_seed=self.params.base_random_seed) output = array_ops.slice( routing_probabilities, [0, self.params.num_nodes - self.params.num_leaves - 1], [-1, self.params.num_leaves]) return output
Example #18
Source File: decisions_to_data.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def inference_graph(self, data): with ops.device(self.device_assigner.get_device(self.layer_num)): routing_probabilities = self.training_ops.k_feature_routing_function( data, self.tree_parameters, self.tree_thresholds, max_nodes=self.params.num_nodes, num_features_per_node=self.params.num_features_per_node, layer_num=0, random_seed=self.params.base_random_seed) output = array_ops.slice( routing_probabilities, [0, self.params.num_nodes - self.params.num_leaves - 1], [-1, self.params.num_leaves]) return output
Example #19
Source File: decisions_to_data.py From lambda-packs with MIT License | 6 votes |
def soft_inference_graph(self, data): with ops.device(self.device_assigner): path_probability, path = ( gen_training_ops.stochastic_hard_routing_function( data, self.tree_parameters, self.tree_thresholds, tree_depth=self.params.hybrid_tree_depth, random_seed=self.params.base_random_seed)) output = array_ops.slice( gen_training_ops.unpack_path(path, path_probability), [0, self.params.num_nodes - self.params.num_leaves - 1], [-1, self.params.num_leaves]) return output
Example #20
Source File: tensor_forest.py From lambda-packs with MIT License | 6 votes |
def _gini(self, class_counts): """Calculate the Gini impurity. If c(i) denotes the i-th class count and c = sum_i c(i) then score = 1 - sum_i ( c(i) / c )^2 Args: class_counts: A 2-D tensor of per-class counts, usually a slice or gather from variables.node_sums. Returns: A 1-D tensor of the Gini impurities for each row in the input. """ smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1]) sums = math_ops.reduce_sum(smoothed, 1) sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1) return 1.0 - sum_squares / (sums * sums)
Example #21
Source File: tensor_forest.py From lambda-packs with MIT License | 6 votes |
def _weighted_gini(self, class_counts): """Our split score is the Gini impurity times the number of examples. If c(i) denotes the i-th class count and c = sum_i c(i) then score = c * (1 - sum_i ( c(i) / c )^2 ) = c - sum_i c(i)^2 / c Args: class_counts: A 2-D tensor of per-class counts, usually a slice or gather from variables.node_sums. Returns: A 1-D tensor of the Gini impurities for each row in the input. """ smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1]) sums = math_ops.reduce_sum(smoothed, 1) sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1) return sums - sum_squares / sums
Example #22
Source File: tensor_forest.py From lambda-packs with MIT License | 6 votes |
def _variance(self, sums, squares): """Calculate the variance for each row of the input tensors. Variance is V = E[x^2] - (E[x])^2. Args: sums: A tensor containing output sums, usually a slice from variables.node_sums. Should contain the number of examples seen in index 0 so we can calculate expected value. squares: Same as sums, but sums of squares. Returns: A 1-D tensor of the variances for each row in the input. """ total_count = array_ops.slice(sums, [0, 0], [-1, 1]) e_x = sums / total_count e_x2 = squares / total_count return math_ops.reduce_sum(e_x2 - math_ops.square(e_x), 1)
Example #23
Source File: tensor_forest.py From lambda-packs with MIT License | 6 votes |
def average_impurity(self): """Constructs a TF graph for evaluating the average leaf impurity of a tree. If in regression mode, this is the leaf variance. If in classification mode, this is the gini impurity. Returns: The last op in the graph. """ children = array_ops.squeeze(array_ops.slice( self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) is_leaf = math_ops.equal(constants.LEAF_NODE, children) leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf), squeeze_dims=[1])) counts = array_ops.gather(self.variables.node_sums, leaves) gini = self._weighted_gini(counts) # Guard against step 1, when there often are no leaves yet. def impurity(): return gini # Since average impurity can be used for loss, when there's no data just # return a big number so that loss always decreases. def big(): return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000. return control_flow_ops.cond(math_ops.greater( array_ops.shape(leaves)[0], 0), impurity, big)
Example #24
Source File: tfexample_decoder.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def tensors_to_item(self, keys_to_tensors): indices = keys_to_tensors[self._indices_key] values = keys_to_tensors[self._values_key] if self._shape_key: shape = keys_to_tensors[self._shape_key] if isinstance(shape, sparse_tensor.SparseTensor): shape = sparse_ops.sparse_tensor_to_dense(shape) elif self._shape: shape = self._shape else: shape = indices.dense_shape indices_shape = array_ops.shape(indices.indices) rank = indices_shape[1] ids = math_ops.to_int64(indices.values) indices_columns_to_preserve = array_ops.slice( indices.indices, [0, 0], array_ops.stack([-1, rank - 1])) new_indices = array_ops.concat( [indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])], 1) tensor = sparse_tensor.SparseTensor(new_indices, values.values, shape) if self._densify: tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value) return tensor
Example #25
Source File: core_rnn_cell_impl.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def __call__(self, inputs, state, scope=None): """Run this multi-layer cell on inputs, starting from state.""" with vs.variable_scope(scope or "multi_rnn_cell"): cur_state_pos = 0 cur_inp = inputs new_states = [] for i, cell in enumerate(self._cells): with vs.variable_scope("cell_%d" % i): if self._state_is_tuple: if not nest.is_sequence(state): raise ValueError( "Expected state to be a tuple of length %d, but received: %s" % (len(self.state_size), state)) cur_state = state[i] else: cur_state = array_ops.slice( state, [0, cur_state_pos], [-1, cell.state_size]) cur_state_pos += cell.state_size cur_inp, new_state = cell(cur_inp, cur_state) new_states.append(new_state) new_states = (tuple(new_states) if self._state_is_tuple else array_ops.concat(new_states, 1)) return cur_inp, new_states
Example #26
Source File: rnn_cell.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _attention(self, query, attn_states): conv2d = nn_ops.conv2d reduce_sum = math_ops.reduce_sum softmax = nn_ops.softmax tanh = math_ops.tanh with vs.variable_scope("attention"): k = vs.get_variable( "attn_w", [1, 1, self._attn_size, self._attn_vec_size]) v = vs.get_variable("attn_v", [self._attn_vec_size]) hidden = array_ops.reshape(attn_states, [-1, self._attn_length, 1, self._attn_size]) hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") y = _linear(query, self._attn_vec_size, True) y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size]) s = reduce_sum(v * tanh(hidden_features + y), [2, 3]) a = softmax(s) d = reduce_sum( array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2]) new_attns = array_ops.reshape(d, [-1, self._attn_size]) new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1]) return new_attns, new_attn_states
Example #27
Source File: head.py From lambda-packs with MIT License | 6 votes |
def _split_logits(self, logits): """Splits logits for heads. Args: logits: the logits tensor. Returns: A list of logits for the individual heads. """ all_logits = [] begin = 0 for head in self._heads: current_logits_size = head.logits_dimension current_logits = array_ops.slice(logits, [0, begin], [-1, current_logits_size]) all_logits.append(current_logits) begin += current_logits_size return all_logits
Example #28
Source File: array_grad.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _PadGrad(op, grad): """Gradient for Pad.""" # Pad introduces values around the original tensor, so the gradient function # slices the original shape out of the gradient.""" x = op.inputs[0] a = op.inputs[1] # [Rank(x), 2] # Takes a slice of a. The 1st column. [Rank(x), 1]. pad_before = array_ops.slice(a, [0, 0], array_ops.stack([array_ops.rank(x), 1])) # Make it a 1-D tensor. begin = array_ops.reshape(pad_before, [-1]) sizes = array_ops.shape(x) return array_ops.slice(grad, begin, sizes), None # ReverseSequence is just a permutation. The gradient permutes back.
Example #29
Source File: nn_ops.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _flatten_outer_dims(logits): """Flattens logits' outer dimensions and keep its last dimension.""" rank = array_ops.rank(logits) last_dim_size = array_ops.slice( array_ops.shape(logits), [math_ops.subtract(rank, 1)], [1]) output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0)) # Set output shape if known. shape = logits.get_shape() if shape is not None and shape.dims is not None: shape = shape.as_list() product = 1 product_valid = True for d in shape[:-1]: if d is None: product_valid = False break else: product *= d if product_valid: output_shape = [product, shape[-1]] output.set_shape(output_shape) return output
Example #30
Source File: array_grad.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _MatrixSetDiagGrad(op, grad): input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape()) diag_shape = op.inputs[1].get_shape() batch_shape = input_shape[:-2].merge_with(diag_shape[:-1]) matrix_shape = input_shape[-2:] if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined(): diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())] else: with ops.colocate_with(grad): grad_shape = array_ops.shape(grad) grad_rank = array_ops.rank(grad) batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2]) matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2]) min_dim = math_ops.reduce_min(matrix_shape) diag_shape = array_ops.concat([batch_shape, [min_dim]], 0) grad_input = array_ops.matrix_set_diag( grad, array_ops.zeros( diag_shape, dtype=grad.dtype)) grad_diag = array_ops.matrix_diag_part(grad) return (grad_input, grad_diag)