Hierarchical computations#

In this lesson, we extend what we learned about Basic Computation to hierarchical datasets. By the end of the lesson, we will be able to:

  • Apply basic arithmetic and label-aware reductions to xarray DataTree objects

  • Apply arbitrary functions across all nodes across a tree

import xarray as xr
import numpy as np

xr.set_options(keep_attrs=True, display_expand_attrs=False, display_expand_data=False)
<xarray.core.options.set_options at 0x7f2a7dda64e0>

Example dataset#

First we load the NMC reanalysis air temperature dataset and arrange it to form a hierarchy of temporal resolutions:

ds = xr.tutorial.open_dataset("air_temperature")

ds_daily = ds.resample(time="D").mean("time")
ds_weekly = ds.resample(time="W").mean("time")
ds_monthly = ds.resample(time="ME").mean("time")

tree = xr.DataTree.from_dict(
    {
        "daily": ds_daily,
        "weekly": ds_weekly,
        "monthly": ds_monthly,
        "": xr.Dataset(attrs={"name": "NMC reanalysis temporal pyramid"}),
    }
)
tree
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*
Attributes: (1)

Arithmetic#

As an extension to Dataset, DataTree objects automatically apply arithmetic to all variables within all nodes:

tree - 273.15
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*
Attributes: (1)

Indexing#

Just like arithmetic, indexing is simply forwarded to the node datasets. The only difference is that nodes that don’t have a certain coordinate / dimension are skipped instead of raising an error:

tree.isel(lat=slice(None, 10))
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*
Attributes: (1)
tree.sel(time="2013-11")
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*
Attributes: (1)

Reductions#

In a similar way, we can reduce all nodes in the datatree at once:

tree.mean(dim=["lat", "lon"])
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*
Attributes: (1)

Applying functions designed for Dataset with map_over_datasets#

What if we wanted to convert the data to log-space? For a Dataset or DataArray, we could just use xarray.ufuncs.log(), but that does not support DataTree objects, yet:

xr.ufuncs.log(tree)
<xarray.Dataset> Size: 0B
Dimensions:  ()
Data variables:
    *empty*

Note how the result is a empty Dataset?

To map a function to all nodes, we can use xarray.map_over_datasets() and xarray.DataTree.map_over_datasets():

tree.map_over_datasets(xr.ufuncs.log)
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*
Attributes: (1)

We can also use a custom function to perform more complex operations, like subtracting a group mean:

def demean(ds):
    return ds.groupby("time.day") - ds.groupby("time.day").mean()

Applying that to the dataset raises an error, though:

tree.map_over_datasets(demean)
Hide code cell output
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/dataset.py in ?(self, name)
   1154             variable = self._variables[name]
   1155         except KeyError:
-> 1156             _, name, variable = _get_virtual_variable(self._variables, name, self.sizes)
   1157 

KeyError: 'time.day'

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
File ~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/dataset.py:1261, in Dataset.__getitem__(self, key)
   1260 try:
-> 1261     return self._construct_dataarray(key)
   1262 except KeyError as e:

File ~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/dataset.py:1156, in Dataset._construct_dataarray(self, name)
   1155 except KeyError:
-> 1156     _, name, variable = _get_virtual_variable(self._variables, name, self.sizes)
   1158 needed_dims = set(variable.dims)

File ~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/dataset_utils.py:82, in _get_virtual_variable(variables, key, dim_sizes)
     81 ref_name, var_name = split_key
---> 82 ref_var = variables[ref_name]
     84 if _contains_datetime_like_objects(ref_var):

KeyError: 'time'

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
/tmp/ipykernel_2148/422171771.py in ?()
----> 1 tree.map_over_datasets(demean)

~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/datatree.py in ?(self, func, kwargs, *args)
   1531         map_over_datasets
   1532         """
   1533         # TODO this signature means that func has no way to know which node it is being called upon - change?
   1534         # TODO fix this typing error
-> 1535         return map_over_datasets(func, self, *args, kwargs=kwargs)

~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/datatree_mapping.py in ?(func, kwargs, *args)
    116             if not isinstance(arg, DataTree):
    117                 node_dataset_args.insert(i, arg)
    118 
    119         func_with_error_context = _handle_errors_with_path_context(path)(func)
--> 120         results = func_with_error_context(*node_dataset_args, **kwargs)
    121         out_data_objects[path] = results
    122 
    123     num_return_values = _check_all_return_values(out_data_objects)

~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/datatree_mapping.py in ?(*args, **kwargs)
    152                 # Add the context information to the error message
    153                 add_note(
    154                     e, f"Raised whilst mapping function over node with path {path!r}"
    155                 )
--> 156                 raise

/tmp/ipykernel_2148/373546963.py in ?(ds)
      1 def demean(ds):
----> 2     return ds.groupby("time.day") - ds.groupby("time.day").mean()

~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/util/deprecation_helpers.py in ?(*args, **kwargs)
    114                 kwargs.update(zip_args)
    115 
    116                 return func(*args[:-n_extra_args], **kwargs)
    117 
--> 118             return func(*args, **kwargs)

~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/dataset.py in ?(self, group, squeeze, restore_coord_dims, eagerly_compute_group, **groupers)
   9961             _validate_groupby_squeeze,
   9962         )
   9963 
   9964         _validate_groupby_squeeze(squeeze)
-> 9965         rgroupers = _parse_group_and_groupers(
   9966             self, group, groupers, eagerly_compute_group=eagerly_compute_group
   9967         )
   9968 

~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/groupby.py in ?(obj, group, groupers, eagerly_compute_group)
    430             grouper_mapping = {g: UniqueGrouper() for g in group_iter}
    431         elif groupers:
    432             grouper_mapping = cast("Mapping[Hashable, Grouper]", groupers)
    433 
--> 434         rgroupers = tuple(
    435             ResolvedGrouper(
    436                 grouper, group, obj, eagerly_compute_group=eagerly_compute_group
    437             )

~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/groupby.py in ?(.0)
    434 def _parse_group_and_groupers(
--> 435     obj: T_Xarray,
    436     group: GroupInput,
    437     groupers: dict[str, Grouper],
    438     *,

<string> in ?(self, grouper, group, obj, eagerly_compute_group)

~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/groupby.py in ?(self)
    323         from xarray.groupers import BinGrouper, UniqueGrouper
    324 
    325         self.grouper = copy.deepcopy(self.grouper)
    326 
--> 327         self.group = _resolve_group(self.obj, self.group)
    328 
    329         if self.eagerly_compute_group:
    330             raise ValueError(

~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/groupby.py in ?(obj, group)
    490                 "`group` must be an xarray.DataArray or the "
    491                 "name of an xarray variable or dimension. "
    492                 f"Received {group!r} instead."
    493             )
--> 494         group_da: DataArray = obj[group]
    495         if group_da.name not in obj._indexes and group_da.name in obj.dims:
    496             # DummyGroups should not appear on groupby results
    497             newgroup = _DummyGroup(obj, group_da.name, group_da.coords)

~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/datatree.py in ?(self, key)
    311     def __getitem__(self, key) -> DataArray | Dataset:
    312         # TODO call the `_get_item` method of DataTree to allow path-like access to contents of other nodes
    313         # For now just call Dataset.__getitem__
--> 314         return Dataset.__getitem__(self, key)

~/work/xarray-tutorial/xarray-tutorial/.pixi/envs/default/lib/python3.12/site-packages/xarray/core/dataset.py in ?(self, key)
   1270 
   1271                 # If someone attempts `ds['foo' , 'bar']` instead of `ds[['foo', 'bar']]`
   1272                 if isinstance(key, tuple):
   1273                     message += f"\nHint: use a list to select multiple variables, for example `ds[{list(key)}]`"
-> 1274                 raise KeyError(message) from e
   1275 
   1276         if utils.iterable_of_hashable(key):
   1277             return self._copy_listed(key)

KeyError: "No variable named 'time.day'. Variables on the dataset include []"
Raised whilst mapping function over node with path '.'

The reason for this error is that the root node does not have any variables, and thus in particular no "time" coordinate. To avoid the error, we have to skip computing the function for that node:

def demean(ds):
    if "time" not in ds.coords:
        return ds
    return ds.groupby("time.day") - ds.groupby("time.day").mean()


tree.map_over_datasets(demean)
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*
Attributes: (1)