Python dask.array.compute() Examples

The following are 30 code examples of dask.array.compute(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module dask.array , or try the search function .
Example #1
Source File:    From dask-ml with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def fit(self, X, y=None):
        X = self._check_array(X)
        labels, centroids, inertia, n_iter = k_means(
        self.cluster_centers_ = centroids
        self.labels_ = labels
        self.inertia_ = inertia.compute().item()
        self.n_iter_ = n_iter
        self.n_features_in_ = X.shape[1]
        return self 
Example #2
Source File:    From pyresample with GNU Lesser General Public License v3.0 6 votes vote down vote up
def get_geostationary_angle_extent(geos_area):
    """Get the max earth (vs space) viewing angles in x and y."""
    # get some projection parameters
    a, b = proj4_radius_parameters(geos_area.proj_dict)
    req = a / 1000.0
    rp = b / 1000.0
    h = geos_area.proj_dict['h'] / 1000.0 + req

    # compute some constants
    aeq = 1 - req ** 2 / (h ** 2)
    ap_ = 1 - rp ** 2 / (h ** 2)

    # generate points around the north hemisphere in satellite projection
    # make it a bit smaller so that we stay inside the valid area
    xmax = np.arccos(np.sqrt(aeq))
    ymax = np.arccos(np.sqrt(ap_))
    return xmax, ymax 
Example #3
Source File:    From pyresample with GNU Lesser General Public License v3.0 6 votes vote down vote up
def query_no_distance(target_lons, target_lats,
                      valid_output_index, kdtree, neighbours, epsilon, radius):
    """Query the kdtree. No distances are returned."""
    voi = valid_output_index
    voir = da.ravel(voi)
    target_lons_valid = da.ravel(target_lons)[voir]
    target_lats_valid = da.ravel(target_lats)[voir]

    coords = lonlat2xyz(target_lons_valid, target_lats_valid)
    distance_array, index_array = kdtree.query(

    return index_array 
Example #4
Source File:    From pyresample with GNU Lesser General Public License v3.0 6 votes vote down vote up
def _get_corner_dask(stride, valid, in_x, in_y, index_array):
    """Get closest set of coordinates from the *valid* locations."""
    # Find the closest valid pixels, if any
    idxs = np.argmax(valid, axis=1)
    # Check which of these were actually valid
    invalid = np.invert(np.max(valid, axis=1))

    # idxs = idxs.compute()
    index_array = index_array.compute()

    # Replace invalid points with np.nan
    x__ = in_x[stride, idxs]  # TODO: daskify
    x__ = da.where(invalid, np.nan, x__)
    y__ = in_y[stride, idxs]  # TODO: daskify
    y__ = da.where(invalid, np.nan, y__)

    idx = index_array[stride, idxs]  # TODO: daskify

    return x__, y__, idx 
Example #5
Source File:    From nbodykit with GNU General Public License v3.0 6 votes vote down vote up
def persist(self, columns=None):
        Return a CatalogSource, where the selected columns are
        computed and persist in memory.

        import dask.array as da
        if columns is None:
            columns = self.columns

        r = {}
        for key in columns:
            r[key] = self[key]

        r = da.compute(r)[0] # particularity of dask

        from nbodykit.source.catalog.array import ArrayCatalog
        c = ArrayCatalog(r, comm=self.comm)

        return c 
Example #6
Source File:    From pyresample with GNU Lesser General Public License v3.0 6 votes vote down vote up
def test_compute_indices(self, mock_setattr):
        """Test running .compute() for indices."""
        from pyresample.bilinear.xarr import (XArrayResamplerBilinear,

        resampler = XArrayResamplerBilinear(self.source_def, self.target_def,

        # Set indices to Numpy arrays
        for idx in CACHE_INDICES:
            setattr(resampler, idx, np.array([]))
        # None of the indices shouldn't have been reassigned

        # Set indices to a Mock object
        arr = mock.MagicMock()
        for idx in CACHE_INDICES:
            setattr(resampler, idx, arr)
        # All the indices should have been reassigned
        self.assertEqual(mock_setattr.call_count, len(CACHE_INDICES))
        # The compute should have been called the same amount of times
        self.assertEqual(arr.compute.call_count, len(CACHE_INDICES)) 
Example #7
Source File:    From nbodykit with GNU General Public License v3.0 6 votes vote down vote up
def test_slice(comm):

    source = UniformCatalog(nbar=2e-4, BoxSize=512., seed=42, comm=comm)

    source['NZ'] = 1
    # slice a subset
    subset = source[:10]
    assert all(col in subset for col in source.columns)
    assert isinstance(subset, source.__class__)
    assert len(subset) == 10
    assert_array_equal(subset['Position'], source['Position'].compute()[:10])

    subset = source[[0,1,2]]
    assert_array_equal(subset['Position'], source['Position'].compute()[[0,1,2]])

    # cannot slice with list of floats
    with pytest.raises(KeyError):
        subset = source[[0.0,1.0,2.0]]

    # missing column
    with pytest.raises(KeyError):
        col = source['BAD_COLUMN'] 
Example #8
Source File:    From pyresample with GNU Lesser General Public License v3.0 6 votes vote down vote up
def _create_resample_kdtree(self):
        """Set up kd tree on input."""
        # Get input information
        valid_input_index, source_lons, source_lats = \

        # FIXME: Is dask smart enough to only compute the pixels we end up
        #        using even with this complicated indexing
        input_coords = lonlat2xyz(source_lons, source_lats)
        valid_input_index = da.ravel(valid_input_index)
        input_coords = input_coords[valid_input_index, :]
        input_coords = input_coords.compute()
        # Build kd-tree on input
        input_coords = input_coords.astype(np.float)
        valid_input_index, input_coords = da.compute(valid_input_index,
        return valid_input_index, KDTree(input_coords) 
Example #9
Source File:    From satpy with GNU General Public License v3.0 6 votes vote down vote up
def compute_writer_results(results):
    """Compute all the given dask graphs `results` so that the files are saved.

        results (iterable): Iterable of dask graphs resulting from calls to
                            `scn.save_datasets(..., compute=False)`
    if not results:

    sources, targets, delayeds = split_results(results)

    # one or more writers have targets that we need to close in the future
    if targets:
        delayeds.append(, targets, compute=False))

    if delayeds:

    if targets:
        for target in targets:
            if hasattr(target, 'close'):
Example #10
Source File:    From pyresample with GNU Lesser General Public License v3.0 6 votes vote down vote up
def test_get_valid_input_index_dask(self):
        """Test finding valid indices for reduced input data."""
        from pyresample.bilinear.xarr import _get_valid_input_index_dask

        # Do not reduce data
        vii, lons, lats = _get_valid_input_index_dask(self.source_def,
                                                      False, self.radius)
        self.assertEqual(vii.shape, (self.source_def.size, ))
        self.assertTrue(vii.dtype == np.bool)
        # No data has been reduced, whole input is used

        # Reduce data
        vii, lons, lats = _get_valid_input_index_dask(self.source_def,
                                                      True, self.radius)
        # 2700 valid input points
        self.assertEqual(vii.compute().sum(), 2700) 
Example #11
Source File:    From pyresample with GNU Lesser General Public License v3.0 6 votes vote down vote up
def test_solve_quadratic(self):
        """Test solving quadratic equation."""
        from pyresample.bilinear.xarr import (_solve_quadratic_dask,

        res = _solve_quadratic_dask(1, 0, 0).compute()
        self.assertEqual(res, 0.0)
        res = _solve_quadratic_dask(1, 2, 1).compute()
        res = _solve_quadratic_dask(1, 2, 1, min_val=-2.).compute()
        self.assertEqual(res, -1.0)
        # Test that small adjustments work
        pt_1, pt_2, pt_3, pt_4 = self.pts_vert_parallel
        pt_1 = self.pts_vert_parallel[0].copy()
        pt_1[0][0] += 1e-7
        res = _calc_abc_dask(pt_1, pt_2, pt_3, pt_4, 0.0, 0.0)
        res = _solve_quadratic_dask(res[0], res[1], res[2]).compute()
        self.assertAlmostEqual(res[0], 0.5, 5)
        res = _calc_abc_dask(pt_1, pt_3, pt_2, pt_4, 0.0, 0.0)
        res = _solve_quadratic_dask(res[0], res[1], res[2]).compute()
        self.assertAlmostEqual(res[0], 0.5, 5) 
Example #12
Source File:    From pyresample with GNU Lesser General Public License v3.0 5 votes vote down vote up
def test_lonlat2xyz(self):
        """Test conversion from geographic to cartesian 3D coordinates."""
        from pyresample.bilinear.xarr import lonlat2xyz
        from pyresample import CHUNK_SIZE

        lons, lats = self.target_def.get_lonlats(chunks=CHUNK_SIZE)
        res = lonlat2xyz(lons, lats)
        self.assertEqual(res.shape, (self.target_def.size, 3))
        vals = [3188578.91069278, -612099.36103276, 5481596.63569999]
        self.assertTrue(np.allclose(res.compute()[0, :], vals)) 
Example #13
Source File:    From pyresample with GNU Lesser General Public License v3.0 5 votes vote down vote up
def test_get_average(self):
        """Test averaging bucket resampling."""
        data = da.from_array(np.array([[2., 4.], [3., np.nan]]),
        # Without pre-calculated indices
        with dask.config.set(scheduler=CustomScheduler(max_computes=0)):
            result = self.resampler.get_average(data)
        result = result.compute()
        self.assertEqual(np.nanmax(result), 3.)
        # Use a fill value other than np.nan
        with dask.config.set(scheduler=CustomScheduler(max_computes=0)):
            result = self.resampler.get_average(data, fill_value=-1)
        result = result.compute()
        self.assertEqual(np.max(result), 3.)
        self.assertEqual(np.min(result), -1)

        # Test masking all-NaN bins
        data = da.from_array(np.array([[np.nan, np.nan], [np.nan, np.nan]]),
        with dask.config.set(scheduler=CustomScheduler(max_computes=0)):
            result = self.resampler.get_average(data, mask_all_nan=True)
        # By default all-NaN bins have a value of NaN
        with dask.config.set(scheduler=CustomScheduler(max_computes=0)):
            result = self.resampler.get_average(data)
Example #14
Source File:    From pyresample with GNU Lesser General Public License v3.0 5 votes vote down vote up
def test_resample_bucket_fractions(self):
        """Test fraction calculations for categorical data."""
        data = da.from_array(np.array([[2, 4], [2, 2]]),
        categories = [1, 2, 3, 4]
        with dask.config.set(scheduler=CustomScheduler(max_computes=0)):
            result = self.resampler.get_fractions(data, categories=categories)
        self.assertEqual(set(categories), set(result.keys()))
        res = result[1].compute()
        self.assertTrue(np.nanmax(res) == 0.)
        res = result[2].compute()
        self.assertTrue(np.nanmax(res) == 1.)
        self.assertTrue(np.nanmin(res) == 0.5)
        res = result[3].compute()
        self.assertTrue(np.nanmax(res) == 0.)
        res = result[4].compute()
        self.assertTrue(np.nanmax(res) == 0.5)
        self.assertTrue(np.nanmin(res) == 0.)
        # There should be NaN values

        # Use a fill value
        with dask.config.set(scheduler=CustomScheduler(max_computes=0)):
            result = self.resampler.get_fractions(data, categories=categories,

        # There should not be any NaN values
        for i in categories:
            res = result[i].compute()
            self.assertTrue(np.min(res) == -1)

        # No categories given, need to compute the data once to get
        # the categories
        with dask.config.set(scheduler=CustomScheduler(max_computes=1)):
            result = self.resampler.get_fractions(data, categories=None) 
Example #15
Source File:    From pyresample with GNU Lesser General Public License v3.0 5 votes vote down vote up
def test_get_sample_from_bil_info(self):
        """Test bilinear interpolation as a whole."""
        from pyresample.bilinear.xarr import XArrayResamplerBilinear

        resampler = XArrayResamplerBilinear(self.source_def, self.target_def,
        _ = resampler.get_bil_info()

        # Sample from data1
        res = resampler.get_sample_from_bil_info(self.data1)
        res = res.compute()
        # Check couple of values
        self.assertEqual(res.values[1, 1], 1.)
        self.assertTrue(np.isnan(res.values[0, 3]))
        # Check that the values haven't gone down or up a lot
        self.assertAlmostEqual(np.nanmin(res.values), 1.)
        self.assertAlmostEqual(np.nanmax(res.values), 1.)
        # Check that dimensions are the same
        self.assertEqual(res.dims, self.data1.dims)

        # Sample from data1, custom fill value
        res = resampler.get_sample_from_bil_info(self.data1, fill_value=-1.0)
        res = res.compute()
        self.assertEqual(np.nanmin(res.values), -1.)

        # Sample from integer data
        res = resampler.get_sample_from_bil_info(self.data1.astype(np.uint8),
        res = res.compute()
        # Five values should be filled with zeros, which is the
        # default fill_value for integer data
        self.assertEqual(np.sum(res == 0), 6) 
Example #16
Source File:    From pyresample with GNU Lesser General Public License v3.0 5 votes vote down vote up
def test_get_bounding_corners_dask(self):
        """Test finding surrounding bounding corners."""
        import dask.array as da
        from pyresample.bilinear.xarr import (_get_input_xy_dask,
        from pyresample._spatial_mp import Proj
        from pyresample import CHUNK_SIZE

        proj = Proj(self.target_def.proj_str)
        out_x, out_y = self.target_def.get_proj_coords(chunks=CHUNK_SIZE)
        out_x = da.ravel(out_x)
        out_y = da.ravel(out_y)
        in_x, in_y = _get_input_xy_dask(self.source_def, proj,
        pt_1, pt_2, pt_3, pt_4, ia_ = _get_bounding_corners_dask(
            in_x, in_y, out_x, out_y,

        self.assertTrue(pt_1.shape == pt_2.shape ==
                        pt_3.shape == pt_4.shape ==
                        (self.target_def.size, 2))
        self.assertTrue(ia_.shape == (self.target_def.size, 4))

        # Check which of the locations has four valid X/Y pairs by
        # finding where there are non-NaN values
        res = da.sum(pt_1 + pt_2 + pt_3 + pt_4, axis=1).compute()
        self.assertEqual(np.sum(~np.isnan(res)), 10) 
Example #17
Source File:    From pyresample with GNU Lesser General Public License v3.0 5 votes vote down vote up
def test_get_corner_dask(self):
        """Test finding the closest corners."""
        import dask.array as da
        from pyresample.bilinear.xarr import (_get_corner_dask,
        from pyresample import CHUNK_SIZE
        from pyresample._spatial_mp import Proj

        proj = Proj(self.target_def.proj_str)
        in_x, in_y = _get_input_xy_dask(self.source_def, proj,
        out_x, out_y = self.target_def.get_proj_coords(chunks=CHUNK_SIZE)
        out_x = da.ravel(out_x)
        out_y = da.ravel(out_y)

        # Some copy&paste from the code to get the input
        out_x_tile = np.reshape(np.tile(out_x, self.neighbours),
                                (self.neighbours, out_x.size)).T
        out_y_tile = np.reshape(np.tile(out_y, self.neighbours),
                                (self.neighbours, out_y.size)).T
        x_diff = out_x_tile - in_x
        y_diff = out_y_tile - in_y
        stride = np.arange(x_diff.shape[0])

        # Use lower left source pixels for testing
        valid = (x_diff > 0) & (y_diff > 0)
        x_3, y_3, idx_3 = _get_corner_dask(stride, valid, in_x, in_y,

        self.assertTrue(x_3.shape == y_3.shape == idx_3.shape ==
                        (self.target_def.size, ))
        # Four locations have no data to the lower left of them (the
        # bottom row of the area
        self.assertEqual(np.sum(np.isnan(x_3.compute())), 4) 
Example #18
Source File:    From dask-ml with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def init_pp(X, n_clusters, random_state):
    """K-means initialization using k-means++

    This uses scikit-learn's implementation.
    x_squared_norms = row_norms(X, squared=True).compute()"Initializing with k-means++")
    with _timer("initialization of %2d centers" % n_clusters, _logger=logger):
        # XXX: Using a private scikit-learn API
        centers = _k_init(
            X, n_clusters, random_state=random_state, x_squared_norms=x_squared_norms

    return centers 
Example #19
Source File:    From satpy with GNU General Public License v3.0 5 votes vote down vote up
def save_dataset(self, dataset, filename=None, fill_value=None,
                     compute=True, **kwargs):
        """Save the ``dataset`` to a given ``filename``.

        This method must be overloaded by the subclass.

            dataset (xarray.DataArray): Dataset to save using this writer.
            filename (str): Optionally specify the filename to save this
                            dataset to. If not provided then `filename`
                            which can be provided to the init method will be
                            used and formatted by dataset attributes.
            fill_value (int or float): Replace invalid values in the dataset
                                       with this fill value if applicable to
                                       this writer.
            compute (bool): If `True` (default), compute and save the dataset.
                            If `False` return either a :doc:`dask:delayed`
                            object or tuple of (source, target). See the
                            return values below for more information.
            **kwargs: Other keyword arguments for this particular writer.

            Value returned depends on `compute`. If `compute` is `True` then
            the return value is the result of computing a
            :doc:`dask:delayed` object or running :func:``.
            If `compute` is `False` then the returned value is either a
            :doc:`dask:delayed` object that can be computed using
            `delayed.compute()` or a tuple of (source, target) that should be
            passed to :func:``. If target is provided the the
            caller is responsible for calling `target.close()` if the target
            has this method.

        raise NotImplementedError(
            "Writer '%s' has not implemented dataset saving" % (, )) 
Example #20
Source File:    From satpy with GNU General Public License v3.0 5 votes vote down vote up
def save_dataset(self, dataset, filename=None, fill_value=None,
                     overlay=None, decorate=None, compute=True, **kwargs):
        """Save the ``dataset`` to a given ``filename``.

        This method creates an enhanced image using :func:`get_enhanced_image`.
        The image is then passed to :meth:`save_image`. See both of these
        functions for more details on the arguments passed to this method.

        img = get_enhanced_image(dataset.squeeze(), enhance=self.enhancer, overlay=overlay,
                                 decorate=decorate, fill_value=fill_value)
        return self.save_image(img, filename=filename, compute=compute, fill_value=fill_value, **kwargs) 
Example #21
Source File:    From satpy with GNU General Public License v3.0 5 votes vote down vote up
def save_image(self, img, filename=None, compute=True, **kwargs):
        """Save Image object to a given ``filename``.

            img (trollimage.xrimage.XRImage): Image object to save to disk.
            filename (str): Optionally specify the filename to save this
                            dataset to. It may include string formatting
                            patterns that will be filled in by dataset
            compute (bool): If `True` (default), compute and save the dataset.
                            If `False` return either a :doc:`dask:delayed`
                            object or tuple of (source, target). See the
                            return values below for more information.
            **kwargs: Other keyword arguments to pass to this writer.

            Value returned depends on `compute`. If `compute` is `True` then
            the return value is the result of computing a
            :doc:`dask:delayed` object or running :func:``.
            If `compute` is `False` then the returned value is either a
            :doc:`dask:delayed` object that can be computed using
            `delayed.compute()` or a tuple of (source, target) that should be
            passed to :func:``. If target is provided the the
            caller is responsible for calling `target.close()` if the target
            has this method.

        raise NotImplementedError("Writer '%s' has not implemented image saving" % (,)) 
Example #22
Source File:    From satpy with GNU General Public License v3.0 5 votes vote down vote up
def test_GenericImageFileHandler(self):
        """Test direct use of the reader."""
        from satpy.readers.generic_image import GenericImageFileHandler
        from satpy.readers.generic_image import mask_image_data

        fname = os.path.join(self.base_dir, 'test_rgba.tif')
        fname_info = {'start_time':}
        ftype_info = {}
        reader = GenericImageFileHandler(fname, fname_info, ftype_info)

        class Foo(object):
            """Mock class for dataset id"""
            def __init__(self):
       = 'image'

        foo = Foo()
        self.assertEqual(reader.finfo['filename'], fname)
        self.assertEqual(reader.area, self.area_def)
        self.assertEqual(reader.get_area_def(None), self.area_def)

        dataset = reader.get_dataset(foo, None)
        self.assertTrue(isinstance(dataset, xr.DataArray))
        self.assertTrue('crs' in dataset.attrs)
        self.assertTrue('transform' in dataset.attrs)
        self.assertTrue(np.all(np.isnan([:, :10, :10].compute())))

        # Test masking of floats
        data = self.scn['rgba']
        self.assertRaises(ValueError, mask_image_data, data / 255.)
        data = data.astype(np.uint32)
        self.assertTrue(data.bands.size == 4)
        data = mask_image_data(data)
        self.assertTrue(data.bands.size == 3) 
Example #23
Source File:    From dask-ml with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _fit_array(self, X):
        if self.strategy not in {"mean", "constant"}:
            msg = "Can only use strategy='mean' or 'constant' with Dask Array."
            raise ValueError(msg)

        if self.strategy == "mean":
            statistics = da.nanmean(X, axis=0).compute()
            statistics = np.full(X.shape[1], self.fill_value, dtype=X.dtype)

        (self.statistics_,) = da.compute(statistics) 
Example #24
Source File:    From dask-ml with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _fit_frame(self, X):
        if self.strategy == "mean":
            avg = X.mean(axis=0).values
        elif self.strategy == "median":
            avg = X.quantile().values
        elif self.strategy == "constant":
            avg = np.full(len(X.columns), self.fill_value)
            avg = [X[col].value_counts().nlargest(1).index for col in X.columns]
            avg = np.concatenate(*dask.compute(avg))

        self.statistics_ = pd.Series(dask.compute(avg)[0], index=X.columns) 
Example #25
Source File:    From dask-ml with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _check_array(self, X):
        if isinstance(X, pd.DataFrame):
            X = X.values

        if isinstance(X, dd.DataFrame):
            X = X.to_dask_array(lengths=True)

        X = check_array(

        if X.dtype == "int32":
            X = X.astype("float32")
        elif X.dtype == "int64":
            X = X.astype("float64")

        if isinstance(X, np.ndarray):
            X = da.from_array(X, chunks=(max(1, len(X) // cpu_count()), X.shape[-1]))

        bad = (da.isnull(X).any(), da.isinf(X).any())
        if any(*compute(bad)):
            msg = (
                "Input contains NaN, infinity or a value too large for "
            raise ValueError(msg)
        return X 
Example #26
Source File:    From scanpy with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def materialize_as_ndarray(a):
    """Convert distributed arrays to ndarrays."""
    if type(a) in (list, tuple):
        if da is not None and any(isinstance(arr, da.Array) for arr in a):
            return da.compute(*a, sync=True)
        return tuple(np.asarray(arr) for arr in a)
    return np.asarray(a) 
Example #27
Source File:    From nbodykit with GNU General Public License v3.0 5 votes vote down vote up
def test_getitem_columns(comm):

    source = UniformCatalog(nbar=2e-4, BoxSize=512., seed=42, comm=comm)

    # bad column name
    with pytest.raises(KeyError):
        subset = source[['Position', 'BAD_COLUMN']]

    subset = source[['Position']]

    for col in subset:
        assert_array_equal(subset[col].compute(), source[col].compute()) 
Example #28
Source File:    From nbodykit with GNU General Public License v3.0 5 votes vote down vote up
def __getitem__(self, key):

        # compute dask index b/c they are not fully supported
        if isinstance(key, da.Array):
            key = self.catalog.compute(key)

        # base class behavior
        d = da.Array.__getitem__(self, key)

        # return a ColumnAccessor (okay b/c __setitem__ checks for circular references)
        toret = ColumnAccessor(self.catalog, d)
        return toret 
Example #29
Source File:    From nbodykit with GNU General Public License v3.0 5 votes vote down vote up
def compute(self):
        return self.catalog.compute(self) 
Example #30
Source File:    From nbodykit with GNU General Public License v3.0 5 votes vote down vote up
def __str__(self):
        r = da.Array.__str__(self)
        if len(self) > 0:
            r = r + " first: %s" % str(self[0].compute())
        if len(self) > 1:
            r = r + " last: %s" % str(self[-1].compute())
        return r