From ca286ada88c81d891e43fc14807463fa84424fab Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 5 Dec 2025 13:13:58 +1100 Subject: [PATCH 1/5] add xarray select tests --- .../operations/xarray/test_xarray_select.py | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 packages/pipeline/tests/operations/xarray/test_xarray_select.py diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_select.py b/packages/pipeline/tests/operations/xarray/test_xarray_select.py new file mode 100644 index 00000000..28af89ea --- /dev/null +++ b/packages/pipeline/tests/operations/xarray/test_xarray_select.py @@ -0,0 +1,79 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import xarray as xr +import pytest +import numpy as np + + +from pyearthtools.pipeline.operations.xarray import select + + +@pytest.fixture(scope="module") +def sample(): + """Test xarray dataset.""" + coords = {"dim0": range(3), "dim1": range(3)} + return xr.Dataset( + { + "var1": xr.DataArray(np.array(range(9)).reshape((3, 3)), coords), + "var2": xr.DataArray(np.array(range(9, 18)).reshape((3, 3)), coords), + }, + ) + + +def test_SelectDataset(sample): + """Tests the SelectDataset xarray operation.""" + + s = select.SelectDataset(("var1",)) + + output = s.apply_func(sample) + + assert "var1" in output + assert "var2" not in output + assert output["var1"].equals(sample["var1"]) + + +def test_DropDataset(sample): + """Tests the DropDataset xarray operation.""" + + s = select.DropDataset(("var1",)) + + output = s.apply_func(sample) + assert "var1" not in output + assert "var2" in output + assert output["var2"].equals(sample["var2"]) + + +def test_SliceDataset(sample): + """Tests the SliceDataset xarray operation.""" + + args = {"dim0": (0, 2, 2), "dim1": (0, 1)} + + def test_slicer(slicer, sample): + + output = s.apply_func(sample) + + assert np.array_equal(output.coords["dim0"].values, [0, 2]) + assert np.array_equal(output.coords["dim1"].values, [0, 1]) + + # test passing dict to SliceDataset + s = select.SliceDataset(args) + test_slicer(s, sample) + + # test passing kwargs to SliceDataset + s = select.SliceDataset(**args) + test_slicer(s, sample) + + # test passing dataarray to slicer + test_slicer(s, sample["var1"]) From 40b68c8477b93365c64dc9d7604ddd14cbe46bd4 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 8 Jan 2026 12:24:36 +1100 Subject: [PATCH 2/5] align numpy select functionality with its doc string this fixes the implementation of numpy.Select so that the results match what is stated in the doc string. This changes the implementation to use indexing with native slices. --- .../pipeline/operations/numpy/select.py | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/select.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/select.py index d3d7f463..5d1acdec 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/select.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/select.py @@ -63,27 +63,20 @@ def __init__( self.tuple_index = tuple_index def _index(self, data, array_index): - shape = data.shape - for i, index in enumerate(reversed(array_index)): - if index is None: - pass - selected_data = np.take(data, indices=index, axis=-(i + 1)) - if len(selected_data.shape) < len(shape): - selected_data = np.expand_dims(selected_data, axis=-(i + 1)) - data = selected_data - return data + # below comprehension: + # - ensures indexer is tuple (requirement) + # - converts instances of None into slice(None) + indexer = tuple(slice(None) if index is None else index for index in array_index) + return data[indexer] def apply_func(self, data): array_index = self.array_index if isinstance(data, tuple): - data = list(data) if self.tuple_index is None: - return tuple(map(lambda x: self._index(x, array_index), data)) + return tuple(self._index(x, array_index) for x in data) - data[self.tuple_index] = self._index(data[self.tuple_index], array_index) - data = tuple(data) - return data + return tuple(self._index(arr, array_index) if i == self.tuple_index else arr for i, arr in enumerate(data)) return self._index(data, array_index) From 6a74ae8d9ed73f65f9e6d35b505854682b28b066 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 8 Jan 2026 12:25:07 +1100 Subject: [PATCH 3/5] add numpy select tests --- .../operations/numpy/test_numpy_select.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 packages/pipeline/tests/operations/numpy/test_numpy_select.py diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_select.py b/packages/pipeline/tests/operations/numpy/test_numpy_select.py new file mode 100644 index 00000000..d9d3f8ab --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_select.py @@ -0,0 +1,73 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import numpy as np + + +from pyearthtools.pipeline.operations.numpy import select + + +@pytest.fixture(scope="module") +def sample(): + """Test numpy array.""" + return np.array(range(24)).reshape((2, 3, 4)) + + +def test_Select(sample): + """Tests the Select numpy operation.""" + + s = select.Select([0]) + + output = s.apply_func(sample) + + assert output.shape == (3, 4) + assert np.array_equal(output, sample[0, :, :]) + + # multi-dimensional indexing + s = select.Select([0, None, 3]) + + output = s.apply_func(sample) + + assert output.shape == (3,) + assert np.array_equal(output, sample[0, :, 3]) + + # pass tuple of arrays + output = s.apply_func((sample, sample)) + for arr in output: + assert arr.shape == (3,) + assert np.array_equal(arr, sample[0, :, 3]) + + # pass tuple of arrays with tuple index + s = select.Select(array_index=(0,), tuple_index=1) + output = s.apply_func((sample, sample)) + assert output[0].shape == sample.shape + assert np.array_equal(output[0], sample) + assert output[1].shape == (3, 4) + assert np.array_equal(output[1], sample[0]) + + +def test_Slice(sample): + """Tests the Slice numpy operations.""" + + s = select.Slice((1,), (2,), (1, 4)) + output = s.apply_func(sample) + assert output.shape == (1, 2, 3) + assert np.array_equal(output, sample[:1, :2, 1:4]) + + # test reverse_slice + s = select.Slice((1,), (2,), reverse_slice=True) + output = s.apply_func(sample) + assert output.shape == (2, 1, 2) + assert np.array_equal(output, sample[:, :1, :2]) From f658551c0a74812d60ffe3801ab8bcb4ede1d0b0 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 8 Jan 2026 12:04:43 +1100 Subject: [PATCH 4/5] align dask select functionality with its doc string this fixes the implementation of dask.Select so that the results match what is stated in the doc string. This changes the implementation to use indexing with native slices. --- .../pipeline/operations/dask/select.py | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/select.py b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/select.py index d9fdddfc..bec3eaa7 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/select.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/select.py @@ -65,27 +65,20 @@ def __init__( self.tuple_index = tuple_index def _index(self, data, array_index): - shape = data.shape - for i, index in enumerate(reversed(array_index)): - if index is None: - pass - selected_data = da.take(data, indices=index, axis=-(i + 1)) - if len(selected_data.shape) < len(shape): - selected_data = da.expand_dims(selected_data, axis=-(i + 1)) - data = selected_data - return data + # below comprehension: + # - ensures indexer is tuple (requirement) + # - converts instances of None into slice(None) + indexer = tuple(slice(None) if index is None else index for index in array_index) + return data[indexer] def apply_func(self, data): array_index = self.array_index if isinstance(data, tuple): - data = list(data) if self.tuple_index is None: - return tuple(map(lambda x: self._index(x, array_index), data)) + return tuple(self._index(x, array_index) for x in data) - data[self.tuple_index] = self._index(data[self.tuple_index], array_index) - data = tuple(data) - return data + return tuple(self._index(arr, array_index) if i == self.tuple_index else arr for i, arr in enumerate(data)) return self._index(data, array_index) From fce0ca2753f56adf5705a783f5e7a357bc4c7981 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 8 Jan 2026 12:17:53 +1100 Subject: [PATCH 5/5] add dask select tests --- .../tests/operations/dask/test_dask_select.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 packages/pipeline/tests/operations/dask/test_dask_select.py diff --git a/packages/pipeline/tests/operations/dask/test_dask_select.py b/packages/pipeline/tests/operations/dask/test_dask_select.py new file mode 100644 index 00000000..7a8c7a49 --- /dev/null +++ b/packages/pipeline/tests/operations/dask/test_dask_select.py @@ -0,0 +1,73 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import dask.array as da + + +from pyearthtools.pipeline.operations.dask import select + + +@pytest.fixture(scope="module") +def sample(): + """Test dask array.""" + return da.array(range(24)).reshape((2, 3, 4)) + + +def test_Select(sample): + """Tests the Select dask operation.""" + + s = select.Select([0]) + + output = s.apply_func(sample) + + assert output.shape == (3, 4) + assert (output == sample[0, :, :]).all().compute() + + # multi-dimensional indexing + s = select.Select([0, None, 3]) + + output = s.apply_func(sample) + + assert output.shape == (3,) + assert (output == sample[0, :, 3]).all().compute() + + # pass tuple of arrays + output = s.apply_func((sample, sample)) + for arr in output: + assert arr.shape == (3,) + assert (arr == sample[0, :, 3]).all().compute() + + # pass tuple of arrays with tuple index + s = select.Select(array_index=(0,), tuple_index=1) + output = s.apply_func((sample, sample)) + assert output[0].shape == sample.shape + assert (output[0] == sample).all().compute() + assert output[1].shape == (3, 4) + assert (output[1] == sample[0]).all().compute() + + +def test_Slice(sample): + """Tests the Slice dask operation.""" + + s = select.Slice((1,), (2,), (1, 4)) + output = s.apply_func(sample) + assert output.shape == (1, 2, 3) + assert (output == sample[:1, :2, 1:4]).all().compute() + + # test reverse_slice + s = select.Slice((1,), (2,), reverse_slice=True) + output = s.apply_func(sample) + assert output.shape == (2, 1, 2) + assert (output == sample[:, :1, :2]).all().compute()