Python jax.numpy.concatenate() Examples

The following are 14 code examples of jax.numpy.concatenate(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax.numpy , or try the search function .
Example #1
Source File: modules.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def GRUCell(carry_size, param_init):
    @parametrized
    def gru_cell(carry, x):
        def param(name):
            return parameter((x.shape[1] + carry_size, carry_size), param_init, name)

        both = np.concatenate((x, carry), axis=1)
        update = sigmoid(np.dot(both, param('update_kernel')))
        reset = sigmoid(np.dot(both, param('reset_kernel')))
        both_reset_carry = np.concatenate((x, reset * carry), axis=1)
        compute = np.tanh(np.dot(both_reset_carry, param('compute_kernel')))
        out = update * compute + (1 - update) * carry
        return out, out

    def carry_init(batch_size):
        return np.zeros((batch_size, carry_size))

    return gru_cell, carry_init 
Example #2
Source File: test_core.py    From jaxnet with Apache License 2.0 6 votes vote down vote up
def test_submodule_order():
    @parametrized
    def net():
        p = Parameter(lambda key: jnp.zeros((1,)))
        a = p()
        b = parameter((2,), zeros)
        c = parameter((3,), zeros)
        d = parameter((4,), zeros)
        e = parameter((5,), zeros)
        f = parameter((6,), zeros)

        # must not mess up order (decided by first submodule call):
        k = p()

        return jnp.concatenate([a, f]) + jnp.concatenate([b, e]) + jnp.concatenate([c, d]) + k

    params = net.init_parameters(key=PRNGKey(0))

    assert jnp.zeros((1,)) == params.parameter0
    out = net.apply(params)
    assert (7,) == out.shape 
Example #3
Source File: proportion_test.py    From numpyro with Apache License 2.0 6 votes vote down vote up
def make_dataset(rng_key) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Make simulated dataset where potential customers who get a
    sales calls have ~2% higher chance of making another purchase.
    """
    key1, key2, key3 = random.split(rng_key, 3)

    num_calls = 51342
    num_no_calls = 48658

    made_purchase_got_called = dist.Bernoulli(0.084).sample(key1, sample_shape=(num_calls,))
    made_purchase_no_calls = dist.Bernoulli(0.061).sample(key2, sample_shape=(num_no_calls,))

    made_purchase = jnp.concatenate([made_purchase_got_called, made_purchase_no_calls])

    is_female = dist.Bernoulli(0.5).sample(key3, sample_shape=(num_calls + num_no_calls,))
    got_called = jnp.concatenate([jnp.ones(num_calls), jnp.zeros(num_no_calls)])
    design_matrix = jnp.hstack([jnp.ones((num_no_calls + num_calls, 1)),
                               got_called.reshape(-1, 1),
                               is_female.reshape(-1, 1)])

    return design_matrix, made_purchase 
Example #4
Source File: tabular_irl.py    From imitation with MIT License 5 votes vote down vote up
def _flatten(self, matrix_tups):
        """Flatten everything and concatenate it together."""
        out_vecs = [v.flatten() for t in matrix_tups for v in t]
        return jnp.concatenate(out_vecs) 
Example #5
Source File: tabular_irl.py    From imitation with MIT License 5 votes vote down vote up
def _flatten_batch(self, matrix_tups):
        """Flatten all except leading dim & concatenate results together in channel dim.

        (Channel dim is whatever the dim after the leading dim is)."""
        out_vecs = []
        for t in matrix_tups:
            for v in t:
                new_shape = (v.shape[0],)
                if len(v.shape) > 1:
                    new_shape = new_shape + (np.prod(v.shape[1:]),)
                out_vecs.append(v.reshape(new_shape))
        return jnp.concatenate(out_vecs, axis=1) 
Example #6
Source File: ops.py    From funsor with Apache License 2.0 5 votes vote down vote up
def _cat(dim, *x):
    if len(x) == 1:
        return x[0]
    return np.concatenate(x, axis=dim) 
Example #7
Source File: pixelcnn.py    From jaxnet with Apache License 2.0 5 votes vote down vote up
def concat_elu(x, axis=-1):
    return elu(jnp.concatenate((x, -x), axis)) 
Example #8
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def concatenate(self, tensors, axis=-1):
        values = [self.coerce(v, dtype=self.floatx()) for v in tensors]
        return np.concatenate(values, axis=int(axis)) 
Example #9
Source File: jax.py    From deepx with MIT License 5 votes vote down vote up
def concat(self, values, axis=-1):
        return self.concatenate(values, axis=axis) 
Example #10
Source File: jax_backend.py    From pyhf with Apache License 2.0 5 votes vote down vote up
def concatenate(self, sequence, axis=0):
        """
        Join a sequence of arrays along an existing axis.

        Args:
            sequence: sequence of tensors
            axis: dimension along which to concatenate

        Returns:
            output: the concatenated tensor

        """
        return np.concatenate(sequence, axis=axis) 
Example #11
Source File: util.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def _ravel_list(*leaves):
    leaves_metadata = tree_map(lambda l: pytree_metadata(
        jnp.ravel(l), jnp.shape(l), jnp.size(l), canonicalize_dtype(lax.dtype(l))), leaves)
    leaves_idx = jnp.cumsum(jnp.array((0,) + tuple(d.size for d in leaves_metadata)))

    def unravel_list(arr):
        return [jnp.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size),
                            m.shape).astype(m.dtype)
                for i, m in enumerate(leaves_metadata)]

    flat = jnp.concatenate([m.flat for m in leaves_metadata]) if leaves_metadata else jnp.array([])
    return flat, unravel_list 
Example #12
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def inv(self, y):
        z = matrix_to_tril_vec(y, diagonal=-1)
        return jnp.concatenate([z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1) 
Example #13
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def __call__(self, x):
        z = jnp.concatenate([x[..., :1], jnp.exp(x[..., 1:])], axis=-1)
        return jnp.cumsum(z, axis=-1) 
Example #14
Source File: transforms.py    From numpyro with Apache License 2.0 5 votes vote down vote up
def inv(self, y):
        x = jnp.log(y[..., 1:] - y[..., :-1])
        return jnp.concatenate([y[..., :1], x], axis=-1)