Creating custom accessors#
Introduction#
An accessor is a way of attaching a custom function to xarray objects so that it can be called as if it were a method while retaining a clear separation between the “core” xarray API and custom API. It enables you to easily extend (which is why you’ll sometimes see it referred to as an extension) and customize xarray’s functionality while limiting naming conflicts and minimizing the chances of your code breaking with xarray upgrades.
If you’ve used rioxarray (e.g. da.rio.crs
) or hvplot (e.g. ds.hvplot()
), you may have already used an xarray accessor without knowing it!
The Xarray documentation has some more technical details, and this tutorial provides example custom accessors and their uses.
Why create a custom accessor#
You can easily create a custom suite of tools that work on Xarray objects
It keeps your workflows cleaner and simpler
Your project-specific code is easy to share
It’s easy to implement: you don’t need to integrate any code into Xarray
It makes it easier to perform checks and write code documentation because you only have to create them once!
Easy steps to create your own accessor#
Create your custom class, including the mandatory
__init__
methodAdd the
xr.register_dataarray_accessor()
orxr.register_dataset_accessor()
Use your custom functions
Example 1: accessing scipy functionality#
For example, imagine you’re a statistician who regularly uses a special skewness
function which acts on dataarrays but is only of interest to people in your specific field.
You can create a method which applies this skewness function to an xarray object and then register the method under a custom stats
accessor like this:
import xarray as xr
from scipy.stats import skew
xr.set_options(display_expand_attrs=False, display_expand_coords=False)
@xr.register_dataarray_accessor("stats")
class StatsAccessor:
def __init__(self, da):
self._da = da
def skewness(self, dim):
return self._da.reduce(func=skew, dim=dim)
Now we can conveniently access this functionality via the stats
accessor
ds = xr.tutorial.load_dataset("air_temperature")
ds["skewair"] = ds['air'].stats.skewness(dim="time")
ds
<xarray.Dataset> Size: 31MB Dimensions: (lat: 25, time: 2920, lon: 53) Coordinates: (3) Data variables: air (time, lat, lon) float64 31MB 241.2 242.5 243.5 ... 296.2 295.7 skewair (lat, lon) float64 11kB -0.2934 -0.2828 -0.272 ... -0.19 -0.187 Attributes: (5)
Notice how the presence of .stats
clearly differentiates our new “accessor method” from core xarray methods.
Example 2: creating your own workflows#
Perhaps you find yourself running similar code for multiple xarray objects or across related projects. By packing your code into an extension, it makes it easy to repeat the same operation while reducing the likelihood of [human introduced] errors.
Here we wrap the reorganization of InSAR ice velocity data illustrated in this tutorial into a custom Xarray extension that makes it easy to re-apply each time you begin working with a new InSAR velocity dataset. Please see the linked tutorial for details on the data, applications, and each step in this process.
import numpy as np
import os
import pandas as pd
import xarray as xr
@xr.register_dataset_accessor("insar_vel")
class InsarReorg:
"""
An extension for an XArray dataset that will prepare InSAR data for analysis.
Re-organize the data from its native structure to have x and y velocity and error along a time dimension.
"""
# ----------------------------------------------------------------------
# Constructors
def __init__(self, xrds):
self._xrds = xrds
# ----------------------------------------------------------------------
# Methods
@staticmethod
def _validate(self, req_dim=None, req_vars=None):
'''
Make sure the xarray dataset has the correct dimensions and variables.
Running this function will check that my dataset has all the needed dimensions and variables
for a given function, saving time and headache later if they were missing and the computation fails
partway through.
Parameters
----------
req_dim : list of str
List of all required dimension names
req_vars : list of str
List of all required variable names
'''
if req_dim is not None:
if all([dim not in list(self._xrds.dims) for dim in req_dim]):
raise AttributeError("Required dimensions are missing")
if req_vars is not None:
if all([var not in self._xrds.variables for var in req_vars.keys()]):
raise AttributeError("Required variables are missing")
# print("successfully validated your dataset")
# ----------------------------------------------------------------------
# Functions
def change_vars_to_coords(
self,
req_dim=['ny', 'nx'],
req_vars={'xaxis': ['nx'], 'yaxis': ['ny']},
):
"""
Turn the xaxis and y axis variables into coordinates.
Parameters
----------
req_dim : list of str
List of all required dimension names.
req_vars : list of str
List of all required variable names
"""
self._validate(self, req_dim, req_vars)
self._xrds = self._xrds.swap_dims({'ny': 'yaxis', 'nx': 'xaxis'})
self._xrds = self._xrds.rename({'xaxis': 'x', 'yaxis': 'y'})
return self._xrds
def reorg_dataset(self):
"""
Reorganize the data by time for each of the desired end variables (here vx, vy, err)
"""
reorged = []
for reorg_var in ['vx', 'vy', 'err']:
ds = self.reorg_var_time(reorg_var)
reorged.append(ds)
reorged_ds = xr.merge(reorged)
return reorged_ds
def reorg_var_time(self, reorg_var):
"""
Repeat the process for a given variable.
Figure out which of the original variables are time steps for this variable and turn each one into a dataarray.
Add a time dimension and update the variable name for each dataarray.
Combine the modified data arrays back into a single dataset.
"""
# create storage list for reorganizing
var_ls = list(self._xrds)
to_reorg = [var for var in var_ls if reorg_var in var]
# list the arrays from the original dataset that correspond to the variable
das_to_reorg = [self._xrds[var] for var in to_reorg]
# add the time dimension
das_to_reorg = [das_to_reorg[var].expand_dims('time') for var in range(len(das_to_reorg))]
# update variable name to remove time
das_to_reorg = [das_to_reorg[var].rename(reorg_var) for var in range(len(das_to_reorg))]
ds = xr.concat(das_to_reorg, dim='time')
return ds
ds = xr.tutorial.open_dataset('ASE_ice_velocity.nc')
ds = ds.insar_vel.change_vars_to_coords()
ds
<xarray.Dataset> Size: 48MB Dimensions: (y: 800, x: 500) Coordinates: (2) Data variables: (12/30) vx1996 (y, x) float32 2MB ... vy1996 (y, x) float32 2MB ... err1996 (y, x) float32 2MB ... vx2000 (y, x) float32 2MB ... vy2000 (y, x) float32 2MB ... err2000 (y, x) float32 2MB ... ... ... vx2011 (y, x) float32 2MB ... vy2011 (y, x) float32 2MB ... err2011 (y, x) float32 2MB ... vx2012 (y, x) float32 2MB ... vy2012 (y, x) float32 2MB ... err2012 (y, x) float32 2MB ... Attributes: (21)
ds = ds.insar_vel.reorg_dataset()
ds
<xarray.Dataset> Size: 48MB Dimensions: (x: 500, y: 800, time: 10) Coordinates: (2) Dimensions without coordinates: time Data variables: vx (time, y, x) float32 16MB nan nan nan nan nan ... nan nan nan nan vy (time, y, x) float32 16MB nan nan nan nan nan ... nan nan nan nan err (time, y, x) float32 16MB nan nan nan nan nan ... nan nan nan nan Attributes: (2)
Example 3: creating your own workflows with locally stored corrections#
Consider someone who frequently converts their elevations to be relative to the geoid (rather than the ellipsoid) using a custom, local conversion (otherwise, we’d recommend using an established conversion library like pyproj to switch between datums).
An accessor provides an elegant way to build (once) and apply (as often as needed!) this custom conversion on top of the existing xarray ecosystem without the need to copy-paste the code into the start of each project. By standardizing our approach and adding a few sanity checks within the accessor, we also eliminate the risk of accidentally applying the correction multiple times.
import rasterio
import xarray as xr
@xr.register_dataset_accessor("geoidxr")
class GeoidXR:
"""
An extension for an XArray dataset that will calculate geoidal elevations from a local source file.
"""
# ----------------------------------------------------------------------
# Constructors
def __init__(
self,
xrds,
):
self._xrds = xrds
# Running this function on init will check that my dataset has all the needed dimensions and variables
# as specific to my workflow, saving time and headache later if they were missing and the computation fails
# partway through.
self._validate(
self, req_dim=['x', 'y', 'dtime'], req_vars={'elevation': ['x', 'y', 'dtime']}
)
# ----------------------------------------------------------------------
# Methods
@staticmethod
def _validate(self, req_dim=None, req_vars=None):
'''
Make sure the xarray dataset has the correct dimensions and variables
Parameters
----------
req_dim : list of str
List of all required dimension names
req_vars : list of str
List of all required variable names
'''
if req_dim is not None:
if all([dim not in list(self._xrds.dims) for dim in req_dim]):
raise AttributeError("Required dimensions are missing")
if req_vars is not None:
if all([var not in self._xrds.variables for var in req_vars.keys()]):
raise AttributeError("Required variables are missing")
# Notice that 'geoid' has been added to the req_vars list
def to_geoid(
self,
req_dim=['dtime', 'x', 'y'],
req_vars={'elevation': ['x', 'y', 'dtime', 'geoid']},
source=None,
):
"""
Get geoid layer from your local file, which is provided to the function as "source",
and apply the offset to all elevation values.
Adds 'geoid_offset' keyword to "offsets" attribute so you know the geoid offset was applied.
Parameters
----------
req_dim : list of str
List of all required dimension names.
req_vars : list of str
List of all required variable names
source : str
Full path to your source file containing geoid offsets
"""
# check to make sure you haven't already run this function (and are thus applying the offset twice)
try:
values = self._xrds.attrs['offset_names']
assert 'geoid_offset' not in values, "You've already applied the geoid offset!"
values = list([values]) + ['geoid_offset']
except KeyError:
values = ['geoid_offset']
self._validate(self, req_dim, req_vars)
# read in your geoid values
# WARNING: this implementation assumes your geoid values are in the same CRS and grid as the data you are applying
# them to. If not, you will need to reproject and/or resample them to match the data to which you are applying them.
# That step is not included here to emphasize the accessor aspect of the workflow.
with rasterio.open(source) as src:
geoid = src['geoid_varname']
# As noted above, this step will fail or produce unreliable results if your data is not properly gridded
self._xrds['elevation'] = self._xrds.elevation - geoid
self._xrds.attrs['offset_names'] = values
return self._xrds
Now, each time we want to convert our ellipsoid data to the geoid, we only have to run one line of code, and it will also perform a multitude of checks for us to make sure we’re performing exactly the operation we expect. Imagine the possibilities (and decrease in frustration)!
ds = ds.geoidxr.to_geoid(source='/Path/to/Custom/source/file.nc')
ds
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
File ~/micromamba/envs/xarray-tutorial/lib/python3.12/site-packages/xarray/core/extensions.py:39, in _CachedAccessor.__get__(self, obj, cls)
38 try:
---> 39 accessor_obj = self._accessor(obj)
40 except AttributeError:
41 # __getattr__ on data object will swallow any AttributeErrors
42 # raised when initializing the accessor, so we need to raise as
43 # something else (GH933):
Cell In[7], line 22, in GeoidXR.__init__(self, xrds)
19 # Running this function on init will check that my dataset has all the needed dimensions and variables
20 # as specific to my workflow, saving time and headache later if they were missing and the computation fails
21 # partway through.
---> 22 self._validate(
23 self, req_dim=['x', 'y', 'dtime'], req_vars={'elevation': ['x', 'y', 'dtime']}
24 )
Cell In[7], line 47, in GeoidXR._validate(self, req_dim, req_vars)
46 if all([var not in self._xrds.variables for var in req_vars.keys()]):
---> 47 raise AttributeError("Required variables are missing")
AttributeError: Required variables are missing
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
Cell In[8], line 1
----> 1 ds = ds.geoidxr.to_geoid(source='/Path/to/Custom/source/file.nc')
2 ds
File ~/micromamba/envs/xarray-tutorial/lib/python3.12/site-packages/xarray/core/extensions.py:44, in _CachedAccessor.__get__(self, obj, cls)
39 accessor_obj = self._accessor(obj)
40 except AttributeError:
41 # __getattr__ on data object will swallow any AttributeErrors
42 # raised when initializing the accessor, so we need to raise as
43 # something else (GH933):
---> 44 raise RuntimeError(f"error initializing {self._name!r} accessor.")
46 cache[self._name] = accessor_obj
47 return accessor_obj
RuntimeError: error initializing 'geoidxr' accessor.