diff --git a/packages/data/src/pyearthtools/data/transforms/derive.py b/packages/data/src/pyearthtools/data/transforms/derive.py index 7e097756..bbe9fcae 100644 --- a/packages/data/src/pyearthtools/data/transforms/derive.py +++ b/packages/data/src/pyearthtools/data/transforms/derive.py @@ -390,22 +390,24 @@ def derive_equations( attrs["equation"] = eq LOG.debug(f"Setting {key!r} to result of {eq!r}.") - result, eq_drop_vars = _evaluate(eq, dataset=dataset) + # shallow copy dataset, so new variables aren't added to old dataset + dataset_copy = dataset.copy(deep=False) + result, eq_drop_vars = _evaluate(eq, dataset=dataset_copy) - if key in list(dataset.coords.keys()): - dataset = dataset.assign_coords({key: result}) + if key in list(dataset_copy.coords.keys()): + dataset_copy = dataset_copy.assign_coords({key: result}) else: - dataset[key] = result + dataset_copy[key] = result if attrs.pop("drop", drop): _ = list(drop_vars.append(var) for var in eq_drop_vars) - dataset[key].attrs.update(**attrs) + dataset_copy[key].attrs.update(**attrs) # Drop variables used in the calculation - dataset = dataset.drop(set(drop_vars).intersection(dataset.data_vars), errors="ignore") # type: ignore + dataset_copy = dataset_copy.drop_vars(set(drop_vars).intersection(dataset_copy.data_vars), errors="ignore") # type: ignore - return dataset + return dataset_copy class Derive(Transform): diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py index 3197ad8a..197012d3 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/values.py @@ -53,7 +53,6 @@ def __init__( Value to be used to fill negative infinity values, If no value is passed then negative infinity values will be replaced with a very small (or negative) number. Defaults to None. """ - raise NotImplementedError("Not implemented") super().__init__( operation="apply", @@ -69,7 +68,7 @@ def __init__( self.neginf = neginf def apply_func(self, sample: da.Array): - return da.nan_to_num(da.array(sample), self.nan, self.posinf, self.neginf) + return da.nan_to_num(da.array(sample), True, self.nan, self.posinf, self.neginf) class MaskValue(DaskOperation): @@ -118,7 +117,7 @@ def __init__( self.value = value self.replacement_value = replacement_value - self._mask_transform = pyearthtools.data.transforms.mask.replace_value(value, operation, replacement_value) + self._mask_transform = pyearthtools.data.transforms.mask.Replace(value, operation, replacement_value) def apply_func(self, sample: da.Array) -> da.Array: """ @@ -135,9 +134,9 @@ def apply_func(self, sample: da.Array) -> da.Array: return self._mask_transform(sample) # type: ignore -class ForceNormalised(DaskOperation): +class Clip(DaskOperation): """ - Operation to force data within a certain range, by default 0 & 1 + Operation to force data to be within a certain range, by default 0 & 1 """ _override_interface = ["Serial"] @@ -170,10 +169,8 @@ def __init__( self.record_initialisation() - self._force_min = MaskValue(min_value, "<", min_value) if min_value is not None else None - self._force_max = MaskValue(max_value, ">", max_value) if max_value is not None else None + self._min_value = min_value + self._max_value = max_value def apply_func(self, sample): - for func in (func for func in [self._force_min, self._force_max] if func is not None): - sample = func.apply_func(sample) - return sample + return da.clip(sample, self._min_value, self._max_value) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/values.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/values.py index 3356017e..c795fae1 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/values.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/values.py @@ -112,7 +112,7 @@ def __init__( self.value = value self.replacement_value = replacement_value - self._mask_transform = pyearthtools.data.transforms.mask.replace_value(value, operation, replacement_value) + self._mask_transform = pyearthtools.data.transforms.mask.Replace(value, operation, replacement_value) def apply_func(self, sample: np.ndarray) -> np.ndarray: """ @@ -129,9 +129,9 @@ def apply_func(self, sample: np.ndarray) -> np.ndarray: return self._mask_transform(sample) # type: ignore -class ForceNormalised(Operation): +class Clip(Operation): """ - Operation to force data within a certain range, by default 0 & 1 + Operation to force data to be within a certain range, by default 0 & 1 """ _override_interface = ["Delayed", "Serial"] @@ -164,10 +164,8 @@ def __init__( self.record_initialisation() - self._force_min = MaskValue(min_value, "<", min_value) if min_value is not None else None - self._force_max = MaskValue(max_value, ">", max_value) if max_value is not None else None + self._min_value = min_value + self._max_value = max_value def apply_func(self, sample): - for func in (func for func in [self._force_min, self._force_max] if func is not None): - sample = func.apply_func(sample) - return sample + return np.clip(sample, a_min=self._min_value, a_max=self._max_value) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py index 2b18ca49..7db9ea46 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/values.py @@ -67,7 +67,24 @@ def __init__( self.neginf = neginf def apply_func(self, sample: T) -> T: - return sample.fillna(self.nan) + + if not (isinstance(sample, xr.DataArray) or isinstance(sample, xr.Dataset)): + raise TypeError("sample must be xr.DataArray or xr.Dataset.") + + # create copy of input, with np.nan_to_num applied to underlying numpy arrays + if isinstance(sample, xr.DataArray): + return sample.copy( + deep=True, # since data is provided, deep copy only applies to coordinates + data=np.nan_to_num(sample.values, nan=self.nan, posinf=self.posinf, neginf=self.neginf), + ) + else: + return sample.copy( + deep=True, + data={ + k: np.nan_to_num(v.values, nan=self.nan, posinf=self.posinf, neginf=self.neginf) + for k, v in sample.items() + }, + ) class MaskValue(Operation): @@ -132,9 +149,9 @@ def apply_func(self, sample: T) -> T: return self._mask_transform(sample) -class ForceNormalised(Operation): +class Clip(Operation): """ - Operation to force data within a certain range, by default 0 & 1 + Operation to force data to be within a certain range, by default 0 & 1 """ _override_interface = "Serial" @@ -166,13 +183,11 @@ def __init__( self.record_initialisation() - self._force_min = MaskValue(min_value, "<", min_value) if min_value is not None else None - self._force_max = MaskValue(max_value, ">", max_value) if max_value is not None else None + self._min_value = min_value + self._max_value = max_value def apply_func(self, sample): - for func in (func for func in [self._force_min, self._force_max] if func is not None): - sample = func.apply_func(sample) - return sample + return sample.clip(min=self._min_value, max=self._max_value) class Derive(Operation): @@ -204,7 +219,7 @@ def __init__( **derivations (Union[str, tuple[str, dict[str, Any]]]): Kwarg form of `derivation`. """ - super().__init__(split_tuples=True, recursively_split_tuples=True, recognised_types=(xr.DataArray, xr.Dataset)) + super().__init__(split_tuples=True, recursively_split_tuples=True, recognised_types=(xr.Dataset,)) self.record_initialisation() derivation = derivation or {} diff --git a/packages/pipeline/tests/operations/dask/test_dask_values.py b/packages/pipeline/tests/operations/dask/test_dask_values.py new file mode 100644 index 00000000..d25472a6 --- /dev/null +++ b/packages/pipeline/tests/operations/dask/test_dask_values.py @@ -0,0 +1,160 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.dask import values + +import dask.array as da +import numpy as np +import pytest + + +@pytest.fixture(scope="module") +def example_data(): + return da.array( + [ + [1, 2, 3], + [-np.inf, np.inf, -np.inf], + [1, np.nan, np.nan], + ] + ) + + +def test_fillnan(example_data): + """Tests dask FillNan operation class.""" + op = values.FillNan(nan=123, posinf=456, neginf=-789) + result = op.apply_func(example_data) + + assert ( + ( + result + == da.array( + [ + [1, 2, 3], + [-789, 456, -789], + [1, 123, 123], + ] + ) + ) + .all() + .compute() + ) + + +def test_maskvalue(example_data): + """Tests dask MaskValue operation class.""" + + # pass invalid operation + with pytest.raises(KeyError): + values.MaskValue(1, operation="*") + + # test default op (==) + op = values.MaskValue(1) + result = op.apply_func(example_data) + + assert da.allclose( + result, + da.array( + [ + [np.nan, 2.0, 3.0], + [-np.inf, np.inf, -np.inf], + [np.nan, np.nan, np.nan], + ], + ), + equal_nan=True, + ) + + # test <= op + op = values.MaskValue(2, operation="<=") + result = op.apply_func(example_data) + + assert da.allclose( + result, + da.array( + [ + [np.nan, np.nan, 3.0], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ], + ), + equal_nan=True, + ) + + # test < op + op = values.MaskValue(2, operation="<") + result = op.apply_func(example_data) + + assert da.allclose( + result, + da.array( + [ + [np.nan, 2.0, 3.0], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ], + ), + equal_nan=True, + ) + + # test >= op + op = values.MaskValue(2, operation=">=") + result = op.apply_func(example_data) + + assert np.allclose( + result, + da.array( + [ + [1.0, np.nan, np.nan], + [-np.inf, np.nan, -np.inf], + [1.0, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + # test > op + op = values.MaskValue(2, operation=">") + result = op.apply_func(example_data) + + assert np.allclose( + result, + da.array( + [ + [1.0, 2.0, np.nan], + [-np.inf, np.nan, -np.inf], + [1.0, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + +def test_clip(example_data): + """Tests dask Clip operation class.""" + op = values.Clip() + result = op.apply_func(example_data) + + assert da.allclose( + result, + da.array( + [ + [1.0, 1.0, 1.0], + [0.0, 1.0, 0.0], + [1.0, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_values.py b/packages/pipeline/tests/operations/numpy/test_numpy_values.py new file mode 100644 index 00000000..e01873de --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_values.py @@ -0,0 +1,159 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.numpy import values + +import numpy as np +import pytest + + +@pytest.fixture(scope="module") +def example_data(): + return np.array( + [ + [1, 2, 3], + [-np.inf, np.inf, -np.inf], + [1, np.nan, np.nan], + ] + ) + + +def test_fillnan(example_data): + """Tests numpy FillNan operation class.""" + op = values.FillNan(nan=123, posinf=456, neginf=-789) + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [1, 2, 3], + [-789, 456, -789], + [1, 123, 123], + ], + dtype=result.dtype, + ), + ) + + +def test_maskvalue(example_data): + """Tests numpy MaskValue operation class.""" + + # pass invalid operation + with pytest.raises(KeyError): + values.MaskValue(1, operation="*") + + # test default op (==) + op = values.MaskValue(1) + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [np.nan, 2.0, 3.0], + [-np.inf, np.inf, -np.inf], + [np.nan, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + # test <= op + op = values.MaskValue(2, operation="<=") + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [np.nan, np.nan, 3.0], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + # test < op + op = values.MaskValue(2, operation="<") + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [np.nan, 2.0, 3.0], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + # test >= op + op = values.MaskValue(2, operation=">=") + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [1.0, np.nan, np.nan], + [-np.inf, np.nan, -np.inf], + [1.0, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + # test > op + op = values.MaskValue(2, operation=">") + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [1.0, 2.0, np.nan], + [-np.inf, np.nan, -np.inf], + [1.0, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) + + +def test_clip(example_data): + """Tests numpy Clip operation class.""" + op = values.Clip() + result = op.apply_func(example_data) + + assert np.array_equal( + result, + np.array( + [ + [1.0, 1.0, 1.0], + [0.0, 1.0, 0.0], + [1.0, np.nan, np.nan], + ], + dtype=result.dtype, + ), + equal_nan=True, + ) diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_values.py b/packages/pipeline/tests/operations/xarray/test_xarray_values.py new file mode 100644 index 00000000..6887d9c7 --- /dev/null +++ b/packages/pipeline/tests/operations/xarray/test_xarray_values.py @@ -0,0 +1,293 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.xarray import values + +import numpy as np +import xarray as xr +import pytest + + +@pytest.fixture(scope="module") +def example_dataarray(): + return xr.DataArray( + [ + [1, 2, 3], + [-np.inf, np.inf, -np.inf], + [1, np.nan, np.nan], + ] + ) + + +@pytest.fixture(scope="module") +def example_dataset(example_dataarray): + return xr.Dataset( + { + "a": example_dataarray, + "b": 2 * example_dataarray, + } + ) + + +def test_fillnan(example_dataarray, example_dataset): + """Tests xarray FillNan operation class for DataArray and Dataset inputs.""" + op = values.FillNan(nan=123, posinf=456, neginf=-789) + result = op.apply_func(example_dataarray) + + assert result.equals( + xr.DataArray( + [ + [1, 2, 3], + [-789, 456, -789], + [1, 123, 123], + ], + ) + ) + + result = op.apply_func(example_dataset) + + assert result.equals( + xr.Dataset( + { + "a": xr.DataArray( + [ + [1, 2, 3], + [-789, 456, -789], + [1, 123, 123], + ], + ), + "b": xr.DataArray( + [ + [2, 4, 6], + [-789, 456, -789], + [2, 123, 123], + ], + ), + } + ) + ) + + with pytest.raises(TypeError): + op.apply_func(1) + + +def test_maskvalue(example_dataarray, example_dataset): + """Tests xarray MaskValue operation class.""" + + # pass invalid operation + with pytest.raises(KeyError): + values.MaskValue(1, operation="*") + + # test default op (==) + op = values.MaskValue(1) + result = op.apply_func(example_dataarray) + + assert result.equals( + xr.DataArray( + [ + [np.nan, 2, 3], + [-np.inf, np.inf, -np.inf], + [np.nan, np.nan, np.nan], + ] + ) + ) + + result = op.apply_func(example_dataset) + + assert result.equals( + xr.Dataset( + { + "a": xr.DataArray( + [ + [np.nan, 2, 3], + [-np.inf, np.inf, -np.inf], + [np.nan, np.nan, np.nan], + ] + ), + "b": xr.DataArray( + [ + [2, 4, 6], + [-np.inf, np.inf, -np.inf], + [2, np.nan, np.nan], + ] + ), + } + ) + ) + + # test <= op + op = values.MaskValue(2, operation="<=") + result = op.apply_func(example_dataarray) + + assert result.equals( + xr.DataArray( + [ + [np.nan, np.nan, 3], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ] + ) + ) + + result = op.apply_func(example_dataset) + + assert result.equals( + xr.Dataset( + { + "a": xr.DataArray( + [ + [np.nan, np.nan, 3], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ] + ), + "b": xr.DataArray([[np.nan, 4, 6], [np.nan, np.inf, np.nan], [np.nan, np.nan, np.nan]]), + } + ) + ) + + # test < op + op = values.MaskValue(2, operation="<") + result = op.apply_func(example_dataarray) + + assert result.equals( + xr.DataArray( + [ + [np.nan, 2, 3], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ] + ) + ) + + result = op.apply_func(example_dataset) + + assert result.equals( + xr.Dataset( + { + "a": xr.DataArray( + [ + [np.nan, 2, 3], + [np.nan, np.inf, np.nan], + [np.nan, np.nan, np.nan], + ] + ), + "b": xr.DataArray([[2, 4, 6], [np.nan, np.inf, np.nan], [2, np.nan, np.nan]]), + } + ) + ) + + # test >= op + op = values.MaskValue(2, operation=">=") + result = op.apply_func(example_dataarray) + + assert result.equals( + xr.DataArray( + [ + [1, np.nan, np.nan], + [-np.inf, np.nan, -np.inf], + [1, np.nan, np.nan], + ] + ) + ) + + result = op.apply_func(example_dataset) + + assert result.equals( + xr.Dataset( + { + "a": xr.DataArray( + [ + [1, np.nan, np.nan], + [-np.inf, np.nan, -np.inf], + [1, np.nan, np.nan], + ], + ), + "b": xr.DataArray( + [ + [np.nan, np.nan, np.nan], + [-np.inf, np.nan, -np.inf], + [np.nan, np.nan, np.nan], + ] + ), + } + ) + ) + + # test > op + op = values.MaskValue(2, operation=">") + result = op.apply_func(example_dataarray) + + assert result.equals( + xr.DataArray( + [ + [1, 2, np.nan], + [-np.inf, np.nan, -np.inf], + [1, np.nan, np.nan], + ], + ), + ) + + result = op.apply_func(example_dataset) + + assert result.equals( + xr.Dataset( + { + "a": xr.DataArray( + [ + [1, 2, np.nan], + [-np.inf, np.nan, -np.inf], + [1, np.nan, np.nan], + ], + ), + "b": xr.DataArray( + [ + [2, np.nan, np.nan], + [-np.inf, np.nan, -np.inf], + [2, np.nan, np.nan], + ] + ), + } + ) + ) + + +def test_clip(example_dataarray, example_dataset): + """Tests xarray Clip operation class.""" + op = values.Clip() + result = op.apply_func(example_dataarray) + + correct_da = xr.DataArray( + [ + [1, 1, 1], + [0, 1, 0], + [1, np.nan, np.nan], + ], + ) + + assert result.equals(correct_da) + + result = op.apply_func(example_dataset) + + assert result.equals(xr.Dataset({"a": correct_da, "b": correct_da})) + + +def test_derive(example_dataset): + """Tests xarray Derive operation class.""" + op = values.Derive(c="a + b") + result = op.apply_func(example_dataset) + + assert result.equals(example_dataset.assign({"c": example_dataset["a"] + example_dataset["b"]})) + + assert op.undo_func(result).equals(example_dataset)