Python jax.numpy() Examples

The following are 18 code examples of jax.numpy(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module jax , or try the search function .
Example #1
Source File: jax.py    From trax with Apache License 2.0 6 votes vote down vote up
def nested_stack(objs, axis=0, np_module=np):
  """Stacks the numpy arrays inside any dicts/lists/tuples in `objs`.

  Args:
    objs: List of nested structures to stack.
    axis: Axis to stack along.
    np_module: numpy module to use - typically numpy or jax.numpy.

  Returns:
    An object with the same nested structure as each element of `objs`, with
    leaves stacked together into numpy arrays. Nones are propagated, i.e. if
    each element of the stacked sequence is None, the output will be None.
  """
  # nested_map the stacking operation, but stopping at level 1 so at tuples of
  # numpy arrays.
  return nested_map(
      lambda x: np_module.stack(x, axis=axis),
      nested_zip(objs),
      level=1,
  ) 
Example #2
Source File: jax_backend.py    From TensorNetwork with Apache License 2.0 6 votes vote down vote up
def __init__(self, dtype: Optional[np.dtype] = None) -> None:
    # pylint: disable=global-variable-undefined
    global libjax  # Jax module
    global jnp  # jax.numpy module
    global jsp  # jax.scipy module
    super(JaxBackend, self).__init__()
    try:
      #pylint: disable=import-outside-toplevel
      import jax
    except ImportError:
      raise ImportError("Jax not installed, please switch to a different "
                        "backend or install Jax.")
    libjax = jax
    jnp = libjax.numpy
    jsp = libjax.scipy
    self.name = "jax"
    self._dtype = np.dtype(dtype) if dtype is not None else None 
Example #3
Source File: signal.py    From SymJAX with Apache License 2.0 6 votes vote down vote up
def fourier_complex_morlet(bandwidths, centers, N):
    """Complex Morlet wavelet in Fourier

    Parameters
    ----------

    bandwidths: array
        the bandwidth of the wavelet

    centers: array
        the centers of the wavelet

    freqs: array (optional)
        the frequency sampling in radion going from 0 to pi and back to 0
        :param N:

    """

    freqs = T.linspace(0, 2 * numpy.pi, N)
    envelop = T.exp(-0.25 * (freqs - centers) ** 2 * bandwidths ** 2)
    H = (freqs <= numpy.pi).astype("float32")
    return envelop * H 
Example #4
Source File: base.py    From SymJAX with Apache License 2.0 6 votes vote down vote up
def update(self, update_value):
        """assign a new value for the variable"""
        new_value = symjax.current_graph().get(update_value)

        if self.shape != jax.numpy.shape(new_value):
            warnings.warn(
                "Variable and update {} {}".format(self, new_value)
                + "are not the same shape... attempting to reshape"
            )
            new_value = jax.numpy.reshape(new_value, self.shape)

        if hasattr(new_value, "dtype"):
            ntype = new_value.dtype
        else:
            ntype = type(new_value)
        if self.dtype != ntype:
            warnings.warn(
                "Variable and update {} {}".format(self, new_value)
                + "are not the same dtype... attempting to cast"
            )

            new_value = jax.numpy.astype(new_value, self.dtype)

        self._value = new_value 
Example #5
Source File: signal.py    From SymJAX with Apache License 2.0 5 votes vote down vote up
def freq_to_mel(f, option="linear"):
    # convert frequency to mel with
    if option == "linear":

        # linear part slope
        f_sp = 200.0 / 3

        # Fill in the log-scale part
        min_log_hz = 1000.0  # beginning of log region (Hz)
        min_log_mel = min_log_hz / f_sp  # same (Mels)
        logstep = numpy.log(6.4) / 27.0  # step size for log region
        mel = min_log_mel + T.log(f / min_log_hz) / logstep
        return T.where(f >= min_log_hz, mel, f / f_sp)
    else:
        return 2595 * T.log10(1 + f / 700) 
Example #6
Source File: jax_backend.py    From TensorNetwork with Apache License 2.0 5 votes vote down vote up
def expm(self, matrix: Tensor) -> Tensor:
    if len(matrix.shape) != 2:
      raise ValueError("input to numpy backend method `expm` has shape {}."
                       " Only matrices are supported.".format(matrix.shape))
    if matrix.shape[0] != matrix.shape[1]:
      raise ValueError("input to numpy backend method `expm` only supports"
                       " N*N matrix, {x}*{y} matrix is given".format(
                           x=matrix.shape[0], y=matrix.shape[1]))
    # pylint: disable=no-member
    return jsp.linalg.expm(matrix) 
Example #7
Source File: jax_backend.py    From TensorNetwork with Apache License 2.0 5 votes vote down vote up
def inv(self, matrix: Tensor) -> Tensor:
    if len(matrix.shape) > 2:
      raise ValueError("input to numpy backend method `inv` has shape {}."
                       " Only matrices are supported.".format(matrix.shape))
    return jnp.linalg.inv(matrix) 
Example #8
Source File: base.py    From SymJAX with Apache License 2.0 5 votes vote down vote up
def update_numpydoc(docstr, fun, op):
    """Transforms the numpy docstring to remove references of
       parameters that are supported by the numpy version but not the JAX version"""

    # Some numpy functions have an extra tab at the beginning of each line,
    # If this function is one of those we remove this extra tab from all the lines
    if not hasattr(op, "__code__"):
        return docstr
    if docstr[:4] == "    ":
        lines = docstr.split("\n")
        for idx, line in enumerate(lines):
            lines[idx] = line.replace("    ", "", 1)
        docstr = "\n".join(lines)

    begin_idx = docstr.find("Parameters")
    begin_idx = docstr.find("--\n", begin_idx) + 2
    end_idx = docstr.find("Returns", begin_idx)

    parameters = docstr[begin_idx:end_idx]
    param_list = parameters.replace("\n    ", "@@").split("\n")
    for idx, p in enumerate(param_list):
        param = p[: p.find(" : ")].split(", ")[0]
        if param not in op.__code__.co_varnames:
            param_list[idx] = ""
    param_list = [param for param in param_list if param != ""]
    parameters = "\n".join(param_list).replace("@@", "\n    ")
    return docstr[: begin_idx + 1] + parameters + docstr[end_idx - 2 :] 
Example #9
Source File: base.py    From SymJAX with Apache License 2.0 5 votes vote down vote up
def isvar(item):
    """ check whether an item (possibly a nested list etc) contains a variable
    (any subtype of Tensor) """
    # in case of nested lists/tuples, recursively call the function on it
    if isinstance(item, slice):
        return False
    elif isinstance(item, list) or isinstance(item, tuple):
        return numpy.sum([isvar(value) for value in item])
    # otherwise cheack that it is a subtype of Tensor or a Tracer and not
    # a callable
    else:
        cond1 = isinstance(item, Tensor) or type(item) in [Constant, OpTuple]
        #        cond2 = isinstance(item, jax.interpreters.partial_eval.JaxprTracer)
        cond3 = callable(item)
        return cond1 and not cond3  # (cond1 or cond2) and cond3 
Example #10
Source File: backend.py    From BERT with Apache License 2.0 5 votes vote down vote up
def dataset_as_numpy(*args, **kwargs):
  return backend()["dataset_as_numpy"](*args, **kwargs)


# For numpy and random modules, we need to call "backend()" lazily, only when
# the function is called -- so that it can be set by gin configs.
# (Otherwise, backend() is called on import before gin-config is parsed.)
# To do that, we make objects to encapsulated these modules. 
Example #11
Source File: signal.py    From SymJAX with Apache License 2.0 5 votes vote down vote up
def tukey(M, alpha=0.5):
    r"""Return a Tukey window, also known as a tapered cosine window.
    Parameters
    ----------
    M : int
        Number of points in the output window. If zero or less, an empty
        array is returned.
    alpha : float, optional
        Shape parameter of the Tukey window, representing the fraction of the
        window inside the cosine tapered region.
        If zero, the Tukey window is equivalent to a rectangular window.
        If one, the Tukey window is equivalent to a Hann window.
    Returns
    -------
    w : ndarray
        The window, with the maximum value normalized to 1 (though the value 1
        does not appear if `M` is even and `sym` is True).
    References
    ----------
    .. [1] Harris, Fredric J. (Jan 1978). "On the use of Windows for Harmonic
           Analysis with the Discrete Fourier Transform". Proceedings of the
           IEEE 66 (1): 51-83. :doi:`10.1109/PROC.1978.10837`
    .. [2] Wikipedia, "Window function",
           https://en.wikipedia.org/wiki/Window_function#Tukey_window
    """
    n = T.arange(0, M)
    width = int(numpy.floor(alpha * (M - 1) / 2.0))
    n1 = n[0 : width + 1]
    n2 = n[width + 1 : M - width - 1]
    n3 = n[M - width - 1 :]

    w1 = 0.5 * (1 + T.cos(numpy.pi * (-1 + 2.0 * n1 / alpha / (M - 1))))
    w2 = T.ones(n2.shape)
    w3 = 0.5 * (1 + T.cos(numpy.pi * (-2.0 / alpha + 1 + 2.0 * n3 / alpha / (M - 1))))

    w = T.concatenate((w1, w2, w3))

    return w 
Example #12
Source File: signal.py    From SymJAX with Apache License 2.0 5 votes vote down vote up
def littewood_paley_normalization(filter_bank, down=None, up=None):
    lp = T.abs(filter_bank).sum(0)
    freq = T.linspace(0, 2 * numpy.pi, lp.shape[0])
    down = 0 if down is None else down
    up = numpy.pi or up
    lp = T.where(T.logical_and(freq >= down, freq <= up), lp, 1)
    return filter_bank / lp 
Example #13
Source File: ops_special.py    From SymJAX with Apache License 2.0 5 votes vote down vote up
def _extract_image_patches(
    image, window_shape, hop=1, data_format="NCHW", mode="valid"
):
    if mode == "same":
        p1 = window_shape[0] - 1
        p2 = window_shape[1] - 1
        image = jnp.pad(
            image, [(0, 0), (0, 0), (p1 // 2, p1 - p1 // 2), (p2 // 2, p2 - p2 // 2)]
        )
    if not hasattr(hop, "__len__"):
        hop = (hop, hop)
    if data_format == "NCHW":

        # compute the number of windows in both dimensions
        N = (
            (image.shape[2] - window_shape[0]) // hop[0] + 1,
            (image.shape[3] - window_shape[1]) // hop[1] + 1,
        )

        # compute the base indices of a 2d patch
        patch = jnp.arange(numpy.prod(window_shape)).reshape(window_shape)
        offset = jnp.expand_dims(jnp.arange(window_shape[0]), 1)
        patch_indices = patch + offset * (image.shape[3] - window_shape[1])

        # create all the shifted versions of it
        ver_shifts = jnp.reshape(
            jnp.arange(N[0]) * hop[0] * image.shape[3], (-1, 1, 1, 1)
        )
        hor_shifts = jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        all_cols = patch_indices + jnp.reshape(jnp.arange(N[1]) * hop[1], (-1, 1, 1))
        indices = patch_indices + ver_shifts + hor_shifts

        # now extract shape (1, 1, H'W'a'b')
        flat_indices = jnp.reshape(indices, [1, 1, -1])
        # shape is now (N, C, W*H)
        flat_image = jnp.reshape(image, (image.shape[0], image.shape[1], -1))
        # shape is now (N, C)
        patches = jnp.take_along_axis(flat_image, flat_indices, 2)
        return jnp.reshape(patches, image.shape[:2] + N + tuple(window_shape))
    else:
        error 
Example #14
Source File: jax.py    From trax with Apache License 2.0 5 votes vote down vote up
def _to_numpy(x):
  """Converts non-NumPy tensors to NumPy arrays."""
  return x if isinstance(x, np.ndarray) else x.numpy() 
Example #15
Source File: backend.py    From BERT with Apache License 2.0 5 votes vote down vote up
def backend(name="jax"):
  name = name if not override_backend_name else override_backend_name
  if name == "numpy":
    return _NUMPY_BACKEND
  return _JAX_BACKEND 
Example #16
Source File: signal.py    From SymJAX with Apache License 2.0 4 votes vote down vote up
def power_to_db(S, ref=1.0, amin=1e-10, top_db=80.0):
    """Convert a power spectrogram (amplitude squared) to decibel (dB) units.

    https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html#power_to_db.
    This computes the scaling ``10 * log10(S / ref)`` in a numerically
    stable way.

    Parameters
    ----------
    S : numpy.ndarray
        inumpy.t power

    ref : scalar or callable
        If scalar, the amplitude `abs(S)` is scaled relative to `ref`:
        `10 * log10(S / ref)`.
        Zeros in the output correspond to positions where `S == ref`.

        If callable, the reference value is computed as `ref(S)`.

    amin : float > 0 [scalar]
        minimum threshold for `abs(S)` and `ref`

    top_db : float >= 0 [scalar]
        threshold the output at `top_db` below the peak:
        ``max(10 * log10(S)) - top_db``

    Returns
    -------
    S_db : numpy.ndarray
        ``S_db ~= 10 * log10(S) - 10 * log10(ref)``

    See Also
    --------
    perceptual_weighting
    db_to_power
    amplitude_to_db
    db_to_amplitude
    """
    ref_value = numpy.abs(ref)
    log_spec = 10.0 * T.log10(T.maximum(amin, S) / T.maximum(amin, ref))
    if top_db is not None:
        if top_db < 0:
            error
        return T.maximum(log_spec, log_spec.max() - top_db)
    else:
        return log_spec


# Now some filter-bank and additional Time-Frequency Repr. 
Example #17
Source File: signal.py    From SymJAX with Apache License 2.0 4 votes vote down vote up
def stft(signal, window, hop, apod=T.ones, nfft=None, mode="valid"):
    """
    Compute the Shoft-Time-Fourier-Transform of a
    signal given the window length, hop and additional
    parameters.

    Parameters
    ----------

        signal: array
            the signal (possibly stacked of signals)

        window: int
            the window length to be considered for the fft

        hop: int
            the amount by which the window is moved

        apod: func
            a function that takes an integer as inumpy.t and return
            the apodization window of the same length

        nfft: int (optional)
            the number of bin that the fft on the window will use.
            If not given it is set the same as window.

        mode: 'valid', 'same' or 'full'
            the padding of the inumpy.t signals

    Returns
    -------

        output: complex array
            the complex stft
    """
    assert signal.ndim == 3
    if nfft is None:
        nfft = window
    if mode == "same":
        left = (window + 1) // 2
        psignal = T.pad(signal, [[0, 0], [0, 0], [left, window + 1 - left]])
    elif mode == "full":
        left = (window + 1) // 2
        psignal = T.pad(signal, [[0, 0], [0, 0], [window - 1, window - 1]])
    else:
        psignal = signal

    apodization = apod(window).reshape((1, 1, -1))

    p = T.extract_signal_patches(psignal, window, hop) * apodization
    assert nfft >= window
    pp = T.pad(p, [[0, 0], [0, 0], [0, 0], [0, nfft - window]])
    S = fft(pp)
    return S[..., : int(numpy.ceil(nfft / 2))].transpose([0, 1, 3, 2]) 
Example #18
Source File: signal.py    From SymJAX with Apache License 2.0 4 votes vote down vote up
def morlet(M, s, w=5):
    """
    Complex Morlet wavelet.
    Parameters
    ----------
    M : int
        Length of the wavelet.

    s : float, optional
        Scaling factor, windowed from ``-s*2*pi`` to ``+s*2*pi``. Default is 1.
    w : float, optional
        Omega0. Default is 5
    complete : bool, optional
        Whether to use the complete or the standard version.
    Returns
    -------
    morlet : (M,) ndarray
    See Also
    --------
    morlet2 : Implementation of Morlet wavelet, compatible with `cwt`.
    scipy.signal.gausspulse
    Notes
    -----
    The standard version::
        pi**-0.25 * exp(1j*w*x) * exp(-0.5*(x**2))
    This commonly used wavelet is often referred to simply as the
    Morlet wavelet.  Note that this simplified version can cause
    admissibility problems at low values of `w`.
    The complete version::
        pi**-0.25 * (exp(1j*w*x) - exp(-0.5*(w**2))) * exp(-0.5*(x**2))
    This version has a correction
    term to improve admissibility. For `w` greater than 5, the
    correction term is negligible.
    Note that the energy of the return wavelet is not normalised
    according to `s`.
    The fundamental frequency of this wavelet in Hz is given
    by ``f = 2*s*w*r / M`` where `r` is the sampling rate.
    """
    limit = 2 * numpy.pi
    x = T.linspace(-limit, limit, M) * s
    sine = T.cos(w * x) + 1j * T.sin(w * x)
    envelop = T.exp(-0.5 * (x ** 2))

    # apply correction term for admissibility
    wave = sine - T.exp(-0.5 * (w ** 2))

    # now localize the wave to obtain a wavelet
    wavelet = wave * envelop * numpy.pi ** (-0.25)

    return wavelet