https://docs.xarray.dev/en/stable/_static/dataset-diagram-logo.png

A gentle introduction#

apply_ufunc is a more advanced wrapper that is designed to apply functions that expect and return NumPy (or other arrays). For example, this would include all of SciPy’s API. Since apply_ufunc operates on lower-level NumPy or Dask objects, it skips the overhead of using Xarray objects making it a good choice for performance-critical functions. Xarray uses apply_ufunc internally to implement much of its API, meaning that it is quite powerful!

Learning goals:

  • Learn that apply_ufunc automates aspects of applying computation functions that are designed for pure arrays (like numpy arrays) on xarray objects

Setup#

import numpy as np
import xarray as xr

xr.set_options(display_expand_data=False)
<xarray.core.options.set_options at 0x7f2ebbee1b70>

Let’s load a dataset

ds = xr.tutorial.load_dataset("air_temperature")
ds
<xarray.Dataset>
Dimensions:  (lat: 25, time: 2920, lon: 53)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat, lon) float32 241.2 242.5 243.5 ... 296.5 296.2 295.7
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...

A simple example#

Simple functions that act independently on each value should work without any additional arguments.

Consider the following squared_error function

def squared_error(x, y):
    return (x - y) ** 2

We can apply this manually by extracting the underlying numpy array

numpy_result = squared_error(ds.air.data, 1)
numpy_result
array([[[57696.04 , 58322.25 , 58806.25 , ..., 53731.234, 54990.25 ,
         56453.754],
        [58951.836, 59292.25 , 59389.688, ..., 53731.234, 54896.484,
         56786.883],
        [62001.   , 61901.434, 61449.453, ..., 53916.84 , 55408.453,
         57936.49 ],
        ...,
        [87379.37 , 87143.03 , 87261.16 , ..., 86671.36 , 86494.81 ,
         86259.68 ],
        [86966.01 , 87143.03 , 87491.73 , ..., 86966.01 , 86966.01 ,
         86553.63 ],
        [87196.19 , 87491.73 , 87675.21 , ..., 87556.805, 87491.73 ,
         87379.37 ]],

       [[58129.207, 58418.887, 58612.406, ..., 53361.   , 54102.754,
         55131.035],
        [58854.754, 59097.605, 59146.24 , ..., 52900.   , 53592.25 ,
         55084.09 ],
        [63604.84 , 63448.57 , 63051.207, ..., 52808.035, 54005.113,
         56406.25 ],
        ...,
        [87261.16 , 86966.01 , 87143.03 , ..., 86671.36 , 86494.81 ,
         86312.57 ],
        [87143.03 , 87438.48 , 87491.73 , ..., 86789.164, 86730.25 ,
         86494.81 ],
        [87196.19 , 87734.43 , 87852.95 , ..., 87261.16 , 87261.16 ,
         87379.37 ]],

       [[58225.684, 58177.438, 58225.684, ..., 54428.883, 55272.004,
         56501.29 ],
        [59340.957, 59238.69 , 59049.   , ..., 52578.484, 53361.   ,
         55084.09 ],
        [65127.03 , 64770.25 , 64110.24 , ..., 52992.04 , 53916.84 ,
         56263.84 ],
        ...,
        [86789.164, 86671.36 , 86671.36 , ..., 87196.19 , 86606.61 ,
         86436.   ],
        [87143.03 , 87320.25 , 87196.19 , ..., 87261.16 , 87025.   ,
         86789.164],
        [87261.16 , 87196.19 , 87261.16 , ..., 87616.   , 87616.   ,
         87491.73 ]],

       ...,

       [[58801.395, 58559.156, 58124.387, ..., 59141.37 , 59287.375,
         59482.332],
        [61548.645, 61499.035, 61300.805, ..., 57403.367, 57739.28 ,
         58414.05 ],
        [68481.66 , 68220.22 , 67959.28 , ..., 56829.793, 57931.67 ,
         59628.75 ],
        ...,
        [86312.57 , 86606.61 , 87906.31 , ..., 86724.35 , 86665.46 ,
         86253.82 ],
        [87491.73 , 88143.664, 88381.35 , ..., 86724.35 , 86724.35 ,
         86312.57 ],
        [88321.9  , 88917.28 , 88678.89 , ..., 87078.11 , 86901.15 ,
         86901.15 ]],

       [[59922.14 , 59433.562, 58801.395, ..., 58704.44 , 59044.137,
         59433.562],
        [61946.23 , 61647.92 , 61251.297, ..., 57739.28 , 58317.414,
         59190.02 ],
        [68324.73 , 68011.43 , 67750.89 , ..., 57355.457, 58607.566,
         60461.89 ],
        ...,
        [85667.44 , 85784.54 , 86665.46 , ..., 86488.93 , 86253.82 ,
         86019.03 ],
        [87196.19 , 87728.516, 87965.625, ..., 86606.61 , 86488.93 ,
         86077.68 ],
        [88084.31 , 88440.805, 88500.3  , ..., 86842.195, 86724.35 ,
         86547.76 ]],

       [[59579.926, 59190.02 , 58704.44 , ..., 57931.67 , 57835.434,
         57979.82 ],
        [61946.23 , 61647.92 , 61201.812, ..., 56925.188, 57259.7  ,
         57931.67 ],
        [68638.76 , 68220.22 , 67802.945, ..., 57068.434, 58365.727,
         60167.18 ],
        ...,
        [85725.99 , 85667.44 , 86488.93 , ..., 86606.61 , 86488.93 ,
         86253.82 ],
        [87078.11 , 87550.88 , 87728.516, ..., 86842.195, 86842.195,
         86547.76 ],
        [88024.96 , 88262.47 , 88262.47 , ..., 87314.336, 87137.14 ,
         86842.195]]], dtype=float32)

To convert this result to a DataArray, we could do it manually

xr.DataArray(data=numpy_result, dims=ds.air.dims, coords=ds.air.coords)
<xarray.DataArray (time: 2920, lat: 25, lon: 53)>
5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00

A shorter version uses DataArray.copy

ds.air.copy(data=numpy_result)
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>
5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16 322.1 ]

Using DataArray.copy works for such simple cases but doesn’t generalize that well. For example, consider a function that removed one dimension and added a new dimension.

apply_ufunc can handle such cases. Here’s how to use it with squared_error

xr.apply_ufunc(squared_error, ds.air, 1)
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>
5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00

How does apply_ufunc work?#

To illustrate how apply_ufunc works, let us write a small wrapper function. This will let us examine what data is received and returned from the applied function.

Tip

This trick is very useful for debugging

def wrapper(x, y):
    print(f"received x of type {type(x)}, shape {x.shape}")
    print(f"received y of type {type(y)}")
    return squared_error(x, y)


xr.apply_ufunc(wrapper, ds.air, 1)
received x of type <class 'numpy.ndarray'>, shape (2920, 25, 53)
received y of type <class 'int'>
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>
5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00

We see that wrapper receives the underlying numpy array (ds.air.data), and the integer 1.

Essentially, apply_ufunc does the following:

  1. extracts the underlying array data,

  2. passes it to the user function,

  3. receives the returned values, and

  4. then wraps that back up as an array

apply_ufunc easily handles both dataarrays and datasets.

When passed a Dataset, apply-ufunc will loop over the data variables and sequentially pass those to squared_error. So squared_error always receives a numpy array

xr.apply_ufunc(wrapper, ds, 1)
received x of type <class 'numpy.ndarray'>, shape (2920, 25, 53)
received y of type <class 'int'>
<xarray.Dataset>
Dimensions:  (time: 2920, lat: 25, lon: 53)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat, lon) float32 5.77e+04 5.832e+04 ... 8.714e+04 8.684e+04
xr.apply_ufunc(squared_error, ds, 1)
<xarray.Dataset>
Dimensions:  (time: 2920, lat: 25, lon: 53)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
    air      (time, lat, lon) float32 5.77e+04 5.832e+04 ... 8.714e+04 8.684e+04

Reductions and core dimensions#

squared_error operated on a per-element basis. How about a reduction like np.mean?

Such functions involve the concept of “core dimensions”. One way to think about core dimensions is to consider the smallest dimensionality of data necessary to apply the function.

For using more complex operations that consider some array values collectively, it’s important to understand the idea of core dimensions. Usually, they correspond to the fundamental dimensions over which an operation is defined, e.g., the summed axis in np.sum. A good clue that core dimensions are needed is the presence of an axis argument on the corresponding NumPy function.

Let’s write a function that computes the mean along time for a provided xarray object. This function requires one core dimension time. For ds.air note that time is the 0th axis.

ds.air.dims
('time', 'lat', 'lon')
np.mean(ds.air, axis=ds.air.get_axis_num("time"))
<xarray.DataArray 'air' (lat: 25, lon: 53)>
260.4 260.2 259.9 259.5 259.0 258.6 ... 298.0 297.9 297.8 297.3 297.3 297.3
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
np.mean(ds.air.data, axis=0)
array([[260.37564, 260.1826 , 259.88593, ..., 250.81511, 251.93733,
        253.43741],
       [262.7337 , 262.7936 , 262.7489 , ..., 249.75496, 251.5852 ,
        254.35849],
       [264.7681 , 264.3271 , 264.0614 , ..., 250.60707, 253.58247,
        257.71475],
       ...,
       [297.64932, 296.95294, 296.62912, ..., 296.81033, 296.28793,
        295.81622],
       [298.1287 , 297.93646, 297.47006, ..., 296.8591 , 296.77686,
        296.44348],
       [298.36594, 298.38593, 298.11386, ..., 297.33777, 297.28104,
        297.30502]], dtype=float32)

Let’s try to use apply_ufunc to replicate np.mean(ds.air.data, axis=0)

xr.apply_ufunc(
    # function to apply
    np.mean,
    # object with data to pass to function
    ds,
    # keyword arguments to pass to np.mean
    kwargs={"axis": 0},
)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [14], in <cell line: 1>()
----> 1 xr.apply_ufunc(
      2     # function to apply
      3     np.mean,
      4     # object with data to pass to function
      5     ds,
      6     # keyword arguments to pass to np.mean
      7     kwargs={"axis": 0},
      8 )

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:1147, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
   1145 # feed datasets apply_variable_ufunc through apply_dataset_vfunc
   1146 elif any(is_dict_like(a) for a in args):
-> 1147     return apply_dataset_vfunc(
   1148         variables_vfunc,
   1149         *args,
   1150         signature=signature,
   1151         join=join,
   1152         exclude_dims=exclude_dims,
   1153         dataset_join=dataset_join,
   1154         fill_value=dataset_fill_value,
   1155         keep_attrs=keep_attrs,
   1156     )
   1157 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
   1158 elif any(isinstance(a, DataArray) for a in args):

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:441, in apply_dataset_vfunc(func, signature, join, dataset_join, fill_value, exclude_dims, keep_attrs, *args)
    436 list_of_coords = build_output_coords(
    437     args, signature, exclude_dims, combine_attrs=keep_attrs
    438 )
    439 args = [getattr(arg, "data_vars", arg) for arg in args]
--> 441 result_vars = apply_dict_of_variables_vfunc(
    442     func, *args, signature=signature, join=dataset_join, fill_value=fill_value
    443 )
    445 if signature.num_outputs > 1:
    446     out = tuple(_fast_dataset(*args) for args in zip(result_vars, list_of_coords))

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:385, in apply_dict_of_variables_vfunc(func, signature, join, fill_value, *args)
    383 result_vars = {}
    384 for name, variable_args in zip(names, grouped_by_name):
--> 385     result_vars[name] = func(*variable_args)
    387 if signature.num_outputs > 1:
    388     return _unpack_dict_tuples(result_vars, signature.num_outputs)

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:752, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    750 data = as_compatible_data(data)
    751 if data.ndim != len(dims):
--> 752     raise ValueError(
    753         "applied function returned data with unexpected "
    754         f"number of dimensions. Received {data.ndim} dimension(s) but "
    755         f"expected {len(dims)} dimensions with names: {dims!r}"
    756     )
    758 var = Variable(dims, data, fastpath=True)
    759 for dim, new_size in var.sizes.items():

ValueError: applied function returned data with unexpected number of dimensions. Received 2 dimension(s) but expected 3 dimensions with names: ('time', 'lat', 'lon')

The error here

applied function returned data with unexpected number of dimensions. Received 2 dimension(s) but expected 3 dimensions with names: (‘time’, ‘lat’, ‘lon’)

means that while np.mean did indeed reduce one dimension, we did not tell apply_ufunc that this would happen. That is, we need to specify the core dimensions on the input.

xr.apply_ufunc(
    np.mean,
    ds,
    # specify core dimensions as a list of lists
    # here 'time' is the core dimension on `ds`
    input_core_dims=[["time"]],
    kwargs={"axis": 0},
)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [15], in <cell line: 1>()
----> 1 xr.apply_ufunc(
      2     np.mean,
      3     ds,
      4     # specify core dimensions as a list of lists
      5     # here 'time' is the core dimension on `ds`
      6     input_core_dims=[["time"]],
      7     kwargs={"axis": 0},
      8 )

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:1147, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
   1145 # feed datasets apply_variable_ufunc through apply_dataset_vfunc
   1146 elif any(is_dict_like(a) for a in args):
-> 1147     return apply_dataset_vfunc(
   1148         variables_vfunc,
   1149         *args,
   1150         signature=signature,
   1151         join=join,
   1152         exclude_dims=exclude_dims,
   1153         dataset_join=dataset_join,
   1154         fill_value=dataset_fill_value,
   1155         keep_attrs=keep_attrs,
   1156     )
   1157 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
   1158 elif any(isinstance(a, DataArray) for a in args):

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:441, in apply_dataset_vfunc(func, signature, join, dataset_join, fill_value, exclude_dims, keep_attrs, *args)
    436 list_of_coords = build_output_coords(
    437     args, signature, exclude_dims, combine_attrs=keep_attrs
    438 )
    439 args = [getattr(arg, "data_vars", arg) for arg in args]
--> 441 result_vars = apply_dict_of_variables_vfunc(
    442     func, *args, signature=signature, join=dataset_join, fill_value=fill_value
    443 )
    445 if signature.num_outputs > 1:
    446     out = tuple(_fast_dataset(*args) for args in zip(result_vars, list_of_coords))

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:385, in apply_dict_of_variables_vfunc(func, signature, join, fill_value, *args)
    383 result_vars = {}
    384 for name, variable_args in zip(names, grouped_by_name):
--> 385     result_vars[name] = func(*variable_args)
    387 if signature.num_outputs > 1:
    388     return _unpack_dict_tuples(result_vars, signature.num_outputs)

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:761, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    759 for dim, new_size in var.sizes.items():
    760     if dim in dim_sizes and new_size != dim_sizes[dim]:
--> 761         raise ValueError(
    762             "size of dimension {!r} on inputs was unexpectedly "
    763             "changed by applied function from {} to {}. Only "
    764             "dimensions specified in ``exclude_dims`` with "
    765             "xarray.apply_ufunc are allowed to change size.".format(
    766                 dim, dim_sizes[dim], new_size
    767             )
    768         )
    770 var.attrs = attrs
    771 output.append(var)

ValueError: size of dimension 'lat' on inputs was unexpectedly changed by applied function from 25 to 53. Only dimensions specified in ``exclude_dims`` with xarray.apply_ufunc are allowed to change size.

This next error is a little confusing.

size of dimension ‘lat’ on inputs was unexpectedly changed by applied function from 25 to 53. Only dimensions specified in exclude_dims with xarray.apply_ufunc are allowed to change size.

A good trick here is to pass a little wrapper function to apply_ufunc instead and inspect the shapes of data received by the wrapper.

def wrapper(array, **kwargs):
    print(f"received {type(array)} shape: {array.shape}, kwargs: {kwargs}")
    result = np.mean(array, **kwargs)
    print(f"result.shape: {result.shape}")
    return result


xr.apply_ufunc(
    wrapper,
    ds,
    # specify core dimensions as a list of lists
    # here 'time' is the core dimension on `ds`
    input_core_dims=[["time"]],
    kwargs={"axis": 0},
)
received <class 'numpy.ndarray'> shape: (25, 53, 2920), kwargs: {'axis': 0}
result.shape: (53, 2920)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [16], in <cell line: 8>()
      4     print(f"result.shape: {result.shape}")
      5     return result
----> 8 xr.apply_ufunc(
      9     wrapper,
     10     ds,
     11     # specify core dimensions as a list of lists
     12     # here 'time' is the core dimension on `ds`
     13     input_core_dims=[["time"]],
     14     kwargs={"axis": 0},
     15 )

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:1147, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
   1145 # feed datasets apply_variable_ufunc through apply_dataset_vfunc
   1146 elif any(is_dict_like(a) for a in args):
-> 1147     return apply_dataset_vfunc(
   1148         variables_vfunc,
   1149         *args,
   1150         signature=signature,
   1151         join=join,
   1152         exclude_dims=exclude_dims,
   1153         dataset_join=dataset_join,
   1154         fill_value=dataset_fill_value,
   1155         keep_attrs=keep_attrs,
   1156     )
   1157 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
   1158 elif any(isinstance(a, DataArray) for a in args):

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:441, in apply_dataset_vfunc(func, signature, join, dataset_join, fill_value, exclude_dims, keep_attrs, *args)
    436 list_of_coords = build_output_coords(
    437     args, signature, exclude_dims, combine_attrs=keep_attrs
    438 )
    439 args = [getattr(arg, "data_vars", arg) for arg in args]
--> 441 result_vars = apply_dict_of_variables_vfunc(
    442     func, *args, signature=signature, join=dataset_join, fill_value=fill_value
    443 )
    445 if signature.num_outputs > 1:
    446     out = tuple(_fast_dataset(*args) for args in zip(result_vars, list_of_coords))

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:385, in apply_dict_of_variables_vfunc(func, signature, join, fill_value, *args)
    383 result_vars = {}
    384 for name, variable_args in zip(names, grouped_by_name):
--> 385     result_vars[name] = func(*variable_args)
    387 if signature.num_outputs > 1:
    388     return _unpack_dict_tuples(result_vars, signature.num_outputs)

File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:761, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    759 for dim, new_size in var.sizes.items():
    760     if dim in dim_sizes and new_size != dim_sizes[dim]:
--> 761         raise ValueError(
    762             "size of dimension {!r} on inputs was unexpectedly "
    763             "changed by applied function from {} to {}. Only "
    764             "dimensions specified in ``exclude_dims`` with "
    765             "xarray.apply_ufunc are allowed to change size.".format(
    766                 dim, dim_sizes[dim], new_size
    767             )
    768         )
    770 var.attrs = attrs
    771 output.append(var)

ValueError: size of dimension 'lat' on inputs was unexpectedly changed by applied function from 25 to 53. Only dimensions specified in ``exclude_dims`` with xarray.apply_ufunc are allowed to change size.

Now we see the issue:

received <class 'numpy.ndarray'> shape: (25, 53, 2920), kwargs: {'axis': 0}
result.shape: (53, 2920)

The time dimension is of size 2920 and is now the last axis of the array but was initially the first axis

ds.air.get_axis_num("time")
0

This illustrates an important concept: arrays are transposed so that core dimensions are at the end.

With apply_ufunc, core dimensions are recognized by name, and then moved to the last dimension of any input arguments before applying the given function. This means that for functions that accept an axis argument, you usually need to set axis=-1

Such behaviour means that our functions (like wrapper or np.mean) do not need to know the exact order of dimensions. They can rely on the core dimensions being at the end allowing us to write very general code!

We can fix our apply_ufunc call by specifying axis=-1 instead.

def wrapper(array, **kwargs):
    print(f"received {type(array)} shape: {array.shape}, kwargs: {kwargs}")
    result = np.mean(array, **kwargs)
    print(f"result.shape: {result.shape}")
    return result


xr.apply_ufunc(
    wrapper,
    ds,
    input_core_dims=[["time"]],
    kwargs={"axis": -1},
)
received <class 'numpy.ndarray'> shape: (25, 53, 2920), kwargs: {'axis': -1}
result.shape: (25, 53)
<xarray.Dataset>
Dimensions:  (lat: 25, lon: 53)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
Data variables:
    air      (lat, lon) float32 260.4 260.2 259.9 259.5 ... 297.3 297.3 297.3

Exercise#

Use apply_ufunc to apply sp.integrate.trapz along the time axis.

import scipy as sp
import scipy.integrate

xr.apply_ufunc(scipy.integrate.trapz, ds, input_core_dims=[["time"]], kwargs={"axis": -1})
<xarray.Dataset>
Dimensions:  (lat: 25, lon: 53)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
Data variables:
    air      (lat, lon) float32 7.601e+05 7.595e+05 ... 8.678e+05 8.678e+05