Handling dask arrays#

We have previously worked over applying functions to NumPy arrays contained in Xarray objects. apply_ufunc also lets you easily perform many of the steps involving in applying functions that expect and return Dask arrays.

Learning goals:

  • Learn that apply_ufunc can automate aspects of applying computation functions on dask arrays

  • It is possible to automatically parallelize certain operations by providing dask="parallelized"

  • In some cases, extra information needs to be provided such as sizes of any new dimensions added, or data types for output variables.

  • Learn that all the concepts from the numpy lessons carry over: like automatic vectorization and specifying input and output core dimensions.

Tip

We’ll reduce the length of error messages using %xmode minimal See the ipython documentation for details.

Setup#

%xmode minimal

import dask
import numpy as np
import xarray as xr

# limit the amount of information printed to screen
xr.set_options(display_expand_data=False)
np.set_printoptions(threshold=10, edgeitems=2)
Exception reporting mode: Minimal

First lets set up a LocalCluster using dask.distributed.

You can use any kind of dask cluster. This step is completely independent of xarray. While not strictly necessary, the dashboard provides a nice learning tool.

from dask.distributed import Client

client = Client()
client

Client

Client-4aaa3f95-03d6-11ef-8dc7-000d3a3cdb9c

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status

Cluster Info

👆

Click the Dashboard link above. Or click the "Search" button in the dashboard.

Let’s test that the dashboard is working..

import dask.array

dask.array.ones((1000, 4), chunks=(2, 1)).compute()  # should see activity in dashboard
array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       ...,
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]])

Let’s open a dataset. We specify chunks so that we create a dask arrays for the DataArrays

ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": 100})
ds
<xarray.Dataset>
Dimensions:  (lat: 25, time: 2920, lon: 53)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat, lon) float32 dask.array<chunksize=(100, 25, 53), meta=np.ndarray>
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...

A simple example#

All the concepts from applying numpy functions carry over.

However the handling of dask arrays needs to be explicitly activated.

There are three options for the dask kwarg.

    dask : {"forbidden", "allowed", "parallelized"}, default: "forbidden"
        How to handle applying to objects containing lazy data in the form of
        dask arrays:

        - 'forbidden' (default): raise an error if a dask array is encountered.
        - 'allowed': pass dask arrays directly on to ``func``. Prefer this option if
          ``func`` natively supports dask arrays.
        - 'parallelized': automatically parallelize ``func`` if any of the
          inputs are a dask array by using :py:func:`dask.array.apply_gufunc`. Multiple output
          arguments are supported. Only use this option if ``func`` does not natively
          support dask arrays (e.g. converts them to numpy arrays).

We will work through the following two:

  1. dask="allowed" Dask arrays are passed to the user function. This is a good choice if your function can handle dask arrays and won’t compute the result unless explicitly requested.

  2. dask="parallelized". This applies the user function over blocks of the dask array using dask.array.apply_gufunc. This is useful when your function cannot handle dask arrays natively (e.g. scipy API).

# Expect an error here
def squared_error(x, y):
    return (x - y) ** 2


xr.apply_ufunc(squared_error, ds.air, 1)
ValueError: apply_ufunc encountered a chunked array on an argument, but handling for chunked arrays has not been enabled. Either set the ``dask`` argument or load your data into memory first with ``.load()`` or ``.compute()``

A good thing to check is whether the applied function (here squared_error) can handle pure dask arrays. To do this call squared_error(ds.air.data, 1) and make sure of the following:

  1. That you don’t see any activity on the dask dashboard

  2. That the returned result is a dask array.

squared_error(ds.air.data, 1)
Array Chunk
Bytes 14.76 MiB 517.58 kiB
Shape (2920, 25, 53) (100, 25, 53)
Dask graph 30 chunks in 4 graph layers
Data type float32 numpy.ndarray
53 25 2920

Since squared_error can handle dask arrays without computing them, we specify dask="allowed".

sqer = xr.apply_ufunc(
    squared_error,
    ds.air,
    1,
    dask="allowed",
)
sqer  # dask-backed DataArray! with nice metadata!
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>
dask.array<chunksize=(100, 25, 53), meta=np.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00

Understanding what’s happening#

Let’s again use the wrapper trick to understand what squared_error receives.

We see that it receives a dask array (analogous to the numpy array in the previous example).

def wrapper(x, y):
    print(f"received x of type {type(x)}, shape {x.shape}")
    print(f"received y of type {type(y)}")
    return squared_error(x, y)


xr.apply_ufunc(wrapper, ds.air, 1, dask="allowed")
received x of type <class 'dask.array.core.Array'>, shape (2920, 25, 53)
received y of type <class 'int'>
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>
dask.array<chunksize=(100, 25, 53), meta=np.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00

Core dimensions#

squared_error operated on a per-element basis. How about a reduction like np.mean?

Such functions involve the concept of “core dimensions”. This concept is independent of the underlying array type, and is a property of the applied function. See the core dimensions with NumPy tutorial for more.

Exercise 28

Use dask.array.mean as an example of a function that can handle dask arrays and uses an axis kwarg.

Again, this is identical to the built-in mean

def time_mean(da):
    return xr.apply_ufunc(
        dask.array.mean,
        da,
        input_core_dims=[["time"]],
        dask="allowed",
        kwargs={"axis": -1},  # core dimensions are moved to the end
    )


ds.air.mean("time").identical(time_mean(ds.air))
True

Automatically parallelizing dask-unaware functions#

Basics#

Not all functions can handle dask arrays appropriately by default.

A very useful apply_ufunc feature is the ability to apply arbitrary functions in parallel to each block. This ability can be activated using dask="parallelized".

We will use scipy.integrate.trapz as an example of a function that cannot handle dask arrays and requires a core dimension. If we call trapz with a dask array, we get a numpy array back that is, the values have been eagerly computed. This is undesirable behaviour

import scipy as sp
import scipy.integrate

sp.integrate.trapz(
    ds.air.data, axis=ds.air.get_axis_num("lon")
)  # does NOT return a dask array, you should see activity on the dashboard
/tmp/ipykernel_3527/2065701532.py:4: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
  sp.integrate.trapz(
array([[12588.54  , 12582.26  , ..., 15430.039 , 15493.165 ],
       [12571.841 , 12567.279 , ..., 15413.14  , 15477.346 ],
       ...,
       [12726.679 , 12634.4795, ..., 15454.13  , 15511.4795],
       [12767.33  , 12630.78  , ..., 15495.53  , 15538.18  ]],
      dtype=float32)

Let’s activate automatic parallelization by using apply_ufunc with dask="parallelized"

integrated = xr.apply_ufunc(
    sp.integrate.trapz,
    ds,
    input_core_dims=[["lon"]],
    kwargs={"axis": -1},
    dask="parallelized",
)
integrated
<xarray.Dataset>
Dimensions:  (time: 2920, lat: 25)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat) float32 dask.array<chunksize=(100, 25), meta=np.ndarray>

And make sure the returned data is a dask array

integrated.air.data
Array Chunk
Bytes 285.16 kiB 9.77 kiB
Shape (2920, 25) (100, 25)
Dask graph 30 chunks in 5 graph layers
Data type float32 numpy.ndarray
25 2920

Now you have control over executing this parallel computation.

# Dask -> Numpy array of integrated values
parallelized_results = integrated.compute()
parallelized_results
<xarray.Dataset>
Dimensions:  (time: 2920, lat: 25)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat) float32 1.259e+04 1.258e+04 ... 1.55e+04 1.554e+04

Understanding dask="parallelized"#

It is very important to understand what dask="parallelized" does. To fully understand it, requires understanding some core concepts.

See also

For dask="parallelized" apply_ufunc will call dask.array.apply_gufunc. See the dask documentation on generalized ufuncs and apply_gufunc for more.

Embarrassingly parallel or blockwise operations#

dask="parallelized" works well for “blockwise” or “embarrassingly parallel” operations (Wikipedia).

These are operations where one block or chunk of the output array corresponds to one block or chunk of the input array. Specifically, the blocks or chunks of the core dimension is what matters. Importantly, no communication between blocks is necessary to create the output, which makes parallelization quite simple or “embarrassing”.

Let’s look at the dask repr for ds and note chunksizes are (100,25,53) for a array with shape (2920, 25, 53). This means that each block or chunk of the array contains all lat, lon points and a subset of time points.

ds.air.data
Array Chunk
Bytes 14.76 MiB 517.58 kiB
Shape (2920, 25, 53) (100, 25, 53)
Dask graph 30 chunks in 2 graph layers
Data type float32 numpy.ndarray
53 25 2920

The core dimension for trapz is lon, and there is only one chunk along lon. This means that integrating along lon is a “blockwise” or “embarrassingly parallel” operation and dask="parallelized" works quite well.

Caution

Question Do you understand why integrate(ds) when ds has a single chunk along lon is a “embarrassingly parallel” operation?

Exercise 29

Apply the integrate function to ds after rechunking to have a different chunksize along lon using ds.chunk(lon=4) (for example). What happens?

Understanding execution#

We are layering many concepts together there so it is important to understand how the function is executed, and what input it will receive. Again we will use our wrapper trick.

def integrate_wrapper(array, **kwargs):
    print(f"received array of type {type(array)}, shape {array.shape}")
    result = sp.integrate.trapz(array, **kwargs)
    print(f"received array of type {type(result)}, shape {result.shape}")
    return result


integrated = xr.apply_ufunc(
    integrate_wrapper,
    ds,
    input_core_dims=[["lon"]],
    kwargs={"axis": -1},
    dask="parallelized",
)
integrated
received array of type <class 'numpy.ndarray'>, shape (1, 1, 1)
received array of type <class 'numpy.ndarray'>, shape (1, 1)
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
  result = sp.integrate.trapz(array, **kwargs)
<xarray.Dataset>
Dimensions:  (time: 2920, lat: 25)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat) float32 dask.array<chunksize=(100, 25), meta=np.ndarray>

Note that we received an Xarray object back (integrated) but our wrapper function was called with a numpy array of shape (1,1,1).

Important

the full 3D array has not yet been passed to integrate_wrapper. Yet dask needs to know the shape and dtype of the result. This is key.

The integrate_wrapper function is treated like a black box, and its effect on the inputs has to either be described through additional keyword arguments, or inferred by passing dummy inputs.

To do so, dask.array.apply_gufunc calls the user function with dummy inputs (here a numpy array of shape (1,1,1)), and inspects the returned value to understand that one dimension was removed (returned a numpy array of shape (1,1).

Since no errors were raised we proceed as-is.

Let’s compute the array to get real values.

integrated.compute()
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
/tmp/ipykernel_3527/2804912483.py:3: DeprecationWarning: 'scipy.integrate.trapz' is deprecated in favour of 'scipy.integrate.trapezoid' and will be removed in SciPy 1.14.0
<xarray.Dataset>
Dimensions:  (time: 2920, lat: 25)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat) float32 1.259e+04 1.258e+04 ... 1.55e+04 1.554e+04

We see that integrate_wrapper is called many times! As many times as there are blocks in the array in fact, which is 30 here (ds.air.data.numblocks).

Our function is independently executed on each block of the array, and then the results are concatenated to form the final result.

Conceptually, there is a two-way flow of information between various packages when executing integrated.compute():

xarray.apply_ufuncdask.array.apply_gufuncintegrate_wrapperscipy.integrate.trapzds.air.data

When executed

  1. Xarray loops over all data variables.

  2. Xarray unwraps the underlying dask array (e.g. ds.air) and passes that to dask’s apply_gufunc.

  3. apply_gufunc calls integrate_wrapper on each block of the array.

  4. For each block, integrate_wrapper calls scipy.integrate.trapz and returns one block of the output array.

  5. dask stitches all the output blocks to form the output array.

  6. xarray.apply_ufunc wraps the output array with Xarray metadata to give the final result.

Phew!

More complex situations#

Here we quickly demonstrate that all the concepts from the numpy material earlier carry over.

Xarray needs a lot of extra metadata, so depending on the function, extra arguments such as output_dtypes and output_sizes may be necessary for supporting dask arrays. We demonstrate this below.

Adding new dimensions#

We use the expand_dims example that changes the size of the input along a single dimension.

def add_new_dim(array):
    return np.expand_dims(array, axis=0)

When automatically parallelizing with dask, we need to provide some more information about the outputs.

  1. When adding a new dimensions, we need to provide the size in dask_gufunc_kwargs using the key output_sizes

  2. Usually we need provide the datatype or dtype of the returned array. Usually the dtype of the input is a good guess.

def add_new_dim(array):
    return np.expand_dims(array, axis=-1)


xr.apply_ufunc(
    add_new_dim,  # first the function
    ds.air.chunk({"time": 2, "lon": 2}),
    output_core_dims=[["newdim"]],
    dask="parallelized",
)
ValueError: dimension 'newdim' in 'output_core_dims' needs corresponding (dim, size) in 'output_sizes'

Provide the size of the newly added dimension newdim in output_sizes as part of the dask_gufunc_kwargs keyword argument:

dask_gufunc_kwargs (dict, optional) – Optional keyword arguments passed to dask.array.apply_gufunc() 
if dask=’parallelized’. Possible keywords are output_sizes, allow_rechunk and meta.

The syntax is

dask_gufunc_kwargs={
    "output_sizes": {"newdim": 1}
}
xr.apply_ufunc(
    add_new_dim,  # first the function
    ds.air.chunk({"time": 2, "lon": 2}),
    output_core_dims=[["newdim"]],
    dask="parallelized",
    dask_gufunc_kwargs={"output_sizes": {"newdim": 1}},
)
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53, newdim: 1)>
dask.array<chunksize=(2, 25, 2, 1), meta=np.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Dimensions without coordinates: newdim

Dimensions that change size#

We will now repeat the interpolation example from earlier with "lat" as the output core dimension. See the numpy notebook on complex output for more.

newlat = np.linspace(15, 75, 100)

xr.apply_ufunc(
    np.interp,
    newlat,
    ds.air.lat,
    ds.air.chunk({"time": 2, "lon": 2}),
    input_core_dims=[["lat"], ["lat"], ["lat"]],
    output_core_dims=[["lat"]],
    exclude_dims={"lat"},
)

We will first add dask="parallelized" and provide output_sizes in dask_gufunc_kwargs

newlat = np.linspace(15, 75, 100)

xr.apply_ufunc(
    np.interp,  # first the function
    newlat,
    ds.air.lat,
    ds.air.chunk({"time": 2, "lon": 2}),
    input_core_dims=[["lat"], ["lat"], ["lat"]],
    output_core_dims=[["lat"]],
    exclude_dims={"lat"},
    # The following are dask-specific
    dask="parallelized",
    dask_gufunc_kwargs=dict(output_sizes={"lat": len(newlat)}),
)
ValueError: `dtype` inference failed in `apply_gufunc`.

Please specify the dtype explicitly using the `output_dtypes` kwarg.

Original error is below:
------------------------
ValueError('object too deep for desired array')

Traceback:
---------
  File "/home/runner/micromamba/envs/xarray-tutorial/lib/python3.11/site-packages/dask/array/core.py", line 463, in apply_infer_dtype
    o = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/micromamba/envs/xarray-tutorial/lib/python3.11/site-packages/numpy/lib/function_base.py", line 1599, in interp
    return interp_func(x, xp, fp, left, right)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This error means that we need to provide output_dtypes

output_dtypes (list of dtype, optional) – Optional list of output dtypes. 
Only used if dask='parallelized' or vectorize=True.
newlat = np.linspace(15, 75, 100)

xr.apply_ufunc(
    np.interp,  # first the function
    newlat,
    ds.air.lat,
    ds.air.chunk({"time": 100, "lon": -1}),
    input_core_dims=[["lat"], ["lat"], ["lat"]],
    output_core_dims=[["lat"]],
    exclude_dims={"lat"},
    # The following are dask-specific
    dask="parallelized",
    dask_gufunc_kwargs=dict(output_sizes={"lat": len(newlat)}),
    output_dtypes=[ds.air.dtype],
)
<xarray.DataArray (time: 2920, lon: 53, lat: 100)>
dask.array<chunksize=(100, 53, 100), meta=np.ndarray>
Coordinates:
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Dimensions without coordinates: lat

Tip

Dask can sometimes figure out the output sizes and dtypes. The usual workflow is to read the error messages and iteratively pass more information to apply_ufunc.

Automatic Vectorizing#

Automatic vectorizing with vectorize=True also carries over!

interped = xr.apply_ufunc(
    np.interp,  # first the function
    newlat,
    ds.air.lat,
    ds.chunk({"time": 100, "lon": -1}),
    input_core_dims=[["lat"], ["lat"], ["lat"]],
    output_core_dims=[["lat"]],
    exclude_dims={"lat"},  # dimensions allowed to change size. Must be set!
    dask="parallelized",
    dask_gufunc_kwargs=dict(output_sizes={"lat": len(newlat)}),
    output_dtypes=[ds.air.dtype],
    vectorize=True,
)
interped
<xarray.Dataset>
Dimensions:  (time: 2920, lon: 53, lat: 100)
Coordinates:
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Dimensions without coordinates: lat
Data variables:
    air      (time, lon, lat) float32 dask.array<chunksize=(100, 53, 100), meta=np.ndarray>

Again, it is important to understand the conceptual flow of information between the variuus packages when executing interped.compute() which looks ilke

xarray.apply_ufuncdask.array.apply_gufuncnumpy.vectorizenumpy.interp

When executed

  1. Xarray loops over all data variables.

  2. Xarray unwraps the underlying dask array (e.g. ds.air) and passes that to dask’s apply_gufunc.

  3. apply_gufunc calls the vectorized function on each block of the array.

  4. For each block, numpy.vectorize handles looping over the loop dimensions and passes 1D vectors along the core dimension to numpy.interp

  5. The 1D results for each block are concatenated by numpy.vectorize to create one output block.

  6. dask stitches all the output blocks to form the output array.

  7. xarray.apply_ufunc wraps the output array with Xarray metadata to give the final result.

Phew!

Clean up the cluster#

client.close();