Python tensorflow.compat.v1.Assert() Examples
The following are 22
code examples of tensorflow.compat.v1.Assert().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.compat.v1
, or try the search function
.
Example #1
Source File: test_debugging.py From incubator-tvm with Apache License 2.0 | 6 votes |
def test_assert_false(): g = tf.Graph() with g.as_default(): assert_op = tf.Assert(tf.constant(False), ["it failed"]) with tf.Session() as sess: try: print(sess.run(assert_op)) assert False # TF should have thrown an exception except tf.errors.InvalidArgumentError as e: assert "it failed" in e.message # In TVM, tf.assert is converted to a no-op which is actually a 0, # though it should probably be none or an empty tuple. For the same # reason, there should not be an error here, even though the assertion # argument is false. np.testing.assert_allclose(0, run_relay(g).asnumpy())
Example #2
Source File: test_debugging.py From incubator-tvm with Apache License 2.0 | 6 votes |
def test_assert_true_var_capture(): g = tf.Graph() with g.as_default(): x = tf.placeholder(tf.float32, shape=()) # It turns out that tf.assert() creates a large and complex subgraph if # you capture a variable as part of the error message. So we need to # test that, too. assert_op = tf.Assert(tf.less_equal(x, x), ["it failed", x]) with tf.Session() as sess: x_value = np.random.rand() assert sess.run(assert_op, feed_dict={x: x_value}) is None # TODO: The frontend converter notes the output of # the graph as a boolean, which is not correct - as you can see above, # TF believes that the value of this graph is None. np.testing.assert_allclose(True, run_relay(g, None, x_value).asnumpy())
Example #3
Source File: test_debugging.py From incubator-tvm with Apache License 2.0 | 6 votes |
def test_assert_true(): g = tf.Graph() shape = (1, 2) with g.as_default(): x = tf.placeholder(tf.float32, shape=shape, name="input") assert_op = tf.Assert(tf.reduce_all(tf.less_equal(x, x)), ["it failed"]) with tf.Session() as sess: x_value = np.random.rand(*shape) assert sess.run(assert_op, feed_dict={x: x_value}) is None # In TVM, tf.assert is converted to a no-op which is actually a 0, # though it should probably be none or an empty tuple. # # ToDo: It appears that the frontend converter gets confused here and # entirely eliminates all operands from main(). Likely because x <= x # is always true, so the placeholder can be eliminated. But TF doesn't # do that, it's happening in Relay, and that optimization shouldn't # affect the arity of the main function. We should have to pass in # x_value here. np.testing.assert_allclose(0, run_relay(g, {'input': shape}).asnumpy())
Example #4
Source File: dataloader.py From Object_Detection_Tracking with Apache License 2.0 | 6 votes |
def pad_to_fixed_size(data, pad_value, output_shape): """Pad data to a fixed length at the first dimension. Args: data: Tensor to be padded to output_shape. pad_value: A constant value assigned to the paddings. output_shape: The output shape of a 2D tensor. Returns: The Padded tensor with output_shape [max_num_instances, dimension]. """ max_num_instances = output_shape[0] dimension = output_shape[1] data = tf.reshape(data, [-1, dimension]) num_instances = tf.shape(data)[0] assert_length = tf.Assert( tf.less_equal(num_instances, max_num_instances), [num_instances]) with tf.control_dependencies([assert_length]): pad_length = max_num_instances - num_instances paddings = pad_value * tf.ones([pad_length, dimension]) padded_data = tf.concat([data, paddings], axis=0) padded_data = tf.reshape(padded_data, output_shape) return padded_data
Example #5
Source File: receptive_field_computation_test.py From receptive_field with Apache License 2.0 | 6 votes |
def create_test_network_7(): """Aligned network for test, with a control dependency. The graph is similar to create_test_network_1(), except that it includes an assert operation on the left branch. Returns: g: Tensorflow graph object (Graph proto). """ g = tf.Graph() with g.as_default(): # An 8x8 test image. x = tf.placeholder(tf.float32, (1, 8, 8, 1), name='input_image') # Left branch. l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID') l1_shape = tf.shape(l1) assert_op = tf.Assert(tf.equal(l1_shape[1], 2), [l1_shape], summarize=4) # Right branch. l2_pad = tf.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]]) l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID') l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID') # Addition. with tf.control_dependencies([assert_op]): tf.nn.relu(l1 + l3, name='output') return g
Example #6
Source File: shape_utils.py From models with Apache License 2.0 | 6 votes |
def assert_box_normalized(boxes, maximum_normalized_coordinate=1.1): """Asserts the input box tensor is normalized. Args: boxes: a tensor of shape [N, 4] where N is the number of boxes. maximum_normalized_coordinate: Maximum coordinate value to be considered as normalized, default to 1.1. Returns: a tf.Assert op which fails when the input box tensor is not normalized. Raises: ValueError: When the input box tensor is not normalized. """ box_minimum = tf.reduce_min(boxes) box_maximum = tf.reduce_max(boxes) return tf.Assert( tf.logical_and( tf.less_equal(box_maximum, maximum_normalized_coordinate), tf.greater_equal(box_minimum, 0)), [boxes])
Example #7
Source File: center_net_meta_arch.py From models with Apache License 2.0 | 5 votes |
def _get_shape(tensor, num_dims): tf.Assert(tensor.get_shape().ndims == num_dims, [tensor]) return shape_utils.combined_static_and_dynamic_shape(tensor)
Example #8
Source File: densepose_ops.py From models with Apache License 2.0 | 5 votes |
def to_absolute_coordinates(dp_surface_coords, height, width, check_range=True, scope=None): """Converts normalized DensePose coordinates to absolute pixel coordinates. This function raises an assertion failed error when the maximum coordinate value is larger than 1.01 (in which case coordinates are already absolute). Args: dp_surface_coords: a tensor of shape [num_instances, num_points, 4] with DensePose normalized surface coordinates in (y, x, v, u) format. height: Height of image. width: Width of image. check_range: If True, checks if the coordinates are normalized or not. scope: name scope. Returns: A tensor of shape [num_instances, num_points, 4] with absolute coordinates. """ with tf.name_scope(scope, 'DensePoseToAbsoluteCoordinates'): height = tf.cast(height, tf.float32) width = tf.cast(width, tf.float32) if check_range: max_val = tf.reduce_max(dp_surface_coords[:, :, :2]) max_assert = tf.Assert(tf.greater_equal(1.01, max_val), ['maximum coordinate value is larger than 1.01: ', max_val]) with tf.control_dependencies([max_assert]): width = tf.identity(width) return scale(dp_surface_coords, height, width)
Example #9
Source File: densepose_ops.py From models with Apache License 2.0 | 5 votes |
def to_normalized_coordinates(dp_surface_coords, height, width, check_range=True, scope=None): """Converts absolute DensePose coordinates to normalized in range [0, 1]. This function raises an assertion failed error at graph execution time when the maximum coordinate is smaller than 1.01 (which means that coordinates are already normalized). The value 1.01 is to deal with small rounding errors. Args: dp_surface_coords: a tensor of shape [num_instances, num_points, 4] with DensePose absolute surface coordinates in (y, x, v, u) format. height: Height of image. width: Width of image. check_range: If True, checks if the coordinates are already normalized. scope: name scope. Returns: A tensor of shape [num_instances, num_points, 4] with normalized coordinates. """ with tf.name_scope(scope, 'DensePoseToNormalizedCoordinates'): height = tf.cast(height, tf.float32) width = tf.cast(width, tf.float32) if check_range: max_val = tf.reduce_max(dp_surface_coords[:, :, :2]) max_assert = tf.Assert(tf.greater(max_val, 1.01), ['max value is lower than 1.01: ', max_val]) with tf.control_dependencies([max_assert]): width = tf.identity(width) return scale(dp_surface_coords, 1.0 / height, 1.0 / width)
Example #10
Source File: box_list_ops.py From models with Apache License 2.0 | 5 votes |
def to_normalized_coordinates(boxlist, height, width, check_range=True, scope=None): """Converts absolute box coordinates to normalized coordinates in [0, 1]. Usually one uses the dynamic shape of the image or conv-layer tensor: boxlist = box_list_ops.to_normalized_coordinates(boxlist, tf.shape(images)[1], tf.shape(images)[2]), This function raises an assertion failed error at graph execution time when the maximum coordinate is smaller than 1.01 (which means that coordinates are already normalized). The value 1.01 is to deal with small rounding errors. Args: boxlist: BoxList with coordinates in terms of pixel-locations. height: Maximum value for height of absolute box coordinates. width: Maximum value for width of absolute box coordinates. check_range: If True, checks if the coordinates are normalized or not. scope: name scope. Returns: boxlist with normalized coordinates in [0, 1]. """ with tf.name_scope(scope, 'ToNormalizedCoordinates'): height = tf.cast(height, tf.float32) width = tf.cast(width, tf.float32) if check_range: max_val = tf.reduce_max(boxlist.get()) max_assert = tf.Assert(tf.greater(max_val, 1.01), ['max value is lower than 1.01: ', max_val]) with tf.control_dependencies([max_assert]): width = tf.identity(width) return scale(boxlist, 1 / height, 1 / width)
Example #11
Source File: box_list_ops.py From models with Apache License 2.0 | 5 votes |
def sort_by_field(boxlist, field, order=SortOrder.descend, scope=None): """Sort boxes and associated fields according to a scalar field. A common use case is reordering the boxes according to descending scores. Args: boxlist: BoxList holding N boxes. field: A BoxList field for sorting and reordering the BoxList. order: (Optional) descend or ascend. Default is descend. scope: name scope. Returns: sorted_boxlist: A sorted BoxList with the field in the specified order. Raises: ValueError: if specified field does not exist ValueError: if the order is not either descend or ascend """ with tf.name_scope(scope, 'SortByField'): if order != SortOrder.descend and order != SortOrder.ascend: raise ValueError('Invalid sort order') field_to_sort = boxlist.get_field(field) if len(field_to_sort.shape.as_list()) != 1: raise ValueError('Field should have rank 1') num_boxes = boxlist.num_boxes() num_entries = tf.size(field_to_sort) length_assert = tf.Assert( tf.equal(num_boxes, num_entries), ['Incorrect field size: actual vs expected.', num_entries, num_boxes]) with tf.control_dependencies([length_assert]): _, sorted_indices = tf.nn.top_k(field_to_sort, num_boxes, sorted=True) if order == SortOrder.ascend: sorted_indices = tf.reverse_v2(sorted_indices, [0]) return gather(boxlist, sorted_indices)
Example #12
Source File: keypoint_ops.py From models with Apache License 2.0 | 5 votes |
def to_absolute_coordinates(keypoints, height, width, check_range=True, scope=None): """Converts normalized keypoint coordinates to absolute pixel coordinates. This function raises an assertion failed error when the maximum keypoint coordinate value is larger than 1.01 (in which case coordinates are already absolute). Args: keypoints: A tensor of shape [num_instances, num_keypoints, 2] height: Maximum value for y coordinate of absolute keypoint coordinates. width: Maximum value for x coordinate of absolute keypoint coordinates. check_range: If True, checks if the coordinates are normalized or not. scope: name scope. Returns: tensor of shape [num_instances, num_keypoints, 2] with absolute coordinates in terms of the image size. """ with tf.name_scope(scope, 'ToAbsoluteCoordinates'): height = tf.cast(height, tf.float32) width = tf.cast(width, tf.float32) # Ensure range of input keypoints is correct. if check_range: max_val = tf.reduce_max(keypoints) max_assert = tf.Assert(tf.greater_equal(1.01, max_val), ['maximum keypoint coordinate value is larger ' 'than 1.01: ', max_val]) with tf.control_dependencies([max_assert]): width = tf.identity(width) return scale(keypoints, height, width)
Example #13
Source File: keypoint_ops.py From models with Apache License 2.0 | 5 votes |
def to_normalized_coordinates(keypoints, height, width, check_range=True, scope=None): """Converts absolute keypoint coordinates to normalized coordinates in [0, 1]. Usually one uses the dynamic shape of the image or conv-layer tensor: keypoints = keypoint_ops.to_normalized_coordinates(keypoints, tf.shape(images)[1], tf.shape(images)[2]), This function raises an assertion failed error at graph execution time when the maximum coordinate is smaller than 1.01 (which means that coordinates are already normalized). The value 1.01 is to deal with small rounding errors. Args: keypoints: A tensor of shape [num_instances, num_keypoints, 2]. height: Maximum value for y coordinate of absolute keypoint coordinates. width: Maximum value for x coordinate of absolute keypoint coordinates. check_range: If True, checks if the coordinates are normalized. scope: name scope. Returns: tensor of shape [num_instances, num_keypoints, 2] with normalized coordinates in [0, 1]. """ with tf.name_scope(scope, 'ToNormalizedCoordinates'): height = tf.cast(height, tf.float32) width = tf.cast(width, tf.float32) if check_range: max_val = tf.reduce_max(keypoints) max_assert = tf.Assert(tf.greater(max_val, 1.01), ['max value is lower than 1.01: ', max_val]) with tf.control_dependencies([max_assert]): width = tf.identity(width) return scale(keypoints, 1.0 / height, 1.0 / width)
Example #14
Source File: yellowfin.py From tensor2tensor with Apache License 2.0 | 5 votes |
def _get_cubic_root(self): """Get the cubic root.""" # We have the equation x^2 D^2 + (1-x)^4 * C / h_min^2 # where x = sqrt(mu). # We substitute x, which is sqrt(mu), with x = y + 1. # It gives y^3 + py = q # where p = (D^2 h_min^2)/(2*C) and q = -p. # We use the Vieta's substitution to compute the root. # There is only one real solution y (which is in [0, 1] ). # http://mathworld.wolfram.com/VietasSubstitution.html assert_array = [ tf.Assert( tf.logical_not(tf.is_nan(self._dist_to_opt_avg)), [self._dist_to_opt_avg,]), tf.Assert( tf.logical_not(tf.is_nan(self._h_min)), [self._h_min,]), tf.Assert( tf.logical_not(tf.is_nan(self._grad_var)), [self._grad_var,]), tf.Assert( tf.logical_not(tf.is_inf(self._dist_to_opt_avg)), [self._dist_to_opt_avg,]), tf.Assert( tf.logical_not(tf.is_inf(self._h_min)), [self._h_min,]), tf.Assert( tf.logical_not(tf.is_inf(self._grad_var)), [self._grad_var,]) ] with tf.control_dependencies(assert_array): p = self._dist_to_opt_avg**2 * self._h_min**2 / 2 / self._grad_var w3 = (-tf.sqrt(p**2 + 4.0 / 27.0 * p**3) - p) / 2.0 w = tf.sign(w3) * tf.pow(tf.abs(w3), 1.0/3.0) y = w - p / 3.0 / w x = y + 1 return x
Example #15
Source File: shape_utils.py From models with Apache License 2.0 | 5 votes |
def check_min_image_dim(min_dim, image_tensor): """Checks that the image width/height are greater than some number. This function is used to check that the width and height of an image are above a certain value. If the image shape is static, this function will perform the check at graph construction time. Otherwise, if the image shape varies, an Assertion control dependency will be added to the graph. Args: min_dim: The minimum number of pixels along the width and height of the image. image_tensor: The image tensor to check size for. Returns: If `image_tensor` has dynamic size, return `image_tensor` with a Assert control dependency. Otherwise returns image_tensor. Raises: ValueError: if `image_tensor`'s' width or height is smaller than `min_dim`. """ image_shape = image_tensor.get_shape() image_height = static_shape.get_height(image_shape) image_width = static_shape.get_width(image_shape) if image_height is None or image_width is None: shape_assert = tf.Assert( tf.logical_and(tf.greater_equal(tf.shape(image_tensor)[1], min_dim), tf.greater_equal(tf.shape(image_tensor)[2], min_dim)), ['image size must be >= {} in both height and width.'.format(min_dim)]) with tf.control_dependencies([shape_assert]): return tf.identity(image_tensor) if image_height < min_dim or image_width < min_dim: raise ValueError( 'image size must be >= %d in both height and width; image dim = %d,%d' % (min_dim, image_height, image_width)) return image_tensor
Example #16
Source File: inputs.py From models with Apache License 2.0 | 5 votes |
def assert_or_prune_invalid_boxes(boxes): """Makes sure boxes have valid sizes (ymax >= ymin, xmax >= xmin). When the hardware supports assertions, the function raises an error when boxes have an invalid size. If assertions are not supported (e.g. on TPU), boxes with invalid sizes are filtered out. Args: boxes: float tensor of shape [num_boxes, 4] Returns: boxes: float tensor of shape [num_valid_boxes, 4] with invalid boxes filtered out. Raises: tf.errors.InvalidArgumentError: When we detect boxes with invalid size. This is not supported on TPUs. """ ymin, xmin, ymax, xmax = tf.split( boxes, num_or_size_splits=4, axis=1) height_check = tf.Assert(tf.reduce_all(ymax >= ymin), [ymin, ymax]) width_check = tf.Assert(tf.reduce_all(xmax >= xmin), [xmin, xmax]) with tf.control_dependencies([height_check, width_check]): boxes_tensor = tf.concat([ymin, xmin, ymax, xmax], axis=1) boxlist = box_list.BoxList(boxes_tensor) # TODO(b/149221748) Remove pruning when XLA supports assertions. boxlist = box_list_ops.prune_small_boxes(boxlist, 0) return boxlist.get()
Example #17
Source File: vgg_preprocessing.py From models with Apache License 2.0 | 5 votes |
def _crop(image, offset_height, offset_width, crop_height, crop_width): """Crops the given image using the provided offsets and sizes. Note that the method doesn't assume we know the input image size but it does assume we know the input image rank. Args: image: an image of shape [height, width, channels]. offset_height: a scalar tensor indicating the height offset. offset_width: a scalar tensor indicating the width offset. crop_height: the height of the cropped image. crop_width: the width of the cropped image. Returns: the cropped (and resized) image. Raises: InvalidArgumentError: if the rank is not 3 or if the image dimensions are less than the crop size. """ original_shape = tf.shape(image) rank_assertion = tf.Assert( tf.equal(tf.rank(image), 3), ['Rank of image must be equal to 3.']) with tf.control_dependencies([rank_assertion]): cropped_shape = tf.stack([crop_height, crop_width, original_shape[2]]) size_assertion = tf.Assert( tf.logical_and( tf.greater_equal(original_shape[0], crop_height), tf.greater_equal(original_shape[1], crop_width)), ['Crop size greater than the image size.']) offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0])) # Use tf.slice instead of crop_to_bounding box as it accepts tensors to # define the crop size. with tf.control_dependencies([size_assertion]): image = tf.slice(image, offsets, cropped_shape) return tf.reshape(image, cropped_shape)
Example #18
Source File: preprocess.py From language with Apache License 2.0 | 5 votes |
def pad_to_length(tensor, target_length): """Pads a 1-D Tensor with zeros to the target length.""" pad_amt = target_length - tf.size(tensor) # Assert that pad_amt is non-negative. assert_op = tf.Assert(pad_amt >= 0, ["\nERROR: len(tensor) > target_length.", pad_amt]) with tf.control_dependencies([assert_op]): padded = tf.pad(tensor, [[0, pad_amt]]) padded.set_shape([target_length]) return padded
Example #19
Source File: imagenet.py From tensor2tensor with Apache License 2.0 | 5 votes |
def _crop(image, offset_height, offset_width, crop_height, crop_width): """Crops the given image using the provided offsets and sizes. Note that the method doesn't assume we know the input image size but it does assume we know the input image rank. Args: image: `Tensor` image of shape [height, width, channels]. offset_height: `Tensor` indicating the height offset. offset_width: `Tensor` indicating the width offset. crop_height: the height of the cropped image. crop_width: the width of the cropped image. Returns: the cropped (and resized) image. Raises: InvalidArgumentError: if the rank is not 3 or if the image dimensions are less than the crop size. """ original_shape = tf.shape(image) rank_assertion = tf.Assert( tf.equal(tf.rank(image), 3), ["Rank of image must be equal to 3."]) with tf.control_dependencies([rank_assertion]): cropped_shape = tf.stack([crop_height, crop_width, original_shape[2]]) size_assertion = tf.Assert( tf.logical_and( tf.greater_equal(original_shape[0], crop_height), tf.greater_equal(original_shape[1], crop_width)), ["Crop size greater than the image size."]) offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0])) # Use tf.slice instead of crop_to_bounding box as it accepts tensors to # define the crop size. with tf.control_dependencies([size_assertion]): image = tf.slice(image, offsets, cropped_shape) return tf.reshape(image, cropped_shape)
Example #20
Source File: faster_rcnn_resnet_v1_feature_extractor.py From models with Apache License 2.0 | 4 votes |
def _extract_proposal_features(self, preprocessed_inputs, scope): """Extracts first stage RPN features. Args: preprocessed_inputs: A [batch, height, width, channels] float32 tensor representing a batch of images. scope: A scope name. Returns: rpn_feature_map: A tensor with shape [batch, height, width, depth] activations: A dictionary mapping feature extractor tensor names to tensors Raises: InvalidArgumentError: If the spatial size of `preprocessed_inputs` (height or width) is less than 33. ValueError: If the created network is missing the required activation. """ if len(preprocessed_inputs.get_shape().as_list()) != 4: raise ValueError('`preprocessed_inputs` must be 4 dimensional, got a ' 'tensor of shape %s' % preprocessed_inputs.get_shape()) shape_assert = tf.Assert( tf.logical_and( tf.greater_equal(tf.shape(preprocessed_inputs)[1], 33), tf.greater_equal(tf.shape(preprocessed_inputs)[2], 33)), ['image size must at least be 33 in both height and width.']) with tf.control_dependencies([shape_assert]): # Disables batchnorm for fine-tuning with smaller batch sizes. # TODO(chensun): Figure out if it is needed when image # batch size is bigger. with slim.arg_scope( resnet_utils.resnet_arg_scope( batch_norm_epsilon=1e-5, batch_norm_scale=True, activation_fn=self._activation_fn, weight_decay=self._weight_decay)): with tf.variable_scope( self._architecture, reuse=self._reuse_weights) as var_scope: _, activations = self._resnet_model( preprocessed_inputs, num_classes=None, is_training=self._train_batch_norm, global_pool=False, output_stride=self._first_stage_features_stride, spatial_squeeze=False, scope=var_scope) handle = scope + '/%s/block3' % self._architecture return activations[handle], activations
Example #21
Source File: image_utils.py From magenta with Apache License 2.0 | 4 votes |
def _crop(image, offset_height, offset_width, crop_height, crop_width): """Crops the given image using the provided offsets and sizes. Note that the method doesn't assume we know the input image size but it does assume we know the input image rank. Args: image: an image of shape [height, width, channels]. offset_height: a scalar tensor indicating the height offset. offset_width: a scalar tensor indicating the width offset. crop_height: the height of the cropped image. crop_width: the width of the cropped image. Returns: the cropped (and resized) image. Raises: InvalidArgumentError: if the rank is not 3 or if the image dimensions are less than the crop size. """ original_shape = tf.shape(image) rank_assertion = tf.Assert( tf.equal(tf.rank(image), 3), ['Rank of image must be equal to 3.']) with tf.control_dependencies([rank_assertion]): cropped_shape = tf.stack([crop_height, crop_width, original_shape[2]]) size_assertion = tf.Assert( tf.logical_and( tf.greater_equal(original_shape[0], crop_height), tf.greater_equal(original_shape[1], crop_width)), ['Crop size greater than the image size.']) offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0])) # Use tf.strided_slice instead of crop_to_bounding box as it accepts tensors # to define the crop size. with tf.control_dependencies([size_assertion]): image = tf.strided_slice(image, offsets, offsets + cropped_shape, strides=tf.ones_like(offsets)) return tf.reshape(image, cropped_shape)
Example #22
Source File: lstm_utils.py From magenta with Apache License 2.0 | 4 votes |
def maybe_split_sequence_lengths(sequence_length, num_splits, total_length): """Validates and splits `sequence_length`, if necessary. Returned value must be used in graph for all validations to be executed. Args: sequence_length: A batch of sequence lengths, either sized `[batch_size]` and equal to either 0 or `total_length`, or sized `[batch_size, num_splits]`. num_splits: The scalar number of splits of the full sequences. total_length: The scalar total sequence length (potentially padded). Returns: sequence_length: If input shape was `[batch_size, num_splits]`, returns the same Tensor. Otherwise, returns a Tensor of that shape with each input length in the batch divided by `num_splits`. Raises: ValueError: If `sequence_length` is not shaped `[batch_size]` or `[batch_size, num_splits]`. tf.errors.InvalidArgumentError: If `sequence_length` is shaped `[batch_size]` and all values are not either 0 or `total_length`. """ if sequence_length.shape.ndims == 1: if total_length % num_splits != 0: raise ValueError( '`total_length` must be evenly divisible by `num_splits`.') with tf.control_dependencies( [tf.Assert( tf.reduce_all( tf.logical_or(tf.equal(sequence_length, 0), tf.equal(sequence_length, total_length))), data=[sequence_length])]): sequence_length = ( tf.tile(tf.expand_dims(sequence_length, axis=1), [1, num_splits]) // num_splits) elif sequence_length.shape.ndims == 2: with tf.control_dependencies([ tf.assert_less_equal( sequence_length, tf.constant(total_length // num_splits, tf.int32), message='Segment length cannot be more than ' '`total_length / num_splits`.')]): sequence_length = tf.identity(sequence_length) sequence_length.set_shape([sequence_length.shape[0], num_splits]) else: raise ValueError( 'Sequence lengths must be given as a vector or a 2D Tensor whose ' 'second dimension size matches its initial hierarchical split. Got ' 'shape: %s' % sequence_length.shape.as_list()) return sequence_length