From 116ca48947bacdf7a671c9fb6cfa1bf368b53e2d Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 7 Nov 2025 13:51:40 +1100 Subject: [PATCH 1/3] add test for dataset coord aligner --- .../tests/operations/xarray/test_sort.py | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/packages/pipeline/tests/operations/xarray/test_sort.py b/packages/pipeline/tests/operations/xarray/test_sort.py index e5683dba..e0b22f4e 100644 --- a/packages/pipeline/tests/operations/xarray/test_sort.py +++ b/packages/pipeline/tests/operations/xarray/test_sort.py @@ -15,7 +15,7 @@ import pytest import xarray as xr -from pyearthtools.pipeline.operations.xarray import _sort as sort +from pyearthtools.pipeline.operations.xarray import _sort as sort, AlignDataVariableDimensionsToDatasetCoords SIMPLE_DA1 = xr.DataArray( [ @@ -37,6 +37,51 @@ SIMPLE_DS2 = xr.Dataset({"Humidity": SIMPLE_DA1, "Temperature": SIMPLE_DA1, "WombatsPerKm2": SIMPLE_DA1}) +def test_align(): + """Tests that the dataset dimension alignment operation works.""" + align_op = AlignDataVariableDimensionsToDatasetCoords() + + # create dataset with arrays that are not consistently ordered + ds = xr.Dataset( + { + "Temperature": SIMPLE_DA1.transpose("lat", "height", "lon"), + "Humidity": SIMPLE_DA1, + "WombatsPerKm2": SIMPLE_DA1.transpose("lon", "height", "lat"), + } + ) + + # check that dataset dims are indeed unaligned + assert ds["Temperature"].dims != ds["Humidity"].dims + assert ds["Temperature"].dims != ds["WombatsPerKm2"].dims + + # apply aligner to dataset and check that dataset dims now align + ds_aligned = align_op.apply_func(ds) + assert ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims + assert ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims + + ## Test that alignment works even when coordinate names don't match dims + da_with_named_coords = xr.DataArray( + SIMPLE_DA1.data, + coords={"h": ("height", [10, 20]), "x": ("lat", [0, 1, 2]), "y": ("lon", [5, 6, 7])}, + dims=["height", "lat", "lon"], + ) + ds = xr.Dataset( + { + "Temperature": da_with_named_coords.transpose("lat", "height", "lon"), + "Humidity": da_with_named_coords, + "WombatsPerKm2": da_with_named_coords.transpose("lon", "height", "lat"), + } + ) + # check that dataset dims are indeed unaligned + assert ds["Temperature"].dims != ds["Humidity"].dims + assert ds["Temperature"].dims != ds["WombatsPerKm2"].dims + + # apply aligner to dataset and check that dataset dims now align + ds_aligned = align_op.apply_func(ds) + assert ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims + assert ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims + + def test_Sort(): s = sort.Sort() From 3aa9cf0cc2ad5b936706b4f3fa674b4def62b0d6 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 7 Nov 2025 14:00:37 +1100 Subject: [PATCH 2/3] use coord.dims instead of coords --- .../src/pyearthtools/pipeline/operations/xarray/_sort.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py index 3d009531..8de827a8 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py @@ -35,7 +35,8 @@ class AlignDataVariableDimensionsToDatasetCoords(Operation): """ def apply_func(self, data: xr.Dataset) -> xr.Dataset: - dataset_ordering = list(data.coords) + # use coords.dim for when coordinates don't have the same name as dimensions + dataset_ordering = list(data.coords.dims) data = data.transpose(*dataset_ordering) return data From 3506e39e80ca2c34cb6850e20afcbe396f16bb67 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 7 Nov 2025 14:12:44 +1100 Subject: [PATCH 3/3] raise notimplemented for dataset align undo --- .../src/pyearthtools/pipeline/operations/xarray/_sort.py | 2 +- packages/pipeline/tests/operations/xarray/test_sort.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py index 8de827a8..1eb0c59a 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/_sort.py @@ -44,7 +44,7 @@ def apply_func(self, data: xr.Dataset) -> xr.Dataset: def undo_func(self, data: xr.Dataset) -> xr.Dataset: # TODO: Record all the original orderings and transpose them back, I guess - return data + raise NotImplementedError("Don't yet know how to undo data variable alignment.") class Sort(Operation): diff --git a/packages/pipeline/tests/operations/xarray/test_sort.py b/packages/pipeline/tests/operations/xarray/test_sort.py index e0b22f4e..cf5f72f9 100644 --- a/packages/pipeline/tests/operations/xarray/test_sort.py +++ b/packages/pipeline/tests/operations/xarray/test_sort.py @@ -81,6 +81,10 @@ def test_align(): assert ds_aligned["Temperature"].dims == ds_aligned["Humidity"].dims assert ds_aligned["Temperature"].dims == ds_aligned["WombatsPerKm2"].dims + # placeholder test for undo method + with pytest.raises(NotImplementedError): + align_op.undo_func(ds) + def test_Sort():