Python tensorflow.assert_equal() Examples
The following are 30
code examples of tensorflow.assert_equal().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: anchor_generator.py From Person-Detection-and-Tracking with MIT License | 6 votes |
def _assert_correct_number_of_anchors(self, anchors_list, feature_map_shape_list): """Assert that correct number of anchors was generated. Args: anchors_list: A list of box_list.BoxList object holding anchors generated. feature_map_shape_list: list of (height, width) pairs in the format [(height_0, width_0), (height_1, width_1), ...] that the generated anchors must align with. Returns: Op that raises InvalidArgumentError if the number of anchors does not match the number of expected anchors. """ expected_num_anchors = 0 actual_num_anchors = 0 for num_anchors_per_location, feature_map_shape, anchors in zip( self.num_anchors_per_location(), feature_map_shape_list, anchors_list): expected_num_anchors += (num_anchors_per_location * feature_map_shape[0] * feature_map_shape[1]) actual_num_anchors += anchors.num_boxes() return tf.assert_equal(expected_num_anchors, actual_num_anchors)
Example #2
Source File: anchor_generator.py From object_detector_app with MIT License | 6 votes |
def _assert_correct_number_of_anchors(self, anchors, feature_map_shape_list): """Assert that correct number of anchors was generated. Args: anchors: box_list.BoxList object holding anchors generated feature_map_shape_list: list of (height, width) pairs in the format [(height_0, width_0), (height_1, width_1), ...] that the generated anchors must align with. Returns: Op that raises InvalidArgumentError if the number of anchors does not match the number of expected anchors. """ expected_num_anchors = 0 for num_anchors_per_location, feature_map_shape in zip( self.num_anchors_per_location(), feature_map_shape_list): expected_num_anchors += (num_anchors_per_location * feature_map_shape[0] * feature_map_shape[1]) return tf.assert_equal(expected_num_anchors, anchors.num_boxes())
Example #3
Source File: anchor_generator.py From Hands-On-Machine-Learning-with-OpenCV-4 with MIT License | 6 votes |
def _assert_correct_number_of_anchors(self, anchors, feature_map_shape_list): """Assert that correct number of anchors was generated. Args: anchors: box_list.BoxList object holding anchors generated feature_map_shape_list: list of (height, width) pairs in the format [(height_0, width_0), (height_1, width_1), ...] that the generated anchors must align with. Returns: Op that raises InvalidArgumentError if the number of anchors does not match the number of expected anchors. """ expected_num_anchors = 0 for num_anchors_per_location, feature_map_shape in zip( self.num_anchors_per_location(), feature_map_shape_list): expected_num_anchors += (num_anchors_per_location * feature_map_shape[0] * feature_map_shape[1]) return tf.assert_equal(expected_num_anchors, anchors.num_boxes())
Example #4
Source File: anchor_generator.py From tensorflow with BSD 2-Clause "Simplified" License | 6 votes |
def _assert_correct_number_of_anchors(self, anchors, feature_map_shape_list): """Assert that correct number of anchors was generated. Args: anchors: box_list.BoxList object holding anchors generated feature_map_shape_list: list of (height, width) pairs in the format [(height_0, width_0), (height_1, width_1), ...] that the generated anchors must align with. Returns: Op that raises InvalidArgumentError if the number of anchors does not match the number of expected anchors. """ expected_num_anchors = 0 for num_anchors_per_location, feature_map_shape in zip( self.num_anchors_per_location(), feature_map_shape_list): expected_num_anchors += (num_anchors_per_location * feature_map_shape[0] * feature_map_shape[1]) return tf.assert_equal(expected_num_anchors, anchors.num_boxes())
Example #5
Source File: anchor_generator.py From DOTA_models with Apache License 2.0 | 6 votes |
def _assert_correct_number_of_anchors(self, anchors, feature_map_shape_list): """Assert that correct number of anchors was generated. Args: anchors: box_list.BoxList object holding anchors generated feature_map_shape_list: list of (height, width) pairs in the format [(height_0, width_0), (height_1, width_1), ...] that the generated anchors must align with. Returns: Op that raises InvalidArgumentError if the number of anchors does not match the number of expected anchors. """ expected_num_anchors = 0 for num_anchors_per_location, feature_map_shape in zip( self.num_anchors_per_location(), feature_map_shape_list): expected_num_anchors += (num_anchors_per_location * feature_map_shape[0] * feature_map_shape[1]) return tf.assert_equal(expected_num_anchors, anchors.num_boxes())
Example #6
Source File: anchor_generator.py From yolo_v2 with Apache License 2.0 | 6 votes |
def _assert_correct_number_of_anchors(self, anchors, feature_map_shape_list): """Assert that correct number of anchors was generated. Args: anchors: box_list.BoxList object holding anchors generated feature_map_shape_list: list of (height, width) pairs in the format [(height_0, width_0), (height_1, width_1), ...] that the generated anchors must align with. Returns: Op that raises InvalidArgumentError if the number of anchors does not match the number of expected anchors. """ expected_num_anchors = 0 for num_anchors_per_location, feature_map_shape in zip( self.num_anchors_per_location(), feature_map_shape_list): expected_num_anchors += (num_anchors_per_location * feature_map_shape[0] * feature_map_shape[1]) return tf.assert_equal(expected_num_anchors, anchors.num_boxes())
Example #7
Source File: anchor_generator.py From Traffic-Rule-Violation-Detection-System with MIT License | 6 votes |
def _assert_correct_number_of_anchors(self, anchors, feature_map_shape_list): """Assert that correct number of anchors was generated. Args: anchors: box_list.BoxList object holding anchors generated feature_map_shape_list: list of (height, width) pairs in the format [(height_0, width_0), (height_1, width_1), ...] that the generated anchors must align with. Returns: Op that raises InvalidArgumentError if the number of anchors does not match the number of expected anchors. """ expected_num_anchors = 0 for num_anchors_per_location, feature_map_shape in zip( self.num_anchors_per_location(), feature_map_shape_list): expected_num_anchors += (num_anchors_per_location * feature_map_shape[0] * feature_map_shape[1]) return tf.assert_equal(expected_num_anchors, anchors.num_boxes())
Example #8
Source File: anchor_generator.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def _assert_correct_number_of_anchors(self, anchors_list, feature_map_shape_list): """Assert that correct number of anchors was generated. Args: anchors_list: A list of box_list.BoxList object holding anchors generated. feature_map_shape_list: list of (height, width) pairs in the format [(height_0, width_0), (height_1, width_1), ...] that the generated anchors must align with. Returns: Op that raises InvalidArgumentError if the number of anchors does not match the number of expected anchors. """ expected_num_anchors = 0 actual_num_anchors = 0 for num_anchors_per_location, feature_map_shape, anchors in zip( self.num_anchors_per_location(), feature_map_shape_list, anchors_list): expected_num_anchors += (num_anchors_per_location * feature_map_shape[0] * feature_map_shape[1]) actual_num_anchors += anchors.num_boxes() return tf.assert_equal(expected_num_anchors, actual_num_anchors)
Example #9
Source File: utils.py From zhusuan with MIT License | 6 votes |
def __call__(self, x, y): ''' Return K(x, y), where x and y are possibly batched. :param x: shape [..., n_x, n_covariates] :param y: shape [..., n_y, n_covariates] :return: Tensor with shape [..., n_x, n_y] ''' batch_shape = tf.shape(x)[:-2] rank = x.shape.ndims assert_ops = [ tf.assert_greater_equal( rank, 2, message='RBFKernel: rank(x) should be static and >=2'), tf.assert_equal( rank, tf.rank(y), message='RBFKernel: x and y should have the same rank')] with tf.control_dependencies(assert_ops): x = tf.expand_dims(x, rank - 1) y = tf.expand_dims(y, rank - 2) k_scale = tf.reshape(self.k_scale, [1] * rank + [-1]) ret = tf.exp( -tf.reduce_sum(tf.square(x - y) / k_scale, axis=-1) / 2) return ret
Example #10
Source File: anchor_generator.py From HereIsWally with MIT License | 6 votes |
def _assert_correct_number_of_anchors(self, anchors, feature_map_shape_list): """Assert that correct number of anchors was generated. Args: anchors: box_list.BoxList object holding anchors generated feature_map_shape_list: list of (height, width) pairs in the format [(height_0, width_0), (height_1, width_1), ...] that the generated anchors must align with. Returns: Op that raises InvalidArgumentError if the number of anchors does not match the number of expected anchors. """ expected_num_anchors = 0 for num_anchors_per_location, feature_map_shape in zip( self.num_anchors_per_location(), feature_map_shape_list): expected_num_anchors += (num_anchors_per_location * feature_map_shape[0] * feature_map_shape[1]) return tf.assert_equal(expected_num_anchors, anchors.num_boxes())
Example #11
Source File: pyramid_network.py From FastMaskRCNN with Apache License 2.0 | 6 votes |
def _filter_negative_samples(labels, tensors): """keeps only samples with none-negative labels Params: ----- labels: of shape (N,) tensors: a list of tensors, each of shape (N, .., ..) the first axis is sample number Returns: ----- tensors: filtered tensors """ # return tensors keeps = tf.where(tf.greater_equal(labels, 0)) keeps = tf.reshape(keeps, [-1]) filtered = [] for t in tensors: tf.assert_equal(tf.shape(t)[0], tf.shape(labels)[0]) f = tf.gather(t, keeps) filtered.append(f) return filtered
Example #12
Source File: anchor_generator.py From garbage-object-detection-tensorflow with MIT License | 6 votes |
def _assert_correct_number_of_anchors(self, anchors, feature_map_shape_list): """Assert that correct number of anchors was generated. Args: anchors: box_list.BoxList object holding anchors generated feature_map_shape_list: list of (height, width) pairs in the format [(height_0, width_0), (height_1, width_1), ...] that the generated anchors must align with. Returns: Op that raises InvalidArgumentError if the number of anchors does not match the number of expected anchors. """ expected_num_anchors = 0 for num_anchors_per_location, feature_map_shape in zip( self.num_anchors_per_location(), feature_map_shape_list): expected_num_anchors += (num_anchors_per_location * feature_map_shape[0] * feature_map_shape[1]) return tf.assert_equal(expected_num_anchors, anchors.num_boxes())
Example #13
Source File: anchor_generator.py From ros_people_object_detection_tensorflow with Apache License 2.0 | 6 votes |
def _assert_correct_number_of_anchors(self, anchors_list, feature_map_shape_list): """Assert that correct number of anchors was generated. Args: anchors_list: A list of box_list.BoxList object holding anchors generated. feature_map_shape_list: list of (height, width) pairs in the format [(height_0, width_0), (height_1, width_1), ...] that the generated anchors must align with. Returns: Op that raises InvalidArgumentError if the number of anchors does not match the number of expected anchors. """ expected_num_anchors = 0 actual_num_anchors = 0 for num_anchors_per_location, feature_map_shape, anchors in zip( self.num_anchors_per_location(), feature_map_shape_list, anchors_list): expected_num_anchors += (num_anchors_per_location * feature_map_shape[0] * feature_map_shape[1]) actual_num_anchors += anchors.num_boxes() return tf.assert_equal(expected_num_anchors, actual_num_anchors)
Example #14
Source File: anchor_generator.py From cartoonify with MIT License | 6 votes |
def _assert_correct_number_of_anchors(self, anchors, feature_map_shape_list): """Assert that correct number of anchors was generated. Args: anchors: box_list.BoxList object holding anchors generated feature_map_shape_list: list of (height, width) pairs in the format [(height_0, width_0), (height_1, width_1), ...] that the generated anchors must align with. Returns: Op that raises InvalidArgumentError if the number of anchors does not match the number of expected anchors. """ expected_num_anchors = 0 for num_anchors_per_location, feature_map_shape in zip( self.num_anchors_per_location(), feature_map_shape_list): expected_num_anchors += (num_anchors_per_location * feature_map_shape[0] * feature_map_shape[1]) return tf.assert_equal(expected_num_anchors, anchors.num_boxes())
Example #15
Source File: transformer_memory_test.py From BERT with Apache License 2.0 | 6 votes |
def testReset(self): batch_size = 2 key_depth = 3 val_depth = 5 memory_size = 4 memory = transformer_memory.TransformerMemory( batch_size, key_depth, val_depth, memory_size) vals = tf.random_uniform([batch_size, memory_size, val_depth], minval=1.0) logits = tf.random_uniform([batch_size, memory_size], minval=1.0) update_op = memory.set(vals, logits) reset_op = memory.reset([1]) mem_vals, mem_logits = memory.get() assert_op1 = tf.assert_equal(mem_vals[0], vals[0]) assert_op2 = tf.assert_equal(mem_logits[0], logits[0]) with tf.control_dependencies([assert_op1, assert_op2]): all_zero1 = tf.reduce_sum(tf.abs(mem_vals[1])) all_zero2 = tf.reduce_sum(tf.abs(mem_logits[1])) with self.test_session() as session: session.run(tf.global_variables_initializer()) session.run(update_op) session.run(reset_op) zero1, zero2 = session.run([all_zero1, all_zero2]) self.assertAllEqual(0, zero1) self.assertAllEqual(0, zero2)
Example #16
Source File: shape_utils.py From Traffic-Rule-Violation-Detection-System with MIT License | 5 votes |
def assert_shape_equal_along_first_dimension(shape_a, shape_b): """Asserts that shape_a and shape_b are the same along the 0th-dimension. If the shapes are static, raises a ValueError when the shapes mismatch. If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes mismatch. Args: shape_a: a list containing shape of the first tensor. shape_b: a list containing shape of the second tensor. Returns: Either a tf.no_op() when shapes are all static and a tf.assert_equal() op when the shapes are dynamic. Raises: ValueError: When shapes are both static and unequal. """ if isinstance(shape_a[0], int) and isinstance(shape_b[0], int): if shape_a[0] != shape_b[0]: raise ValueError('Unequal first dimension {}, {}'.format( shape_a[0], shape_b[0])) else: return tf.no_op() else: return tf.assert_equal(shape_a[0], shape_b[0])
Example #17
Source File: digraph_ops.py From Gun-Detector with Apache License 2.0 | 5 votes |
def CombineArcAndRootPotentials(arcs, roots): """Combines arc and root potentials into a single set of potentials. Args: arcs: [B,N,N] tensor of batched arc potentials. roots: [B,N] matrix of batched root potentials. Returns: [B,N,N] tensor P of combined potentials where P_{b,s,t} = s == t ? roots[b,t] : arcs[b,s,t] """ # All arguments must have statically-known rank. check.Eq(arcs.get_shape().ndims, 3, 'arcs must be rank 3') check.Eq(roots.get_shape().ndims, 2, 'roots must be a matrix') # All arguments must share the same type. dtype = arcs.dtype.base_dtype check.Same([dtype, roots.dtype.base_dtype], 'dtype mismatch') roots_shape = tf.shape(roots) arcs_shape = tf.shape(arcs) batch_size = roots_shape[0] num_tokens = roots_shape[1] with tf.control_dependencies([ tf.assert_equal(batch_size, arcs_shape[0]), tf.assert_equal(num_tokens, arcs_shape[1]), tf.assert_equal(num_tokens, arcs_shape[2])]): return tf.matrix_set_diag(arcs, roots)
Example #18
Source File: shape_utils.py From Traffic-Rule-Violation-Detection-System with MIT License | 5 votes |
def assert_shape_equal(shape_a, shape_b): """Asserts that shape_a and shape_b are equal. If the shapes are static, raises a ValueError when the shapes mismatch. If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes mismatch. Args: shape_a: a list containing shape of the first tensor. shape_b: a list containing shape of the second tensor. Returns: Either a tf.no_op() when shapes are all static and a tf.assert_equal() op when the shapes are dynamic. Raises: ValueError: When shapes are both static and unequal. """ if (all(isinstance(dim, int) for dim in shape_a) and all(isinstance(dim, int) for dim in shape_b)): if shape_a != shape_b: raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) else: return tf.no_op() else: return tf.assert_equal(shape_a, shape_b)
Example #19
Source File: gan_losses.py From federated with Apache License 2.0 | 5 votes |
def _wass_disc_loss_fn(real_images, gen_images, discriminator: tf.keras.Model, grad_penalty_lambda=0.0): """Calculate the Wasserstein (discriminator) loss.""" if grad_penalty_lambda < 0.0: raise ValueError('grad_penalty_lambda must be greater than or equal to 0.0') # For calculating the discriminator loss, it's desirable to have equal-sized # contributions from both the real and fake data. Also, it's necessary if # computing the Wasserstein gradient penalty (where a difference is taken b/w # the real and fake data). So we assert batch_size equality here. with tf.control_dependencies( [tf.assert_equal(tf.shape(real_images)[0], tf.shape(gen_images)[0])]): disc_gen_output = discriminator(gen_images, training=True) score_on_generated = tf.reduce_mean(disc_gen_output) disc_real_output = discriminator(real_images, training=True) score_on_real = tf.reduce_mean(disc_real_output) disc_loss = score_on_generated - score_on_real # Add gradient penalty, if indicated. if grad_penalty_lambda > 0.0: disc_loss += _wass_grad_penalty_term(real_images, gen_images, discriminator, grad_penalty_lambda) # Now add discriminator model regularization losses in. if discriminator.losses: disc_loss += tf.add_n(discriminator.losses) return disc_loss
Example #20
Source File: common_attention.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def conv_elems_1d(x, factor, out_depth=None): """Decrease the length and change the dimensionality. Merge/restore/compress factors positions of dim depth of the input into a single position of dim out_depth. This is basically just a strided convolution without overlap between each strides. The original length has to be divided by factor. Args: x (tf.Tensor): shape [batch_size, length, depth] factor (int): Length compression factor. out_depth (int): Output depth Returns: tf.Tensor: shape [batch_size, length//factor, out_depth] """ out_depth = out_depth or x.get_shape().as_list()[-1] # with tf.control_dependencies( # Dynamic assertion # [tf.assert_equal(tf.shape(x)[1] % factor, 0)]): x = tf.expand_dims(x, 1) # [batch_size, 1, length, depth] x = tf.layers.conv2d( inputs=x, filters=out_depth, kernel_size=(1, factor), strides=(1, factor), padding="valid", data_format="channels_last", ) # [batch_size, 1, length//factor, out_depth] x = tf.squeeze(x, 1) # [batch_size, length//factor, depth] return x
Example #21
Source File: shape_utils.py From training_results_v0.5 with Apache License 2.0 | 5 votes |
def assert_shape_equal(shape_a, shape_b): """Asserts that shape_a and shape_b are equal. If the shapes are static, raises a ValueError when the shapes mismatch. If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes mismatch. Args: shape_a: a list containing shape of the first tensor. shape_b: a list containing shape of the second tensor. Returns: Either a tf.no_op() when shapes are all static and a tf.assert_equal() op when the shapes are dynamic. Raises: ValueError: When shapes are both static and unequal. """ if (all(isinstance(dim, int) for dim in shape_a) and all(isinstance(dim, int) for dim in shape_b)): if shape_a != shape_b: raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) else: return tf.no_op() else: return tf.assert_equal(shape_a, shape_b)
Example #22
Source File: estimator.py From kfac with Apache License 2.0 | 5 votes |
def _check_batch_sizes(self, factor): """Checks that the batch size(s) for a factor matches the reference value.""" # Should make these messages use quote characters instead of parentheses # when the bug with quote character rendering in assertion messages is # fixed. See b/129476712 if self._batch_size is None: batch_size = self.factors[0].batch_size() string = ("Batch size {} for factor (" + factor.name + ") of type " + utils.cls_name(factor) + " did not match value {} used by " "factor (" + self.factors[0].name + ") of type " + utils.cls_name(self.factors[0]) + ".") else: batch_size = self._batch_size string = ("Batch size {} for factor (" + factor.name + ") of type " + utils.cls_name(factor) + " did not match value {} which was " "passed to optimizer/estimator.") factor_batch_size = factor.batch_size() if isinstance(batch_size, int) and isinstance(factor_batch_size, int): if factor_batch_size != batch_size: raise ValueError(string.format(factor_batch_size, batch_size)) return factor.check_partial_batch_sizes() else: message = string.format("(x)", "(y)") with tf.control_dependencies([factor.check_partial_batch_sizes()]): return tf.assert_equal(factor_batch_size, batch_size, message=message)
Example #23
Source File: cutout_ops.py From addons with Apache License 2.0 | 5 votes |
def _norm_params(images, mask_size, data_format): mask_size = tf.convert_to_tensor(mask_size) if tf.executing_eagerly(): tf.assert_equal( tf.reduce_any(mask_size % 2 != 0), False, "mask_size should be divisible by 2", ) if tf.rank(mask_size) == 0: mask_size = tf.stack([mask_size, mask_size]) data_format = keras_utils.normalize_data_format(data_format) image_height, image_width = _get_image_wh(images, data_format) return mask_size, data_format, image_height, image_width
Example #24
Source File: helpers.py From Feed-Forward-Style-Transfer with MIT License | 5 votes |
def get_style_loss_for_layer(variable_img, style_img, layer): """Compute style loss for a layer-out op (l) given the variable vgg-out op (x) and the style vgg-out op (s). Args: variable_img: 4D tensor representing the variable image vgg encodings style_img: 4D tensor representing the style image vgg encodings layer: string representing the vgg layer Returns: loss: float tensor representing the style loss for the given layer """ with tf.name_scope('get_style_loss_for_layer'): # Compute gram matrices using the activated filter maps of the art and generated images x_layer_maps = getattr(variable_img, layer) s_layer_maps = getattr(style_img, layer) x_layer_gram = convert_to_gram(x_layer_maps) s_layer_gram = convert_to_gram(s_layer_maps) # Make sure the feature grams have the same dimensions assert_equal_shapes = tf.assert_equal(x_layer_gram.get_shape(), s_layer_gram.get_shape()) with tf.control_dependencies([assert_equal_shapes]): # Compute and return the normalized gram loss using the gram matrices shape = x_layer_maps.get_shape().as_list() size = reduce(lambda a, b: a * b, shape) ** 2 gram_loss = get_l2_norm_loss(x_layer_gram - s_layer_gram) return gram_loss / size
Example #25
Source File: util.py From shortest-path with The Unlicense | 5 votes |
def dynamic_assert_shape(tensor, shape, name=None): """ Check that a tensor has a shape given by a list of constants and tensor values. This function will place an operation into your graph that gets executed at runtime. This is helpful because often tensors have many dynamic sized dimensions that you cannot otherwise compare / assert are as you expect. For example, measure a dimension at run time: `batch_size = tf.shape(my_tensor)[0]` then assert another tensor does indeed have the right shape: `other_tensor = dynamic_assert_shape(other_tensor, [batch_size, 16])` You should use this as an inline identity function so that the operation it generates gets added and executed in the graph Returns: the argument `tensor` unchanged """ if global_args["use_assert"]: tensor_shape = tf.shape(tensor) tensor_shape = tf.cast(tensor_shape, tf.int64) expected_shape = tf.convert_to_tensor(shape) expected_shape = tf.cast(expected_shape, tf.int64) t_name = "tensor" if tf.executing_eagerly() else tensor.name if isinstance(shape, list): assert len(tensor.shape) == len(shape), f"Tensor shape {tensor_shape} and expected shape {expected_shape} have different lengths" assert_op = tf.assert_equal(tensor_shape, expected_shape, message=f"Asserting shape of {t_name}", summarize=10, name=name) with tf.control_dependencies([assert_op]): return tf.identity(tensor, name="dynamic_assert_shape") else: return tensor
Example #26
Source File: shape_utils.py From object_centric_VAD with MIT License | 5 votes |
def expand_first_dimension(inputs, dims): """Expands `K-d` tensor along first dimension to be a `(K+n-1)-d` tensor. Converts `inputs` with shape [D0, D1, ..., D(K-1)] into a tensor of shape [dims[0], dims[1], ..., dims[-1], D1, ..., D(k-1)]. Example: `inputs` is a tensor with shape [50, 20, 20, 3]. new_tensor = expand_first_dimension(inputs, [10, 5]). new_tensor.shape -> [10, 5, 20, 20, 3]. Args: inputs: a tensor with shape [D0, D1, ..., D(K-1)]. dims: List with new dimensions to expand first axis into. The length of `dims` is typically 2 or larger. Returns: a tensor with shape [dims[0], dims[1], ..., dims[-1], D1, ..., D(k-1)]. """ inputs_shape = combined_static_and_dynamic_shape(inputs) expanded_shape = tf.stack(dims + inputs_shape[1:]) # Verify that it is possible to expand the first axis of inputs. assert_op = tf.assert_equal( inputs_shape[0], tf.reduce_prod(tf.stack(dims)), message=('First dimension of `inputs` cannot be expanded into provided ' '`dims`')) with tf.control_dependencies([assert_op]): inputs_reshaped = tf.reshape(inputs, expanded_shape) return inputs_reshaped
Example #27
Source File: shape_utils.py From object_centric_VAD with MIT License | 5 votes |
def assert_shape_equal_along_first_dimension(shape_a, shape_b): """Asserts that shape_a and shape_b are the same along the 0th-dimension. If the shapes are static, raises a ValueError when the shapes mismatch. If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes mismatch. Args: shape_a: a list containing shape of the first tensor. shape_b: a list containing shape of the second tensor. Returns: Either a tf.no_op() when shapes are all static and a tf.assert_equal() op when the shapes are dynamic. Raises: ValueError: When shapes are both static and unequal. """ if isinstance(shape_a[0], int) and isinstance(shape_b[0], int): if shape_a[0] != shape_b[0]: raise ValueError('Unequal first dimension {}, {}'.format( shape_a[0], shape_b[0])) else: return tf.no_op() else: return tf.assert_equal(shape_a[0], shape_b[0])
Example #28
Source File: shape_utils.py From object_centric_VAD with MIT License | 5 votes |
def assert_shape_equal(shape_a, shape_b): """Asserts that shape_a and shape_b are equal. If the shapes are static, raises a ValueError when the shapes mismatch. If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes mismatch. Args: shape_a: a list containing shape of the first tensor. shape_b: a list containing shape of the second tensor. Returns: Either a tf.no_op() when shapes are all static and a tf.assert_equal() op when the shapes are dynamic. Raises: ValueError: When shapes are both static and unequal. """ if (all(isinstance(dim, int) for dim in shape_a) and all(isinstance(dim, int) for dim in shape_b)): if shape_a != shape_b: raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) else: return tf.no_op() else: return tf.assert_equal(shape_a, shape_b)
Example #29
Source File: digraph_ops.py From DOTA_models with Apache License 2.0 | 5 votes |
def CombineArcAndRootPotentials(arcs, roots): """Combines arc and root potentials into a single set of potentials. Args: arcs: [B,N,N] tensor of batched arc potentials. roots: [B,N] matrix of batched root potentials. Returns: [B,N,N] tensor P of combined potentials where P_{b,s,t} = s == t ? roots[b,t] : arcs[b,s,t] """ # All arguments must have statically-known rank. check.Eq(arcs.get_shape().ndims, 3, 'arcs must be rank 3') check.Eq(roots.get_shape().ndims, 2, 'roots must be a matrix') # All arguments must share the same type. dtype = arcs.dtype.base_dtype check.Same([dtype, roots.dtype.base_dtype], 'dtype mismatch') roots_shape = tf.shape(roots) arcs_shape = tf.shape(arcs) batch_size = roots_shape[0] num_tokens = roots_shape[1] with tf.control_dependencies([ tf.assert_equal(batch_size, arcs_shape[0]), tf.assert_equal(num_tokens, arcs_shape[1]), tf.assert_equal(num_tokens, arcs_shape[2])]): return tf.matrix_set_diag(arcs, roots)
Example #30
Source File: embracenet.py From embracenet with MIT License | 5 votes |
def add_modality(self, input_data, input_size, bypass_docking=False): """ Add a modality to EmbraceNet. Args: input_data: An input data to feed into EmbraceNet. Must be a 2-D tensor of shape [batch_size, input_size]. input_size: The second dimension of input_data. bypass_docking: Bypass docking step, i.e., connect the input data directly to the embracement layer. If True, input_data must have a shape of [batch_size, embracement_size]. """ # check input data tf_assertions = [] tf_assertions.append(tf.assert_rank(input_data, 2)) tf_assertions.append(tf.assert_equal(tf.shape(input_data)[0], self.batch_size)) with tf.control_dependencies(tf_assertions): input_data = tf.identity(input_data) with tf.variable_scope('embracenet'): # construct docking layer modality_index = len(self.graph.modalities) modality_graph = EmbraceNetObject() modality_feeds = EmbraceNetObject() with tf.variable_scope('docking/%d' % modality_index): docking_input = input_data if (bypass_docking): modality_graph.docking_output = docking_input else: docking_output = tf.layers.dense(docking_input, units=self.embracement_size, kernel_initializer=None, bias_initializer=None) docking_output = tf.nn.relu(docking_output) modality_graph.docking_output = docking_output # finalize self.graph.modalities.append(modality_graph) self.feeds.modalities.append(modality_feeds)