Handling dask arrays
Contents

Handling dask arrays#
apply_ufunc
is a more advanced wrapper that is designed to apply functions
that expect and return NumPy (or other arrays). For example, this would include
all of SciPy’s API. Since apply_ufunc
operates on lower-level NumPy or Dask
objects, it skips the overhead of using Xarray objects making it a good choice
for performance-critical functions.
apply_ufunc
can be a little tricky to get right since it operates at a lower
level than map_blocks
. On the other hand, Xarray uses apply_ufunc
internally
to implement much of its API, meaning that it is quite powerful!
Learning goals:
Learn that
apply_ufunc
automates aspects of applying computation functions that are designed for pure arrays (like numpy arrays) on xarray objects
Setup#
import dask
import numpy as np
import xarray as xr
First lets set up a LocalCluster
using dask.distributed.
You can use any kind of dask cluster. This step is completely independent of xarray. While not strictly necessary, the dashboard provides a nice learning tool.
from dask.distributed import Client
client = Client()
client
Client
Client-ef096be6-017b-11ee-8b1d-6045bdc9e3c7
Connection method: Cluster object | Cluster type: distributed.LocalCluster |
Dashboard: http://127.0.0.1:8787/status |
Cluster Info
LocalCluster
776dc20b
Dashboard: http://127.0.0.1:8787/status | Workers: 2 |
Total threads: 2 | Total memory: 6.78 GiB |
Status: running | Using processes: True |
Scheduler Info
Scheduler
Scheduler-f10f734b-bf32-4b8d-8da2-441b3a7c07ea
Comm: tcp://127.0.0.1:34473 | Workers: 2 |
Dashboard: http://127.0.0.1:8787/status | Total threads: 2 |
Started: Just now | Total memory: 6.78 GiB |
Workers
Worker: 0
Comm: tcp://127.0.0.1:39313 | Total threads: 1 |
Dashboard: http://127.0.0.1:46129/status | Memory: 3.39 GiB |
Nanny: tcp://127.0.0.1:33435 | |
Local directory: /tmp/dask-worker-space/worker-uklbt3f8 |
Worker: 1
Comm: tcp://127.0.0.1:35229 | Total threads: 1 |
Dashboard: http://127.0.0.1:44225/status | Memory: 3.39 GiB |
Nanny: tcp://127.0.0.1:43581 | |
Local directory: /tmp/dask-worker-space/worker-seudtnw0 |
👆
Click the Dashboard link above. Or click the "Search" button in the dashboard.Let’s test that the dashboard is working..
import dask.array
dask.array.ones((1000, 4), chunks=(2, 1)).compute() # should see activity in dashboard
array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
...,
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
Let’s open a dataset. We specify chunks
so that we create a dask arrays for the DataArrays
ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": 100})
ds
<xarray.Dataset> Dimensions: (lat: 25, time: 2920, lon: 53) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float32 dask.array<chunksize=(100, 25, 53), meta=np.ndarray> Attributes: Conventions: COARDS title: 4x daily NMC reanalysis (1948) description: Data is from NMC initialized reanalysis\n(4x/day). These a... platform: Model references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
A simple example#
All the concepts from applying numpy functions carry over.
However the handling of dask arrays needs to be explicitly activated.
# Expect an error here
def squared_error(x, y):
return (x - y) ** 2
xr.apply_ufunc(squared_error, ds.air, 1)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[5], line 6
2 def squared_error(x, y):
3 return (x - y) ** 2
----> 6 xr.apply_ufunc(squared_error, ds.air, 1)
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:1204, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
1202 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
1203 elif any(isinstance(a, DataArray) for a in args):
-> 1204 return apply_dataarray_vfunc(
1205 variables_vfunc,
1206 *args,
1207 signature=signature,
1208 join=join,
1209 exclude_dims=exclude_dims,
1210 keep_attrs=keep_attrs,
1211 )
1212 # feed Variables directly through apply_variable_ufunc
1213 elif any(isinstance(a, Variable) for a in args):
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:315, in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
310 result_coords, result_indexes = build_output_coords_and_indexes(
311 args, signature, exclude_dims, combine_attrs=keep_attrs
312 )
314 data_vars = [getattr(a, "variable", a) for a in args]
--> 315 result_var = func(*data_vars)
317 out: tuple[DataArray, ...] | DataArray
318 if signature.num_outputs > 1:
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:692, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
690 if any(is_duck_dask_array(array) for array in input_data):
691 if dask == "forbidden":
--> 692 raise ValueError(
693 "apply_ufunc encountered a dask array on an "
694 "argument, but handling for dask arrays has not "
695 "been enabled. Either set the ``dask`` argument "
696 "or load your data into memory first with "
697 "``.load()`` or ``.compute()``"
698 )
699 elif dask == "parallelized":
700 numpy_func = func
ValueError: apply_ufunc encountered a dask array on an argument, but handling for dask arrays has not been enabled. Either set the ``dask`` argument or load your data into memory first with ``.load()`` or ``.compute()``
There are two options for the dask
kwarg.
dask="allowed"
Dask arrays are passed to the user function. This is a good choice if your function can handle dask arrays and won’t call compute explicitly.dask="parallelized"
. This applies the user function over blocks of the dask array usingdask.array.blockwise
. This is useful when your function cannot handle dask arrays natively (e.g. scipy API).
Since squared_error
can handle dask arrays without computing them, we specify
dask="allowed"
.
sqer = xr.apply_ufunc(
squared_error,
ds.air,
1,
dask="allowed",
)
sqer # dask-backed DataArray! with nice metadata!
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> dask.array<pow, shape=(2920, 25, 53), dtype=float32, chunksize=(100, 25, 53), chunktype=numpy.ndarray> Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Let’s again use the wrapper trick to understand what squared_error
receives.
We see that it receives a dask array.
def wrapper(x, y):
print(f"received x of type {type(x)}, shape {x.shape}")
print(f"received y of type {type(y)}")
return squared_error(x, y)
xr.apply_ufunc(wrapper, ds.air, 1, dask="allowed")
received x of type <class 'dask.array.core.Array'>, shape (2920, 25, 53)
received y of type <class 'int'>
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> dask.array<pow, shape=(2920, 25, 53), dtype=float32, chunksize=(100, 25, 53), chunktype=numpy.ndarray> Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Reductions and core dimensions#
squared_error
operated on a per-element basis. How about a reduction like np.mean
?
Such functions involve the concept of “core dimensions”. One way to think about core dimensions is to consider the smallest dimensionality of data necessary to apply the function.
For using more complex operations that consider some array values collectively,
it’s important to understand the idea of core dimensions.
Usually, they correspond to the fundamental dimensions over
which an operation is defined, e.g., the summed axis in np.sum
. A good clue
that core dimensions are needed is the presence of an axis
argument on the
corresponding NumPy function.
With apply_ufunc
, core dimensions are recognized by name, and then moved to
the last dimension of any input arguments before applying the given function.
This means that for functions that accept an axis
argument, you usually need
to set axis=-1
Let’s use dask.array.mean
as an example of a function that can handle dask
arrays and uses an axis
kwarg
def time_mean(da):
return xr.apply_ufunc(
dask.array.mean,
da,
input_core_dims=[["time"]],
dask="allowed",
kwargs={"axis": -1}, # core dimensions are moved to the end
)
time_mean(ds.air)
<xarray.DataArray 'air' (lat: 25, lon: 53)> dask.array<mean_agg-aggregate, shape=(25, 53), dtype=float32, chunksize=(25, 53), chunktype=numpy.ndarray> Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
Again, this identical to the built-in mean
ds.air.mean("time").identical(time_mean(ds.air))
True
Automatically parallelizing dask-unaware functions#
A very useful apply_ufunc
feature is the ability to apply arbitrary functions
in parallel to each block. This ability can be activated using
dask="parallelized"
. Again xarray needs a lot of extra metadata, so depending
on the function, extra arguments such as output_dtypes
and output_sizes
may
be necessary.
We will use scipy.integrate.trapz
as an example of a function that cannot
handle dask arrays and requires a core dimension. If we call trapz
with a dask
array, we get a numpy array back that is, the values have been eagerly computed.
This is undesirable behaviour
import scipy as sp
import scipy.integrate
sp.integrate.trapz(ds.air.data, axis=ds.air.get_axis_num("lon")) # does NOT return a dask array
/usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/dask/array/core.py:1711: FutureWarning: The `numpy.trapz` function is not implemented by Dask array. You may want to use the da.map_blocks function or something similar to silence this warning. Your code may stop working in a future release.
warnings.warn(
array([[12588.54 , 12582.26 , 12671.649 , ..., 15374.26 , 15430.039 ,
15493.165 ],
[12571.841 , 12567.279 , 12654.569 , ..., 15355.915 , 15413.14 ,
15477.346 ],
[12584.62 , 12537.54 , 12644.909 , ..., 15347.77 , 15399.9 ,
15460.965 ],
...,
[12709.4795, 12638.4795, 12810.2295, ..., 15416.831 , 15459.581 ,
15510.4795],
[12726.679 , 12634.4795, 12794.63 , ..., 15401.4795, 15454.13 ,
15511.4795],
[12767.33 , 12630.78 , 12754.531 , ..., 15446.33 , 15495.53 ,
15538.18 ]], dtype=float32)
xr.apply_ufunc(
sp.integrate.trapz,
ds,
input_core_dims=[["lon"]],
kwargs={"axis": -1},
dask="parallelized",
)
<xarray.Dataset> Dimensions: (time: 2920, lat: 25) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat) float32 dask.array<chunksize=(100, 25), meta=np.ndarray>
client.close()