diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py index 3d009531..1eb0c59a 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py @@ -35,7 +35,8 @@ class AlignDataVariableDimensionsToDatasetCoords(Operation): """ def apply_func(self, data: xr.Dataset) -> xr.Dataset: - dataset_ordering = list(data.coords) + # use coords.dim for when coordinates don't have the same name as dimensions + dataset_ordering = list(data.coords.dims) data = data.transpose(*dataset_ordering) return data @@ -43,7 +44,7 @@ def apply_func(self, data: xr.Dataset) -> xr.Dataset: def undo_func(self, data: xr.Dataset) -> xr.Dataset: # TODO: Record all the original orderings and transpose them back, I guess - return data + raise NotImplementedError("Don't yet know how to undo data variable alignment.") class Sort(Operation): diff --git a/packages/pipeline/tests/operations/xarray/test_sort.py b/packages/pipeline/tests/operations/xarray/test_sort.py index e5683dba..cf5f72f9 100644 --- a/packages/pipeline/tests/operations/xarray/test_sort.py +++ b/packages/pipeline/tests/operations/xarray/test_sort.py @@ -15,7 +15,7 @@ import pytest import xarray as xr -from pyearthtools.pipeline.operations.xarray import _sort as sort +from pyearthtools.pipeline.operations.xarray import _sort as sort, AlignDataVariableDimensionsToDatasetCoords SIMPLE_DA1 = xr.DataArray( [ @@ -37,6 +37,55 @@ SIMPLE_DS2 = xr.Dataset({"Humidity": SIMPLE_DA1, "Temperature": SIMPLE_DA1, "WombatsPerKm2": SIMPLE_DA1}) +def test_align(): + """Tests that the dataset dimension alignment operation works.""" + align_op = AlignDataVariableDimensionsToDatasetCoords() + + # create dataset with arrays that are not consistently ordered + ds = xr.Dataset( + { + "Temperature": SIMPLE_DA1.transpose("lat", "height", "lon"), + "Humidity": SIMPLE_DA1, + "WombatsPerKm2": SIMPLE_DA1.transpose("lon", "height", "lat"), + } + ) + + # check that dataset dims are indeed unaligned + assert ds["Temperature"].dims != ds["Humidity"].dims + assert ds["Temperature"].dims != ds["WombatsPerKm2"].dims + + # apply aligner to dataset and check that dataset dims now align + ds_aligned = align_op.apply_func(ds) + assert ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims + assert ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims + + ## Test that alignment works even when coordinate names don't match dims + da_with_named_coords = xr.DataArray( + SIMPLE_DA1.data, + coords={"h": ("height", [10, 20]), "x": ("lat", [0, 1, 2]), "y": ("lon", [5, 6, 7])}, + dims=["height", "lat", "lon"], + ) + ds = xr.Dataset( + { + "Temperature": da_with_named_coords.transpose("lat", "height", "lon"), + "Humidity": da_with_named_coords, + "WombatsPerKm2": da_with_named_coords.transpose("lon", "height", "lat"), + } + ) + # check that dataset dims are indeed unaligned + assert ds["Temperature"].dims != ds["Humidity"].dims + assert ds["Temperature"].dims != ds["WombatsPerKm2"].dims + + # apply aligner to dataset and check that dataset dims now align + ds_aligned = align_op.apply_func(ds) + assert ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims + assert ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims + + # placeholder test for undo method + with pytest.raises(NotImplementedError): + align_op.undo_func(ds) + + def test_Sort(): s = sort.Sort()