Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,20 @@ def __init__(
self.tuple_index = tuple_index

def _index(self, data, array_index):
shape = data.shape
for i, index in enumerate(reversed(array_index)):
if index is None:
pass
selected_data = da.take(data, indices=index, axis=-(i + 1))
if len(selected_data.shape) < len(shape):
selected_data = da.expand_dims(selected_data, axis=-(i + 1))
data = selected_data
return data
# below comprehension:
# - ensures indexer is tuple (requirement)
# - converts instances of None into slice(None)
indexer = tuple(slice(None) if index is None else index for index in array_index)
return data[indexer]

def apply_func(self, data):
array_index = self.array_index

if isinstance(data, tuple):
data = list(data)
if self.tuple_index is None:
return tuple(map(lambda x: self._index(x, array_index), data))
return tuple(self._index(x, array_index) for x in data)

data[self.tuple_index] = self._index(data[self.tuple_index], array_index)
data = tuple(data)
return data
return tuple(self._index(arr, array_index) if i == self.tuple_index else arr for i, arr in enumerate(data))

return self._index(data, array_index)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,20 @@ def __init__(
self.tuple_index = tuple_index

def _index(self, data, array_index):
shape = data.shape
for i, index in enumerate(reversed(array_index)):
if index is None:
pass
selected_data = np.take(data, indices=index, axis=-(i + 1))
if len(selected_data.shape) < len(shape):
selected_data = np.expand_dims(selected_data, axis=-(i + 1))
data = selected_data
return data
# below comprehension:
# - ensures indexer is tuple (requirement)
# - converts instances of None into slice(None)
indexer = tuple(slice(None) if index is None else index for index in array_index)
return data[indexer]

def apply_func(self, data):
array_index = self.array_index

if isinstance(data, tuple):
data = list(data)
if self.tuple_index is None:
return tuple(map(lambda x: self._index(x, array_index), data))
return tuple(self._index(x, array_index) for x in data)

data[self.tuple_index] = self._index(data[self.tuple_index], array_index)
data = tuple(data)
return data
return tuple(self._index(arr, array_index) if i == self.tuple_index else arr for i, arr in enumerate(data))

return self._index(data, array_index)

Expand Down
73 changes: 73 additions & 0 deletions packages/pipeline/tests/operations/dask/test_dask_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright Commonwealth of Australia, Bureau of Meteorology 2024.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import dask.array as da


from pyearthtools.pipeline.operations.dask import select


@pytest.fixture(scope="module")
def sample():
"""Test dask array."""
return da.array(range(24)).reshape((2, 3, 4))


def test_Select(sample):
"""Tests the Select dask operation."""

s = select.Select([0])

output = s.apply_func(sample)

assert output.shape == (3, 4)
assert (output == sample[0, :, :]).all().compute()

# multi-dimensional indexing
s = select.Select([0, None, 3])

output = s.apply_func(sample)

assert output.shape == (3,)
assert (output == sample[0, :, 3]).all().compute()

# pass tuple of arrays
output = s.apply_func((sample, sample))
for arr in output:
assert arr.shape == (3,)
assert (arr == sample[0, :, 3]).all().compute()

# pass tuple of arrays with tuple index
s = select.Select(array_index=(0,), tuple_index=1)
output = s.apply_func((sample, sample))
assert output[0].shape == sample.shape
assert (output[0] == sample).all().compute()
assert output[1].shape == (3, 4)
assert (output[1] == sample[0]).all().compute()


def test_Slice(sample):
"""Tests the Slice dask operation."""

s = select.Slice((1,), (2,), (1, 4))
output = s.apply_func(sample)
assert output.shape == (1, 2, 3)
assert (output == sample[:1, :2, 1:4]).all().compute()

# test reverse_slice
s = select.Slice((1,), (2,), reverse_slice=True)
output = s.apply_func(sample)
assert output.shape == (2, 1, 2)
assert (output == sample[:, :1, :2]).all().compute()
73 changes: 73 additions & 0 deletions packages/pipeline/tests/operations/numpy/test_numpy_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright Commonwealth of Australia, Bureau of Meteorology 2024.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import numpy as np


from pyearthtools.pipeline.operations.numpy import select


@pytest.fixture(scope="module")
def sample():
"""Test numpy array."""
return np.array(range(24)).reshape((2, 3, 4))


def test_Select(sample):
"""Tests the Select numpy operation."""

s = select.Select([0])

output = s.apply_func(sample)

assert output.shape == (3, 4)
assert np.array_equal(output, sample[0, :, :])

# multi-dimensional indexing
s = select.Select([0, None, 3])

output = s.apply_func(sample)

assert output.shape == (3,)
assert np.array_equal(output, sample[0, :, 3])

# pass tuple of arrays
output = s.apply_func((sample, sample))
for arr in output:
assert arr.shape == (3,)
assert np.array_equal(arr, sample[0, :, 3])

# pass tuple of arrays with tuple index
s = select.Select(array_index=(0,), tuple_index=1)
output = s.apply_func((sample, sample))
assert output[0].shape == sample.shape
assert np.array_equal(output[0], sample)
assert output[1].shape == (3, 4)
assert np.array_equal(output[1], sample[0])


def test_Slice(sample):
"""Tests the Slice numpy operations."""

s = select.Slice((1,), (2,), (1, 4))
output = s.apply_func(sample)
assert output.shape == (1, 2, 3)
assert np.array_equal(output, sample[:1, :2, 1:4])

# test reverse_slice
s = select.Slice((1,), (2,), reverse_slice=True)
output = s.apply_func(sample)
assert output.shape == (2, 1, 2)
assert np.array_equal(output, sample[:, :1, :2])
79 changes: 79 additions & 0 deletions packages/pipeline/tests/operations/xarray/test_xarray_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright Commonwealth of Australia, Bureau of Meteorology 2024.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import xarray as xr
import pytest
import numpy as np


from pyearthtools.pipeline.operations.xarray import select


@pytest.fixture(scope="module")
def sample():
"""Test xarray dataset."""
coords = {"dim0": range(3), "dim1": range(3)}
return xr.Dataset(
{
"var1": xr.DataArray(np.array(range(9)).reshape((3, 3)), coords),
"var2": xr.DataArray(np.array(range(9, 18)).reshape((3, 3)), coords),
},
)


def test_SelectDataset(sample):
"""Tests the SelectDataset xarray operation."""

s = select.SelectDataset(("var1",))

output = s.apply_func(sample)

assert "var1" in output
assert "var2" not in output
assert output["var1"].equals(sample["var1"])


def test_DropDataset(sample):
"""Tests the DropDataset xarray operation."""

s = select.DropDataset(("var1",))

output = s.apply_func(sample)
assert "var1" not in output
assert "var2" in output
assert output["var2"].equals(sample["var2"])


def test_SliceDataset(sample):
"""Tests the SliceDataset xarray operation."""

args = {"dim0": (0, 2, 2), "dim1": (0, 1)}

def test_slicer(slicer, sample):

output = s.apply_func(sample)

assert np.array_equal(output.coords["dim0"].values, [0, 2])
assert np.array_equal(output.coords["dim1"].values, [0, 1])

# test passing dict to SliceDataset
s = select.SliceDataset(args)
test_slicer(s, sample)

# test passing kwargs to SliceDataset
s = select.SliceDataset(**args)
test_slicer(s, sample)

# test passing dataarray to slicer
test_slicer(s, sample["var1"])