https://docs.xarray.dev/en/stable/_static/dataset-diagram-logo.png

Handling dask arrays#

apply_ufunc is a more advanced wrapper that is designed to apply functions that expect and return NumPy (or other arrays). For example, this would include all of SciPy’s API. Since apply_ufunc operates on lower-level NumPy or Dask objects, it skips the overhead of using Xarray objects making it a good choice for performance-critical functions.

apply_ufunc can be a little tricky to get right since it operates at a lower level than map_blocks. On the other hand, Xarray uses apply_ufunc internally to implement much of its API, meaning that it is quite powerful!

Learning goals:

  • Learn that apply_ufunc automates aspects of applying computation functions that are designed for pure arrays (like numpy arrays) on xarray objects

Setup#

import dask
import numpy as np
import xarray as xr

First lets set up a LocalCluster using dask.distributed.

You can use any kind of dask cluster. This step is completely independent of xarray. While not strictly necessary, the dashboard provides a nice learning tool.

from dask.distributed import Client

client = Client()
client

Client

Client-ef096be6-017b-11ee-8b1d-6045bdc9e3c7

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status

Cluster Info

👆

Click the Dashboard link above. Or click the "Search" button in the dashboard.

Let’s test that the dashboard is working..

import dask.array

dask.array.ones((1000, 4), chunks=(2, 1)).compute()  # should see activity in dashboard
array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       ...,
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]])

Let’s open a dataset. We specify chunks so that we create a dask arrays for the DataArrays

ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": 100})
ds
<xarray.Dataset>
Dimensions:  (lat: 25, time: 2920, lon: 53)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat, lon) float32 dask.array<chunksize=(100, 25, 53), meta=np.ndarray>
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...

A simple example#

All the concepts from applying numpy functions carry over.

However the handling of dask arrays needs to be explicitly activated.

# Expect an error here
def squared_error(x, y):
    return (x - y) ** 2


xr.apply_ufunc(squared_error, ds.air, 1)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 6
      2 def squared_error(x, y):
      3     return (x - y) ** 2
----> 6 xr.apply_ufunc(squared_error, ds.air, 1)

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:1204, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
   1202 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
   1203 elif any(isinstance(a, DataArray) for a in args):
-> 1204     return apply_dataarray_vfunc(
   1205         variables_vfunc,
   1206         *args,
   1207         signature=signature,
   1208         join=join,
   1209         exclude_dims=exclude_dims,
   1210         keep_attrs=keep_attrs,
   1211     )
   1212 # feed Variables directly through apply_variable_ufunc
   1213 elif any(isinstance(a, Variable) for a in args):

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:315, in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
    310 result_coords, result_indexes = build_output_coords_and_indexes(
    311     args, signature, exclude_dims, combine_attrs=keep_attrs
    312 )
    314 data_vars = [getattr(a, "variable", a) for a in args]
--> 315 result_var = func(*data_vars)
    317 out: tuple[DataArray, ...] | DataArray
    318 if signature.num_outputs > 1:

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:692, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    690 if any(is_duck_dask_array(array) for array in input_data):
    691     if dask == "forbidden":
--> 692         raise ValueError(
    693             "apply_ufunc encountered a dask array on an "
    694             "argument, but handling for dask arrays has not "
    695             "been enabled. Either set the ``dask`` argument "
    696             "or load your data into memory first with "
    697             "``.load()`` or ``.compute()``"
    698         )
    699     elif dask == "parallelized":
    700         numpy_func = func

ValueError: apply_ufunc encountered a dask array on an argument, but handling for dask arrays has not been enabled. Either set the ``dask`` argument or load your data into memory first with ``.load()`` or ``.compute()``

There are two options for the dask kwarg.

  1. dask="allowed" Dask arrays are passed to the user function. This is a good choice if your function can handle dask arrays and won’t call compute explicitly.

  2. dask="parallelized". This applies the user function over blocks of the dask array using dask.array.blockwise. This is useful when your function cannot handle dask arrays natively (e.g. scipy API).

Since squared_error can handle dask arrays without computing them, we specify dask="allowed".

sqer = xr.apply_ufunc(
    squared_error,
    ds.air,
    1,
    dask="allowed",
)
sqer  # dask-backed DataArray! with nice metadata!
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>
dask.array<pow, shape=(2920, 25, 53), dtype=float32, chunksize=(100, 25, 53), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00

Let’s again use the wrapper trick to understand what squared_error receives.

We see that it receives a dask array.

def wrapper(x, y):
    print(f"received x of type {type(x)}, shape {x.shape}")
    print(f"received y of type {type(y)}")
    return squared_error(x, y)


xr.apply_ufunc(wrapper, ds.air, 1, dask="allowed")
received x of type <class 'dask.array.core.Array'>, shape (2920, 25, 53)
received y of type <class 'int'>
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>
dask.array<pow, shape=(2920, 25, 53), dtype=float32, chunksize=(100, 25, 53), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00

Reductions and core dimensions#

squared_error operated on a per-element basis. How about a reduction like np.mean?

Such functions involve the concept of “core dimensions”. One way to think about core dimensions is to consider the smallest dimensionality of data necessary to apply the function.

For using more complex operations that consider some array values collectively, it’s important to understand the idea of core dimensions. Usually, they correspond to the fundamental dimensions over which an operation is defined, e.g., the summed axis in np.sum. A good clue that core dimensions are needed is the presence of an axis argument on the corresponding NumPy function.

With apply_ufunc, core dimensions are recognized by name, and then moved to the last dimension of any input arguments before applying the given function. This means that for functions that accept an axis argument, you usually need to set axis=-1

Let’s use dask.array.mean as an example of a function that can handle dask arrays and uses an axis kwarg

def time_mean(da):
    return xr.apply_ufunc(
        dask.array.mean,
        da,
        input_core_dims=[["time"]],
        dask="allowed",
        kwargs={"axis": -1},  # core dimensions are moved to the end
    )


time_mean(ds.air)
<xarray.DataArray 'air' (lat: 25, lon: 53)>
dask.array<mean_agg-aggregate, shape=(25, 53), dtype=float32, chunksize=(25, 53), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0

Again, this identical to the built-in mean

ds.air.mean("time").identical(time_mean(ds.air))
True

Automatically parallelizing dask-unaware functions#

A very useful apply_ufunc feature is the ability to apply arbitrary functions in parallel to each block. This ability can be activated using dask="parallelized". Again xarray needs a lot of extra metadata, so depending on the function, extra arguments such as output_dtypes and output_sizes may be necessary.

We will use scipy.integrate.trapz as an example of a function that cannot handle dask arrays and requires a core dimension. If we call trapz with a dask array, we get a numpy array back that is, the values have been eagerly computed. This is undesirable behaviour

import scipy as sp
import scipy.integrate

sp.integrate.trapz(ds.air.data, axis=ds.air.get_axis_num("lon"))  # does NOT return a dask array
/usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/dask/array/core.py:1711: FutureWarning: The `numpy.trapz` function is not implemented by Dask array. You may want to use the da.map_blocks function or something similar to silence this warning. Your code may stop working in a future release.
  warnings.warn(
array([[12588.54  , 12582.26  , 12671.649 , ..., 15374.26  , 15430.039 ,
        15493.165 ],
       [12571.841 , 12567.279 , 12654.569 , ..., 15355.915 , 15413.14  ,
        15477.346 ],
       [12584.62  , 12537.54  , 12644.909 , ..., 15347.77  , 15399.9   ,
        15460.965 ],
       ...,
       [12709.4795, 12638.4795, 12810.2295, ..., 15416.831 , 15459.581 ,
        15510.4795],
       [12726.679 , 12634.4795, 12794.63  , ..., 15401.4795, 15454.13  ,
        15511.4795],
       [12767.33  , 12630.78  , 12754.531 , ..., 15446.33  , 15495.53  ,
        15538.18  ]], dtype=float32)
xr.apply_ufunc(
    sp.integrate.trapz,
    ds,
    input_core_dims=[["lon"]],
    kwargs={"axis": -1},
    dask="parallelized",
)
<xarray.Dataset>
Dimensions:  (time: 2920, lat: 25)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat) float32 dask.array<chunksize=(100, 25), meta=np.ndarray>
client.close()

More#

  1. https://docs.xarray.dev/en/stable/examples/apply_ufunc_vectorize_1d.html

  2. https://docs.dask.org/en/latest/array-best-practices.html