From 9157a6ac4d5de0bcc295341f4cb309aadabe3776 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 8 Jan 2026 16:15:57 +1100 Subject: [PATCH 01/12] use Replace class instead of replace_values --- .../src/pyearthtools/pipeline/operations/numpy/values.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/values.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/values.py index 3356017e..2a682650 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/values.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/values.py @@ -112,7 +112,7 @@ def __init__( self.value = value self.replacement_value = replacement_value - self._mask_transform = pyearthtools.data.transforms.mask.replace_value(value, operation, replacement_value) + self._mask_transform = pyearthtools.data.transforms.mask.Replace(value, operation, replacement_value) def apply_func(self, sample: np.ndarray) -> np.ndarray: """ From f22c67a840afa95254d7a7b745b12727a7925ef9 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 8 Jan 2026 16:17:24 +1100 Subject: [PATCH 02/12] rename numpy ForceNormalised to Clip Clip is more appropriate as there is an equivalent dask/numpy/xarray function. Additionally, no normalisation is occuring. --- .../pyearthtools/pipeline/operations/numpy/values.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/values.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/values.py index 2a682650..c795fae1 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/values.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/values.py @@ -129,9 +129,9 @@ def apply_func(self, sample: np.ndarray) -> np.ndarray: return self._mask_transform(sample) # type: ignore -class ForceNormalised(Operation): +class Clip(Operation): """ - Operation to force data within a certain range, by default 0 & 1 + Operation to force data to be within a certain range, by default 0 & 1 """ _override_interface = ["Delayed", "Serial"] @@ -164,10 +164,8 @@ def __init__( self.record_initialisation() - self._force_min = MaskValue(min_value, "<", min_value) if min_value is not None else None - self._force_max = MaskValue(max_value, ">", max_value) if max_value is not None else None + self._min_value = min_value + self._max_value = max_value def apply_func(self, sample): - for func in (func for func in [self._force_min, self._force_max] if func is not None): - sample = func.apply_func(sample) - return sample + return np.clip(sample, a_min=self._min_value, a_max=self._max_value) From 23f91f16c7bc1195788bf73f3d8aca3312c72a66 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 8 Jan 2026 16:17:49 +1100 Subject: [PATCH 03/12] add numpy values operations tests --- .../operations/numpy/test_numpy_values.py | 159 ++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 packages/pipeline/tests/operations/numpy/test_numpy_values.py diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_values.py b/packages/pipeline/tests/operations/numpy/test_numpy_values.py new file mode 100644 index 00000000..e01873de --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_values.py @@ -0,0 +1,159 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.numpy import values + +import numpy as np +import pytest + + +@pytest.fixture(scope="module") +def example_data(): + return np.array( + [ + [1, 2, 3], + [-np.inf, np.inf, -np.inf], + [1, np.nan, np.nan], + ] + ) + + +def test_fillnan(example_data): + """Tests numpy FillNan operation class.""" + op = values.FillNan(nan=123, posinf=456, neginf=-789) + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [1, 2, 3], + [-789, 456, -789], + [1, 123, 123], + ], + dtype=result.dtype, + ), + ) + + +def test_maskvalue(example_data): + """Tests numpy MaskValue operation class.""" + + # pass invalid operation + with pytest.raises(KeyError): + values.MaskValue(1, operation="*") + + # test default op (==) + op = values.MaskValue(1) + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [np.nan, 2.0, 3.0], + [-np.inf, np.inf, -np.inf], + [np.nan, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + # test <= op + op = values.MaskValue(2, operation="<=") + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [np.nan, np.nan, 3.0], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + # test < op + op = values.MaskValue(2, operation="<") + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [np.nan, 2.0, 3.0], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + # test >= op + op = values.MaskValue(2, operation=">=") + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [1.0, np.nan, np.nan], + [-np.inf, np.nan, -np.inf], + [1.0, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + # test > op + op = values.MaskValue(2, operation=">") + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [1.0, 2.0, np.nan], + [-np.inf, np.nan, -np.inf], + [1.0, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + +def test_clip(example_data): + """Tests numpy Clip operation class.""" + op = values.Clip() + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [1.0, 1.0, 1.0], + [0.0, 1.0, 0.0], + [1.0, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) From c1a6a9c4e7a6f79a7d5e8fbc1a7fc8c3b061a8f4 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 8 Jan 2026 17:40:12 +1100 Subject: [PATCH 04/12] add pos/neg inf replacement in xarray FillNan This aligns xarray FillNan with numpy FillNan --- .../pipeline/operations/xarray/values.py | 19 ++++- .../operations/xarray/test_xarray_values.py | 82 +++++++++++++++++++ 2 files changed, 100 insertions(+), 1 deletion(-) create mode 100644 packages/pipeline/tests/operations/xarray/test_xarray_values.py diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py index 2b18ca49..add81d8a 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py @@ -67,7 +67,24 @@ def __init__( self.neginf = neginf def apply_func(self, sample: T) -> T: - return sample.fillna(self.nan) + + if not (isinstance(sample, xr.DataArray) or isinstance(sample, xr.Dataset)): + raise TypeError("sample must be xr.DataArray or xr.Dataset.") + + # create copy of input, with np.nan_to_num applied to underlying numpy arrays + if isinstance(sample, xr.DataArray): + return sample.copy( + deep=True, # since data is provided, deep copy only applies to coordinates + data=np.nan_to_num(sample.values, nan=self.nan, posinf=self.posinf, neginf=self.neginf), + ) + else: + return sample.copy( + deep=True, + data={ + k: np.nan_to_num(v.values, nan=self.nan, posinf=self.posinf, neginf=self.neginf) + for k, v in sample.items() + }, + ) class MaskValue(Operation): diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_values.py b/packages/pipeline/tests/operations/xarray/test_xarray_values.py new file mode 100644 index 00000000..d686fb19 --- /dev/null +++ b/packages/pipeline/tests/operations/xarray/test_xarray_values.py @@ -0,0 +1,82 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.xarray import values + +import numpy as np +import xarray as xr +import pytest + + +@pytest.fixture(scope="module") +def example_dataarray(): + return xr.DataArray( + [ + [1, 2, 3], + [-np.inf, np.inf, -np.inf], + [1, np.nan, np.nan], + ] + ) + + +@pytest.fixture(scope="module") +def example_dataset(example_dataarray): + return xr.Dataset( + { + "a": example_dataarray, + "b": 2 * example_dataarray, + } + ) + + +def test_fillnan(example_dataarray, example_dataset): + """Tests xarray FillNan operation class for DataArray and Dataset inputs.""" + op = values.FillNan(nan=123, posinf=456, neginf=-789) + result = op.apply_func(example_dataarray) + + assert result.equals( + xr.DataArray( + [ + [1, 2, 3], + [-789, 456, -789], + [1, 123, 123], + ], + ) + ) + + result = op.apply_func(example_dataset) + + assert result.equals( + xr.Dataset( + { + "a": xr.DataArray( + [ + [1, 2, 3], + [-789, 456, -789], + [1, 123, 123], + ], + ), + "b": xr.DataArray( + [ + [2, 4, 6], + [-789, 456, -789], + [2, 123, 123], + ], + ), + } + ) + ) + + with pytest.raises(TypeError): + op.apply_func(1) From e4cceb6f55b558e82806cedc2a0e8d1840909c44 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 08:42:54 +1100 Subject: [PATCH 05/12] add xarray MaskValue op tests --- .../operations/xarray/test_xarray_values.py | 185 +++++++++++++++++- 1 file changed, 183 insertions(+), 2 deletions(-) diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_values.py b/packages/pipeline/tests/operations/xarray/test_xarray_values.py index d686fb19..da269da4 100644 --- a/packages/pipeline/tests/operations/xarray/test_xarray_values.py +++ b/packages/pipeline/tests/operations/xarray/test_xarray_values.py @@ -1,10 +1,10 @@ # Copyright Commonwealth of Australia, Bureau of Meteorology 2025. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -80,3 +80,184 @@ def test_fillnan(example_dataarray, example_dataset): with pytest.raises(TypeError): op.apply_func(1) + + +def test_maskvalue(example_dataarray, example_dataset): + """Tests xarray MaskValue operation class.""" + + # pass invalid operation + with pytest.raises(KeyError): + values.MaskValue(1, operation="*") + + # test default op (==) + op = values.MaskValue(1) + result = op.apply_func(example_dataarray) + + assert result.equals( + xr.DataArray( + [ + [np.nan, 2, 3], + [-np.inf, np.inf, -np.inf], + [np.nan, np.nan, np.nan], + ] + ) + ) + + result = op.apply_func(example_dataset) + + assert result.equals( + xr.Dataset( + { + "a": xr.DataArray( + [ + [np.nan, 2, 3], + [-np.inf, np.inf, -np.inf], + [np.nan, np.nan, np.nan], + ] + ), + "b": xr.DataArray( + [ + [2, 4, 6], + [-np.inf, np.inf, -np.inf], + [2, np.nan, np.nan], + ] + ), + } + ) + ) + + # test <= op + op = values.MaskValue(2, operation="<=") + result = op.apply_func(example_dataarray) + + assert result.equals( + xr.DataArray( + [ + [np.nan, np.nan, 3], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ] + ) + ) + + result = op.apply_func(example_dataset) + + assert result.equals( + xr.Dataset( + { + "a": xr.DataArray( + [ + [np.nan, np.nan, 3], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ] + ), + "b": xr.DataArray([[np.nan, 4, 6], [np.nan, np.inf, np.nan], [np.nan, np.nan, np.nan]]), + } + ) + ) + + # test < op + op = values.MaskValue(2, operation="<") + result = op.apply_func(example_dataarray) + + assert result.equals( + xr.DataArray( + [ + [np.nan, 2, 3], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ] + ) + ) + + result = op.apply_func(example_dataset) + + assert result.equals( + xr.Dataset( + { + "a": xr.DataArray( + [ + [np.nan, 2, 3], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ] + ), + "b": xr.DataArray([[2, 4, 6], [np.nan, np.inf, np.nan], [2, np.nan, np.nan]]), + } + ) + ) + + # test >= op + op = values.MaskValue(2, operation=">=") + result = op.apply_func(example_dataarray) + + assert result.equals( + xr.DataArray( + [ + [1, np.nan, np.nan], + [-np.inf, np.nan, -np.inf], + [1, np.nan, np.nan], + ] + ) + ) + + result = op.apply_func(example_dataset) + + assert result.equals( + xr.Dataset( + { + "a": xr.DataArray( + [ + [1, np.nan, np.nan], + [-np.inf, np.nan, -np.inf], + [1, np.nan, np.nan], + ], + ), + "b": xr.DataArray( + [ + [np.nan, np.nan, np.nan], + [-np.inf, np.nan, -np.inf], + [np.nan, np.nan, np.nan], + ] + ), + } + ) + ) + + # test > op + op = values.MaskValue(2, operation=">") + result = op.apply_func(example_dataarray) + + assert result.equals( + xr.DataArray( + [ + [1, 2, np.nan], + [-np.inf, np.nan, -np.inf], + [1, np.nan, np.nan], + ], + ), + ) + + result = op.apply_func(example_dataset) + + assert result.equals( + xr.Dataset( + { + "a": xr.DataArray( + [ + [1, 2, np.nan], + [-np.inf, np.nan, -np.inf], + [1, np.nan, np.nan], + ], + ), + "b": xr.DataArray( + [ + [2, np.nan, np.nan], + [-np.inf, np.nan, -np.inf], + [2, np.nan, np.nan], + ] + ), + } + ) + ) From 578b9fdbac8259b84e2eb1fdb69e3d136f45d2c4 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 08:51:57 +1100 Subject: [PATCH 06/12] rename xarray ForceNormalised to Clip Clip is more appropriate as there is an equivalent dask/numpy/xarray operation. --- .../pipeline/operations/xarray/values.py | 12 +++++------ .../operations/xarray/test_xarray_values.py | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py index add81d8a..7850afcb 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py @@ -149,9 +149,9 @@ def apply_func(self, sample: T) -> T: return self._mask_transform(sample) -class ForceNormalised(Operation): +class Clip(Operation): """ - Operation to force data within a certain range, by default 0 & 1 + Operation to force data to be within a certain range, by default 0 & 1 """ _override_interface = "Serial" @@ -183,13 +183,11 @@ def __init__( self.record_initialisation() - self._force_min = MaskValue(min_value, "<", min_value) if min_value is not None else None - self._force_max = MaskValue(max_value, ">", max_value) if max_value is not None else None + self._min_value = min_value + self._max_value = max_value def apply_func(self, sample): - for func in (func for func in [self._force_min, self._force_max] if func is not None): - sample = func.apply_func(sample) - return sample + return sample.clip(min=self._min_value, max=self._max_value) class Derive(Operation): diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_values.py b/packages/pipeline/tests/operations/xarray/test_xarray_values.py index da269da4..f209eae9 100644 --- a/packages/pipeline/tests/operations/xarray/test_xarray_values.py +++ b/packages/pipeline/tests/operations/xarray/test_xarray_values.py @@ -261,3 +261,23 @@ def test_maskvalue(example_dataarray, example_dataset): } ) ) + + +def test_clip(example_dataarray, example_dataset): + """Tests xarray Clip operation class.""" + op = values.Clip() + result = op.apply_func(example_dataarray) + + correct_da = xr.DataArray( + [ + [1, 1, 1], + [0, 1, 0], + [1, np.nan, np.nan], + ], + ) + + assert result.equals(correct_da) + + result = op.apply_func(example_dataset) + + assert result.equals(xr.Dataset({"a": correct_da, "b": correct_da})) From 5ef7a6e5982f85b4f9eea887a5d5d9dae4a497da Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 09:32:35 +1100 Subject: [PATCH 07/12] prevent derive from adding vars to input when creating a new variable with Derive, the new variable was being added to the input dataset. This commit fixes that by making a shallow copy of the input and adding the variables to that shallow copy before returning it. --- .../src/pyearthtools/data/transforms/derive.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/packages/data/src/pyearthtools/data/transforms/derive.py b/packages/data/src/pyearthtools/data/transforms/derive.py index 7e097756..31908933 100644 --- a/packages/data/src/pyearthtools/data/transforms/derive.py +++ b/packages/data/src/pyearthtools/data/transforms/derive.py @@ -390,22 +390,24 @@ def derive_equations( attrs["equation"] = eq LOG.debug(f"Setting {key!r} to result of {eq!r}.") - result, eq_drop_vars = _evaluate(eq, dataset=dataset) + # shallow copy dataset, so new variables aren't added to old dataset + dataset_copy = dataset.copy(deep=False) + result, eq_drop_vars = _evaluate(eq, dataset=dataset_copy) - if key in list(dataset.coords.keys()): - dataset = dataset.assign_coords({key: result}) + if key in list(dataset_copy.coords.keys()): + dataset_copy = dataset_copy.assign_coords({key: result}) else: - dataset[key] = result + dataset_copy[key] = result if attrs.pop("drop", drop): _ = list(drop_vars.append(var) for var in eq_drop_vars) - dataset[key].attrs.update(**attrs) + dataset_copy[key].attrs.update(**attrs) # Drop variables used in the calculation - dataset = dataset.drop(set(drop_vars).intersection(dataset.data_vars), errors="ignore") # type: ignore + dataset_copy = dataset_copy.drop(set(drop_vars).intersection(dataset_copy.data_vars), errors="ignore") # type: ignore - return dataset + return dataset_copy class Derive(Transform): From eba6df046808e20b75804f5fbdba987aa338f9df Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 09:33:26 +1100 Subject: [PATCH 08/12] add test for xarray derive op --- .../pyearthtools/pipeline/operations/xarray/values.py | 2 +- .../tests/operations/xarray/test_xarray_values.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py index 7850afcb..7db9ea46 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py @@ -219,7 +219,7 @@ def __init__( **derivations (Union[str, tuple[str, dict[str, Any]]]): Kwarg form of `derivation`. """ - super().__init__(split_tuples=True, recursively_split_tuples=True, recognised_types=(xr.DataArray, xr.Dataset)) + super().__init__(split_tuples=True, recursively_split_tuples=True, recognised_types=(xr.Dataset,)) self.record_initialisation() derivation = derivation or {} diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_values.py b/packages/pipeline/tests/operations/xarray/test_xarray_values.py index f209eae9..6887d9c7 100644 --- a/packages/pipeline/tests/operations/xarray/test_xarray_values.py +++ b/packages/pipeline/tests/operations/xarray/test_xarray_values.py @@ -281,3 +281,13 @@ def test_clip(example_dataarray, example_dataset): result = op.apply_func(example_dataset) assert result.equals(xr.Dataset({"a": correct_da, "b": correct_da})) + + +def test_derive(example_dataset): + """Tests xarray Derive operation class.""" + op = values.Derive(c="a + b") + result = op.apply_func(example_dataset) + + assert result.equals(example_dataset.assign({"c": example_dataset["a"] + example_dataset["b"]})) + + assert op.undo_func(result).equals(example_dataset) From f8f50c27ece908f9f11c74c94be15756c4c15e85 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 09:55:15 +1100 Subject: [PATCH 09/12] complete implementation of dask FillNan --- .../pipeline/operations/dask/values.py | 3 +- .../tests/operations/dask/test_dask_values.py | 51 +++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 packages/pipeline/tests/operations/dask/test_dask_values.py diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py index 3197ad8a..c1aa7d12 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py @@ -53,7 +53,6 @@ def __init__( Value to be used to fill negative infinity values, If no value is passed then negative infinity values will be replaced with a very small (or negative) number. Defaults to None. """ - raise NotImplementedError("Not implemented") super().__init__( operation="apply", @@ -69,7 +68,7 @@ def __init__( self.neginf = neginf def apply_func(self, sample: da.Array): - return da.nan_to_num(da.array(sample), self.nan, self.posinf, self.neginf) + return da.nan_to_num(da.array(sample), True, self.nan, self.posinf, self.neginf) class MaskValue(DaskOperation): diff --git a/packages/pipeline/tests/operations/dask/test_dask_values.py b/packages/pipeline/tests/operations/dask/test_dask_values.py new file mode 100644 index 00000000..e0b0842c --- /dev/null +++ b/packages/pipeline/tests/operations/dask/test_dask_values.py @@ -0,0 +1,51 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.dask import values + +import dask.array as da +import numpy as np +import pytest + + +@pytest.fixture(scope="module") +def example_data(): + return da.array( + [ + [1, 2, 3], + [-np.inf, np.inf, -np.inf], + [1, np.nan, np.nan], + ] + ) + + +def test_fillnan(example_data): + """Tests dask FillNan operation class.""" + op = values.FillNan(nan=123, posinf=456, neginf=-789) + result = op.apply_func(example_data) + + assert ( + ( + result + == da.array( + [ + [1, 2, 3], + [-789, 456, -789], + [1, 123, 123], + ] + ) + ) + .all() + .compute() + ) From 085abeac0acd38414ace43db3d684372d34a7265 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 10:20:34 +1100 Subject: [PATCH 10/12] add dask maskvalue tests --- .../pipeline/operations/dask/values.py | 2 +- .../tests/operations/dask/test_dask_values.py | 90 +++++++++++++++++++ 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py index c1aa7d12..7270d556 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py @@ -117,7 +117,7 @@ def __init__( self.value = value self.replacement_value = replacement_value - self._mask_transform = pyearthtools.data.transforms.mask.replace_value(value, operation, replacement_value) + self._mask_transform = pyearthtools.data.transforms.mask.Replace(value, operation, replacement_value) def apply_func(self, sample: da.Array) -> da.Array: """ diff --git a/packages/pipeline/tests/operations/dask/test_dask_values.py b/packages/pipeline/tests/operations/dask/test_dask_values.py index e0b0842c..f1a7c7b3 100644 --- a/packages/pipeline/tests/operations/dask/test_dask_values.py +++ b/packages/pipeline/tests/operations/dask/test_dask_values.py @@ -49,3 +49,93 @@ def test_fillnan(example_data): .all() .compute() ) + + +def test_maskvalue(example_data): + """Tests dask MaskValue operation class.""" + + # pass invalid operation + with pytest.raises(KeyError): + values.MaskValue(1, operation="*") + + # test default op (==) + op = values.MaskValue(1) + result = op.apply_func(example_data) + + assert da.allclose( + result, + da.array( + [ + [np.nan, 2.0, 3.0], + [-np.inf, np.inf, -np.inf], + [np.nan, np.nan, np.nan], + ], + ), + equal_nan=True, + ) + + # test <= op + op = values.MaskValue(2, operation="<=") + result = op.apply_func(example_data) + + assert da.allclose( + result, + da.array( + [ + [np.nan, np.nan, 3.0], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ], + ), + equal_nan=True, + ) + + # test < op + op = values.MaskValue(2, operation="<") + result = op.apply_func(example_data) + + assert da.allclose( + result, + da.array( + [ + [np.nan, 2.0, 3.0], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ], + ), + equal_nan=True, + ) + + # test >= op + op = values.MaskValue(2, operation=">=") + result = op.apply_func(example_data) + + assert np.allclose( + result, + da.array( + [ + [1.0, np.nan, np.nan], + [-np.inf, np.nan, -np.inf], + [1.0, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + # test > op + op = values.MaskValue(2, operation=">") + result = op.apply_func(example_data) + + assert np.allclose( + result, + da.array( + [ + [1.0, 2.0, np.nan], + [-np.inf, np.nan, -np.inf], + [1.0, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) From 898786c76c6aaa273dfc8267cd4f834de51bd5fb Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 10:22:18 +1100 Subject: [PATCH 11/12] rename dask ForceNormalised to Clip Clip is more appropriate as there is an equivalent dask/numpy/xarray operation. --- .../pipeline/operations/dask/values.py | 12 +++++------- .../tests/operations/dask/test_dask_values.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py index 7270d556..197012d3 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py @@ -134,9 +134,9 @@ def apply_func(self, sample: da.Array) -> da.Array: return self._mask_transform(sample) # type: ignore -class ForceNormalised(DaskOperation): +class Clip(DaskOperation): """ - Operation to force data within a certain range, by default 0 & 1 + Operation to force data to be within a certain range, by default 0 & 1 """ _override_interface = ["Serial"] @@ -169,10 +169,8 @@ def __init__( self.record_initialisation() - self._force_min = MaskValue(min_value, "<", min_value) if min_value is not None else None - self._force_max = MaskValue(max_value, ">", max_value) if max_value is not None else None + self._min_value = min_value + self._max_value = max_value def apply_func(self, sample): - for func in (func for func in [self._force_min, self._force_max] if func is not None): - sample = func.apply_func(sample) - return sample + return da.clip(sample, self._min_value, self._max_value) diff --git a/packages/pipeline/tests/operations/dask/test_dask_values.py b/packages/pipeline/tests/operations/dask/test_dask_values.py index f1a7c7b3..d25472a6 100644 --- a/packages/pipeline/tests/operations/dask/test_dask_values.py +++ b/packages/pipeline/tests/operations/dask/test_dask_values.py @@ -139,3 +139,22 @@ def test_maskvalue(example_data): ), equal_nan=True, ) + + +def test_clip(example_data): + """Tests dask Clip operation class.""" + op = values.Clip() + result = op.apply_func(example_data) + + assert da.allclose( + result, + da.array( + [ + [1.0, 1.0, 1.0], + [0.0, 1.0, 0.0], + [1.0, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) From 241c73bbbb9148900d99c505a8c214dd0a658635 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 9 Jan 2026 10:25:01 +1100 Subject: [PATCH 12/12] use xr drop_vars instead of depracated drop --- packages/data/src/pyearthtools/data/transforms/derive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/data/src/pyearthtools/data/transforms/derive.py b/packages/data/src/pyearthtools/data/transforms/derive.py index 31908933..bbe9fcae 100644 --- a/packages/data/src/pyearthtools/data/transforms/derive.py +++ b/packages/data/src/pyearthtools/data/transforms/derive.py @@ -405,7 +405,7 @@ def derive_equations( dataset_copy[key].attrs.update(**attrs) # Drop variables used in the calculation - dataset_copy = dataset_copy.drop(set(drop_vars).intersection(dataset_copy.data_vars), errors="ignore") # type: ignore + dataset_copy = dataset_copy.drop_vars(set(drop_vars).intersection(dataset_copy.data_vars), errors="ignore") # type: ignore return dataset_copy