diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/select.py b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/select.py index d9fdddfc..bec3eaa7 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/select.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/select.py @@ -65,27 +65,20 @@ def __init__( self.tuple_index = tuple_index def _index(self, data, array_index): - shape = data.shape - for i, index in enumerate(reversed(array_index)): - if index is None: - pass - selected_data = da.take(data, indices=index, axis=-(i + 1)) - if len(selected_data.shape) < len(shape): - selected_data = da.expand_dims(selected_data, axis=-(i + 1)) - data = selected_data - return data + # below comprehension: + # - ensures indexer is tuple (requirement) + # - converts instances of None into slice(None) + indexer = tuple(slice(None) if index is None else index for index in array_index) + return data[indexer] def apply_func(self, data): array_index = self.array_index if isinstance(data, tuple): - data = list(data) if self.tuple_index is None: - return tuple(map(lambda x: self._index(x, array_index), data)) + return tuple(self._index(x, array_index) for x in data) - data[self.tuple_index] = self._index(data[self.tuple_index], array_index) - data = tuple(data) - return data + return tuple(self._index(arr, array_index) if i == self.tuple_index else arr for i, arr in enumerate(data)) return self._index(data, array_index) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/select.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/select.py index d3d7f463..5d1acdec 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/select.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/select.py @@ -63,27 +63,20 @@ def __init__( self.tuple_index = tuple_index def _index(self, data, array_index): - shape = data.shape - for i, index in enumerate(reversed(array_index)): - if index is None: - pass - selected_data = np.take(data, indices=index, axis=-(i + 1)) - if len(selected_data.shape) < len(shape): - selected_data = np.expand_dims(selected_data, axis=-(i + 1)) - data = selected_data - return data + # below comprehension: + # - ensures indexer is tuple (requirement) + # - converts instances of None into slice(None) + indexer = tuple(slice(None) if index is None else index for index in array_index) + return data[indexer] def apply_func(self, data): array_index = self.array_index if isinstance(data, tuple): - data = list(data) if self.tuple_index is None: - return tuple(map(lambda x: self._index(x, array_index), data)) + return tuple(self._index(x, array_index) for x in data) - data[self.tuple_index] = self._index(data[self.tuple_index], array_index) - data = tuple(data) - return data + return tuple(self._index(arr, array_index) if i == self.tuple_index else arr for i, arr in enumerate(data)) return self._index(data, array_index) diff --git a/packages/pipeline/tests/operations/dask/test_dask_select.py b/packages/pipeline/tests/operations/dask/test_dask_select.py new file mode 100644 index 00000000..7a8c7a49 --- /dev/null +++ b/packages/pipeline/tests/operations/dask/test_dask_select.py @@ -0,0 +1,73 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import dask.array as da + + +from pyearthtools.pipeline.operations.dask import select + + +@pytest.fixture(scope="module") +def sample(): + """Test dask array.""" + return da.array(range(24)).reshape((2, 3, 4)) + + +def test_Select(sample): + """Tests the Select dask operation.""" + + s = select.Select([0]) + + output = s.apply_func(sample) + + assert output.shape == (3, 4) + assert (output == sample[0, :, :]).all().compute() + + # multi-dimensional indexing + s = select.Select([0, None, 3]) + + output = s.apply_func(sample) + + assert output.shape == (3,) + assert (output == sample[0, :, 3]).all().compute() + + # pass tuple of arrays + output = s.apply_func((sample, sample)) + for arr in output: + assert arr.shape == (3,) + assert (arr == sample[0, :, 3]).all().compute() + + # pass tuple of arrays with tuple index + s = select.Select(array_index=(0,), tuple_index=1) + output = s.apply_func((sample, sample)) + assert output[0].shape == sample.shape + assert (output[0] == sample).all().compute() + assert output[1].shape == (3, 4) + assert (output[1] == sample[0]).all().compute() + + +def test_Slice(sample): + """Tests the Slice dask operation.""" + + s = select.Slice((1,), (2,), (1, 4)) + output = s.apply_func(sample) + assert output.shape == (1, 2, 3) + assert (output == sample[:1, :2, 1:4]).all().compute() + + # test reverse_slice + s = select.Slice((1,), (2,), reverse_slice=True) + output = s.apply_func(sample) + assert output.shape == (2, 1, 2) + assert (output == sample[:, :1, :2]).all().compute() diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_select.py b/packages/pipeline/tests/operations/numpy/test_numpy_select.py new file mode 100644 index 00000000..d9d3f8ab --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_select.py @@ -0,0 +1,73 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import numpy as np + + +from pyearthtools.pipeline.operations.numpy import select + + +@pytest.fixture(scope="module") +def sample(): + """Test numpy array.""" + return np.array(range(24)).reshape((2, 3, 4)) + + +def test_Select(sample): + """Tests the Select numpy operation.""" + + s = select.Select([0]) + + output = s.apply_func(sample) + + assert output.shape == (3, 4) + assert np.array_equal(output, sample[0, :, :]) + + # multi-dimensional indexing + s = select.Select([0, None, 3]) + + output = s.apply_func(sample) + + assert output.shape == (3,) + assert np.array_equal(output, sample[0, :, 3]) + + # pass tuple of arrays + output = s.apply_func((sample, sample)) + for arr in output: + assert arr.shape == (3,) + assert np.array_equal(arr, sample[0, :, 3]) + + # pass tuple of arrays with tuple index + s = select.Select(array_index=(0,), tuple_index=1) + output = s.apply_func((sample, sample)) + assert output[0].shape == sample.shape + assert np.array_equal(output[0], sample) + assert output[1].shape == (3, 4) + assert np.array_equal(output[1], sample[0]) + + +def test_Slice(sample): + """Tests the Slice numpy operations.""" + + s = select.Slice((1,), (2,), (1, 4)) + output = s.apply_func(sample) + assert output.shape == (1, 2, 3) + assert np.array_equal(output, sample[:1, :2, 1:4]) + + # test reverse_slice + s = select.Slice((1,), (2,), reverse_slice=True) + output = s.apply_func(sample) + assert output.shape == (2, 1, 2) + assert np.array_equal(output, sample[:, :1, :2]) diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_select.py b/packages/pipeline/tests/operations/xarray/test_xarray_select.py new file mode 100644 index 00000000..28af89ea --- /dev/null +++ b/packages/pipeline/tests/operations/xarray/test_xarray_select.py @@ -0,0 +1,79 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import xarray as xr +import pytest +import numpy as np + + +from pyearthtools.pipeline.operations.xarray import select + + +@pytest.fixture(scope="module") +def sample(): + """Test xarray dataset.""" + coords = {"dim0": range(3), "dim1": range(3)} + return xr.Dataset( + { + "var1": xr.DataArray(np.array(range(9)).reshape((3, 3)), coords), + "var2": xr.DataArray(np.array(range(9, 18)).reshape((3, 3)), coords), + }, + ) + + +def test_SelectDataset(sample): + """Tests the SelectDataset xarray operation.""" + + s = select.SelectDataset(("var1",)) + + output = s.apply_func(sample) + + assert "var1" in output + assert "var2" not in output + assert output["var1"].equals(sample["var1"]) + + +def test_DropDataset(sample): + """Tests the DropDataset xarray operation.""" + + s = select.DropDataset(("var1",)) + + output = s.apply_func(sample) + assert "var1" not in output + assert "var2" in output + assert output["var2"].equals(sample["var2"]) + + +def test_SliceDataset(sample): + """Tests the SliceDataset xarray operation.""" + + args = {"dim0": (0, 2, 2), "dim1": (0, 1)} + + def test_slicer(slicer, sample): + + output = s.apply_func(sample) + + assert np.array_equal(output.coords["dim0"].values, [0, 2]) + assert np.array_equal(output.coords["dim1"].values, [0, 1]) + + # test passing dict to SliceDataset + s = select.SliceDataset(args) + test_slicer(s, sample) + + # test passing kwargs to SliceDataset + s = select.SliceDataset(**args) + test_slicer(s, sample) + + # test passing dataarray to slicer + test_slicer(s, sample["var1"])