https://docs.xarray.dev/en/stable/_static/dataset-diagram-logo.png

Parallel computing with Dask#

This notebook demonstrates one of Xarray’s most powerful features: the ability to wrap dask arrays and allow users to seamlessly execute analysis code in parallel.

By the end of this notebook, you will:

  1. Xarray DataArrays and Datasets are “dask collections” i.e. you can execute top-level dask functions such as dask.visualize(xarray_object)

  2. Learn that all xarray built-in operations can transparently use dask

Important: Using Dask does not always make your computations run faster! Performance will depend on the computational infrastructure you’re using (for example, how many CPU cores), how the data you’re working with is structured and stored, and the algorithms and code you’re running. Be sure to review the Dask best-practices if you’re new to Dask!

Reading data#

The chunks argument to both open_dataset and open_mfdataset allow you to read datasets as dask arrays.

import xarray as xr
ds = xr.tutorial.open_dataset(
    "air_temperature",
    chunks={  # this tells xarray to open the dataset as a dask array
        "lat": 25,
        "lon": 25,
        "time": -1,
    },
)
ds.air
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>
dask.array<open_dataset-6d930c2810aaa0006ac7c82c2d96f77aair, shape=(2920, 25, 53), dtype=float32, chunksize=(2920, 25, 25), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16 322.1 ]

The representation (“repr” in Python parlance) for the air DataArray shows the very nice HTML dask array repr. You can access the underlying chunk sizes using .chunks:

ds.air.chunks
((2920,), (25,), (25, 25, 3))

Tip: All variables in a Dataset need not have the same chunk size along common dimensions.

lazy computation#

Xarray seamlessly wraps dask so all computation is deferred until explicitly requested.

mean = ds.air.mean("time")

Dask actually constructs a graph of the required computation. Here it’s pretty simple: The full array is subdivided into 3 arrays. Dask will load each of these subarrays in a separate thread using the default single-machine scheduling. You can visualize dask ‘task graphs’ which represent the requested computation:

mean.data.visualize(optimize_graph=True)
../_images/xarray_and_dask_10_0.png

Getting concrete values from dask arrays#

At some point, you will want to actually get concrete values (usually a numpy array) from dask.

There are two ways to compute values on dask arrays.

  1. .compute() returns an xarray object

  2. .load() replaces the dask array in the xarray object with a numpy array. This is equivalent to ds = ds.compute()

Distributed Clusters#

As your data volumes grow and algorithms get more complex it can be hard to print out task graph representations and understand what Dask is doing behind the scenes. Luckily, you can use Dask’s ‘Distributed’ scheduler to get very useful diagnotisic information.

First let’s set up a LocalCluster using dask.distributed.

You can use any kind of Dask cluster. This step is completely independent of xarray. While not strictly necessary, the dashboard provides a nice learning tool.

from dask.distributed import Client

# This piece of code is just for a correct dashboard link mybinder.org or other JupyterHub demos
import dask
import os

if os.environ.get('JUPYTERHUB_USER'):
    dask.config.set(**{"distributed.dashboard.link": "/user/{JUPYTERHUB_USER}/proxy/{port}/status"})

client = Client(local_directory='/tmp')
client

Client

Client-8fc2396e-39de-11ed-8b8d-00224843d23f

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status

Cluster Info

☝️ Click the Dashboard link above.

👈 Or click the “Search” 🔍 button in the dask-labextension dashboard.

NOTE: if using the dask-labextension, you should disable the ‘Simple’ JupyterLab interface (View -> Simple Interface), so that you can drag and rearrange whichever dashboards you want. The Workers and Task Stream are good to make sure the dashboard is working!

import dask.array

dask.array.ones((1000, 4), chunks=(2, 1)).compute()  # should see activity in dashboard
array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       ...,
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]])

Examining a DataArray with dask#

Let’s go back to our xarray DataSet, in addition to computing the mean, other operations such as indexing will automatically use whichever Dask Cluster we are connected to!

ds.air.isel(lon=1, lat=20)
<xarray.DataArray 'air' (time: 2920)>
dask.array<getitem, shape=(2920,), dtype=float32, chunksize=(2920,), chunktype=numpy.ndarray>
Coordinates:
    lat      float32 25.0
    lon      float32 202.5
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16 322.1 ]

and more complicated operations…

timeseries = ds.air.rolling(time=5).mean().isel(lon=1, lat=20)  # no activity on dashboard
timeseries  # contains dask array
<xarray.DataArray 'air' (time: 2920)>
dask.array<getitem, shape=(2920,), dtype=float64, chunksize=(2920,), chunktype=numpy.ndarray>
Coordinates:
    lat      float32 25.0
    lon      float32 202.5
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16 322.1 ]
timeseries = ds.air.rolling(time=5).mean()  # no activity on dashboard
timeseries  # contains dask array
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>
dask.array<truediv, shape=(2920, 25, 53), dtype=float64, chunksize=(2920, 25, 25), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16 322.1 ]
computed = mean.compute()  # activity on dashboard
computed  # has real numpy values
<xarray.DataArray 'air' (lat: 25, lon: 53)>
array([[260.37564, 260.1826 , 259.88593, ..., 250.81511, 251.93733,
        253.43741],
       [262.7337 , 262.7936 , 262.7489 , ..., 249.75496, 251.5852 ,
        254.35849],
       [264.7681 , 264.3271 , 264.0614 , ..., 250.60707, 253.58247,
        257.71475],
       ...,
       [297.64932, 296.95294, 296.62912, ..., 296.81033, 296.28793,
        295.81622],
       [298.1287 , 297.93646, 297.47006, ..., 296.8591 , 296.77686,
        296.44348],
       [298.36594, 298.38593, 298.11386, ..., 297.33777, 297.28104,
        297.30502]], dtype=float32)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0

Note that mean still contains a dask array

mean
<xarray.DataArray 'air' (lat: 25, lon: 53)>
dask.array<mean_agg-aggregate, shape=(25, 53), dtype=float32, chunksize=(25, 25), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0

But if we call .load(), mean will now contain a numpy array

mean.load()
<xarray.DataArray 'air' (lat: 25, lon: 53)>
array([[260.37564, 260.1826 , 259.88593, ..., 250.81511, 251.93733,
        253.43741],
       [262.7337 , 262.7936 , 262.7489 , ..., 249.75496, 251.5852 ,
        254.35849],
       [264.7681 , 264.3271 , 264.0614 , ..., 250.60707, 253.58247,
        257.71475],
       ...,
       [297.64932, 296.95294, 296.62912, ..., 296.81033, 296.28793,
        295.81622],
       [298.1287 , 297.93646, 297.47006, ..., 296.8591 , 296.77686,
        296.44348],
       [298.36594, 298.38593, 298.11386, ..., 297.33777, 297.28104,
        297.30502]], dtype=float32)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0

Let’s check that again…

mean
<xarray.DataArray 'air' (lat: 25, lon: 53)>
array([[260.37564, 260.1826 , 259.88593, ..., 250.81511, 251.93733,
        253.43741],
       [262.7337 , 262.7936 , 262.7489 , ..., 249.75496, 251.5852 ,
        254.35849],
       [264.7681 , 264.3271 , 264.0614 , ..., 250.60707, 253.58247,
        257.71475],
       ...,
       [297.64932, 296.95294, 296.62912, ..., 296.81033, 296.28793,
        295.81622],
       [298.1287 , 297.93646, 297.47006, ..., 296.8591 , 296.77686,
        296.44348],
       [298.36594, 298.38593, 298.11386, ..., 297.33777, 297.28104,
        297.30502]], dtype=float32)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0

Tip: .persist() loads the values into distributed RAM. This is useful if you will be repeatedly using a dataset for computation but it is too large to load into local memory. You will see a persistent task on the dashboard.

See https://docs.dask.org/en/latest/api.html#dask.persist for more

Extracting underlying data: .values vs .data#

There are two ways to pull out the underlying data in an xarray object.

  1. .values will always return a NumPy array. For dask-backed xarray objects, this means that compute will always be called

  2. .data will return a Dask array

Xarray data structures are first-class dask collections.#

This means you can do things like dask.compute(xarray_object), dask.visualize(xarray_object), dask.persist(xarray_object). This works for both DataArrays and Datasets

Exercise#

Visualize the task graph for a few different computations on ds.air!

Gracefully shutdown our connection to the Dask cluster. This becomes more important when you are running on large HPC or Cloud servers rather than a laptop!

client.close()