From d66d2d9a7426d82083065e2c6aeea5efda236286 Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Wed, 12 Nov 2025 17:25:56 +1100 Subject: [PATCH 1/7] Add a few tests for data.transforms.derive --- packages/data/tests/transform/test_derive.py | 34 +++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/packages/data/tests/transform/test_derive.py b/packages/data/tests/transform/test_derive.py index 2e973cb5..1b9fae62 100644 --- a/packages/data/tests/transform/test_derive.py +++ b/packages/data/tests/transform/test_derive.py @@ -16,7 +16,8 @@ import pytest import math -from pyearthtools.data.transforms.derive import evaluate +from numpy import nan, isnan +from pyearthtools.data.transforms.derive import evaluate, EquationException @pytest.mark.parametrize( @@ -46,6 +47,20 @@ def test_evaluate_only_eq(eq, result): assert evaluate(eq) == float(result) +@pytest.mark.parametrize( + "eq", + [ + ("1 + (2"), + ("1 + 2)"), + ("1 + ((2 + 3)"), + ("1 + (2 + 3))"), + ], +) +def test_evaluate_mismatched_brackets(eq): + with pytest.raises(EquationException): + evaluate(eq) + + @pytest.mark.parametrize( "eq, result", [ @@ -55,3 +70,20 @@ def test_evaluate_only_eq(eq, result): ) def test_constants(eq, result): assert evaluate(eq) == float(result) + + +@pytest.mark.parametrize( + "eq, result", + [ + ("2 not_nan 3", 2.0), + ("nan not_nan 3", 3.0), + ("nan not_nan 3 not_nan 4", 3.0), + ("nan not_nan nan not_nan 4", 4.0), + ], +) +def test_evaluate_only_not_nan(eq, result): + assert evaluate(eq) == result + + +def test_evaluate_only_not_nan_all_nan(): + assert isnan(evaluate("nan not_nan nan")) From fb1889c5852605dbed6789007aa2eeaf76234efd Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Thu, 13 Nov 2025 12:33:41 +1100 Subject: [PATCH 2/7] Expand tests for pipeline.operations.numpy.conversion --- .../operations/numpy/test_numpy_conversion.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_conversion.py b/packages/pipeline/tests/operations/numpy/test_numpy_conversion.py index fd63a552..72f1b4c2 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_conversion.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_conversion.py @@ -16,6 +16,7 @@ import numpy as np import xarray as xr +import dask.array as da def test_ToXarray_with_DataArray(): @@ -54,19 +55,29 @@ def test_drop_coords(): coords = {"x": list(range(5)), "y": list(range(5))} data = np.ones((5, 5)) - _data1 = np.ones((1, 5, 5)) + data1 = np.ones((1, 5, 5)) sample_da = xr.DataArray(coords=coords, data=data) sample_ds = xr.Dataset(coords=coords, data_vars={"z": sample_da}) + expected_data = np.ones((5)) + expected_da = xr.DataArray(coords={"y": list(range(5))}, data=expected_data) + expected_ds = xr.Dataset(coords={"y": list(range(5))}, data_vars={"z": sample_da}) + tox = conversion.ToXarray.like(sample_ds, drop_coords=["x"]) - assert tox is not None + result = tox.apply_func(data1) + assert (result == expected_ds).all() + as_numpy = tox.undo_func(sample_ds) + assert (as_numpy == data1).all() def test_ToDask(): data = np.ones((5, 5)) + expected = da.from_array(data) tod = conversion.ToDask() - da = tod.apply_func(data) - orig = tod.undo_func(da) + result = tod.apply_func(data) + da.assert_eq(result, expected) + + orig = tod.undo_func(result) assert (orig == data).all() From 2e418e4cf7802274d755093503aea707e8dd7bf0 Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Thu, 13 Nov 2025 16:56:11 +1100 Subject: [PATCH 3/7] Use random data instead of 1s --- .../operations/numpy/test_numpy_conversion.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_conversion.py b/packages/pipeline/tests/operations/numpy/test_numpy_conversion.py index 72f1b4c2..20243f23 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_conversion.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_conversion.py @@ -22,7 +22,7 @@ def test_ToXarray_with_DataArray(): coords = {"x": list(range(5)), "y": list(range(5))} - data = np.ones((5, 5)) + data = np.random.randn(5, 5) sample = xr.DataArray(coords=coords, data=data) tox = conversion.ToXarray.like(sample) @@ -37,42 +37,40 @@ def test_ToXarray_with_DataArray(): def test_ToXarray_with_Dataset(): coords = {"x": list(range(5)), "y": list(range(5))} - data = np.ones((5, 5)) - data1 = np.ones((1, 5, 5)) - sample_da = xr.DataArray(coords=coords, data=data) + data_3d = np.random.randn(1, 5, 5) + data_2d = data_3d[0] + sample_da = xr.DataArray(coords=coords, data=data_2d) sample_ds = xr.Dataset(coords=coords, data_vars={"z": sample_da}) tox = conversion.ToXarray.like(sample_ds) - result = tox.apply_func(data1) + result = tox.apply_func(data_3d) assert (result == sample_ds).all() as_numpy = tox.undo_func(sample_ds) - assert (as_numpy == data1).all() + assert (as_numpy == data_3d).all() def test_drop_coords(): coords = {"x": list(range(5)), "y": list(range(5))} - data = np.ones((5, 5)) - data1 = np.ones((1, 5, 5)) - sample_da = xr.DataArray(coords=coords, data=data) - sample_ds = xr.Dataset(coords=coords, data_vars={"z": sample_da}) - expected_data = np.ones((5)) - expected_da = xr.DataArray(coords={"y": list(range(5))}, data=expected_data) - expected_ds = xr.Dataset(coords={"y": list(range(5))}, data_vars={"z": sample_da}) + data_3d = np.random.randn(1, 5, 5) + data_2d = data_3d[0] + sample_da = xr.DataArray(coords=coords, data=data_2d) + sample_ds = xr.Dataset(coords=coords, data_vars={"z": sample_da}) tox = conversion.ToXarray.like(sample_ds, drop_coords=["x"]) - result = tox.apply_func(data1) - assert (result == expected_ds).all() + result = tox.apply_func(data_3d) + assert (result == sample_ds).all() as_numpy = tox.undo_func(sample_ds) - assert (as_numpy == data1).all() + assert (as_numpy == data_3d).all() + def test_ToDask(): - data = np.ones((5, 5)) + data = np.random.randn(5, 5) expected = da.from_array(data) tod = conversion.ToDask() From b14a9529f6f161de13d0cf20226403f01d13ba34 Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Thu, 13 Nov 2025 16:56:23 +1100 Subject: [PATCH 4/7] Add augment tests --- .../operations/numpy/test_numpy_augment.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 packages/pipeline/tests/operations/numpy/test_numpy_augment.py diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_augment.py b/packages/pipeline/tests/operations/numpy/test_numpy_augment.py new file mode 100644 index 00000000..a1d0eab5 --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_augment.py @@ -0,0 +1,93 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.numpy import augment + +import numpy as np +import pytest + + +@pytest.mark.parametrize( + "seed, rotations", + [ + (42, 0), + (1, 1), + (4, 2), + (2, 3), + ] +) +def test_Rotate(seed, rotations): + + original = np.array([ + [1, 2], + [4, 3] + ]) + + # The result depends on the random seed. This one has been manually checked + # to produce a certain number of rotations the first time. + match rotations: + case 0: + expected = np.array([ + [1, 2], + [4, 3] + ]) + case 1: + expected = np.array([ + [4, 1], + [3, 2] + ]) + case 2: + expected = np.array([ + [3, 4], + [2, 1] + ]) + case 3: + expected = np.array([ + [2, 3], + [1, 4] + ]) + + + rotate = augment.Rotate(seed=seed, axis=(1, 0)) + + result = rotate.apply_func(original) + assert (result == expected).all() + + +@pytest.mark.parametrize( + "seed, should_flip", + [ + (0, True), + (1, False), + ] +) +def test_Flip(seed, should_flip): + + original = np.array([ + [1, 2], + [3, 4] + ]) + + flipped = np.array([ + [4, 3], + [2, 1] + ]) + + # The result depends on the random seed. This one has been manually checked + # to produce a single rotation the first time. + expected = flipped if should_flip else original + flip = augment.Flip(seed=seed, axis=(1, 0)) + + result = flip.apply_func(original) + assert (result == expected).all() From f847214ff62149f14e27c7571bab742a11ffe7e3 Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Fri, 14 Nov 2025 10:23:10 +1100 Subject: [PATCH 5/7] 100% coverage for operations.numpy.augment --- .../operations/numpy/test_numpy_augment.py | 36 ++++++++++++++++--- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_augment.py b/packages/pipeline/tests/operations/numpy/test_numpy_augment.py index a1d0eab5..4da12533 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_augment.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_augment.py @@ -19,6 +19,8 @@ @pytest.mark.parametrize( + # The result depends on the random seed. This one has been manually + # checked to produce a certain number of rotations the first time. "seed, rotations", [ (42, 0), @@ -34,8 +36,6 @@ def test_Rotate(seed, rotations): [4, 3] ]) - # The result depends on the random seed. This one has been manually checked - # to produce a certain number of rotations the first time. match rotations: case 0: expected = np.array([ @@ -58,13 +58,16 @@ def test_Rotate(seed, rotations): [1, 4] ]) - rotate = augment.Rotate(seed=seed, axis=(1, 0)) result = rotate.apply_func(original) assert (result == expected).all() +def test_Rotate_axis_must_be_tuple(): + with pytest.raises(TypeError): + augment.Rotate(axis=0) + @pytest.mark.parametrize( "seed, should_flip", [ @@ -76,11 +79,11 @@ def test_Flip(seed, should_flip): original = np.array([ [1, 2], - [3, 4] + [4, 3] ]) flipped = np.array([ - [4, 3], + [3, 4], [2, 1] ]) @@ -91,3 +94,26 @@ def test_Flip(seed, should_flip): result = flip.apply_func(original) assert (result == expected).all() + + +@pytest.mark.parametrize( + "seed, should_flip", + [ + (0, True), + (1, False), + ] +) +def test_FlipAndRotate(seed, should_flip): + + original = np.array([ + [1, 2], + [4, 3] + ]) + + flip_and_rotate = augment.FlipAndRotate() + + result = flip_and_rotate.apply_func(original) + # Don't worry about the number of flips and rotations, just check the + # shape and type returned + assert isinstance(result, np.ndarray) + assert result.shape == (2, 2) From 1445c7b6d34974a9544c543f1c7c91e95810ae23 Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Fri, 14 Nov 2025 12:00:29 +1100 Subject: [PATCH 6/7] Some coverage for operations.numpy.filters --- .../operations/numpy/test_numpy_augment.py | 49 +++++----------- .../operations/numpy/test_numpy_filter.py | 57 +++++++++++++++++++ 2 files changed, 70 insertions(+), 36 deletions(-) create mode 100644 packages/pipeline/tests/operations/numpy/test_numpy_filter.py diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_augment.py b/packages/pipeline/tests/operations/numpy/test_numpy_augment.py index 4da12533..630ec40c 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_augment.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_augment.py @@ -27,36 +27,21 @@ (1, 1), (4, 2), (2, 3), - ] + ], ) def test_Rotate(seed, rotations): - original = np.array([ - [1, 2], - [4, 3] - ]) + original = np.array([[1, 2], [4, 3]]) match rotations: case 0: - expected = np.array([ - [1, 2], - [4, 3] - ]) + expected = np.array([[1, 2], [4, 3]]) case 1: - expected = np.array([ - [4, 1], - [3, 2] - ]) + expected = np.array([[4, 1], [3, 2]]) case 2: - expected = np.array([ - [3, 4], - [2, 1] - ]) + expected = np.array([[3, 4], [2, 1]]) case 3: - expected = np.array([ - [2, 3], - [1, 4] - ]) + expected = np.array([[2, 3], [1, 4]]) rotate = augment.Rotate(seed=seed, axis=(1, 0)) @@ -68,24 +53,19 @@ def test_Rotate_axis_must_be_tuple(): with pytest.raises(TypeError): augment.Rotate(axis=0) + @pytest.mark.parametrize( "seed, should_flip", [ (0, True), (1, False), - ] + ], ) def test_Flip(seed, should_flip): - original = np.array([ - [1, 2], - [4, 3] - ]) - - flipped = np.array([ - [3, 4], - [2, 1] - ]) + original = np.array([[1, 2], [4, 3]]) + + flipped = np.array([[3, 4], [2, 1]]) # The result depends on the random seed. This one has been manually checked # to produce a single rotation the first time. @@ -101,14 +81,11 @@ def test_Flip(seed, should_flip): [ (0, True), (1, False), - ] + ], ) def test_FlipAndRotate(seed, should_flip): - original = np.array([ - [1, 2], - [4, 3] - ]) + original = np.array([[1, 2], [4, 3]]) flip_and_rotate = augment.FlipAndRotate() diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_filter.py b/packages/pipeline/tests/operations/numpy/test_numpy_filter.py new file mode 100644 index 00000000..3265a77c --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_filter.py @@ -0,0 +1,57 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.numpy import filters +from pyearthtools.pipeline.exceptions import PipelineFilterException + +import numpy as np +import pytest + + +def test_DropAnyNan_false(): + + original = np.array([[1, 2], [4, 3]]) + + drop = filters.DropAnyNan() + # No return value, just check no exception is raised + drop.filter(original) + + +def test_DropAnyNan_true(): + + original = np.array([[1, 2], [4, np.nan]]) + + drop = filters.DropAnyNan() + + with pytest.raises(PipelineFilterException): + result = drop.filter(original) + + +def test_DropAllNan_false(): + + original = np.array([[1, 2], [np.nan, 3]]) + + drop = filters.DropAllNan() + # No return value, just check no exception is raised + drop.filter(original) + + +def test_DropAllNan_true(): + + original = np.array([[np.nan, np.nan], [np.nan, np.nan]]) + + drop = filters.DropAllNan() + + with pytest.raises(PipelineFilterException): + result = drop.filter(original) From dd43fbbb63e366486b5dd50371d4065142b28ce0 Mon Sep 17 00:00:00 2001 From: Luke Hoffmann Date: Fri, 14 Nov 2025 12:02:52 +1100 Subject: [PATCH 7/7] Fix nan filter implementations and --- .../pipeline/operations/numpy/filters.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/filters.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/filters.py index 51553011..bd04042a 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/filters.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/filters.py @@ -48,16 +48,16 @@ def __init__(self) -> None: self.record_initialisation() def filter(self, sample: np.ndarray): - """Check if any of the sample is nan + """Reject the sample if any value is nan Args: sample (np.ndarray): Sample to check - Returns: - (bool): - If sample contains nan's + Raises: + (PipelineFilterException): + If sample contains one or more nan value """ - if not bool(np.array(list(np.isnan(sample))).any()): + if bool(np.array(list(np.isnan(sample))).any()): raise PipelineFilterException(sample, "Data contained nan's.") @@ -76,16 +76,16 @@ def __init__(self) -> None: self.record_initialisation() def filter(self, sample: np.ndarray): - """Check if all of the sample is nan + """Reject the sample if all of its values are nan Args: sample (np.ndarray): Sample to check - Returns: - (bool): - If sample contains nan's + Raises: + (PipelineFilterException): + If sample contains only nan values """ - if not bool(np.array(list(np.isnan(sample))).all()): + if bool(np.array(list(np.isnan(sample))).all()): raise PipelineFilterException(sample, "Data contained all nan's.")