diff --git a/packages/data/tests/transform/test_derive.py b/packages/data/tests/transform/test_derive.py index 2e973cb5..1b9fae62 100644 --- a/packages/data/tests/transform/test_derive.py +++ b/packages/data/tests/transform/test_derive.py @@ -16,7 +16,8 @@ import pytest import math -from pyearthtools.data.transforms.derive import evaluate +from numpy import nan, isnan +from pyearthtools.data.transforms.derive import evaluate, EquationException @pytest.mark.parametrize( @@ -46,6 +47,20 @@ def test_evaluate_only_eq(eq, result): assert evaluate(eq) == float(result) +@pytest.mark.parametrize( + "eq", + [ + ("1 + (2"), + ("1 + 2)"), + ("1 + ((2 + 3)"), + ("1 + (2 + 3))"), + ], +) +def test_evaluate_mismatched_brackets(eq): + with pytest.raises(EquationException): + evaluate(eq) + + @pytest.mark.parametrize( "eq, result", [ @@ -55,3 +70,20 @@ def test_evaluate_only_eq(eq, result): ) def test_constants(eq, result): assert evaluate(eq) == float(result) + + +@pytest.mark.parametrize( + "eq, result", + [ + ("2 not_nan 3", 2.0), + ("nan not_nan 3", 3.0), + ("nan not_nan 3 not_nan 4", 3.0), + ("nan not_nan nan not_nan 4", 4.0), + ], +) +def test_evaluate_only_not_nan(eq, result): + assert evaluate(eq) == result + + +def test_evaluate_only_not_nan_all_nan(): + assert isnan(evaluate("nan not_nan nan")) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/filters.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/filters.py index 51553011..bd04042a 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/filters.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/filters.py @@ -48,16 +48,16 @@ def __init__(self) -> None: self.record_initialisation() def filter(self, sample: np.ndarray): - """Check if any of the sample is nan + """Reject the sample if any value is nan Args: sample (np.ndarray): Sample to check - Returns: - (bool): - If sample contains nan's + Raises: + (PipelineFilterException): + If sample contains one or more nan value """ - if not bool(np.array(list(np.isnan(sample))).any()): + if bool(np.array(list(np.isnan(sample))).any()): raise PipelineFilterException(sample, "Data contained nan's.") @@ -76,16 +76,16 @@ def __init__(self) -> None: self.record_initialisation() def filter(self, sample: np.ndarray): - """Check if all of the sample is nan + """Reject the sample if all of its values are nan Args: sample (np.ndarray): Sample to check - Returns: - (bool): - If sample contains nan's + Raises: + (PipelineFilterException): + If sample contains only nan values """ - if not bool(np.array(list(np.isnan(sample))).all()): + if bool(np.array(list(np.isnan(sample))).all()): raise PipelineFilterException(sample, "Data contained all nan's.") diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_augment.py b/packages/pipeline/tests/operations/numpy/test_numpy_augment.py new file mode 100644 index 00000000..630ec40c --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_augment.py @@ -0,0 +1,96 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.numpy import augment + +import numpy as np +import pytest + + +@pytest.mark.parametrize( + # The result depends on the random seed. This one has been manually + # checked to produce a certain number of rotations the first time. + "seed, rotations", + [ + (42, 0), + (1, 1), + (4, 2), + (2, 3), + ], +) +def test_Rotate(seed, rotations): + + original = np.array([[1, 2], [4, 3]]) + + match rotations: + case 0: + expected = np.array([[1, 2], [4, 3]]) + case 1: + expected = np.array([[4, 1], [3, 2]]) + case 2: + expected = np.array([[3, 4], [2, 1]]) + case 3: + expected = np.array([[2, 3], [1, 4]]) + + rotate = augment.Rotate(seed=seed, axis=(1, 0)) + + result = rotate.apply_func(original) + assert (result == expected).all() + + +def test_Rotate_axis_must_be_tuple(): + with pytest.raises(TypeError): + augment.Rotate(axis=0) + + +@pytest.mark.parametrize( + "seed, should_flip", + [ + (0, True), + (1, False), + ], +) +def test_Flip(seed, should_flip): + + original = np.array([[1, 2], [4, 3]]) + + flipped = np.array([[3, 4], [2, 1]]) + + # The result depends on the random seed. This one has been manually checked + # to produce a single rotation the first time. + expected = flipped if should_flip else original + flip = augment.Flip(seed=seed, axis=(1, 0)) + + result = flip.apply_func(original) + assert (result == expected).all() + + +@pytest.mark.parametrize( + "seed, should_flip", + [ + (0, True), + (1, False), + ], +) +def test_FlipAndRotate(seed, should_flip): + + original = np.array([[1, 2], [4, 3]]) + + flip_and_rotate = augment.FlipAndRotate() + + result = flip_and_rotate.apply_func(original) + # Don't worry about the number of flips and rotations, just check the + # shape and type returned + assert isinstance(result, np.ndarray) + assert result.shape == (2, 2) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_conversion.py b/packages/pipeline/tests/operations/numpy/test_numpy_conversion.py index fd63a552..20243f23 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_conversion.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_conversion.py @@ -16,12 +16,13 @@ import numpy as np import xarray as xr +import dask.array as da def test_ToXarray_with_DataArray(): coords = {"x": list(range(5)), "y": list(range(5))} - data = np.ones((5, 5)) + data = np.random.randn(5, 5) sample = xr.DataArray(coords=coords, data=data) tox = conversion.ToXarray.like(sample) @@ -36,37 +37,45 @@ def test_ToXarray_with_DataArray(): def test_ToXarray_with_Dataset(): coords = {"x": list(range(5)), "y": list(range(5))} - data = np.ones((5, 5)) - data1 = np.ones((1, 5, 5)) - sample_da = xr.DataArray(coords=coords, data=data) + data_3d = np.random.randn(1, 5, 5) + data_2d = data_3d[0] + sample_da = xr.DataArray(coords=coords, data=data_2d) sample_ds = xr.Dataset(coords=coords, data_vars={"z": sample_da}) tox = conversion.ToXarray.like(sample_ds) - result = tox.apply_func(data1) + result = tox.apply_func(data_3d) assert (result == sample_ds).all() as_numpy = tox.undo_func(sample_ds) - assert (as_numpy == data1).all() + assert (as_numpy == data_3d).all() def test_drop_coords(): coords = {"x": list(range(5)), "y": list(range(5))} - data = np.ones((5, 5)) - _data1 = np.ones((1, 5, 5)) - sample_da = xr.DataArray(coords=coords, data=data) + + data_3d = np.random.randn(1, 5, 5) + data_2d = data_3d[0] + sample_da = xr.DataArray(coords=coords, data=data_2d) sample_ds = xr.Dataset(coords=coords, data_vars={"z": sample_da}) tox = conversion.ToXarray.like(sample_ds, drop_coords=["x"]) - assert tox is not None + result = tox.apply_func(data_3d) + assert (result == sample_ds).all() + + as_numpy = tox.undo_func(sample_ds) + assert (as_numpy == data_3d).all() def test_ToDask(): - data = np.ones((5, 5)) + data = np.random.randn(5, 5) + expected = da.from_array(data) tod = conversion.ToDask() - da = tod.apply_func(data) - orig = tod.undo_func(da) + result = tod.apply_func(data) + da.assert_eq(result, expected) + + orig = tod.undo_func(result) assert (orig == data).all() diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_filter.py b/packages/pipeline/tests/operations/numpy/test_numpy_filter.py new file mode 100644 index 00000000..3265a77c --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_filter.py @@ -0,0 +1,57 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.numpy import filters +from pyearthtools.pipeline.exceptions import PipelineFilterException + +import numpy as np +import pytest + + +def test_DropAnyNan_false(): + + original = np.array([[1, 2], [4, 3]]) + + drop = filters.DropAnyNan() + # No return value, just check no exception is raised + drop.filter(original) + + +def test_DropAnyNan_true(): + + original = np.array([[1, 2], [4, np.nan]]) + + drop = filters.DropAnyNan() + + with pytest.raises(PipelineFilterException): + result = drop.filter(original) + + +def test_DropAllNan_false(): + + original = np.array([[1, 2], [np.nan, 3]]) + + drop = filters.DropAllNan() + # No return value, just check no exception is raised + drop.filter(original) + + +def test_DropAllNan_true(): + + original = np.array([[np.nan, np.nan], [np.nan, np.nan]]) + + drop = filters.DropAllNan() + + with pytest.raises(PipelineFilterException): + result = drop.filter(original)