diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/split.py b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/split.py index 4069c10c..25742625 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/split.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/split.py @@ -106,9 +106,7 @@ def split(self, sample: da.Array) -> tuple[da.Array]: def join(self, sample: tuple[da.Array]) -> da.Array: """Join `sample` together""" - data = da.stack(sample, axis=0) - data = da.moveaxis(data, 0, self.axis) - return data + return da.concatenate(sample, axis=self.axis) # class VSplit(Spliter, DaskOperation): diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/split.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/split.py index 1fee2dd2..bab763bf 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/split.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/split.py @@ -107,9 +107,7 @@ def split(self, sample: np.ndarray) -> tuple[np.ndarray]: def join(self, sample: tuple[np.ndarray]) -> np.ndarray: """Join `sample` together""" - data = np.stack(sample, axis=0) - data = np.moveaxis(data, 0, self.axis) - return data + return np.concat(sample, axis=self.axis) class VSplit(Spliter): @@ -135,7 +133,8 @@ def __init__( self.record_initialisation() def split(self, sample: np.ndarray) -> tuple[np.ndarray]: - return np.vsplit(sample) # type: ignore + # splits into equally-sized subsections. + return np.vsplit(sample, indices_or_sections=sample.shape[0]) # type: ignore def join(self, sample: tuple[np.ndarray]) -> np.ndarray: """Join `sample` together""" @@ -165,7 +164,10 @@ def __init__( self.record_initialisation() def split(self, sample: np.ndarray) -> tuple[np.ndarray]: - return np.hsplit(sample) # type: ignore + # splits into equally-sized subsections. + + nsections = sample.shape[min(1, sample.ndim - 1)] # np.hsplit will split along axis 0 if ndims==1 + return np.hsplit(sample, indices_or_sections=nsections) # type: ignore def join(self, sample: tuple[np.ndarray]) -> np.ndarray: """Join `sample` together""" diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/split.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/split.py index 5a240139..d152bd18 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/split.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/split.py @@ -43,7 +43,7 @@ def __init__( Kwargs needed for merge on the `undo`. Defaults to None. """ super().__init__( - recognised_types=(xr.DataArray, xr.Dataset), + recognised_types=(xr.Dataset,), recursively_split_tuples=True, ) self.record_initialisation() @@ -90,5 +90,5 @@ def __init__( def split(self, sample: T) -> tuple[T, ...]: return tuple(sample.sel(**{self.coordinate: i}) for i in sample.coords[self.coordinate]) - def undo(self, sample: tuple[T, ...]) -> xr.Dataset: - return xr.merge(sample, **(self._merge_kwargs or {})) + def join(self, sample: tuple[T, ...]) -> xr.Dataset: + return xr.concat(sample, dim=self.coordinate) diff --git a/packages/pipeline/tests/operations/dask/test_dask_split.py b/packages/pipeline/tests/operations/dask/test_dask_split.py new file mode 100644 index 00000000..59cf3d4f --- /dev/null +++ b/packages/pipeline/tests/operations/dask/test_dask_split.py @@ -0,0 +1,49 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.dask import split + +import dask.array as da +import pytest + + +@pytest.fixture(scope="module") +def example_data(): + return da.array(range(2 * 3 * 4)).reshape((2, 3, 4)) + + +def test_onaxis(example_data): + """Tests dask OnAxis split operation class.""" + op = split.OnAxis(axis=1) + + # try join before splitting + with pytest.raises(RuntimeError): + op.join((example_data, example_data)) + result = op.split(example_data) + assert all((arr == example_data[:, d, :]).all().compute() for d, arr in enumerate(result)) + + orig = op.join(result) + assert (orig == example_data).all().compute() + + +def test_onslice(example_data): + """Tests dask OnSlice split operation class.""" + slices = ((0, 1), (1, 2), (2, 4)) + op = split.OnSlice(*slices, axis=2) + result = op.split(example_data) + for sl, arr in zip(slices, result, strict=True): + assert (arr == example_data[:, :, sl[0] : sl[1]]).all().compute() + + orig = op.join(result) + assert (orig == example_data).all().compute() diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_split.py b/packages/pipeline/tests/operations/numpy/test_numpy_split.py new file mode 100644 index 00000000..88ee96b5 --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_split.py @@ -0,0 +1,67 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.numpy import split + +import numpy as np +import pytest + + +@pytest.fixture(scope="module") +def example_data(): + return np.array(range(2 * 3 * 4)).reshape((2, 3, 4)) + + +def test_onaxis(example_data): + """Tests numpy OnAxis split operation class.""" + op = split.OnAxis(axis=1) + + # try join before splitting + with pytest.raises(RuntimeError): + op.join((example_data, example_data)) + result = op.split(example_data) + assert all(np.array_equal(arr, example_data[:, d, :]) for d, arr in enumerate(result)) + + orig = op.join(result) + assert np.array_equal(orig, example_data) + + +def test_onslice(example_data): + """Tests numpy OnSlice split operation class.""" + slices = ((0, 1), (1, 2), (2, 4)) + op = split.OnSlice(*slices, axis=2) + result = op.split(example_data) + for sl, arr in zip(slices, result, strict=True): + assert np.array_equal(arr, example_data[:, :, sl[0] : sl[1]]) + + orig = op.join(result) + assert np.array_equal(orig, example_data) + + +def test_vsplit(example_data): + """Tests numpy VSplit split operation class.""" + op = split.VSplit() + result = op.split(example_data) + assert all(np.array_equal(arr, result[i]) for i, arr in enumerate(result)) + orig = op.join(result) + assert np.array_equal(orig, example_data) + + +def test_hsplit(example_data): + """Tests numpy HSplit split operation class.""" + op = split.HSplit() + result = op.split(example_data) + assert all(np.array_equal(arr, example_data[:, i : i + 1, :]) for i, arr in enumerate(result)) + orig = op.join(result) + assert np.array_equal(orig, example_data) diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_split.py b/packages/pipeline/tests/operations/xarray/test_xarray_split.py new file mode 100644 index 00000000..d2cc063b --- /dev/null +++ b/packages/pipeline/tests/operations/xarray/test_xarray_split.py @@ -0,0 +1,74 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.xarray import split + +import numpy as np +import xarray as xr +import pytest + + +@pytest.fixture(scope="module") +def example_dataarray(): + return xr.DataArray( + np.array(range(2 * 3 * 4)).reshape((2, 3, 4)), coords={"c0": range(2), "c1": range(3), "c2": range(4)} + ) + + +@pytest.fixture(scope="module") +def example_dataset(example_dataarray): + return xr.Dataset({"a": example_dataarray, "b": 2 * example_dataarray}) + + +def test_onvariables(example_dataset): + """Tests xarray OnVariables operation class.""" + + # split on all variables + op = split.OnVariables() + result = op.split(example_dataset) + assert result[0].equals(example_dataset.drop_vars("b")) + assert result[1].equals(example_dataset.drop_vars("a")) + + # join datasets + orig = op.join(result) + assert orig.equals(example_dataset) + + # split on selected variables + op = split.OnVariables(variables=("a",)) + result = op.split(example_dataset) + assert result[0].equals(example_dataset.drop_vars("b")) + + # split on non-existent variable + op = split.OnVariables(variables=("c",)) + with pytest.raises(ValueError): + op.split(example_dataset) + + +def test_oncoordinate(example_dataarray, example_dataset): + """Tests xarray OnCoordinate operation class.""" + op = split.OnCoordinate("c1") + result = op.split(example_dataset) + for i, arr in enumerate(result): + assert arr["a"].equals(example_dataset["a"].loc[:, i, :]) + assert arr["b"].equals(example_dataset["b"].loc[:, i, :]) + + orig = op.join(result) + assert orig.broadcast_equals(example_dataset) + + result = op.split(example_dataarray) + for i, arr in enumerate(result): + assert arr.equals(example_dataarray.loc[:, i, :]) + + orig = op.join(result) + assert orig.broadcast_equals(example_dataarray)