Creating custom accessors#

Introduction#

An accessor is a way of attaching a custom function to xarray objects so that it can be called as if it were a method while retaining a clear separation between the “core” xarray API and custom API. It enables you to easily extend (which is why you’ll sometimes see it referred to as an extension) and customize xarray’s functionality while limiting naming conflicts and minimizing the chances of your code breaking with xarray upgrades.

If you’ve used rioxarray (e.g. da.rio.crs) or hvplot (e.g. ds.hvplot()), you may have already used an xarray accessor without knowing it!

The Xarray documentation has some more technical details, and this tutorial provides example custom accessors and their uses.

Why create a custom accessor#

  • You can easily create a custom suite of tools that work on Xarray objects

  • It keeps your workflows cleaner and simpler

  • Your project-specific code is easy to share

  • It’s easy to implement: you don’t need to integrate any code into Xarray

  • It makes it easier to perform checks and write code documentation because you only have to create them once!

Easy steps to create your own accessor#

  1. Create your custom class, including the mandatory __init__ method

  2. Add the xr.register_dataarray_accessor() or xr.register_dataset_accessor()

  3. Use your custom functions

Example 1: accessing scipy functionality#

For example, imagine you’re a statistician who regularly uses a special skewness function which acts on dataarrays but is only of interest to people in your specific field.

You can create a method which applies this skewness function to an xarray object and then register the method under a custom stats accessor like this:

import xarray as xr
from scipy.stats import skew

xr.set_options(display_expand_attrs=False, display_expand_coords=False)


@xr.register_dataarray_accessor("stats")
class StatsAccessor:
    def __init__(self, da):
        self._da = da

    def skewness(self, dim):
        return self._da.reduce(func=skew, dim=dim)

Now we can conveniently access this functionality via the stats accessor

ds = xr.tutorial.load_dataset("air_temperature")
ds["skewair"] = ds['air'].stats.skewness(dim="time")
ds
<xarray.Dataset>
Dimensions:  (lat: 25, time: 2920, lon: 53)
Coordinates: (3)
Data variables:
    air      (time, lat, lon) float32 241.2 242.5 243.5 ... 296.5 296.2 295.7
    skewair  (lat, lon) float32 -0.2931 -0.2827 -0.2719 ... -0.1893 -0.1869
Attributes: (5)

Notice how the presence of .stats clearly differentiates our new “accessor method” from core xarray methods.

Example 2: creating your own workflows#

Perhaps you find yourself running similar code for multiple xarray objects or across related projects. By packing your code into an extension, it makes it easy to repeat the same operation while reducing the likelihood of [human introduced] errors.

Here we wrap the reorganization of InSAR ice velocity data illustrated in this tutorial into a custom Xarray extension that makes it easy to re-apply each time you begin working with a new InSAR velocity dataset. Please see the linked tutorial for details on the data, applications, and each step in this process.

import numpy as np
import os
import pandas as pd
import xarray as xr


@xr.register_dataset_accessor("insar_vel")
class InsarReorg:
    """
    An extension for an XArray dataset that will prepare InSAR data for analysis.

    Re-organize the data from its native structure to have x and y velocity and error along a time dimension.
    """

    # ----------------------------------------------------------------------
    # Constructors

    def __init__(self, xrds):
        self._xrds = xrds

    # ----------------------------------------------------------------------
    # Methods

    @staticmethod
    def _validate(self, req_dim=None, req_vars=None):
        '''
        Make sure the xarray dataset has the correct dimensions and variables.

        Running this function will check that my dataset has all the needed dimensions and variables
        for a given function, saving time and headache later if they were missing and the computation fails
        partway through.

        Parameters
        ----------
        req_dim : list of str
            List of all required dimension names
        req_vars : list of str
            List of all required variable  names
        '''

        if req_dim is not None:
            if all([dim not in list(self._xrds.dims) for dim in req_dim]):
                raise AttributeError("Required dimensions are missing")
        if req_vars is not None:
            if all([var not in self._xrds.variables for var in req_vars.keys()]):
                raise AttributeError("Required variables are missing")
        # print("successfully validated your dataset")

    # ----------------------------------------------------------------------
    # Functions

    def change_vars_to_coords(
        self,
        req_dim=['ny', 'nx'],
        req_vars={'xaxis': ['nx'], 'yaxis': ['ny']},
    ):
        """
        Turn the xaxis and y axis variables into coordinates.

        Parameters
        ----------
        req_dim : list of str
            List of all required dimension names.
        req_vars : list of str
            List of all required variable  names
        """

        self._validate(self, req_dim, req_vars)

        self._xrds = self._xrds.swap_dims({'ny': 'yaxis', 'nx': 'xaxis'})
        self._xrds = self._xrds.rename({'xaxis': 'x', 'yaxis': 'y'})

        return self._xrds

    def reorg_dataset(self):
        """
        Reorganize the data by time for each of the desired end variables (here vx, vy, err)

        """

        reorged = []
        for reorg_var in ['vx', 'vy', 'err']:
            ds = self.reorg_var_time(reorg_var)
            reorged.append(ds)

        reorged_ds = xr.merge(reorged)

        return reorged_ds

    def reorg_var_time(self, reorg_var):
        """
        Repeat the process for a given variable.

        Figure out which of the original variables are time steps for this variable and turn each one into a dataarray.
        Add a time dimension and update the variable name for each dataarray.
        Combine the modified data arrays back into a single dataset.
        """

        # create storage list for reorganizing
        var_ls = list(self._xrds)
        to_reorg = [var for var in var_ls if reorg_var in var]

        # list the arrays from the original dataset that correspond to the variable
        das_to_reorg = [self._xrds[var] for var in to_reorg]

        # add the time dimension
        das_to_reorg = [das_to_reorg[var].expand_dims('time') for var in range(len(das_to_reorg))]

        # update variable name to remove time
        das_to_reorg = [das_to_reorg[var].rename(reorg_var) for var in range(len(das_to_reorg))]

        ds = xr.concat(das_to_reorg, dim='time')

        return ds
ds = xr.tutorial.open_dataset('ASE_ice_velocity.nc')
ds = ds.insar_vel.change_vars_to_coords()
ds
<xarray.Dataset>
Dimensions:  (y: 800, x: 500)
Coordinates: (2)
Data variables: (12/30)
    vx1996   (y, x) float32 ...
    vy1996   (y, x) float32 ...
    err1996  (y, x) float32 ...
    vx2000   (y, x) float32 ...
    vy2000   (y, x) float32 ...
    err2000  (y, x) float32 ...
    ...       ...
    vx2011   (y, x) float32 ...
    vy2011   (y, x) float32 ...
    err2011  (y, x) float32 ...
    vx2012   (y, x) float32 ...
    vy2012   (y, x) float32 ...
    err2012  (y, x) float32 ...
Attributes: (21)
ds = ds.insar_vel.reorg_dataset()
ds
<xarray.Dataset>
Dimensions:  (x: 500, y: 800, time: 10)
Coordinates: (2)
Dimensions without coordinates: time
Data variables:
    vx       (time, y, x) float32 nan nan nan nan nan ... nan nan nan nan nan
    vy       (time, y, x) float32 nan nan nan nan nan ... nan nan nan nan nan
    err      (time, y, x) float32 nan nan nan nan nan ... nan nan nan nan nan
Attributes: (2)

Example 3: creating your own workflows with locally stored corrections#

Consider someone who frequently converts their elevations to be relative to the geoid (rather than the ellipsoid) using a custom, local conversion (otherwise, we’d recommend using an established conversion library like pyproj to switch between datums).

An accessor provides an elegant way to build (once) and apply (as often as needed!) this custom conversion on top of the existing xarray ecosystem without the need to copy-paste the code into the start of each project. By standardizing our approach and adding a few sanity checks within the accessor, we also eliminate the risk of accidentally applying the correction multiple times.

import rasterio
import xarray as xr


@xr.register_dataset_accessor("geoidxr")
class GeoidXR:
    """
    An extension for an XArray dataset that will calculate geoidal elevations from a local source file.
    """

    # ----------------------------------------------------------------------
    # Constructors

    def __init__(
        self,
        xrds,
    ):
        self._xrds = xrds
        # Running this function on init will check that my dataset has all the needed dimensions and variables
        # as specific to my workflow, saving time and headache later if they were missing and the computation fails
        # partway through.
        self._validate(
            self, req_dim=['x', 'y', 'dtime'], req_vars={'elevation': ['x', 'y', 'dtime']}
        )

    # ----------------------------------------------------------------------
    # Methods

    @staticmethod
    def _validate(self, req_dim=None, req_vars=None):
        '''
        Make sure the xarray dataset has the correct dimensions and variables

        Parameters
        ----------
        req_dim : list of str
            List of all required dimension names
        req_vars : list of str
            List of all required variable  names
        '''

        if req_dim is not None:
            if all([dim not in list(self._xrds.dims) for dim in req_dim]):
                raise AttributeError("Required dimensions are missing")
        if req_vars is not None:
            if all([var not in self._xrds.variables for var in req_vars.keys()]):
                raise AttributeError("Required variables are missing")

    # Notice that 'geoid' has been added to the req_vars list
    def to_geoid(
        self,
        req_dim=['dtime', 'x', 'y'],
        req_vars={'elevation': ['x', 'y', 'dtime', 'geoid']},
        source=None,
    ):
        """
        Get geoid layer from your local file, which is provided to the function as "source",
        and apply the offset to all elevation values.
        Adds 'geoid_offset' keyword to "offsets" attribute so you know the geoid offset was applied.

        Parameters
        ----------
        req_dim : list of str
            List of all required dimension names.
        req_vars : list of str
            List of all required variable  names
        source : str
            Full path to your source file containing geoid offsets
        """

        # check to make sure you haven't already run this function (and are thus applying the offset twice)
        try:
            values = self._xrds.attrs['offset_names']
            assert 'geoid_offset' not in values, "You've already applied the geoid offset!"
            values = list([values]) + ['geoid_offset']
        except KeyError:
            values = ['geoid_offset']

        self._validate(self, req_dim, req_vars)

        # read in your geoid values
        # WARNING: this implementation assumes your geoid values are in the same CRS and grid as the data you are applying
        # them to. If not, you will need to reproject and/or resample them to match the data to which you are applying them.
        # That step is not included here to emphasize the accessor aspect of the workflow.
        with rasterio.open(source) as src:
            geoid = src['geoid_varname']

        # As noted above, this step will fail or produce unreliable results if your data is not properly gridded
        self._xrds['elevation'] = self._xrds.elevation - geoid

        self._xrds.attrs['offset_names'] = values

        return self._xrds

Now, each time we want to convert our ellipsoid data to the geoid, we only have to run one line of code, and it will also perform a multitude of checks for us to make sure we’re performing exactly the operation we expect. Imagine the possibilities (and decrease in frustration)!

ds = ds.geoidxr.to_geoid(source='/Path/to/Custom/source/file.nc')
ds
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File ~/micromamba/envs/xarray-tutorial/lib/python3.11/site-packages/xarray/core/extensions.py:38, in _CachedAccessor.__get__(self, obj, cls)
     37 try:
---> 38     accessor_obj = self._accessor(obj)
     39 except AttributeError:
     40     # __getattr__ on data object will swallow any AttributeErrors
     41     # raised when initializing the accessor, so we need to raise as
     42     # something else (GH933):

Cell In[7], line 22, in GeoidXR.__init__(self, xrds)
     19 # Running this function on init will check that my dataset has all the needed dimensions and variables
     20 # as specific to my workflow, saving time and headache later if they were missing and the computation fails
     21 # partway through.
---> 22 self._validate(
     23     self, req_dim=['x', 'y', 'dtime'], req_vars={'elevation': ['x', 'y', 'dtime']}
     24 )

Cell In[7], line 47, in GeoidXR._validate(self, req_dim, req_vars)
     46 if all([var not in self._xrds.variables for var in req_vars.keys()]):
---> 47     raise AttributeError("Required variables are missing")

AttributeError: Required variables are missing

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[8], line 1
----> 1 ds = ds.geoidxr.to_geoid(source='/Path/to/Custom/source/file.nc')
      2 ds

File ~/micromamba/envs/xarray-tutorial/lib/python3.11/site-packages/xarray/core/extensions.py:43, in _CachedAccessor.__get__(self, obj, cls)
     38     accessor_obj = self._accessor(obj)
     39 except AttributeError:
     40     # __getattr__ on data object will swallow any AttributeErrors
     41     # raised when initializing the accessor, so we need to raise as
     42     # something else (GH933):
---> 43     raise RuntimeError(f"error initializing {self._name!r} accessor.")
     45 cache[self._name] = accessor_obj
     46 return accessor_obj

RuntimeError: error initializing 'geoidxr' accessor.