From 4249231111cd497dcae20b6df47b7c7771594cc5 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 12:40:17 +1100 Subject: [PATCH 1/6] add tests for numpy split OnAxis --- .../operations/numpy/test_numpy_split.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 packages/pipeline/tests/operations/numpy/test_numpy_split.py diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_split.py b/packages/pipeline/tests/operations/numpy/test_numpy_split.py new file mode 100644 index 00000000..4bbee2d8 --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_split.py @@ -0,0 +1,37 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.numpy import split + +import numpy as np +import pytest + + +@pytest.fixture(scope="module") +def example_data(): + return np.array(range(2 * 3 * 4)).reshape((2, 3, 4)) + + +def test_onaxis(example_data): + """Tests numpy OnAxis split operation class.""" + op = split.OnAxis(axis=1) + + # try join before splitting + with pytest.raises(RuntimeError): + op.join((example_data, example_data)) + result = op.split(example_data) + assert all(np.array_equal(arr, example_data[:, d, :]) for d, arr in enumerate(result)) + + orig = op.join(result) + assert np.array_equal(orig, example_data) From 7a4cf836d0b62ae5571f3eb26b25d451f9a32a8b Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 12:44:42 +1100 Subject: [PATCH 2/6] add split onslice tests and fix join The join fix is needed since the split onslice must keep the dimension that it was split on e.g. if splitting from (2, 3) into (2, 1), and (2, 2), the second dim must be kept and the join should join again on that second dim. Before, it was trying to join by creating a new dim --- .../pyearthtools/pipeline/operations/numpy/split.py | 4 +--- .../tests/operations/numpy/test_numpy_split.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/split.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/split.py index 1fee2dd2..31223ac3 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/split.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/split.py @@ -107,9 +107,7 @@ def split(self, sample: np.ndarray) -> tuple[np.ndarray]: def join(self, sample: tuple[np.ndarray]) -> np.ndarray: """Join `sample` together""" - data = np.stack(sample, axis=0) - data = np.moveaxis(data, 0, self.axis) - return data + return np.concat(sample, axis=self.axis) class VSplit(Spliter): diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_split.py b/packages/pipeline/tests/operations/numpy/test_numpy_split.py index 4bbee2d8..d3cf8ef7 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_split.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_split.py @@ -35,3 +35,15 @@ def test_onaxis(example_data): orig = op.join(result) assert np.array_equal(orig, example_data) + + +def test_onslice(example_data): + """Tests numpy OnSlice split operation class.""" + slices = ((0, 1), (1, 2), (2, 4)) + op = split.OnSlice(*slices, axis=2) + result = op.split(example_data) + for sl, arr in zip(slices, result, strict=True): + assert np.array_equal(arr, example_data[:, :, sl[0] : sl[1]]) + + orig = op.join(result) + assert np.array_equal(orig, example_data) From f5a23bed6a6b72cd3246a033d5e03427e6d5f86f Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 13:05:14 +1100 Subject: [PATCH 3/6] add tests for numpy h/vsplit Fixes hsplit and vsplit calls as they require an indices_or_sections argument. The value chosen is such that the number of arrays returned is equal to the number of elements in the 0th and 1st dimensions for vsplit and hsplit, respectively. --- .../pipeline/operations/numpy/split.py | 8 ++++++-- .../tests/operations/numpy/test_numpy_split.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/split.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/split.py index 31223ac3..bab763bf 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/split.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/split.py @@ -133,7 +133,8 @@ def __init__( self.record_initialisation() def split(self, sample: np.ndarray) -> tuple[np.ndarray]: - return np.vsplit(sample) # type: ignore + # splits into equally-sized subsections. + return np.vsplit(sample, indices_or_sections=sample.shape[0]) # type: ignore def join(self, sample: tuple[np.ndarray]) -> np.ndarray: """Join `sample` together""" @@ -163,7 +164,10 @@ def __init__( self.record_initialisation() def split(self, sample: np.ndarray) -> tuple[np.ndarray]: - return np.hsplit(sample) # type: ignore + # splits into equally-sized subsections. + + nsections = sample.shape[min(1, sample.ndim - 1)] # np.hsplit will split along axis 0 if ndims==1 + return np.hsplit(sample, indices_or_sections=nsections) # type: ignore def join(self, sample: tuple[np.ndarray]) -> np.ndarray: """Join `sample` together""" diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_split.py b/packages/pipeline/tests/operations/numpy/test_numpy_split.py index d3cf8ef7..88ee96b5 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_split.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_split.py @@ -47,3 +47,21 @@ def test_onslice(example_data): orig = op.join(result) assert np.array_equal(orig, example_data) + + +def test_vsplit(example_data): + """Tests numpy VSplit split operation class.""" + op = split.VSplit() + result = op.split(example_data) + assert all(np.array_equal(arr, result[i]) for i, arr in enumerate(result)) + orig = op.join(result) + assert np.array_equal(orig, example_data) + + +def test_hsplit(example_data): + """Tests numpy HSplit split operation class.""" + op = split.HSplit() + result = op.split(example_data) + assert all(np.array_equal(arr, example_data[:, i : i + 1, :]) for i, arr in enumerate(result)) + orig = op.join(result) + assert np.array_equal(orig, example_data) From aa1d1339c2861e31e1b0a5fbc2bdfb7cd375431b Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 13:27:58 +1100 Subject: [PATCH 4/6] add xarray split OnVariables tests --- .../pipeline/operations/xarray/split.py | 2 +- .../operations/xarray/test_xarray_split.py | 53 +++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 packages/pipeline/tests/operations/xarray/test_xarray_split.py diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/split.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/split.py index 5a240139..dc7aea2f 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/split.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/split.py @@ -43,7 +43,7 @@ def __init__( Kwargs needed for merge on the `undo`. Defaults to None. """ super().__init__( - recognised_types=(xr.DataArray, xr.Dataset), + recognised_types=(xr.Dataset,), recursively_split_tuples=True, ) self.record_initialisation() diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_split.py b/packages/pipeline/tests/operations/xarray/test_xarray_split.py new file mode 100644 index 00000000..91ae3385 --- /dev/null +++ b/packages/pipeline/tests/operations/xarray/test_xarray_split.py @@ -0,0 +1,53 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.xarray import split + +import numpy as np +import xarray as xr +import pytest + + +@pytest.fixture(scope="module") +def example_dataarray(): + return xr.DataArray(np.array(range(2 * 3 * 4)).reshape((2, 3, 4))) + + +@pytest.fixture(scope="module") +def example_dataset(example_dataarray): + return xr.Dataset({"a": example_dataarray, "b": 2 * example_dataarray}) + + +def test_onvariables(example_dataset): + """Tests xarray OnVariables operation class.""" + + # split on all variables + op = split.OnVariables() + result = op.split(example_dataset) + assert result[0].equals(example_dataset.drop_vars("b")) + assert result[1].equals(example_dataset.drop_vars("a")) + + # join datasets + orig = op.join(result) + assert orig.equals(example_dataset) + + # split on selected variables + op = split.OnVariables(variables=("a",)) + result = op.split(example_dataset) + assert result[0].equals(example_dataset.drop_vars("b")) + + # split on non-existent variable + op = split.OnVariables(variables=("c",)) + with pytest.raises(ValueError): + op.split(example_dataset) From 958396c2cbd876d03b6e97b212df0f9a3c9da129 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 14:06:18 +1100 Subject: [PATCH 5/6] add tests for xarray split OnCoordinate Also includes fix for the join method since xr.merge would fail due to incompatible coords after splitting. --- .../pipeline/operations/xarray/split.py | 4 ++-- .../operations/xarray/test_xarray_split.py | 23 ++++++++++++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/split.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/split.py index dc7aea2f..d152bd18 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/split.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/split.py @@ -90,5 +90,5 @@ def __init__( def split(self, sample: T) -> tuple[T, ...]: return tuple(sample.sel(**{self.coordinate: i}) for i in sample.coords[self.coordinate]) - def undo(self, sample: tuple[T, ...]) -> xr.Dataset: - return xr.merge(sample, **(self._merge_kwargs or {})) + def join(self, sample: tuple[T, ...]) -> xr.Dataset: + return xr.concat(sample, dim=self.coordinate) diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_split.py b/packages/pipeline/tests/operations/xarray/test_xarray_split.py index 91ae3385..d2cc063b 100644 --- a/packages/pipeline/tests/operations/xarray/test_xarray_split.py +++ b/packages/pipeline/tests/operations/xarray/test_xarray_split.py @@ -21,7 +21,9 @@ @pytest.fixture(scope="module") def example_dataarray(): - return xr.DataArray(np.array(range(2 * 3 * 4)).reshape((2, 3, 4))) + return xr.DataArray( + np.array(range(2 * 3 * 4)).reshape((2, 3, 4)), coords={"c0": range(2), "c1": range(3), "c2": range(4)} + ) @pytest.fixture(scope="module") @@ -51,3 +53,22 @@ def test_onvariables(example_dataset): op = split.OnVariables(variables=("c",)) with pytest.raises(ValueError): op.split(example_dataset) + + +def test_oncoordinate(example_dataarray, example_dataset): + """Tests xarray OnCoordinate operation class.""" + op = split.OnCoordinate("c1") + result = op.split(example_dataset) + for i, arr in enumerate(result): + assert arr["a"].equals(example_dataset["a"].loc[:, i, :]) + assert arr["b"].equals(example_dataset["b"].loc[:, i, :]) + + orig = op.join(result) + assert orig.broadcast_equals(example_dataset) + + result = op.split(example_dataarray) + for i, arr in enumerate(result): + assert arr.equals(example_dataarray.loc[:, i, :]) + + orig = op.join(result) + assert orig.broadcast_equals(example_dataarray) From b09fd0f0d614e5dc010c37173df5a3c601cc71e1 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 14:14:08 +1100 Subject: [PATCH 6/6] add tests for dask split and fix OnSlice join Like numpy, concat is required to reverse the split. --- .../pipeline/operations/dask/split.py | 4 +- .../tests/operations/dask/test_dask_split.py | 49 +++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 packages/pipeline/tests/operations/dask/test_dask_split.py diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/split.py b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/split.py index 4069c10c..25742625 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/split.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/split.py @@ -106,9 +106,7 @@ def split(self, sample: da.Array) -> tuple[da.Array]: def join(self, sample: tuple[da.Array]) -> da.Array: """Join `sample` together""" - data = da.stack(sample, axis=0) - data = da.moveaxis(data, 0, self.axis) - return data + return da.concatenate(sample, axis=self.axis) # class VSplit(Spliter, DaskOperation): diff --git a/packages/pipeline/tests/operations/dask/test_dask_split.py b/packages/pipeline/tests/operations/dask/test_dask_split.py new file mode 100644 index 00000000..59cf3d4f --- /dev/null +++ b/packages/pipeline/tests/operations/dask/test_dask_split.py @@ -0,0 +1,49 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.dask import split + +import dask.array as da +import pytest + + +@pytest.fixture(scope="module") +def example_data(): + return da.array(range(2 * 3 * 4)).reshape((2, 3, 4)) + + +def test_onaxis(example_data): + """Tests dask OnAxis split operation class.""" + op = split.OnAxis(axis=1) + + # try join before splitting + with pytest.raises(RuntimeError): + op.join((example_data, example_data)) + result = op.split(example_data) + assert all((arr == example_data[:, d, :]).all().compute() for d, arr in enumerate(result)) + + orig = op.join(result) + assert (orig == example_data).all().compute() + + +def test_onslice(example_data): + """Tests dask OnSlice split operation class.""" + slices = ((0, 1), (1, 2), (2, 4)) + op = split.OnSlice(*slices, axis=2) + result = op.split(example_data) + for sl, arr in zip(slices, result, strict=True): + assert (arr == example_data[:, :, sl[0] : sl[1]]).all().compute() + + orig = op.join(result) + assert (orig == example_data).all().compute()