Python xarray.ones_like() Examples

The following are 13 code examples of xarray.ones_like(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module xarray , or try the search function .
Example #1
Source File: times.py    From aospy with Apache License 2.0 6 votes vote down vote up
def add_uniform_time_weights(ds):
    """Append uniform time weights to a Dataset.

    All DataArrays with a time coordinate require a time weights coordinate.
    For Datasets read in without a time bounds coordinate or explicit
    time weights built in, aospy adds uniform time weights at each point
    in the time coordinate.

    Parameters
    ----------
    ds : Dataset
        Input data

    Returns
    -------
    Dataset
    """
    time = ds[TIME_STR]
    unit_interval = time.attrs['units'].split('since')[0].strip()
    time_weights = xr.ones_like(time)
    time_weights.attrs['units'] = unit_interval
    del time_weights.attrs['calendar']
    ds[TIME_WEIGHTS_STR] = time_weights
    return ds 
Example #2
Source File: core.py    From esmlab with Apache License 2.0 5 votes vote down vote up
def compute_time_bound_diff(self, ds):
        """Compute the difference between time bounds.
        """
        time_bound_diff = xr.ones_like(ds[self.time_coord_name], dtype=np.float64)

        if self.time_bound is not None:
            time_bound_diff.name = self.tb_name + '_diff'
            time_bound_diff.attrs = {}
            # Compute
            time_bound_diff.data = self.time_bound.diff(dim=self.tb_dim)[:, 0]
            if self.tb_dim in time_bound_diff.coords:
                time_bound_diff = time_bound_diff.drop(self.tb_dim)

        return time_bound_diff 
Example #3
Source File: stats.py    From climpred with MIT License 5 votes vote down vote up
def decorrelation_time(da, r=20, dim='time'):
    """Calculate the decorrelaton time of a time series.

    .. math::
        \\tau_{d} = 1 + 2 * \\sum_{k=1}^{r}(\\alpha_{k})^{k}

    Args:
        da (xarray object): Time series.
        r (optional int): Number of iterations to run the above formula.
        dim (optional str): Time dimension for xarray object.

    Returns:
        Decorrelation time of time series.

    Reference:
        * Storch, H. v, and Francis W. Zwiers. Statistical Analysis in Climate
          Research. Cambridge ; New York: Cambridge University Press, 1999.,
          p.373

    """
    one = xr.ones_like(da.isel({dim: 0}))
    one = one.where(da.isel({dim: 0}).notnull())
    return one + 2 * xr.concat(
        [autocorr(da, dim=dim, lag=i) ** i for i in range(1, r)], 'it'
    ).sum('it')


# --------------------------------------------#
# Diagnostic Potential Predictability (DPP)
# Functions related to DPP from Boer et al.
# --------------------------------------------# 
Example #4
Source File: test_PredictionEnsemble_math.py    From climpred with MIT License 5 votes vote down vote up
def test_hindcastEnsemble_plus_broadcast(hind_ds_initialized_3d, operator):
    """Test that HindcastEnsemble math operator (+-*/) other also broadcasts
    correctly."""
    he = HindcastEnsemble(hind_ds_initialized_3d)
    operator = eval(operator)
    # minimal adding an offset or like multiplying area
    he2 = operator(
        he, xr.ones_like(hind_ds_initialized_3d.isel(init=1, lead=1, drop=True))
    )
    he3 = operator(he, 1)
    assert_PredictionEnsemble(he2, he3) 
Example #5
Source File: test_PredictionEnsemble_math.py    From climpred with MIT License 5 votes vote down vote up
def test_PerfectModelEnsemble_plus_broadcast(PM_ds_initialized_3d, operator):
    """Test that PerfectModelEnsemble math operator (+-*/) other also broadcasts
    correctly."""
    he = PerfectModelEnsemble(PM_ds_initialized_3d)
    operator = eval(operator)
    # minimal adding an offset or like multiplying area
    he2 = operator(
        he, xr.ones_like(PM_ds_initialized_3d.isel(init=1, lead=1, drop=True))
    )
    he3 = operator(he, 1)
    assert_PredictionEnsemble(he2, he3) 
Example #6
Source File: vector_calc.py    From ECCOv4-py with MIT License 5 votes vote down vote up
def get_latitude_masks(lat_val,yc,grid):
    """Compute maskW/S which grabs vector field grid cells along specified latitude
    band and corrects the sign associated with X-Y LLC grid

    This mirrors the MATLAB function gcmfaces/gcmfaces_calc/gcmfaces_lines_zonal.m

    Parameters
    ----------

    lat_val : int
        latitude at which to compute mask 
    yc : xarray DataArray
        Contains latitude values at cell centers
    grid : xgcm Grid object
        llc grid object generated via get_llc_grid

    Returns
    -------

    maskWedge, maskSedge : xarray DataArray
        contains masks of latitude band at grid cell west and south grid edges
    """

    # Compute difference in X, Y direction. 
    # multiply by 1 so that "True" -> 1, 2nd arg to "where" puts False -> 0 
    ones = xr.ones_like(yc)
    maskC = ones.where(yc>=lat_val,0)

    maskWedge = grid.diff( maskC, 'X', boundary='fill')
    maskSedge = grid.diff( maskC, 'Y', boundary='fill')

    return maskWedge, maskSedge 
Example #7
Source File: scalar_calc.py    From ECCOv4-py with MIT License 5 votes vote down vote up
def get_latitude_mask(lat_val,yc,grid):
    """Compute maskCedge which grabs the grid cell center points along 
    the desired latitude

    This mirrors the MATLAB function  gcmfaces/gcmfaces_calc/gcmfaces_lines_zonal.m

    Parameters
    ----------

    lat_val : int
        latitude at which to compute mask 
    yc : xarray DataArray
        Contains latitude values at cell centers
    grid : xgcm Grid object
        llc grid object generated via get_llc_grid

    Returns
    -------

    maskCedge : xarray DataArray
        contains mask of latitude at grid cell tracer points
    """

    # Compute difference in X, Y direction. 
    # multiply by 1 so that "True" -> 1, 2nd arg to "where" puts False -> 0 
    ones = xr.ones_like(yc)
    lat_maskC = ones.where(yc>=lat_val,0)

    maskCedge = get_edge_mask(lat_maskC,grid)

    return maskCedge 
Example #8
Source File: scalar_calc.py    From ECCOv4-py with MIT License 5 votes vote down vote up
def get_edge_mask(maskC,grid):
    """From a given mask with points at cell centers, compute the 
    boundary between 1's and 0's

    Parameters
    ----------
    
    maskC : xarray DataArray
        containing 1's at interior points, 0's outside. We want the 
        boundary between them
    grid : xgcm Grid object

    Returns
    -------

    maskCedge : xarray DataArray
        with same dimensions as input maskC, with 1's at boundary 
        between 1's and 0's
    """

    # This first interpolation gets 0.5 at boundary points
    # however, the result lives on West and South grid cell edges
    maskX = grid.interp(maskC,'X', boundary='fill')
    maskY = grid.interp(maskC,'Y', boundary='fill')

    # Now interpolate these to get back on to cell centers
    # edge will now be at locations where values are 0.75
    maskXY= grid.interp_2d_vector({'X' : maskX, 'Y' : maskY}, boundary='fill')

    # Now wherever this is > 0 and the original mask is 0 is the boundary
    maskCedge = xr.ones_like(maskC).where( ((maskXY['X'] + maskXY['Y']) > 0) & (maskC==0.) , 0)

    return maskCedge 
Example #9
Source File: test_utils.py    From xarrayutils with MIT License 5 votes vote down vote up
def test_xr_linregress(chunks, dim, variant, dtype, nans, parameter, ni):
    a = xr.DataArray(np.random.rand(6, 8, 5), dims=["x", "time", "y"])
    b = xr.DataArray(np.random.rand(6, 5, 8), dims=["x", "y", "time"])
    if nans:
        if nans == "all":
            a = xr.ones_like(a) * np.nan
            b = xr.ones_like(b) * np.nan

        else:
            # add nans at random positions
            a.data[
                np.unravel_index(np.random.randint(0, 5 * 7 * 3, 10), a.shape)
            ] = np.nan
            b.data[
                np.unravel_index(np.random.randint(0, 5 * 7 * 3, 10), b.shape)
            ] = np.nan

    if chunks is not None:
        if variant == 0:
            a = a.chunk(chunks)
        elif variant == 1:
            b = b.chunk(chunks)
        elif variant == 2:
            a = a.chunk(chunks)
            b = b.chunk(chunks)

    reg = xr_linregress(a, b, dim=dim)

    dims = list(set(a.dims) - set([dim]))
    for ii in range(len(a[dims[0]])):
        for jj in range(len(a[dims[1]])):
            pos = dict({dims[0]: ii, dims[1]: jj})

            expected = _linregress_ufunc(a.isel(**pos), b.isel(**pos), nanmask=True)
            reg_sub = reg.isel(**pos)

            np.testing.assert_allclose(reg_sub[parameter].data, expected[ni]) 
Example #10
Source File: test_plotting.py    From xarrayutils with MIT License 5 votes vote down vote up
def test_linear_piecewise_scale(cut, scale, axis, scaled_half):
    da_z = xr.DataArray(np.arange(100), dims=["x"])
    da_x = xr.DataArray(np.arange(50), dims=["z"])
    da_data = da_z * xr.ones_like(da_x)
    plt.contourf(da_x, da_z, da_data)

    linear_piecewise_scale(cut, scale, axis=axis, scaled_half=scaled_half)

    if axis == "x":
        if scale != 0:
            assert plt.gca().get_xscale() == "function"
        # this is not a great test. Need something more definitive...
    elif axis == "y":
        if scale != 0:
            assert plt.gca().get_yscale() == "function" 
Example #11
Source File: test_seaice.py    From xclim with Apache License 2.0 5 votes vote down vote up
def values(self, areacello):
        s = xr.ones_like(areacello)
        s = s.where(s.lat > 0, 10)
        s = s.where(s.lat <= 0, 50)
        sic = xr.concat([s, s], dim="time")
        sic.attrs["units"] = "%"
        sic.attrs["standard_name"] = "sea_ice_area_fraction"

        return areacello, sic 
Example #12
Source File: core.py    From esmlab with Apache License 2.0 4 votes vote down vote up
def compute_ann_mean(self, weights=None, method=None):
        """ Calculates annual mean """
        time_dot_year = '.'.join([self.time_coord_name, 'year'])

        if isinstance(weights, (xr.DataArray, np.ndarray, da.Array, list)):
            if len(weights) != len(self._ds_time_computed[self.time_coord_name]):
                raise ValueError(
                    'weights and dataset time coordinate values must be of the same length'
                )
            else:
                dt = xr.ones_like(self._ds_time_computed[self.time_coord_name])
                dt.data = weights
                wgts = dt / dt.sum(xr.ALL_DIMS)
                np.testing.assert_allclose(wgts.sum(xr.ALL_DIMS), 1.0)

        else:
            dt = self.time_bound_diff
            wgts = dt.groupby(time_dot_year) / dt.groupby(time_dot_year).sum(xr.ALL_DIMS)
            np.testing.assert_allclose(wgts.groupby(time_dot_year).sum(xr.ALL_DIMS), 1.0)

        wgts = wgts.rename('weights')

        dset = self._ds_time_computed.drop(self.static_variables)

        def weighted_mean_arr(darr, wgts=None):
            # if NaN are present, we need to use individual weights
            cond = darr.isnull()
            ones = xr.where(cond, 0.0, 1.0)
            mask = (
                darr.resample({self.time_coord_name: 'A'}).mean(dim=self.time_coord_name).notnull()
            )
            da_sum = (
                (darr * wgts).resample({self.time_coord_name: 'A'}).sum(dim=self.time_coord_name)
            )
            ones_out = (
                (ones * wgts).resample({self.time_coord_name: 'A'}).sum(dim=self.time_coord_name)
            )
            ones_out = ones_out.where(ones_out > 0.0)
            da_weighted_mean = da_sum / ones_out
            return da_weighted_mean.where(mask)

        computed_dset = dset.apply(weighted_mean_arr, wgts=wgts)

        computed_dset = self.compute_resample_times(
            ds=computed_dset,
            temporary_time_coord_name='year',
            time_dot=time_dot_year,
            method=method,
        )

        return self.restore_dataset(computed_dset) 
Example #13
Source File: test_core.py    From esmlab with Apache License 2.0 4 votes vote down vote up
def test_esmlab_accessor():
    ds = xr.Dataset(
        {
            'temp': xr.DataArray(
                [1, 2],
                dims=['time'],
                coords={'time': pd.date_range(start='2000', periods=2, freq='1D')},
            )
        }
    )
    attrs = {'calendar': 'noleap', 'units': 'days since 2000-01-01 00:00:00'}
    ds.time.attrs = attrs
    esm = ds.esmlab.set_time(time_coord_name='time')
    xr.testing._assert_internal_invariants(esm._ds_time_computed)
    # Time and Time bound Attributes
    expected = dict(esm.time_attrs)
    attrs['bounds'] = None
    assert expected == attrs
    assert esm.time_bound_attrs == {}

    assert esm.variables == ['temp']
    assert esm.static_variables == []

    # Time bound diff
    expected = xr.ones_like(ds.time, dtype='float64')
    xr.testing.assert_equal(expected, esm.time_bound_diff)

    # Compute time var
    with pytest.raises(ValueError):
        esm.compute_time_var(midpoint=True, year_offset=2100)

    # Decode arbitrary time value
    with pytest.raises(ValueError):
        esm.decode_arbitrary_time(ds.time.data[0], units=attrs['units'], calendar=attrs['calendar'])

    res = esm.decode_arbitrary_time(
        np.array([30]), units=attrs['units'], calendar=attrs['calendar']
    )
    assert res[0] == cftime.DatetimeNoLeap(2000, 1, 31, 0, 0, 0, 0, 0, 31)

    data = xr.DataArray(
        [1, 2],
        dims=['time'],
        coords={'time': pd.date_range(start='2000', freq='1D', periods=2)},
        attrs={'calendar': 'standard', 'units': 'days since 2001-01-01 00:00:00'},
        name='rand',
    ).to_dataset()

    data['time'] = xr.cftime_range(start='2000', freq='1D', periods=2)

    with pytest.raises(ValueError):
        data.esmlab.set_time().get_time_decoded()

    with pytest.raises(ValueError):
        data.esmlab.set_time().get_time_undecoded()

    data = xr.DataArray(
        [[1, 2], [7, 8]], dims=['x', 'y'], coords={'x': [1, 2], 'y': [2, 3]}, name='rand'
    ).to_dataset()
    with pytest.raises(ValueError):
        data.esmlab.set_time('time-bound-coord')