A gentle introduction#
map_blocks
is inspired by the dask.array
function of the same name and lets
you map a function on blocks of the xarray object (including Datasets!).
At compute time, your function will receive a chunk of an xarray object with concrete (computed) values along with appropriate metadata. This function should return an xarray object.
Setup#
First lets set up a LocalCluster
using dask.distributed.
You can use any kind of dask cluster. This step is completely independent of xarray. While not strictly necessary, the dashboard provides a nice learning tool.
Client
Client-19d1c761-565a-11f0-918d-000d3a353cbe
Connection method: Cluster object | Cluster type: distributed.LocalCluster |
Dashboard: http://127.0.0.1:8787/status |
Cluster Info
LocalCluster
4cfe907f
Dashboard: http://127.0.0.1:8787/status | Workers: 4 |
Total threads: 4 | Total memory: 15.62 GiB |
Status: running | Using processes: True |
Scheduler Info
Scheduler
Scheduler-66a9e00d-f481-4193-969f-98abe282c9b6
Comm: tcp://127.0.0.1:40003 | Workers: 0 |
Dashboard: http://127.0.0.1:8787/status | Total threads: 0 |
Started: Just now | Total memory: 0 B |
Workers
Worker: 0
Comm: tcp://127.0.0.1:34737 | Total threads: 1 |
Dashboard: http://127.0.0.1:44693/status | Memory: 3.91 GiB |
Nanny: tcp://127.0.0.1:42723 | |
Local directory: /tmp/dask-scratch-space/worker-1mxi31ib |
Worker: 1
Comm: tcp://127.0.0.1:38393 | Total threads: 1 |
Dashboard: http://127.0.0.1:41057/status | Memory: 3.91 GiB |
Nanny: tcp://127.0.0.1:43789 | |
Local directory: /tmp/dask-scratch-space/worker-6ci4g21w |
Worker: 2
Comm: tcp://127.0.0.1:35327 | Total threads: 1 |
Dashboard: http://127.0.0.1:35991/status | Memory: 3.91 GiB |
Nanny: tcp://127.0.0.1:42409 | |
Local directory: /tmp/dask-scratch-space/worker-u30y3_ov |
Worker: 3
Comm: tcp://127.0.0.1:33211 | Total threads: 1 |
Dashboard: http://127.0.0.1:44129/status | Memory: 3.91 GiB |
Nanny: tcp://127.0.0.1:39579 | |
Local directory: /tmp/dask-scratch-space/worker-tsvrvaf9 |
👆
Click the Dashboard link above. Or click the "Search" button in the dashboard.Let’s test that the dashboard is working..
import dask.array
dask.array.ones((1000, 4), chunks=(2, 1)).compute() # should see activity in dashboard
array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
...,
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], shape=(1000, 4))
Let’s open a dataset. We specify chunks
so that we create a dask arrays for the DataArrays.
Depending on the desired function to be applied on the chunks, it is vital to set the chunks correctly. Our goal is to compute the mean along the time dimension. Therefore we do not chunk the time dimension at all (indicated by "time": -1
). We deliberately set lat
and lon
chunks to something smaller then the size of their respective dimension (otherwise we would potentially end up with a single big chunk for the entire ds
).
ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": -1, "lat": 5, "lon": 10})
ds
<xarray.Dataset> Size: 31MB Dimensions: (lat: 25, time: 2920, lon: 53) Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float64 31MB dask.array<chunksize=(2920, 5, 10), meta=np.ndarray> Attributes: Conventions: COARDS title: 4x daily NMC reanalysis (1948) description: Data is from NMC initialized reanalysis\n(4x/day). These a... platform: Model references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
Simple example#
Here is an example
def time_mean(obj: xr.Dataset):
# use xarray's convenient API here
# you could convert to a pandas dataframe and use pandas' extensive API
# or use .plot() and plt.savefig to save visualizations to disk in parallel.
return obj.mean("time")
ds.map_blocks(time_mean) # this is lazy!
<xarray.Dataset> Size: 11kB Dimensions: (lat: 25, lon: 53) Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 Data variables: air (lat, lon) float64 11kB dask.array<chunksize=(5, 10), meta=np.ndarray>
# this triggers the actual computation
ds.map_blocks(time_mean).compute()
<xarray.Dataset> Size: 11kB Dimensions: (lat: 25, lon: 53) Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 Data variables: air (lat, lon) float64 11kB 260.4 260.2 259.9 ... 297.3 297.3 297.3
# this will calculate values and will return True if the computation works as expected
ds.map_blocks(time_mean).equals(ds.mean("time"))
True
Exercises#
Exercise 1
When opening the dataset, set the chunks for the dimension to anything smaller than the size of the time dimension (< 2920), e.g., "time": 100
, and keep the size of the other chunks the same:
ds = xr.tutorial.open_dataset(
"air_temperature",
chunks={"time": 100, "lat": 5, "lon": 10},
)
Now run the notebook again. The result of ds.map_blocks(time_mean)
is no more equivalent to ds.mean("time")
. Why does ds.map_blocks(time_mean)
return a different result this time?
Solution
Quoting from the documentation of map_blocks
: The function will receive a subset or ‘block’ of obj (see below), corresponding to one chunk along each chunked dimension.
ds.mean("time")
computes the mean over the entire time dimension. In our example ds.map_blocks(time_mean)
passes individual chunks of ds
to time_mean
. Once the time dimension is chunked, time_mean
receives more than a single chunk along the dimension, meaning time_mean
computes the mean along the time dimension for a single chunk rather than along the entire time dimension. Therefore we do not receive an identical result.
You can also modify the function to show the shape of the chunks passed to time_mean
. Compare the output of the modified function with ds.chunks
to find out how they relate to each other!
Exercise 2
Try applying the following function with map_blocks
. Specify scale
as an
argument and offset
as a kwarg.
The docstring should help: https://docs.xarray.dev/en/stable/generated/xarray.map_blocks.html
def time_mean_scaled(obj, scale, offset):
return obj.mean("lat") * scale + offset
More advanced functions#
map_blocks
needs to know what the returned object looks like exactly. It
does so by passing a 0-shaped xarray object to the function and examining the
result. This approach cannot work in all cases For such advanced use cases,
map_blocks
allows a template
kwarg. See
https://docs.xarray.dev/en/stable/user-guide/dask.html#map-blocks for more details
client.close()