Python sonnet.Conv2D() Examples

The following are 30 code examples of sonnet.Conv2D(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module sonnet , or try the search function .
Example #1
Source File: blocks_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_incompatible_higher_rank_inputs_raises(self,
                                                  use_edges,
                                                  use_receiver_nodes,
                                                  use_sender_nodes,
                                                  use_globals,
                                                  field):
    """A exception should be raised if the inputs have incompatible shapes."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.EdgeBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_edges=use_edges,
        use_receiver_nodes=use_receiver_nodes,
        use_sender_nodes=use_sender_nodes,
        use_globals=use_globals
    )
    with self.assertRaisesRegexp(
        tf.errors.InvalidArgumentError, "Dimensions of inputs should match"):
      network(input_graph) 
Example #2
Source File: modules_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_incompatible_higher_rank_partial_outputs_raises(self):
    """A error should be raised if partial outputs have incompatible shapes."""
    input_graph = self._get_shaped_input_graph()
    edge_model_fn, node_model_fn, global_model_fn = self._get_shaped_model_fns()
    edge_model_fn_2 = functools.partial(
        snt.Conv2D, output_channels=10, kernel_shape=[3, 3], stride=[1, 2])
    graph_network = modules.GraphNetwork(
        edge_model_fn_2, node_model_fn, global_model_fn)
    with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"):
      graph_network(input_graph)
    node_model_fn_2 = functools.partial(
        snt.Conv2D, output_channels=10, kernel_shape=[3, 3], stride=[1, 2])
    graph_network = modules.GraphNetwork(
        edge_model_fn, node_model_fn_2, global_model_fn)
    with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"):
      graph_network(input_graph) 
Example #3
Source File: blocks_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_incompatible_higher_rank_inputs_no_raise(self,
                                                    use_edges,
                                                    use_nodes,
                                                    use_globals,
                                                    field):
    """No exception should occur if a differently shapped field is not used."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.GlobalBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_edges=use_edges,
        use_nodes=use_nodes,
        use_globals=use_globals
    )
    self._assert_build_and_run(network, input_graph) 
Example #4
Source File: blocks_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_incompatible_higher_rank_inputs_raises(self,
                                                  use_edges,
                                                  use_nodes,
                                                  use_globals,
                                                  field):
    """A exception should be raised if the inputs have incompatible shapes."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.GlobalBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_edges=use_edges,
        use_nodes=use_nodes,
        use_globals=use_globals
    )
    with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"):
      network(input_graph) 
Example #5
Source File: blocks_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_incompatible_higher_rank_inputs_raises(self,
                                                  use_received_edges,
                                                  use_sent_edges,
                                                  use_nodes,
                                                  use_globals,
                                                  field):
    """A exception should be raised if the inputs have incompatible shapes."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.NodeBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_received_edges=use_received_edges,
        use_sent_edges=use_sent_edges,
        use_nodes=use_nodes,
        use_globals=use_globals
    )
    with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"):
      network(input_graph) 
Example #6
Source File: blocks_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_incompatible_higher_rank_inputs_no_raise(self,
                                                    use_edges,
                                                    use_receiver_nodes,
                                                    use_sender_nodes,
                                                    use_globals,
                                                    field):
    """No exception should occur if a differently shapped field is not used."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.EdgeBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_edges=use_edges,
        use_receiver_nodes=use_receiver_nodes,
        use_sender_nodes=use_sender_nodes,
        use_globals=use_globals
    )
    self._assert_build_and_run(network, input_graph) 
Example #7
Source File: blocks_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_incompatible_higher_rank_inputs_raises(self,
                                                  use_edges,
                                                  use_receiver_nodes,
                                                  use_sender_nodes,
                                                  use_globals,
                                                  field):
    """A exception should be raised if the inputs have incompatible shapes."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.EdgeBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_edges=use_edges,
        use_receiver_nodes=use_receiver_nodes,
        use_sender_nodes=use_sender_nodes,
        use_globals=use_globals
    )
    with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"):
      network(input_graph) 
Example #8
Source File: modules_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_incompatible_higher_rank_partial_outputs_raises(self):
    """A error should be raised if partial outputs have incompatible shapes."""
    input_graph = self._get_shaped_input_graph()
    edge_model_fn, node_model_fn, global_model_fn = self._get_shaped_model_fns()
    edge_model_fn_2 = functools.partial(
        snt.Conv2D, output_channels=10, kernel_shape=[3, 3], stride=[1, 2])
    graph_network = modules.GraphNetwork(
        edge_model_fn_2, node_model_fn, global_model_fn)
    with self.assertRaisesRegexp(
        tf.errors.InvalidArgumentError, "Dimensions of inputs should match"):
      graph_network(input_graph)
    node_model_fn_2 = functools.partial(
        snt.Conv2D, output_channels=10, kernel_shape=[3, 3], stride=[1, 2])
    graph_network = modules.GraphNetwork(
        edge_model_fn, node_model_fn_2, global_model_fn)
    with self.assertRaisesRegexp(
        tf.errors.InvalidArgumentError, "Dimensions of inputs should match"):
      graph_network(input_graph) 
Example #9
Source File: blocks_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_incompatible_higher_rank_inputs_no_raise(self,
                                                    use_edges,
                                                    use_nodes,
                                                    use_globals,
                                                    field):
    """No exception should occur if a differently shapped field is not used."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.GlobalBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_edges=use_edges,
        use_nodes=use_nodes,
        use_globals=use_globals
    )
    self._assert_build_and_run(network, input_graph) 
Example #10
Source File: blocks_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_incompatible_higher_rank_inputs_raises(self,
                                                  use_edges,
                                                  use_nodes,
                                                  use_globals,
                                                  field):
    """A exception should be raised if the inputs have incompatible shapes."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.GlobalBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_edges=use_edges,
        use_nodes=use_nodes,
        use_globals=use_globals
    )
    with self.assertRaisesRegexp(
        tf.errors.InvalidArgumentError,
        "Dimensions of inputs should match"):
      network(input_graph) 
Example #11
Source File: blocks_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_incompatible_higher_rank_inputs_raises(self,
                                                  use_received_edges,
                                                  use_sent_edges,
                                                  use_nodes,
                                                  use_globals,
                                                  field):
    """A exception should be raised if the inputs have incompatible shapes."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.NodeBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_received_edges=use_received_edges,
        use_sent_edges=use_sent_edges,
        use_nodes=use_nodes,
        use_globals=use_globals
    )
    with self.assertRaisesRegexp(
        tf.errors.InvalidArgumentError,
        "Dimensions of inputs should match"):
      network(input_graph) 
Example #12
Source File: blocks_test.py    From graph_nets with Apache License 2.0 6 votes vote down vote up
def test_incompatible_higher_rank_inputs_no_raise(self,
                                                    use_edges,
                                                    use_receiver_nodes,
                                                    use_sender_nodes,
                                                    use_globals,
                                                    field):
    """No exception should occur if a differently shapped field is not used."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.replace(
        **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
    network = blocks.EdgeBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]),
        use_edges=use_edges,
        use_receiver_nodes=use_receiver_nodes,
        use_sender_nodes=use_sender_nodes,
        use_globals=use_globals
    )
    self._assert_build_and_run(network, input_graph) 
Example #13
Source File: dpf_kitti.py    From differentiable-particle-filters with MIT License 6 votes vote down vote up
def custom_build(self, inputs):
        """A custom build method to wrap into a sonnet Module."""
        outputs = snt.Conv2D(output_channels=16, kernel_shape=[7, 7], stride=[1, 1])(inputs)
        outputs = tf.nn.relu(outputs)
        outputs = snt.Conv2D(output_channels=16, kernel_shape=[5, 5], stride=[1, 2])(outputs)
        outputs = tf.nn.relu(outputs)
        outputs = snt.Conv2D(output_channels=16, kernel_shape=[5, 5], stride=[1, 2])(outputs)
        outputs = tf.nn.relu(outputs)
        outputs = snt.Conv2D(output_channels=16, kernel_shape=[5, 5], stride=[2, 2])(outputs)
        outputs = tf.nn.relu(outputs)
        outputs = tf.nn.dropout(outputs,  self.placeholders['keep_prob'])
        outputs = snt.BatchFlatten()(outputs)
        outputs = snt.Linear(128)(outputs)
        outputs = tf.nn.relu(outputs)

        return outputs 
Example #14
Source File: bounds_test.py    From interval-bound-propagation with Apache License 2.0 6 votes vote down vote up
def testConv2dIntervalBounds(self):
    m = snt.Conv2D(
        output_channels=1,
        kernel_shape=(2, 2),
        padding='VALID',
        stride=1,
        use_bias=True,
        initializers={
            'w': tf.constant_initializer(1.),
            'b': tf.constant_initializer(2.),
        })
    z = tf.constant([1, 2, 3, 4], dtype=tf.float32)
    z = tf.reshape(z, [1, 2, 2, 1])
    m(z)  # Connect to create weights.
    m = ibp.LinearConv2dWrapper(m)
    input_bounds = ibp.IntervalBounds(z - 1., z + 1.)
    output_bounds = m.propagate_bounds(input_bounds)
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      l, u = sess.run([output_bounds.lower, output_bounds.upper])
      l = l.item()
      u = u.item()
      self.assertAlmostEqual(8., l)
      self.assertAlmostEqual(16., u) 
Example #15
Source File: modules_test.py    From graph_nets with Apache License 2.0 5 votes vote down vote up
def _get_shaped_model_fns(self):
    edge_model_fn = functools.partial(
        snt.Conv2D, output_channels=10, kernel_shape=[3, 3])
    node_model_fn = functools.partial(
        snt.Conv2D, output_channels=8, kernel_shape=[3, 3])
    global_model_fn = functools.partial(
        snt.Conv2D, output_channels=7, kernel_shape=[3, 3])
    return edge_model_fn, node_model_fn, global_model_fn 
Example #16
Source File: classifier_mnist.py    From kfac with Apache License 2.0 5 votes vote down vote up
def _build(self, inputs):

    if FLAGS.l2_reg:
      regularizers = {'w': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),
                      'b': lambda w: FLAGS.l2_reg*tf.nn.l2_loss(w),}
    else:
      regularizers = None

    reshape = snt.BatchReshape([28, 28, 1])

    conv = snt.Conv2D(2, 5, padding=snt.SAME, regularizers=regularizers)
    act = _NONLINEARITY(conv(reshape(inputs)))

    pool = tf.nn.pool(act, window_shape=(2, 2), pooling_type=_POOL,
                      padding=snt.SAME, strides=(2, 2))

    conv = snt.Conv2D(4, 5, padding=snt.SAME, regularizers=regularizers)
    act = _NONLINEARITY(conv(pool))

    pool = tf.nn.pool(act, window_shape=(2, 2), pooling_type=_POOL,
                      padding=snt.SAME, strides=(2, 2))

    flatten = snt.BatchFlatten()(pool)

    linear = snt.Linear(32, regularizers=regularizers)(flatten)

    return snt.Linear(10, regularizers=regularizers)(linear) 
Example #17
Source File: modules_test.py    From graph_nets with Apache License 2.0 5 votes vote down vote up
def _get_shaped_model_fns(self):
    edge_model_fn = functools.partial(
        snt.Conv2D, output_channels=10, kernel_shape=[3, 3])
    node_model_fn = functools.partial(
        snt.Conv2D, output_channels=8, kernel_shape=[3, 3])
    global_model_fn = functools.partial(
        snt.Conv2D, output_channels=7, kernel_shape=[3, 3])
    return edge_model_fn, node_model_fn, global_model_fn 
Example #18
Source File: nn.py    From magenta with Apache License 2.0 5 votes vote down vote up
def _build(self, x):
    h = x
    for unused_i, l in enumerate(self.layers):
      h = tf.nn.relu(snt.Conv2D(l[0], l[1], l[2])(h))

    h_shape = h.get_shape().as_list()
    h = tf.reshape(h, [-1, h_shape[1] * h_shape[2] * h_shape[3]])
    for _, l in enumerate(self.padding_linear_layers):
      h = snt.Linear(l)(h)
    pre_z = snt.Linear(2 * self.n_latent)(h)
    mu = pre_z[:, :self.n_latent]
    sigma = tf.nn.softplus(pre_z[:, self.n_latent:])
    return mu, sigma 
Example #19
Source File: nn.py    From magenta with Apache License 2.0 5 votes vote down vote up
def _build(self, x):
    h = x
    for unused_i, l in enumerate(self.layers):
      h = tf.nn.relu(snt.Conv2D(l[0], l[1], l[2])(h))

    h_shape = h.get_shape().as_list()
    h = tf.reshape(h, [-1, h_shape[1] * h_shape[2] * h_shape[3]])
    logits = snt.Linear(self.output_size)(h)
    return logits 
Example #20
Source File: blocks_test.py    From graph_nets with Apache License 2.0 5 votes vote down vote up
def test_compatible_higher_rank_no_raise(self):
    """No exception should occur with higher ranks tensors."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.map(lambda v: tf.transpose(v, [0, 2, 1, 3]))
    network = blocks.GlobalBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]))
    self._assert_build_and_run(network, input_graph) 
Example #21
Source File: model.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def _inputs_for_observed_module(self, subgraph):
    """Extracts input tensors from a connected Sonnet module.

    This default implementation supports common layer types, but should be
    overridden if custom layer types are to be supported.

    Args:
      subgraph: `snt.ConnectedSubGraph` specifying the Sonnet module being
        connected, and its inputs and outputs.

    Returns:
      List of input tensors, or None if not a supported Sonnet module.
    """
    m = subgraph.module
    # Only support a few operations for now.
    if not (isinstance(m, snt.BatchReshape) or
            isinstance(m, snt.Linear) or
            isinstance(m, snt.Conv1D) or
            isinstance(m, snt.Conv2D) or
            isinstance(m, snt.BatchNorm) or
            isinstance(m, layers.ImageNorm)):
      return None

    if isinstance(m, snt.BatchNorm):
      return subgraph.inputs['input_batch'],
    else:
      return subgraph.inputs['inputs'], 
Example #22
Source File: blocks_test.py    From graph_nets with Apache License 2.0 5 votes vote down vote up
def test_compatible_higher_rank_no_raise(self):
    """No exception should occur with higher ranks tensors."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.map(lambda v: tf.transpose(v, [0, 2, 1, 3]))
    network = blocks.NodeBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]))
    self._assert_build_and_run(network, input_graph) 
Example #23
Source File: model.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def _wrapper_for_observed_module(self, subgraph):
    """Creates a wrapper for a connected Sonnet module.

    This default implementation supports common layer types, but should be
    overridden if custom layer types are to be supported.

    Args:
      subgraph: `snt.ConnectedSubGraph` specifying the Sonnet module being
        connected, and its inputs and outputs.

    Returns:
      `ibp.VerifiableWrapper` for the Sonnet module.
    """
    m = subgraph.module
    if isinstance(m, snt.BatchReshape):
      shape = subgraph.outputs.get_shape()[1:].as_list()
      return verifiable_wrapper.BatchReshapeWrapper(m, shape)
    elif isinstance(m, snt.Linear):
      return verifiable_wrapper.LinearFCWrapper(m)
    elif isinstance(m, snt.Conv1D):
      return verifiable_wrapper.LinearConv1dWrapper(m)
    elif isinstance(m, snt.Conv2D):
      return verifiable_wrapper.LinearConv2dWrapper(m)
    elif isinstance(m, layers.ImageNorm):
      return verifiable_wrapper.ImageNormWrapper(m)
    else:
      assert isinstance(m, snt.BatchNorm)
      return verifiable_wrapper.BatchNormWrapper(m) 
Example #24
Source File: blocks_test.py    From graph_nets with Apache License 2.0 5 votes vote down vote up
def test_compatible_higher_rank_no_raise(self):
    """No exception should occur with higher ranks tensors."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.map(lambda v: tf.transpose(v, [0, 2, 1, 3]))
    network = blocks.EdgeBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]))
    self._assert_build_and_run(network, input_graph) 
Example #25
Source File: verifiable_wrapper.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def __init__(self, module):
    if not isinstance(module, snt.Conv2D):
      raise ValueError('Cannot wrap {} with a LinearConv2dWrapper.'.format(
          module))
    super(LinearConv2dWrapper, self).__init__(module) 
Example #26
Source File: meta_test.py    From learning-to-learn with Apache License 2.0 5 votes vote down vote up
def testConvolutional(self):
    """Tests L2L applied to problem with convolutions."""
    kernel_shape = 4
    def convolutional_problem():
      conv = snt.Conv2D(output_channels=1,
                        kernel_shape=kernel_shape,
                        stride=1,
                        name="conv")
      output = conv(tf.random_normal((100, 100, 3, 10)))
      return tf.reduce_sum(output)

    net_config = {
        "conv": {
            "net": "KernelDeepLSTM",
            "net_options": {
                "kernel_shape": [kernel_shape] * 2,
                "layers": (5,)
            },
        },
    }
    optimizer = meta.MetaOptimizer(**net_config)
    minimize_ops = optimizer.meta_minimize(
        convolutional_problem, 3,
        net_assignments=[("conv", ["conv/w"])]
    )
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      train(sess, minimize_ops, 1, 2) 
Example #27
Source File: fastlin_test.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def testConv2dSymbolicBounds(self):
    m = snt.Conv2D(
        output_channels=1,
        kernel_shape=(2, 2),
        padding='VALID',
        stride=1,
        use_bias=True,
        initializers={
            'w': tf.constant_initializer(1.),
            'b': tf.constant_initializer(2.),
        })
    z = tf.constant([1, 2, 3, 4], dtype=tf.float32)
    z = tf.reshape(z, [1, 2, 2, 1])
    m(z)  # Connect to create weights.
    m = ibp.LinearConv2dWrapper(m)
    input_bounds = ibp.IntervalBounds(z - 1., z + 1.)
    input_bounds = ibp.SymbolicBounds.convert(input_bounds)
    output_bounds = m.propagate_bounds(input_bounds)
    output_bounds = ibp.IntervalBounds.convert(output_bounds)
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      l, u = sess.run([output_bounds.lower, output_bounds.upper])
      l = l.item()
      u = u.item()
      self.assertAlmostEqual(8., l)
      self.assertAlmostEqual(16., u) 
Example #28
Source File: blocks_test.py    From graph_nets with Apache License 2.0 5 votes vote down vote up
def test_compatible_higher_rank_no_raise(self):
    """No exception should occur with higher ranks tensors."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.map(lambda v: tf.transpose(v, [0, 2, 1, 3]))
    network = blocks.GlobalBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]))
    self._assert_build_and_run(network, input_graph) 
Example #29
Source File: crown_test.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def testConv2dBackwardBounds(self):
    m = snt.Conv2D(
        output_channels=1,
        kernel_shape=(2, 2),
        padding='VALID',
        stride=1,
        use_bias=True,
        initializers={
            'w': tf.constant_initializer(1.),
            'b': tf.constant_initializer(2.),
        })
    z = tf.constant([1, 2, 3, 4], dtype=tf.float32)
    z = tf.reshape(z, [1, 2, 2, 1])
    m(z)  # Connect to create weights.
    m = ibp.LinearConv2dWrapper(m)
    input_bounds = ibp.IntervalBounds(z - 1., z + 1.)
    m.propagate_bounds(input_bounds)   # Create IBP bounds.
    crown_init_bounds = _generate_identity_spec([m], shape=(1, 1, 1, 1, 1))
    output_bounds = m.propagate_bounds(crown_init_bounds)
    concrete_bounds = output_bounds.concretize()
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      l, u = sess.run([concrete_bounds.lower, concrete_bounds.upper])
      l = l.item()
      u = u.item()
      self.assertAlmostEqual(8., l)
      self.assertAlmostEqual(16., u) 
Example #30
Source File: blocks_test.py    From graph_nets with Apache License 2.0 5 votes vote down vote up
def test_compatible_higher_rank_no_raise(self):
    """No exception should occur with higher ranks tensors."""
    input_graph = self._get_shaped_input_graph()
    input_graph = input_graph.map(lambda v: tf.transpose(v, [0, 2, 1, 3]))
    network = blocks.NodeBlock(
        functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]))
    self._assert_build_and_run(network, input_graph)