Python tensorflow.tensor_scatter_nd_update() Examples
The following are 16
code examples of tensorflow.tensor_scatter_nd_update().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: modes.py From spektral with MIT License | 6 votes |
def disjoint_signal_to_batch(X, I): """ Converts a disjoint graph signal to batch node by zero-padding. :param X: Tensor, node features of shape (nodes, features). :param I: Tensor, graph IDs of shape `(N, )`; :return batch: Tensor, batched node features of shape (batch, N_max, F) """ I = tf.cast(I, tf.int32) num_nodes = tf.math.segment_sum(tf.ones_like(I), I) start_index = tf.cumsum(num_nodes, exclusive=True) n_graphs = tf.shape(num_nodes)[0] max_n_nodes = tf.reduce_max(num_nodes) batch_n_nodes = tf.shape(I)[0] feature_dim = tf.shape(X)[-1] index = tf.range(batch_n_nodes) index = (index - tf.gather(start_index, I)) + (I * max_n_nodes) dense = tf.zeros((n_graphs * max_n_nodes, feature_dim), dtype=X.dtype) dense = tf.tensor_scatter_nd_update(dense, index[..., None], X) batch = tf.reshape(dense, (n_graphs, max_n_nodes, feature_dim)) return batch
Example #2
Source File: delayedmodels.py From spiking-net-tensorflow with GNU General Public License v3.0 | 6 votes |
def update_active_spikes(self, spikes): """ Given some spikes, add them to active spikes with the appropraite delays Parameters: spikes (array like): The spikes that have just occured Returns: None """ delays_some_hot = spikes * self.delays # (100, 2) idxs = tf.where(tf.not_equal(delays_some_hot, 0)) # Will give indices of delays (num_spikes * num_neurns, 2) elements are indices into delays_some_hot that are not 0 just_delays = tf.gather_nd(delays_some_hot, idxs) # These become the idx's in delay dimension? (after correction) (num_spikes * num_neurns, 2) elements are delays (floats) # adjust for variable step size and circular array delay_dim_idxs = tf.reshape(self.spike_arrival_step(just_delays), [-1, 1]) # Okay now is the arrival index (num_spikes * num_neurns, 1) elements are the correction step at which this spike will arrive full_idxs = tf.concat([idxs, delay_dim_idxs], axis=1) # add delay indices as a column since they are an index and not more examples self.active_spikes = tf.tensor_scatter_nd_update(self.active_spikes, full_idxs, tf.ones(full_idxs.shape[0]))
Example #3
Source File: embeddings.py From mead-baseline with Apache License 2.0 | 6 votes |
def encode(self, x): """Build a simple Lookup Table and set as input `x` if it exists, or `self.x` otherwise. :param x: An optional input sub-graph to bind to this operation or use `self.x` if `None` :return: The sub-graph output """ self.x = x e0 = tf.tensor_scatter_nd_update( self.W, tf.constant(Offsets.PAD, dtype=tf.int32, shape=[1, 1]), tf.zeros(shape=[1, self.dsz]) ) with tf.control_dependencies([e0]): # The ablation table (4) in https://arxiv.org/pdf/1708.02182.pdf shows this has a massive impact embedding_w_dropout = self.drop(self.W, training=TRAIN_FLAG()) word_embeddings = tf.nn.embedding_lookup(embedding_w_dropout, self.x) return word_embeddings
Example #4
Source File: dataset.py From tf2-yolo3 with Apache License 2.0 | 5 votes |
def transform_targets_for_output(y_true, grid_y, grid_x, anchor_idxs, classes): # y_true: (N, boxes, (x1, y1, x2, y2, class, best_anchor)) N = tf.shape(y_true)[0] # y_true_out: (N, grid, grid, anchors, [x, y, w, h, obj, class]) y_true_out = tf.zeros((N, grid_y, grid_x, tf.shape(anchor_idxs)[0], 6)) anchor_idxs = tf.cast(anchor_idxs, tf.int32) indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True) updates = tf.TensorArray(tf.float32, 1, dynamic_size=True) idx = 0 for i in tf.range(N): for j in tf.range(tf.shape(y_true)[1]): if tf.equal(y_true[i][j][2], 0): continue anchor_eq = tf.equal(anchor_idxs, tf.cast(y_true[i][j][5], tf.int32)) if tf.reduce_any(anchor_eq): box = y_true[i][j][0:4] box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2. anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32) grid_size = tf.cast(tf.stack([grid_x, grid_y], axis=-1), tf.float32) grid_xy = tf.cast(box_xy * grid_size, tf.int32) # grid[y][x][anchor] = (tx, ty, bw, bh, obj, class) indexes = indexes.write(idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]]) updates = updates.write(idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]]) idx += 1 y_ture_out = tf.tensor_scatter_nd_update(y_true_out, indexes.stack(), updates.stack()) return y_ture_out
Example #5
Source File: relgraphconv.py From dgl with Apache License 2.0 | 5 votes |
def basis_message_func(self, edges): """Message function for basis regularizer""" if self.num_bases < self.num_rels: # generate all weights from bases weight = tf.reshape(self.weight, (self.num_bases, self.in_feat * self.out_feat)) weight = tf.reshape(tf.matmul(self.w_comp, weight), ( self.num_rels, self.in_feat, self.out_feat)) else: weight = self.weight # calculate msg @ W_r before put msg into edge # if src is th.int64 we expect it is an index select if edges.src['h'].dtype != tf.int64 and self.low_mem: etypes, _ = tf.unique(edges.data['type']) msg = tf.zeros([edges.src['h'].shape[0], self.out_feat]) idx = tf.range(edges.src['h'].shape[0]) for etype in etypes: loc = (edges.data['type'] == etype) w = weight[etype] src = tf.boolean_mask(edges.src['h'], loc) sub_msg = tf.matmul(src, w) indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1)) msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg) else: msg = utils.bmm_maybe_select( edges.src['h'], weight, edges.data['type']) if 'norm' in edges.data: msg = msg * edges.data['norm'] return {'msg': msg}
Example #6
Source File: relgraphconv.py From dgl with Apache License 2.0 | 5 votes |
def bdd_message_func(self, edges): """Message function for block-diagonal-decomposition regularizer""" if ((edges.src['h'].dtype == tf.int64) and len(edges.src['h'].shape) == 1): raise TypeError( 'Block decomposition does not allow integer ID feature.') # calculate msg @ W_r before put msg into edge # if src is th.int64 we expect it is an index select if self.low_mem: etypes, _ = tf.unique(edges.data['type']) msg = tf.zeros([edges.src['h'].shape[0], self.out_feat]) idx = tf.range(edges.src['h'].shape[0]) for etype in etypes: loc = (edges.data['type'] == etype) w = tf.reshape(self.weight[etype], (self.num_bases, self.submat_in, self.submat_out)) src = tf.reshape(tf.boolean_mask(edges.src['h'], loc), (-1, self.num_bases, self.submat_in)) sub_msg = tf.einsum('abc,bcd->abd', src, w) sub_msg = tf.reshape(sub_msg, (-1, self.out_feat)) indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1)) msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg) else: weight = tf.reshape(tf.gather( self.weight, edges.data['type']), (-1, self.submat_in, self.submat_out)) node = tf.reshape(edges.src['h'], (-1, 1, self.submat_in)) msg = tf.reshape(tf.matmul(node, weight), (-1, self.out_feat)) if 'norm' in edges.data: msg = msg * edges.data['norm'] return {'msg': msg}
Example #7
Source File: tensor.py From dgl with Apache License 2.0 | 5 votes |
def scatter_row(data, row_index, value): row_index = tf.expand_dims(row_index, 1) return tf.tensor_scatter_nd_update(data, row_index, value)
Example #8
Source File: delayedmodels.py From spiking-net-tensorflow with GNU General Public License v3.0 | 5 votes |
def clear_current_active_spikes(self): """ Remove any spikes that arrived at the current time step Parameters: None Returns: None """ # Fill in any 1's with zeros spike_idxs = tf.where(tf.not_equal(self.active_spikes[:, :, self.get_active_spike_idx()], 0) ) full_idxs = tf.concat([spike_idxs, tf.ones((spike_idxs.shape[0], 1), dtype=tf.int64) * self.get_active_spike_idx()], axis=1) self.active_spikes = tf.tensor_scatter_nd_update(self.active_spikes, full_idxs, tf.zeros(full_idxs.shape[0]))
Example #9
Source File: dataset.py From DirectML with MIT License | 5 votes |
def transform_targets_for_output(y_true, grid_size, anchor_idxs): # y_true: (N, boxes, (x1, y1, x2, y2, class, best_anchor)) N = tf.shape(y_true)[0] # y_true_out: (N, grid, grid, anchors, [x, y, w, h, obj, class]) y_true_out = tf.zeros( (N, grid_size, grid_size, tf.shape(anchor_idxs)[0], 6)) anchor_idxs = tf.cast(anchor_idxs, tf.int32) indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True) updates = tf.TensorArray(tf.float32, 1, dynamic_size=True) idx = 0 for i in tf.range(N): for j in tf.range(tf.shape(y_true)[1]): if tf.equal(y_true[i][j][2], 0): continue anchor_eq = tf.equal( anchor_idxs, tf.cast(y_true[i][j][5], tf.int32)) if tf.reduce_any(anchor_eq): box = y_true[i][j][0:4] box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2 anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32) grid_xy = tf.cast(box_xy // (1/grid_size), tf.int32) # grid[y][x][anchor] = (tx, ty, bw, bh, obj, class) indexes = indexes.write( idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]]) updates = updates.write( idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]]) idx += 1 # tf.print(indexes.stack()) # tf.print(updates.stack()) return tf.tensor_scatter_nd_update( y_true_out, indexes.stack(), updates.stack())
Example #10
Source File: layers.py From deepchem with MIT License | 4 votes |
def call(self, inputs, training=True): """ parent layers: atom_features, parents, calculation_orders, calculation_masks, n_atoms """ atom_features = inputs[0] # each atom corresponds to a graph, which is represented by the `max_atoms*max_atoms` int32 matrix of index # each gragh include `max_atoms` of steps(corresponding to rows) of calculating graph features parents = tf.cast(inputs[1], dtype=tf.int32) # target atoms for each step: (batch_size*max_atoms) * max_atoms calculation_orders = inputs[2] calculation_masks = inputs[3] n_atoms = tf.squeeze(inputs[4]) graph_features = tf.zeros((self.max_atoms * self.batch_size, self.max_atoms + 1, self.n_graph_feat)) for count in range(self.max_atoms): # `count`-th step # extracting atom features of target atoms: (batch_size*max_atoms) * n_atom_features mask = calculation_masks[:, count] current_round = tf.boolean_mask(calculation_orders[:, count], mask) batch_atom_features = tf.gather(atom_features, current_round) # generating index for graph features used in the inputs stack1 = tf.reshape( tf.stack( [tf.boolean_mask(tf.range(n_atoms), mask)] * (self.max_atoms - 1), axis=1), [-1]) stack2 = tf.reshape(tf.boolean_mask(parents[:, count, 1:], mask), [-1]) index = tf.stack([stack1, stack2], axis=1) # extracting graph features for parents of the target atoms, then flatten # shape: (batch_size*max_atoms) * [(max_atoms-1)*n_graph_features] batch_graph_features = tf.reshape( tf.gather_nd(graph_features, index), [-1, (self.max_atoms - 1) * self.n_graph_feat]) # concat into the input tensor: (batch_size*max_atoms) * n_inputs batch_inputs = tf.concat( axis=1, values=[batch_atom_features, batch_graph_features]) # DAGgraph_step maps from batch_inputs to a batch of graph_features # of shape: (batch_size*max_atoms) * n_graph_features # representing the graph features of target atoms in each graph batch_outputs = _DAGgraph_step(batch_inputs, self.W_list, self.b_list, self.activation_fn, self.dropouts, training) # index for targe atoms target_index = tf.stack([tf.range(n_atoms), parents[:, count, 0]], axis=1) target_index = tf.boolean_mask(target_index, mask) graph_features = tf.tensor_scatter_nd_update(graph_features, target_index, batch_outputs) return batch_outputs
Example #11
Source File: ops.py From spektral with MIT License | 4 votes |
def segment_top_k(x, I, ratio, top_k_var): """ Returns indices to get the top K values in x segment-wise, according to the segments defined in I. K is not fixed, but it is defined as a ratio of the number of elements in each segment. :param x: a rank 1 Tensor; :param I: a rank 1 Tensor with segment IDs for x; :param ratio: float, ratio of elements to keep for each segment; :param top_k_var: a tf.Variable created without shape validation (i.e., `tf.Variable(0.0, validate_shape=False)`); :return: a rank 1 Tensor containing the indices to get the top K values of each segment in x. """ I = tf.cast(I, tf.int32) num_nodes = tf.math.segment_sum(tf.ones_like(I), I) # Number of nodes in each graph cumsum = tf.cumsum(num_nodes) # Cumulative number of nodes (A, A+B, A+B+C) cumsum_start = cumsum - num_nodes # Start index of each graph n_graphs = tf.shape(num_nodes)[0] # Number of graphs in batch max_n_nodes = tf.reduce_max(num_nodes) # Order of biggest graph in batch batch_n_nodes = tf.shape(I)[0] # Number of overall nodes in batch to_keep = tf.math.ceil(ratio * tf.cast(num_nodes, tf.float32)) to_keep = tf.cast(to_keep, I.dtype) # Nodes to keep in each graph index = tf.range(batch_n_nodes) index = (index - tf.gather(cumsum_start, I)) + (I * max_n_nodes) y_min = tf.reduce_min(x) dense_y = tf.ones((n_graphs * max_n_nodes,)) # subtract 1 to ensure that filler values do not get picked dense_y = dense_y * tf.cast(y_min - 1, dense_y.dtype) dense_y = tf.cast(dense_y, top_k_var.dtype) # top_k_var is a variable with unknown shape defined in the elsewhere top_k_var.assign(dense_y) dense_y = tf.tensor_scatter_nd_update(top_k_var, index[..., None], tf.cast(x, top_k_var.dtype)) dense_y = tf.reshape(dense_y, (n_graphs, max_n_nodes)) perm = tf.argsort(dense_y, direction='DESCENDING') perm = perm + cumsum_start[:, None] perm = tf.reshape(perm, (-1,)) to_rep = tf.tile(tf.constant([1., 0.]), (n_graphs,)) rep_times = tf.reshape(tf.concat((to_keep[:, None], (max_n_nodes - to_keep)[:, None]), -1), (-1,)) mask = repeat(to_rep, rep_times) perm = tf.boolean_mask(perm, mask) return perm
Example #12
Source File: scatter_elements.py From onnx-tensorflow with Apache License 2.0 | 4 votes |
def version_11(cls, node, **kwargs): axis = node.attrs.get("axis", 0) data = kwargs["tensor_dict"][node.inputs[0]] indices = kwargs["tensor_dict"][node.inputs[1]] updates = kwargs["tensor_dict"][node.inputs[2]] # poocess negative axis axis = axis if axis >= 0 else tf.add(tf.rank(data), axis) # check are there any indices are out of bounds result = cls.chk_idx_out_of_bounds_along_axis(data, axis, indices) msg = 'ScatterElements indices are out of bounds, please double check the indices and retry.' with tf.control_dependencies( [tf.compat.v1.assert_equal(result, True, message=msg)]): # process negative indices indices = cls.process_neg_idx_along_axis(data, axis, indices) # Calculate shape of the tensorflow version of indices tensor. sparsified_dense_idx_shape = tf_shape(updates) # Move on to convert ONNX indices to tensorflow indices in 2 steps: # # Step 1: # What would the index tensors look like if updates are all # dense? In other words, produce a coordinate tensor for updates: # # coordinate[i, j, k ...] = [i, j, k ...] # where the shape of "coordinate" tensor is same as that of updates. # # Step 2: # But the coordinate tensor needs some correction because coord # vector at position axis is wrong (since we assumed update is dense, # but it is not at the axis specified). # So we update coordinate vector tensor elements at psotion=axis with # the sparse coordinate indices. idx_tensors_per_axis = tf.meshgrid( *list( map(lambda x: tf.range(x, dtype=tf.dtypes.int64), sparsified_dense_idx_shape)), indexing='ij') idx_tensors_per_axis[axis] = indices dim_expanded_idx_tensors_per_axis = list( map(lambda x: tf.expand_dims(x, axis=-1), idx_tensors_per_axis)) coordinate = tf.concat(dim_expanded_idx_tensors_per_axis, axis=-1) # Now the coordinate tensor is in the shape # [updates.shape, updates.rank] # we need it to flattened into the shape: # [product(updates.shape), updates.rank] indices = tf.reshape(coordinate, [-1, tf.rank(data)]) updates = tf.reshape(updates, [-1]) return [tf.tensor_scatter_nd_update(data, indices, updates)]
Example #13
Source File: gather_elements.py From onnx-tensorflow with Apache License 2.0 | 4 votes |
def version_11(cls, node, **kwargs): # GatherElements takes two inputs data and indices of the same rank r >= 1 and an optional attribute axis that identifies # an axis of data (by default, the outer-most axis, that is axis 0). It is an indexing operation that produces its output by # indexing into the input data tensor at index positions determined by elements of the indices tensor. Its output shape is the # same as the shape of indices and consists of one value (gathered from the data) for each element in indices. axis = node.attrs.get("axis", 0) data = kwargs["tensor_dict"][node.inputs[0]] indices = kwargs["tensor_dict"][node.inputs[1]] # poocess negative axis axis = axis if axis >= 0 else tf.add(tf.rank(data), axis) # check are there any indices are out of bounds result = cls.chk_idx_out_of_bounds_along_axis(data, axis, indices) msg = 'GatherElements indices are out of bounds,'\ ' please double check the indices and retry.' with tf.control_dependencies( [tf.compat.v1.assert_equal(result, True, message=msg)]): # process negative indices indices = cls.process_neg_idx_along_axis(data, axis, indices) # adapted from reference implementation in onnx/onnx/backend/test/case/node/gatherelements.py if axis == 0: axis_perm = tf.range(tf.rank(data)) data_swaped = data index_swaped = indices else: axis_perm = tf.tensor_scatter_nd_update(tf.range(tf.rank(data)), tf.constant([[0], [axis]]), tf.constant([axis, 0])) data_swaped = tf.transpose(data, perm=axis_perm) index_swaped = tf.transpose(indices, perm=axis_perm) idx_tensors_per_axis = tf.meshgrid(*list( map(lambda x: tf.range(x, dtype=index_swaped.dtype), index_swaped.shape.as_list())), indexing='ij') idx_tensors_per_axis[0] = index_swaped dim_expanded_idx_tensors_per_axis = list( map(lambda x: tf.expand_dims(x, axis=-1), idx_tensors_per_axis)) index_expanded = tf.concat(dim_expanded_idx_tensors_per_axis, axis=-1) gathered = tf.gather_nd(data_swaped, index_expanded) y = tf.transpose(gathered, perm=axis_perm) return [y]
Example #14
Source File: metric_utils.py From ULTRA with Apache License 2.0 | 4 votes |
def scatter_to_2d(tensor, segments, pad_value, output_shape=None): """Scatters a flattened 1-D `tensor` to 2-D with padding based on `segments`. For example: tensor = [1, 2, 3], segments = [0, 1, 0] and pad_value = -1, then the returned 2-D tensor is [[1, 3], [2, -1]]. The output_shape is inferred when None is provided. In this case, the shape will be dynamic and may not be compatible with TPU. For TPU use case, please provide the `output_shape` explicitly. Args: tensor: A 1-D numeric `Tensor`. segments: A 1-D int `Tensor` which is the idx output from tf.unique like [0, 0, 1, 0, 2]. See tf.unique. The segments may or may not be sorted. pad_value: A numeric value to pad the output `Tensor`. output_shape: A `Tensor` of size 2 telling the desired shape of the output tensor. If None, the output_shape will be inferred and not fixed at compilation time. When output_shape is smaller than needed, trucation will be applied. Returns: A 2-D Tensor. """ with tf.compat.v1.name_scope(name='scatter_to_2d'): tensor = tf.convert_to_tensor(value=tensor) segments = tf.convert_to_tensor(value=segments) tensor.get_shape().assert_has_rank(1) segments.get_shape().assert_has_rank(1) tensor.get_shape().assert_is_compatible_with(segments.get_shape()) # Say segments = [0, 0, 0, 1, 2, 2]. We would like to build the 2nd dim so # that we can use scatter_nd to distribute the value in `tensor` to 2-D. The # needed 2nd dim for this case is [0, 1, 2, 0, 0, 1], which is the # in-segment indices. index_2nd_dim = _in_segment_indices(segments) # Compute the output_shape. if output_shape is None: # Set output_shape to the inferred one. output_shape = [ tf.reduce_max(input_tensor=segments) + 1, tf.reduce_max(input_tensor=index_2nd_dim) + 1 ] else: # The output_shape may be smaller. We collapse the out-of-range ones into # indices [output_shape[0], 0] and then use tf.slice to remove extra row # and column after scatter. valid_segments = tf.less(segments, output_shape[0]) valid_2nd_dim = tf.less(index_2nd_dim, output_shape[1]) mask = tf.logical_and(valid_segments, valid_2nd_dim) segments = tf.compat.v1.where(mask, segments, output_shape[0] * tf.ones_like(segments)) index_2nd_dim = tf.compat.v1.where(mask, index_2nd_dim, tf.zeros_like(index_2nd_dim)) # Create the 2D Tensor. For padding, we add one extra row and column and # then slice them to fit the output_shape. nd_indices = tf.stack([segments, index_2nd_dim], axis=1) padding = pad_value * tf.ones( shape=(output_shape + tf.ones_like(output_shape)), dtype=tensor.dtype) tensor = tf.tensor_scatter_nd_update(padding, nd_indices, tensor) tensor = tf.slice(tensor, begin=[0, 0], size=output_shape) return tensor
Example #15
Source File: utils.py From ranking with Apache License 2.0 | 4 votes |
def scatter_to_2d(tensor, segments, pad_value, output_shape=None): """Scatters a flattened 1-D `tensor` to 2-D with padding based on `segments`. For example: tensor = [1, 2, 3], segments = [0, 1, 0] and pad_value = -1, then the returned 2-D tensor is [[1, 3], [2, -1]]. The output_shape is inferred when None is provided. In this case, the shape will be dynamic and may not be compatible with TPU. For TPU use case, please provide the `output_shape` explicitly. Args: tensor: A 1-D numeric `Tensor`. segments: A 1-D int `Tensor` which is the idx output from tf.unique like [0, 0, 1, 0, 2]. See tf.unique. The segments may or may not be sorted. pad_value: A numeric value to pad the output `Tensor`. output_shape: A `Tensor` of size 2 telling the desired shape of the output tensor. If None, the output_shape will be inferred and not fixed at compilation time. When output_shape is smaller than needed, trucation will be applied. Returns: A 2-D Tensor. """ with tf.compat.v1.name_scope(name='scatter_to_2d'): tensor = tf.convert_to_tensor(value=tensor) segments = tf.convert_to_tensor(value=segments) tensor.get_shape().assert_has_rank(1) segments.get_shape().assert_has_rank(1) tensor.get_shape().assert_is_compatible_with(segments.get_shape()) # Say segments = [0, 0, 0, 1, 2, 2]. We would like to build the 2nd dim so # that we can use scatter_nd to distribute the value in `tensor` to 2-D. The # needed 2nd dim for this case is [0, 1, 2, 0, 0, 1], which is the # in-segment indices. index_2nd_dim = _in_segment_indices(segments) # Compute the output_shape. if output_shape is None: # Set output_shape to the inferred one. output_shape = [ tf.reduce_max(input_tensor=segments) + 1, tf.reduce_max(input_tensor=index_2nd_dim) + 1 ] else: # The output_shape may be smaller. We collapse the out-of-range ones into # indices [output_shape[0], 0] and then use tf.slice to remove extra row # and column after scatter. valid_segments = tf.less(segments, output_shape[0]) valid_2nd_dim = tf.less(index_2nd_dim, output_shape[1]) mask = tf.logical_and(valid_segments, valid_2nd_dim) segments = tf.compat.v1.where(mask, segments, output_shape[0] * tf.ones_like(segments)) index_2nd_dim = tf.compat.v1.where(mask, index_2nd_dim, tf.zeros_like(index_2nd_dim)) # Create the 2D Tensor. For padding, we add one extra row and column and # then slice them to fit the output_shape. nd_indices = tf.stack([segments, index_2nd_dim], axis=1) padding = pad_value * tf.ones( shape=(output_shape + tf.ones_like(output_shape)), dtype=tensor.dtype) tensor = tf.tensor_scatter_nd_update(padding, nd_indices, tensor) tensor = tf.slice(tensor, begin=[0, 0], size=output_shape) return tensor
Example #16
Source File: box_utils.py From ssd-tf2 with MIT License | 4 votes |
def compute_target(default_boxes, gt_boxes, gt_labels, iou_threshold=0.5): """ Compute regression and classification targets Args: default_boxes: tensor (num_default, 4) of format (cx, cy, w, h) gt_boxes: tensor (num_gt, 4) of format (xmin, ymin, xmax, ymax) gt_labels: tensor (num_gt,) Returns: gt_confs: classification targets, tensor (num_default,) gt_locs: regression targets, tensor (num_default, 4) """ # Convert default boxes to format (xmin, ymin, xmax, ymax) # in order to compute overlap with gt boxes transformed_default_boxes = transform_center_to_corner(default_boxes) iou = compute_iou(transformed_default_boxes, gt_boxes) best_gt_iou = tf.math.reduce_max(iou, 1) best_gt_idx = tf.math.argmax(iou, 1) best_default_iou = tf.math.reduce_max(iou, 0) best_default_idx = tf.math.argmax(iou, 0) best_gt_idx = tf.tensor_scatter_nd_update( best_gt_idx, tf.expand_dims(best_default_idx, 1), tf.range(best_default_idx.shape[0], dtype=tf.int64)) # Normal way: use a for loop # for gt_idx, default_idx in enumerate(best_default_idx): # best_gt_idx = tf.tensor_scatter_nd_update( # best_gt_idx, # tf.expand_dims([default_idx], 1), # [gt_idx]) best_gt_iou = tf.tensor_scatter_nd_update( best_gt_iou, tf.expand_dims(best_default_idx, 1), tf.ones_like(best_default_idx, dtype=tf.float32)) gt_confs = tf.gather(gt_labels, best_gt_idx) gt_confs = tf.where( tf.less(best_gt_iou, iou_threshold), tf.zeros_like(gt_confs), gt_confs) gt_boxes = tf.gather(gt_boxes, best_gt_idx) gt_locs = encode(default_boxes, gt_boxes) return gt_confs, gt_locs