A gentle introduction
Contents

A gentle introduction#
apply_ufunc
is a more advanced wrapper that is designed to apply functions
that expect and return NumPy (or other arrays). For example, this would include
all of SciPy’s API. Since apply_ufunc
operates on lower-level NumPy or Dask
objects, it skips the overhead of using Xarray objects making it a good choice
for performance-critical functions. Xarray uses apply_ufunc
internally
to implement much of its API, meaning that it is quite powerful!
Learning goals:
Learn that
apply_ufunc
automates aspects of applying computation functions that are designed for pure arrays (like numpy arrays) on xarray objects
Setup#
import numpy as np
import xarray as xr
xr.set_options(display_expand_data=False)
<xarray.core.options.set_options at 0x7f030913f220>
Let’s load a dataset
ds = xr.tutorial.load_dataset("air_temperature")
ds
<xarray.Dataset> Dimensions: (lat: 25, time: 2920, lon: 53) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float32 241.2 242.5 243.5 ... 296.5 296.2 295.7 Attributes: Conventions: COARDS title: 4x daily NMC reanalysis (1948) description: Data is from NMC initialized reanalysis\n(4x/day). These a... platform: Model references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
A simple example#
Simple functions that act independently on each value should work without any additional arguments.
Consider the following squared_error
function
def squared_error(x, y):
return (x - y) ** 2
We can apply this manually by extracting the underlying numpy array
numpy_result = squared_error(ds.air.data, 1)
numpy_result
array([[[57696.04 , 58322.25 , 58806.25 , ..., 53731.234, 54990.25 ,
56453.754],
[58951.836, 59292.25 , 59389.688, ..., 53731.234, 54896.484,
56786.883],
[62001. , 61901.434, 61449.453, ..., 53916.84 , 55408.453,
57936.49 ],
...,
[87379.37 , 87143.03 , 87261.16 , ..., 86671.36 , 86494.81 ,
86259.68 ],
[86966.01 , 87143.03 , 87491.73 , ..., 86966.01 , 86966.01 ,
86553.63 ],
[87196.19 , 87491.73 , 87675.21 , ..., 87556.805, 87491.73 ,
87379.37 ]],
[[58129.207, 58418.887, 58612.406, ..., 53361. , 54102.754,
55131.035],
[58854.754, 59097.605, 59146.24 , ..., 52900. , 53592.25 ,
55084.09 ],
[63604.84 , 63448.57 , 63051.207, ..., 52808.035, 54005.113,
56406.25 ],
...,
[87261.16 , 86966.01 , 87143.03 , ..., 86671.36 , 86494.81 ,
86312.57 ],
[87143.03 , 87438.48 , 87491.73 , ..., 86789.164, 86730.25 ,
86494.81 ],
[87196.19 , 87734.43 , 87852.95 , ..., 87261.16 , 87261.16 ,
87379.37 ]],
[[58225.684, 58177.438, 58225.684, ..., 54428.883, 55272.004,
56501.29 ],
[59340.957, 59238.69 , 59049. , ..., 52578.484, 53361. ,
55084.09 ],
[65127.03 , 64770.25 , 64110.24 , ..., 52992.04 , 53916.84 ,
56263.84 ],
...,
[86789.164, 86671.36 , 86671.36 , ..., 87196.19 , 86606.61 ,
86436. ],
[87143.03 , 87320.25 , 87196.19 , ..., 87261.16 , 87025. ,
86789.164],
[87261.16 , 87196.19 , 87261.16 , ..., 87616. , 87616. ,
87491.73 ]],
...,
[[58801.395, 58559.156, 58124.387, ..., 59141.37 , 59287.375,
59482.332],
[61548.645, 61499.035, 61300.805, ..., 57403.367, 57739.28 ,
58414.05 ],
[68481.66 , 68220.22 , 67959.28 , ..., 56829.793, 57931.67 ,
59628.75 ],
...,
[86312.57 , 86606.61 , 87906.31 , ..., 86724.35 , 86665.46 ,
86253.82 ],
[87491.73 , 88143.664, 88381.35 , ..., 86724.35 , 86724.35 ,
86312.57 ],
[88321.9 , 88917.28 , 88678.89 , ..., 87078.11 , 86901.15 ,
86901.15 ]],
[[59922.14 , 59433.562, 58801.395, ..., 58704.44 , 59044.137,
59433.562],
[61946.23 , 61647.92 , 61251.297, ..., 57739.28 , 58317.414,
59190.02 ],
[68324.73 , 68011.43 , 67750.89 , ..., 57355.457, 58607.566,
60461.89 ],
...,
[85667.44 , 85784.54 , 86665.46 , ..., 86488.93 , 86253.82 ,
86019.03 ],
[87196.19 , 87728.516, 87965.625, ..., 86606.61 , 86488.93 ,
86077.68 ],
[88084.31 , 88440.805, 88500.3 , ..., 86842.195, 86724.35 ,
86547.76 ]],
[[59579.926, 59190.02 , 58704.44 , ..., 57931.67 , 57835.434,
57979.82 ],
[61946.23 , 61647.92 , 61201.812, ..., 56925.188, 57259.7 ,
57931.67 ],
[68638.76 , 68220.22 , 67802.945, ..., 57068.434, 58365.727,
60167.18 ],
...,
[85725.99 , 85667.44 , 86488.93 , ..., 86606.61 , 86488.93 ,
86253.82 ],
[87078.11 , 87550.88 , 87728.516, ..., 86842.195, 86842.195,
86547.76 ],
[88024.96 , 88262.47 , 88262.47 , ..., 87314.336, 87137.14 ,
86842.195]]], dtype=float32)
To convert this result to a DataArray, we could do it manually
xr.DataArray(data=numpy_result, dims=ds.air.dims, coords=ds.air.coords)
<xarray.DataArray (time: 2920, lat: 25, lon: 53)> 5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04 Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
A shorter version uses DataArray.copy
ds.air.copy(data=numpy_result)
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> 5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04 Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 Attributes: long_name: 4xDaily Air temperature at sigma level 995 units: degK precision: 2 GRIB_id: 11 GRIB_name: TMP var_desc: Air temperature dataset: NMC Reanalysis level_desc: Surface statistic: Individual Obs parent_stat: Other actual_range: [185.16 322.1 ]
Using DataArray.copy
works for such simple cases but doesn’t generalize that well. For example, consider a function that removed one dimension and added a new dimension.
apply_ufunc
can handle such cases. Here’s how to use it with squared_error
xr.apply_ufunc(squared_error, ds.air, 1)
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> 5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04 Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
How does apply_ufunc work?#
To illustrate how apply_ufunc
works, let us write a small wrapper function. This will let us examine what data is received and returned from the applied function.
Tip
This trick is very useful for debugging
def wrapper(x, y):
print(f"received x of type {type(x)}, shape {x.shape}")
print(f"received y of type {type(y)}")
return squared_error(x, y)
xr.apply_ufunc(wrapper, ds.air, 1)
received x of type <class 'numpy.ndarray'>, shape (2920, 25, 53)
received y of type <class 'int'>
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> 5.77e+04 5.832e+04 5.881e+04 5.905e+04 ... 8.731e+04 8.714e+04 8.684e+04 Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00
We see that wrapper
receives the underlying numpy array (ds.air.data
), and the integer 1
.
Essentially, apply_ufunc
does the following:
extracts the underlying array data,
passes it to the user function,
receives the returned values, and
then wraps that back up as an array
apply_ufunc easily handles both dataarrays and datasets.
When passed a Dataset, apply-ufunc will loop over the data variables and sequentially pass those to squared_error
. So squared_error
always receives a numpy array
xr.apply_ufunc(wrapper, ds, 1)
received x of type <class 'numpy.ndarray'>, shape (2920, 25, 53)
received y of type <class 'int'>
<xarray.Dataset> Dimensions: (time: 2920, lat: 25, lon: 53) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float32 5.77e+04 5.832e+04 ... 8.714e+04 8.684e+04
xr.apply_ufunc(squared_error, ds, 1)
<xarray.Dataset> Dimensions: (time: 2920, lat: 25, lon: 53) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2013-01-01 ... 2014-12-31T18:00:00 Data variables: air (time, lat, lon) float32 5.77e+04 5.832e+04 ... 8.714e+04 8.684e+04
Reductions and core dimensions#
squared_error
operated on a per-element basis. How about a reduction like np.mean
?
Such functions involve the concept of “core dimensions”. One way to think about core dimensions is to consider the smallest dimensionality of data necessary to apply the function.
For using more complex operations that consider some array values collectively,
it’s important to understand the idea of core dimensions.
Usually, they correspond to the fundamental dimensions over
which an operation is defined, e.g., the summed axis in np.sum
. A good clue
that core dimensions are needed is the presence of an axis
argument on the
corresponding NumPy function.
Let’s write a function that computes the mean along time
for a provided xarray object. This function requires one core dimension time
. For ds.air
note that time
is the 0th axis.
ds.air.dims
('time', 'lat', 'lon')
np.mean(ds.air, axis=ds.air.get_axis_num("time"))
<xarray.DataArray 'air' (lat: 25, lon: 53)> 260.4 260.2 259.9 259.5 259.0 258.6 ... 298.0 297.9 297.8 297.3 297.3 297.3 Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
np.mean(ds.air.data, axis=0)
array([[260.37564, 260.1826 , 259.88593, ..., 250.81511, 251.93733,
253.43741],
[262.7337 , 262.7936 , 262.7489 , ..., 249.75496, 251.5852 ,
254.35849],
[264.7681 , 264.3271 , 264.0614 , ..., 250.60707, 253.58247,
257.71475],
...,
[297.64932, 296.95294, 296.62912, ..., 296.81033, 296.28793,
295.81622],
[298.1287 , 297.93646, 297.47006, ..., 296.8591 , 296.77686,
296.44348],
[298.36594, 298.38593, 298.11386, ..., 297.33777, 297.28104,
297.30502]], dtype=float32)
Let’s try to use apply_ufunc
to replicate np.mean(ds.air.data, axis=0)
xr.apply_ufunc(
# function to apply
np.mean,
# object with data to pass to function
ds,
# keyword arguments to pass to np.mean
kwargs={"axis": 0},
)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[14], line 1
----> 1 xr.apply_ufunc(
2 # function to apply
3 np.mean,
4 # object with data to pass to function
5 ds,
6 # keyword arguments to pass to np.mean
7 kwargs={"axis": 0},
8 )
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:1192, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
1190 # feed datasets apply_variable_ufunc through apply_dataset_vfunc
1191 elif any(is_dict_like(a) for a in args):
-> 1192 return apply_dataset_vfunc(
1193 variables_vfunc,
1194 *args,
1195 signature=signature,
1196 join=join,
1197 exclude_dims=exclude_dims,
1198 dataset_join=dataset_join,
1199 fill_value=dataset_fill_value,
1200 keep_attrs=keep_attrs,
1201 )
1202 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
1203 elif any(isinstance(a, DataArray) for a in args):
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:480, in apply_dataset_vfunc(func, signature, join, dataset_join, fill_value, exclude_dims, keep_attrs, *args)
475 list_of_coords, list_of_indexes = build_output_coords_and_indexes(
476 args, signature, exclude_dims, combine_attrs=keep_attrs
477 )
478 args = tuple(getattr(arg, "data_vars", arg) for arg in args)
--> 480 result_vars = apply_dict_of_variables_vfunc(
481 func, *args, signature=signature, join=dataset_join, fill_value=fill_value
482 )
484 out: Dataset | tuple[Dataset, ...]
485 if signature.num_outputs > 1:
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:422, in apply_dict_of_variables_vfunc(func, signature, join, fill_value, *args)
420 result_vars = {}
421 for name, variable_args in zip(names, grouped_by_name):
--> 422 result_vars[name] = func(*variable_args)
424 if signature.num_outputs > 1:
425 return _unpack_dict_tuples(result_vars, signature.num_outputs)
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:796, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
794 data = as_compatible_data(data)
795 if data.ndim != len(dims):
--> 796 raise ValueError(
797 "applied function returned data with unexpected "
798 f"number of dimensions. Received {data.ndim} dimension(s) but "
799 f"expected {len(dims)} dimensions with names: {dims!r}"
800 )
802 var = Variable(dims, data, fastpath=True)
803 for dim, new_size in var.sizes.items():
ValueError: applied function returned data with unexpected number of dimensions. Received 2 dimension(s) but expected 3 dimensions with names: ('time', 'lat', 'lon')
The error here
applied function returned data with unexpected number of dimensions. Received 2 dimension(s) but expected 3 dimensions with names: (‘time’, ‘lat’, ‘lon’)
means that while np.mean
did indeed reduce one dimension, we did not tell apply_ufunc
that this would happen. That is, we need to specify the core dimensions on the input.
xr.apply_ufunc(
np.mean,
ds,
# specify core dimensions as a list of lists
# here 'time' is the core dimension on `ds`
input_core_dims=[["time"]],
kwargs={"axis": 0},
)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[15], line 1
----> 1 xr.apply_ufunc(
2 np.mean,
3 ds,
4 # specify core dimensions as a list of lists
5 # here 'time' is the core dimension on `ds`
6 input_core_dims=[["time"]],
7 kwargs={"axis": 0},
8 )
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:1192, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
1190 # feed datasets apply_variable_ufunc through apply_dataset_vfunc
1191 elif any(is_dict_like(a) for a in args):
-> 1192 return apply_dataset_vfunc(
1193 variables_vfunc,
1194 *args,
1195 signature=signature,
1196 join=join,
1197 exclude_dims=exclude_dims,
1198 dataset_join=dataset_join,
1199 fill_value=dataset_fill_value,
1200 keep_attrs=keep_attrs,
1201 )
1202 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
1203 elif any(isinstance(a, DataArray) for a in args):
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:480, in apply_dataset_vfunc(func, signature, join, dataset_join, fill_value, exclude_dims, keep_attrs, *args)
475 list_of_coords, list_of_indexes = build_output_coords_and_indexes(
476 args, signature, exclude_dims, combine_attrs=keep_attrs
477 )
478 args = tuple(getattr(arg, "data_vars", arg) for arg in args)
--> 480 result_vars = apply_dict_of_variables_vfunc(
481 func, *args, signature=signature, join=dataset_join, fill_value=fill_value
482 )
484 out: Dataset | tuple[Dataset, ...]
485 if signature.num_outputs > 1:
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:422, in apply_dict_of_variables_vfunc(func, signature, join, fill_value, *args)
420 result_vars = {}
421 for name, variable_args in zip(names, grouped_by_name):
--> 422 result_vars[name] = func(*variable_args)
424 if signature.num_outputs > 1:
425 return _unpack_dict_tuples(result_vars, signature.num_outputs)
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:805, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
803 for dim, new_size in var.sizes.items():
804 if dim in dim_sizes and new_size != dim_sizes[dim]:
--> 805 raise ValueError(
806 "size of dimension {!r} on inputs was unexpectedly "
807 "changed by applied function from {} to {}. Only "
808 "dimensions specified in ``exclude_dims`` with "
809 "xarray.apply_ufunc are allowed to change size.".format(
810 dim, dim_sizes[dim], new_size
811 )
812 )
814 var.attrs = attrs
815 output.append(var)
ValueError: size of dimension 'lat' on inputs was unexpectedly changed by applied function from 25 to 53. Only dimensions specified in ``exclude_dims`` with xarray.apply_ufunc are allowed to change size.
This next error is a little confusing.
size of dimension ‘lat’ on inputs was unexpectedly changed by applied function from 25 to 53. Only dimensions specified in
exclude_dims
with xarray.apply_ufunc are allowed to change size.
A good trick here is to pass a little wrapper function to apply_ufunc
instead and inspect the shapes of data received by the wrapper.
def wrapper(array, **kwargs):
print(f"received {type(array)} shape: {array.shape}, kwargs: {kwargs}")
result = np.mean(array, **kwargs)
print(f"result.shape: {result.shape}")
return result
xr.apply_ufunc(
wrapper,
ds,
# specify core dimensions as a list of lists
# here 'time' is the core dimension on `ds`
input_core_dims=[["time"]],
kwargs={"axis": 0},
)
received <class 'numpy.ndarray'> shape: (25, 53, 2920), kwargs: {'axis': 0}
result.shape: (53, 2920)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[16], line 8
4 print(f"result.shape: {result.shape}")
5 return result
----> 8 xr.apply_ufunc(
9 wrapper,
10 ds,
11 # specify core dimensions as a list of lists
12 # here 'time' is the core dimension on `ds`
13 input_core_dims=[["time"]],
14 kwargs={"axis": 0},
15 )
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:1192, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
1190 # feed datasets apply_variable_ufunc through apply_dataset_vfunc
1191 elif any(is_dict_like(a) for a in args):
-> 1192 return apply_dataset_vfunc(
1193 variables_vfunc,
1194 *args,
1195 signature=signature,
1196 join=join,
1197 exclude_dims=exclude_dims,
1198 dataset_join=dataset_join,
1199 fill_value=dataset_fill_value,
1200 keep_attrs=keep_attrs,
1201 )
1202 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
1203 elif any(isinstance(a, DataArray) for a in args):
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:480, in apply_dataset_vfunc(func, signature, join, dataset_join, fill_value, exclude_dims, keep_attrs, *args)
475 list_of_coords, list_of_indexes = build_output_coords_and_indexes(
476 args, signature, exclude_dims, combine_attrs=keep_attrs
477 )
478 args = tuple(getattr(arg, "data_vars", arg) for arg in args)
--> 480 result_vars = apply_dict_of_variables_vfunc(
481 func, *args, signature=signature, join=dataset_join, fill_value=fill_value
482 )
484 out: Dataset | tuple[Dataset, ...]
485 if signature.num_outputs > 1:
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:422, in apply_dict_of_variables_vfunc(func, signature, join, fill_value, *args)
420 result_vars = {}
421 for name, variable_args in zip(names, grouped_by_name):
--> 422 result_vars[name] = func(*variable_args)
424 if signature.num_outputs > 1:
425 return _unpack_dict_tuples(result_vars, signature.num_outputs)
File /usr/share/miniconda3/envs/xarray-tutorial/lib/python3.10/site-packages/xarray/core/computation.py:805, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
803 for dim, new_size in var.sizes.items():
804 if dim in dim_sizes and new_size != dim_sizes[dim]:
--> 805 raise ValueError(
806 "size of dimension {!r} on inputs was unexpectedly "
807 "changed by applied function from {} to {}. Only "
808 "dimensions specified in ``exclude_dims`` with "
809 "xarray.apply_ufunc are allowed to change size.".format(
810 dim, dim_sizes[dim], new_size
811 )
812 )
814 var.attrs = attrs
815 output.append(var)
ValueError: size of dimension 'lat' on inputs was unexpectedly changed by applied function from 25 to 53. Only dimensions specified in ``exclude_dims`` with xarray.apply_ufunc are allowed to change size.
Now we see the issue:
received <class 'numpy.ndarray'> shape: (25, 53, 2920), kwargs: {'axis': 0}
result.shape: (53, 2920)
The time
dimension is of size 2920
and is now the last axis of the array but was initially the first axis
ds.air.get_axis_num("time")
0
This illustrates an important concept: arrays are transposed so that core dimensions are at the end.
With apply_ufunc
, core dimensions are recognized by name, and then moved to
the last dimension of any input arguments before applying the given function.
This means that for functions that accept an axis
argument, you usually need
to set axis=-1
Such behaviour means that our functions (like wrapper
or np.mean
) do not need to know the exact order of dimensions. They can rely on the core dimensions being at the end allowing us to write very general code!
We can fix our apply_ufunc
call by specifying axis=-1
instead.
def wrapper(array, **kwargs):
print(f"received {type(array)} shape: {array.shape}, kwargs: {kwargs}")
result = np.mean(array, **kwargs)
print(f"result.shape: {result.shape}")
return result
xr.apply_ufunc(
wrapper,
ds,
input_core_dims=[["time"]],
kwargs={"axis": -1},
)
received <class 'numpy.ndarray'> shape: (25, 53, 2920), kwargs: {'axis': -1}
result.shape: (25, 53)
<xarray.Dataset> Dimensions: (lat: 25, lon: 53) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 Data variables: air (lat, lon) float32 260.4 260.2 259.9 259.5 ... 297.3 297.3 297.3
Exercise#
Use apply_ufunc
to apply sp.integrate.trapz
along the time
axis.
import scipy as sp
import scipy.integrate
xr.apply_ufunc(scipy.integrate.trapz, ds, input_core_dims=[["time"]], kwargs={"axis": -1})
<xarray.Dataset> Dimensions: (lat: 25, lon: 53) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 Data variables: air (lat, lon) float32 7.601e+05 7.595e+05 ... 8.678e+05 8.678e+05