Python tensorflow.python.util.nest.assert_same_structure() Examples

The following are 13 code examples of tensorflow.python.util.nest.assert_same_structure(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.util.nest , or try the search function .
Example #1
Source File: rnn_cell_impl.py    From lambda-packs with MIT License 6 votes vote down vote up
def __call__(self, inputs, state, scope=None):
    """Run the cell and add its inputs to its outputs.

    Args:
      inputs: cell inputs.
      state: cell state.
      scope: optional cell scope.

    Returns:
      Tuple of cell outputs and new state.

    Raises:
      TypeError: If cell inputs and outputs have different structure (type).
      ValueError: If cell inputs and outputs have different structure (value).
    """
    outputs, new_state = self._cell(inputs, state, scope=scope)
    nest.assert_same_structure(inputs, outputs)
    # Ensure shapes match
    def assert_shape_match(inp, out):
      inp.get_shape().assert_is_compatible_with(out.get_shape())
    nest.map_structure(assert_shape_match, inputs, outputs)
    res_outputs = nest.map_structure(
        lambda inp, out: inp + out, inputs, outputs)
    return (res_outputs, new_state) 
Example #2
Source File: rnn_cell.py    From lambda-packs with MIT License 6 votes vote down vote up
def __call__(self, inputs, state, scope=None):
    """Run the cell and add its inputs to its outputs.

    Args:
      inputs: cell inputs.
      state: cell state.
      scope: optional cell scope.

    Returns:
      Tuple of cell outputs and new state.

    Raises:
      TypeError: If cell inputs and outputs have different structure (type).
      ValueError: If cell inputs and outputs have different structure (value).
    """
    outputs, new_state = self._cell(inputs, state, scope=scope)
    nest.assert_same_structure(inputs, outputs)
    # Ensure shapes match
    def assert_shape_match(inp, out):
      inp.get_shape().assert_is_compatible_with(out.get_shape())
    nest.map_structure(assert_shape_match, inputs, outputs)
    res_outputs = nest.map_structure(self._highway, inputs, outputs)
    return (res_outputs, new_state) 
Example #3
Source File: rnn_cell.py    From Multiview2Novelview with MIT License 6 votes vote down vote up
def __call__(self, inputs, state, scope=None):
    """Run the cell and add its inputs to its outputs.
    Args:
      inputs: cell inputs.
      state: cell state.
      scope: optional cell scope.
    Returns:
      Tuple of cell outputs and new state.
    Raises:
      TypeError: If cell inputs and outputs have different structure (type).
      ValueError: If cell inputs and outputs have different structure (value).
    """
    outputs, new_state = self._cell(inputs, state, scope=scope)
    nest.assert_same_structure(inputs, outputs)
    # Ensure shapes match
    def assert_shape_match(inp, out):
      inp.get_shape().assert_is_compatible_with(out.get_shape())
    nest.map_structure(assert_shape_match, inputs, outputs)
    res_outputs = nest.map_structure(self._highway, inputs, outputs)
    return (res_outputs, new_state) 
Example #4
Source File: gnmt_model.py    From parallax with Apache License 2.0 6 votes vote down vote up
def gnmt_residual_fn(inputs, outputs):
  """Residual function that handles different inputs and outputs inner dims.

  Args:
    inputs: cell inputs, this is actual inputs concatenated with the attention
      vector.
    outputs: cell outputs

  Returns:
    outputs + actual inputs
  """
  def split_input(inp, out):
    out_dim = out.get_shape().as_list()[-1]
    inp_dim = inp.get_shape().as_list()[-1]
    return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1)
  actual_inputs, _ = nest.map_structure(split_input, inputs, outputs)
  def assert_shape_match(inp, out):
    inp.get_shape().assert_is_compatible_with(out.get_shape())
  nest.assert_same_structure(actual_inputs, outputs)
  nest.map_structure(assert_shape_match, actual_inputs, outputs)
  return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs) 
Example #5
Source File: mod_core_rnn_cell_impl.py    From RGAN with MIT License 6 votes vote down vote up
def __call__(self, inputs, state, scope=None):
    """Run the cell and add its inputs to its outputs.

    Args:
      inputs: cell inputs.
      state: cell state.
      scope: optional cell scope.

    Returns:
      Tuple of cell outputs and new state.

    Raises:
      TypeError: If cell inputs and outputs have different structure (type).
      ValueError: If cell inputs and outputs have different structure (value).
    """
    outputs, new_state = self._cell(inputs, state, scope=scope)
    nest.assert_same_structure(inputs, outputs)
    # Ensure shapes match
    def assert_shape_match(inp, out):
      inp.get_shape().assert_is_compatible_with(out.get_shape())
    nest.map_structure(assert_shape_match, inputs, outputs)
    res_outputs = nest.map_structure(
        lambda inp, out: inp + out, inputs, outputs)
    return (res_outputs, new_state) 
Example #6
Source File: discriminator.py    From cvpr18-caption-eval with MIT License 6 votes vote down vote up
def __call__(self, inputs, state, scope=None):
    """Run the cell and add its inputs to its outputs.
    Args:
      inputs: cell inputs.
      state: cell state.
      scope: optional cell scope.
    Returns:
      Tuple of cell outputs and new state.
    Raises:
      TypeError: If cell inputs and outputs have different structure (type).
      ValueError: If cell inputs and outputs have different structure (value).
    """
    outputs, new_state = self._cell(inputs, state, scope=scope)
    nest.assert_same_structure(inputs, outputs)
    # Ensure shapes match
    def assert_shape_match(inp, out):
      inp.get_shape().assert_is_compatible_with(out.get_shape())
    nest.map_structure(assert_shape_match, inputs, outputs)
    res_outputs = nest.map_structure(
        lambda inp, out: inp + out, inputs, outputs)
    return (res_outputs, new_state) 
Example #7
Source File: bridges.py    From NJUNMT-tf with Apache License 2.0 6 votes vote down vote up
def assert_state_is_compatible(expected_state, state):
    """Asserts that states are compatible.

    Args:
        expected_state: The reference state.
        state: The state that must be compatible with :obj:`expected_state`.

    Raises:
      ValueError: if the states are incompatible.
    """
    # Check structure compatibility.
    nest.assert_same_structure(expected_state, state)

    # Check shape compatibility.
    expected_state_flat = nest.flatten(expected_state)
    state_flat = nest.flatten(state)

    for x, y in zip(expected_state_flat, state_flat):
        if tensor_util.is_tensor(x):
            with_same_shape(x, y) 
Example #8
Source File: gnmt.py    From OpenSeq2Seq with Apache License 2.0 6 votes vote down vote up
def gnmt_residual_fn(inputs, outputs):
  """Residual function that handles different inputs and outputs inner dims.

  Args:
    inputs: cell inputs, this is actual inputs concatenated with the attention
      vector.
    outputs: cell outputs

  Returns:
    outputs + actual inputs
  """
  def split_input(inp, out):
    out_dim = out.get_shape().as_list()[-1]
    inp_dim = inp.get_shape().as_list()[-1]
    return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1)

  actual_inputs, _ = nest.map_structure(split_input, inputs, outputs)

  def assert_shape_match(inp, out):
    inp.get_shape().assert_is_compatible_with(out.get_shape())

  nest.assert_same_structure(actual_inputs, outputs)
  nest.map_structure(assert_shape_match, actual_inputs, outputs)
  return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs) 
Example #9
Source File: gnmt_model.py    From active-qa with Apache License 2.0 6 votes vote down vote up
def gnmt_residual_fn(inputs, outputs):
  """Residual function that handles different inputs and outputs inner dims.

  Args:
    inputs: cell inputs, this is actual inputs concatenated with the attention
      vector.
    outputs: cell outputs

  Returns:
    outputs + actual inputs
  """

  def split_input(inp, out):
    out_dim = out.get_shape().as_list()[-1]
    inp_dim = inp.get_shape().as_list()[-1]
    return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1)

  actual_inputs, _ = nest.map_structure(split_input, inputs, outputs)

  def assert_shape_match(inp, out):
    inp.get_shape().assert_is_compatible_with(out.get_shape())

  nest.assert_same_structure(actual_inputs, outputs)
  nest.map_structure(assert_shape_match, actual_inputs, outputs)
  return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs) 
Example #10
Source File: rnn_cell_impl.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 6 votes vote down vote up
def __call__(self, inputs, state, scope=None):
    """Run the cell and then apply the residual_fn on its inputs to its outputs.

    Args:
      inputs: cell inputs.
      state: cell state.
      scope: optional cell scope.

    Returns:
      Tuple of cell outputs and new state.

    Raises:
      TypeError: If cell inputs and outputs have different structure (type).
      ValueError: If cell inputs and outputs have different structure (value).
    """
    outputs, new_state = self._cell(inputs, state, scope=scope)
    # Ensure shapes match
    def assert_shape_match(inp, out):
      inp.get_shape().assert_is_compatible_with(out.get_shape())
    def default_residual_fn(inputs, outputs):
      nest.assert_same_structure(inputs, outputs)
      nest.map_structure(assert_shape_match, inputs, outputs)
      return nest.map_structure(lambda inp, out: inp + out, inputs, outputs)
    res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs)
    return (res_outputs, new_state) 
Example #11
Source File: bridge.py    From tensorflow_end2end_speech_recognition with MIT License 5 votes vote down vote up
def _create(self):
        nest.assert_same_structure(self.encoder_outputs.final_state,
                                   self.decoder_state_size)
        return self.encoder_outputs.final_state 
Example #12
Source File: dataset_ops.py    From lambda-packs with MIT License 5 votes vote down vote up
def make_initializer(self, dataset):
    """Returns a `tf.Operation` that initializes this iterator on `dataset`.

    Args:
      dataset: A `Dataset` with compatible structure to this iterator.

    Returns:
      A `tf.Operation` that can be run to initialize this iterator on the given
      `dataset`.

    Raises:
      TypeError: If `dataset` and this iterator do not have a compatible
        element structure.
    """
    nest.assert_same_structure(self._output_types, dataset.output_types)
    nest.assert_same_structure(self._output_shapes, dataset.output_shapes)
    for iterator_dtype, dataset_dtype in zip(
        nest.flatten(self._output_types), nest.flatten(dataset.output_types)):
      if iterator_dtype != dataset_dtype:
        raise TypeError(
            "Expected output types %r but got dataset with output types %r." %
            (self._output_types, dataset.output_types))
    for iterator_shape, dataset_shape in zip(
        nest.flatten(self._output_shapes), nest.flatten(dataset.output_shapes)):
      if not iterator_shape.is_compatible_with(dataset_shape):
        raise TypeError("Expected output shapes compatible with %r but got "
                        "dataset with output shapes %r." %
                        (self._output_shapes, dataset.output_shapes))
    return gen_dataset_ops.make_iterator(dataset.make_dataset_resource(),
                                         self._iterator_resource) 
Example #13
Source File: nest_test.py    From deep_image_model with Apache License 2.0 4 votes vote down vote up
def testAssertSameStructure(self):
    structure1 = (((1, 2), 3), 4, (5, 6))
    structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
    structure_different_num_elements = ("spam", "eggs")
    structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
    nest.assert_same_structure(structure1, structure2)
    nest.assert_same_structure("abc", 1.0)
    nest.assert_same_structure("abc", np.array([0, 1]))
    nest.assert_same_structure("abc", tf.constant([0, 1]))

    with self.assertRaisesRegexp(
        ValueError, "don't have the same number of elements"):
      nest.assert_same_structure(structure1, structure_different_num_elements)

    with self.assertRaisesRegexp(
        ValueError, "don't have the same number of elements"):
      nest.assert_same_structure([0, 1], np.array([0, 1]))

    with self.assertRaisesRegexp(
        ValueError, "don't have the same number of elements"):
      nest.assert_same_structure(0, [0, 1])

    self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])

    with self.assertRaisesRegexp(
        ValueError, "don't have the same nested structure"):
      nest.assert_same_structure(structure1, structure_different_nesting)

    named_type_0 = collections.namedtuple("named_0", ("a", "b"))
    named_type_1 = collections.namedtuple("named_1", ("a", "b"))
    self.assertRaises(TypeError, nest.assert_same_structure,
                      (0, 1), named_type_0("a", "b"))

    nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b"))

    self.assertRaises(TypeError, nest.assert_same_structure,
                      named_type_0(3, 4), named_type_1(3, 4))

    with self.assertRaisesRegexp(
        ValueError, "don't have the same nested structure"):
      nest.assert_same_structure(named_type_0(3, 4), named_type_0([3], 4))

    with self.assertRaisesRegexp(
        ValueError, "don't have the same nested structure"):
      nest.assert_same_structure([[3], 4], [3, [4]])