A gentle introduction#
Many, but not all, useful array methods are wrapped by Xarray and accessible
as methods on Xarray objects. For example DataArray.mean
calls numpy.nanmean
.
A very common use-case is to apply functions that expect and return NumPy
(or other array types) on Xarray objects. For example, this would include all of SciPy’s API.
Applying many of these functions to Xarray object involves a series of repeated steps.
apply_ufunc
provides a convenient wrapper function that generalizes the steps
involved in applying such functions to Xarray objects.
Tip
Xarray uses apply_ufunc
internally to implement much of its API, meaning that it is quite powerful!
Our goals are to learn that apply_ufunc
automates aspects of applying computation functions that are designed for pure arrays (like numpy arrays) on xarray objects including
Propagating dimension names, coordinate variables, and (optionally) attributes.
Handle Dataset input by looping over data variables.
Allow passing arbitrary positional and keyword arguments
Tip
We’ll reduce the length of error messages using %xmode minimal
See the ipython documentation for details.
Setup#
%xmode minimal
import numpy as np
import xarray as xr
# limit the amount of information printed to screen
xr.set_options(display_expand_data=False)
np.set_printoptions(threshold=10, edgeitems=2)
Exception reporting mode: Minimal
Let’s load a dataset
ds = xr.tutorial.load_dataset("air_temperature")
ds
<xarray.Dataset> Size: 31MB Dimensions: (lat: 25, time: 2920, lon: 53) Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float64 31MB 241.2 242.5 243.5 ... 296.2 295.7 Attributes: Conventions: COARDS title: 4x daily NMC reanalysis (1948) description: Data is from NMC initialized reanalysis\n(4x/day). These a... platform: Model references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
A simple example: pure numpy#
Simple functions that act independently on each value should work without any additional arguments.
Consider the following squared_error
function
def squared_error(x, y):
return (x - y) ** 2
Tip
This function uses only arithmetic operations. For such simple functions, you can pass Xarray objects directly and receive Xarray objects back. Try
squared_error(ds.air, 1)
We use it here as a very simple example
We can apply squared_error
manually by extracting the underlying numpy array
numpy_result = squared_error(ds.air.data, 1)
numpy_result
array([[[57696.04 , 58322.25 , ..., 54990.25 , 56453.76 ],
[58951.84 , 59292.25 , ..., 54896.49 , 56786.89 ],
...,
[86966.01 , 87143.04 , ..., 86966.01 , 86553.64 ],
[87196.1841, 87491.7241, ..., 87491.7241, 87379.36 ]],
[[58129.21 , 58418.89 , ..., 54102.76 , 55131.04 ],
[58854.76 , 59097.61 , ..., 53592.25 , 55084.09 ],
...,
[87143.04 , 87438.49 , ..., 86730.25 , 86494.81 ],
[87196.1841, 87734.44 , ..., 87261.16 , 87379.36 ]],
...,
[[59922.1441, 59433.5641, ..., 59044.1401, 59433.5641],
[61946.2321, 61647.9241, ..., 58317.4201, 59190.0241],
...,
[87196.1841, 87728.5161, ..., 86488.9281, 86077.6921],
[88084.3041, 88440.8121, ..., 86724.3601, 86547.7561]],
[[59579.9281, 59190.0241, ..., 57835.4401, 57979.8241],
[61946.2321, 61647.9241, ..., 57259.7041, 57931.6761],
...,
[87078.1081, 87550.8921, ..., 86842.1961, 86547.7561],
[88024.9561, 88262.4681, ..., 87137.1361, 86842.1961]]])
To convert this result to a DataArray, we could do it manually
xr.DataArray(
data=numpy_result,
# propagate all the Xarray metadata manually
dims=ds.air.dims,
coords=ds.air.coords,
attrs=ds.air.attrs,
name=ds.air.name,
)
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB 5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04 Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 Attributes: long_name: 4xDaily Air temperature at sigma level 995 units: degK precision: 2 GRIB_id: 11 GRIB_name: TMP var_desc: Air temperature dataset: NMC Reanalysis level_desc: Surface statistic: Individual Obs parent_stat: Other actual_range: [185.16 322.1 ]
A shorter version uses DataArray.copy
ds.air.copy(data=numpy_result)
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB 5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04 Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 Attributes: long_name: 4xDaily Air temperature at sigma level 995 units: degK precision: 2 GRIB_id: 11 GRIB_name: TMP var_desc: Air temperature dataset: NMC Reanalysis level_desc: Surface statistic: Individual Obs parent_stat: Other actual_range: [185.16 322.1 ]
Caution
Using DataArray.copy
works for such simple cases but doesn’t generalize that well.
For example, consider a function that removed one dimension and added a new dimension.
apply_ufunc#
apply_ufunc
can handle more complicated functions. Here’s how to use it with squared_error
xr.apply_ufunc(squared_error, ds.air, 1)
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB 5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04 Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
How does apply_ufunc work?#
This line
xr.apply_ufunc(squared_error, ds.air, 1)
is equivalent to squared_error(ds.air.data, 1)
with automatic propagation of xarray metadata like dimension names, coordinate values etc.
To illustrate how apply_ufunc
works, let us write a small wrapper function. This will let us examine what data is received and returned from the applied function.
Tip
This trick is very useful for debugging
def wrapper(x, y):
print(f"received x of type {type(x)}, shape {x.shape}")
print(f"received y of type {type(y)}")
return squared_error(x, y)
xr.apply_ufunc(wrapper, ds.air, 1)
received x of type <class 'numpy.ndarray'>, shape (2920, 25, 53)
received y of type <class 'int'>
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB 5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04 Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
We see that wrapper
receives the underlying numpy array (ds.air.data
), and the integer 1
.
Essentially, apply_ufunc
does the following:
extracts the underlying array data (
.data
),passes it to the user function,
receives the returned values, and
then wraps that back up as a DataArray
Tip
apply_ufunc
always takes in at least one DataArray or Dataset and returns one DataArray or Dataset
Handling attributes#
By default, attributes are omitted since they may now be inaccurate
result = xr.apply_ufunc(wrapper, ds.air, 1)
result.attrs
received x of type <class 'numpy.ndarray'>, shape (2920, 25, 53)
received y of type <class 'int'>
{}
To propagate attributes, pass keep_attrs=True
result = xr.apply_ufunc(wrapper, ds.air, 1, keep_attrs=True)
result.attrs
received x of type <class 'numpy.ndarray'>, shape (2920, 25, 53)
received y of type <class 'int'>
{'long_name': '4xDaily Air temperature at sigma level 995',
'units': 'degK',
'precision': 2,
'GRIB_id': 11,
'GRIB_name': 'TMP',
'var_desc': 'Air temperature',
'dataset': 'NMC Reanalysis',
'level_desc': 'Surface',
'statistic': 'Individual Obs',
'parent_stat': 'Other',
'actual_range': array([185.16, 322.1 ], dtype=float32)}
Handling datasets#
apply_ufunc
easily handles both DataArrays and Datasets.
When passed a Dataset, apply_ufunc
will loop over the data variables and sequentially pass those to squared_error
.
So squared_error
always receives a single numpy array.
To illustrate that lets create a new Dataset
with two arrays. We’ll create a new array air2
that is 2D time, lat
.
ds2 = ds.copy()
ds2["air2"] = ds2.air.isel(lon=0) ** 2
We see that wrapper
is called twice
xr.apply_ufunc(wrapper, ds2, 1)
received x of type <class 'numpy.ndarray'>, shape (2920, 25, 53)
received y of type <class 'int'>
received x of type <class 'numpy.ndarray'>, shape (2920, 25)
received y of type <class 'int'>
<xarray.Dataset> Size: 32MB Dimensions: (time: 2920, lat: 25, lon: 53) Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float64 31MB 5.77e+04 5.832e+04 ... 8.684e+04 air2 (time, lat) float64 584kB 3.384e+09 3.533e+09 ... 7.853e+09
xr.apply_ufunc(squared_error, ds2, 1)
<xarray.Dataset> Size: 32MB Dimensions: (time: 2920, lat: 25, lon: 53) Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float64 31MB 5.77e+04 5.832e+04 ... 8.684e+04 air2 (time, lat) float64 584kB 3.384e+09 3.533e+09 ... 7.853e+09
Passing positional and keyword arguments#
See also
See the Python tutorial on defining functions for more on positional and keyword arguments.
squared_error
takes two arguments named x
and y
.
In xr.apply_ufunc(squared_error, ds.air, 1)
, the value of 1
for y
was passed positionally.
to use the keyword argument form, pass it using the kwargs
keyword argument to apply_ufunc
kwargs (dict, optional) – Optional keyword arguments passed directly on to call func.
xr.apply_ufunc(squared_error, ds.air, kwargs={"y": 1})
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB 5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04 Coordinates: * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0 * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0 * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00