Skip to content
16 changes: 9 additions & 7 deletions packages/data/src/pyearthtools/data/transforms/derive.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,22 +390,24 @@ def derive_equations(
attrs["equation"] = eq

LOG.debug(f"Setting {key!r} to result of {eq!r}.")
result, eq_drop_vars = _evaluate(eq, dataset=dataset)
# shallow copy dataset, so new variables aren't added to old dataset
dataset_copy = dataset.copy(deep=False)
result, eq_drop_vars = _evaluate(eq, dataset=dataset_copy)

if key in list(dataset.coords.keys()):
dataset = dataset.assign_coords({key: result})
if key in list(dataset_copy.coords.keys()):
dataset_copy = dataset_copy.assign_coords({key: result})
else:
dataset[key] = result
dataset_copy[key] = result

if attrs.pop("drop", drop):
_ = list(drop_vars.append(var) for var in eq_drop_vars)

dataset[key].attrs.update(**attrs)
dataset_copy[key].attrs.update(**attrs)

# Drop variables used in the calculation
dataset = dataset.drop(set(drop_vars).intersection(dataset.data_vars), errors="ignore") # type: ignore
dataset_copy = dataset_copy.drop_vars(set(drop_vars).intersection(dataset_copy.data_vars), errors="ignore") # type: ignore

return dataset
return dataset_copy


class Derive(Transform):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(
Value to be used to fill negative infinity values,
If no value is passed then negative infinity values will be replaced with a very small (or negative) number. Defaults to None.
"""
raise NotImplementedError("Not implemented")

super().__init__(
operation="apply",
Expand All @@ -69,7 +68,7 @@ def __init__(
self.neginf = neginf

def apply_func(self, sample: da.Array):
return da.nan_to_num(da.array(sample), self.nan, self.posinf, self.neginf)
return da.nan_to_num(da.array(sample), True, self.nan, self.posinf, self.neginf)


class MaskValue(DaskOperation):
Expand Down Expand Up @@ -118,7 +117,7 @@ def __init__(
self.value = value
self.replacement_value = replacement_value

self._mask_transform = pyearthtools.data.transforms.mask.replace_value(value, operation, replacement_value)
self._mask_transform = pyearthtools.data.transforms.mask.Replace(value, operation, replacement_value)

def apply_func(self, sample: da.Array) -> da.Array:
"""
Expand All @@ -135,9 +134,9 @@ def apply_func(self, sample: da.Array) -> da.Array:
return self._mask_transform(sample) # type: ignore


class ForceNormalised(DaskOperation):
class Clip(DaskOperation):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with this change, and there are no issues with dependencies on it at the moment. However, the docs will need to be updated (e.g. docs/api/pipeline/pipeline_api.md) and the notebooks (I think maybe only docs/notebooks/pipeline/Operations.ipynb).

"""
Operation to force data within a certain range, by default 0 & 1
Operation to force data to be within a certain range, by default 0 & 1
"""

_override_interface = ["Serial"]
Expand Down Expand Up @@ -170,10 +169,8 @@ def __init__(

self.record_initialisation()

self._force_min = MaskValue(min_value, "<", min_value) if min_value is not None else None
self._force_max = MaskValue(max_value, ">", max_value) if max_value is not None else None
self._min_value = min_value
self._max_value = max_value

def apply_func(self, sample):
for func in (func for func in [self._force_min, self._force_max] if func is not None):
sample = func.apply_func(sample)
return sample
return da.clip(sample, self._min_value, self._max_value)
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
self.value = value
self.replacement_value = replacement_value

self._mask_transform = pyearthtools.data.transforms.mask.replace_value(value, operation, replacement_value)
self._mask_transform = pyearthtools.data.transforms.mask.Replace(value, operation, replacement_value)

def apply_func(self, sample: np.ndarray) -> np.ndarray:
"""
Expand All @@ -129,9 +129,9 @@ def apply_func(self, sample: np.ndarray) -> np.ndarray:
return self._mask_transform(sample) # type: ignore


class ForceNormalised(Operation):
class Clip(Operation):
"""
Operation to force data within a certain range, by default 0 & 1
Operation to force data to be within a certain range, by default 0 & 1
"""

_override_interface = ["Delayed", "Serial"]
Expand Down Expand Up @@ -164,10 +164,8 @@ def __init__(

self.record_initialisation()

self._force_min = MaskValue(min_value, "<", min_value) if min_value is not None else None
self._force_max = MaskValue(max_value, ">", max_value) if max_value is not None else None
self._min_value = min_value
self._max_value = max_value

def apply_func(self, sample):
for func in (func for func in [self._force_min, self._force_max] if func is not None):
sample = func.apply_func(sample)
return sample
return np.clip(sample, a_min=self._min_value, a_max=self._max_value)
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,24 @@ def __init__(
self.neginf = neginf

def apply_func(self, sample: T) -> T:
return sample.fillna(self.nan)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just reading the docstring above (it won't let me comment on unchanged lines), it says "If no value is passed then positive infinity values will be replaced with a very large number. Defaults to None.". But None is not a very large number. So by default that comment doesn't parse for me.

Copy link
Collaborator

@nikeethr nikeethr Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird, I'm guilty of commenting on unchanged code a lot. I usually just hover over the vertical separator until it gives me a blue "+" - maybe different color/style based on your theme. See:

image

Hopefully that helps.

if not (isinstance(sample, xr.DataArray) or isinstance(sample, xr.Dataset)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should have a general xarray subclass of operation which handles xarray type checking rather than doing it down in apply_func... this is okay, but it might be better to put it up higher.

Copy link
Collaborator

@nikeethr nikeethr Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did have a similar review comment, but I didn't mention it because I wasn't familiar enough with the codebase to suggest where an ideal place would be, and how much impact putting it at a higher level would have on everything else.

I guess (in hindsight) those would be the considerations, and also if it needs to be in a separate issue. But yes, it does seem like a very common "entrypoint" check that applies to many things.

raise TypeError("sample must be xr.DataArray or xr.Dataset.")

# create copy of input, with np.nan_to_num applied to underlying numpy arrays
if isinstance(sample, xr.DataArray):
return sample.copy(
deep=True, # since data is provided, deep copy only applies to coordinates
data=np.nan_to_num(sample.values, nan=self.nan, posinf=self.posinf, neginf=self.neginf),
)
else:
return sample.copy(
deep=True,
data={
k: np.nan_to_num(v.values, nan=self.nan, posinf=self.posinf, neginf=self.neginf)
for k, v in sample.items()
},
)


class MaskValue(Operation):
Expand Down Expand Up @@ -132,9 +149,9 @@ def apply_func(self, sample: T) -> T:
return self._mask_transform(sample)


class ForceNormalised(Operation):
class Clip(Operation):
"""
Operation to force data within a certain range, by default 0 & 1
Operation to force data to be within a certain range, by default 0 & 1
"""

_override_interface = "Serial"
Expand Down Expand Up @@ -166,13 +183,11 @@ def __init__(

self.record_initialisation()

self._force_min = MaskValue(min_value, "<", min_value) if min_value is not None else None
self._force_max = MaskValue(max_value, ">", max_value) if max_value is not None else None
self._min_value = min_value
self._max_value = max_value

def apply_func(self, sample):
for func in (func for func in [self._force_min, self._force_max] if func is not None):
sample = func.apply_func(sample)
return sample
return sample.clip(min=self._min_value, max=self._max_value)


class Derive(Operation):
Expand Down Expand Up @@ -204,7 +219,7 @@ def __init__(
**derivations (Union[str, tuple[str, dict[str, Any]]]):
Kwarg form of `derivation`.
"""
super().__init__(split_tuples=True, recursively_split_tuples=True, recognised_types=(xr.DataArray, xr.Dataset))
super().__init__(split_tuples=True, recursively_split_tuples=True, recognised_types=(xr.Dataset,))
self.record_initialisation()

derivation = derivation or {}
Expand Down
160 changes: 160 additions & 0 deletions packages/pipeline/tests/operations/dask/test_dask_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright Commonwealth of Australia, Bureau of Meteorology 2025.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pyearthtools.pipeline.operations.dask import values

import dask.array as da
import numpy as np
import pytest


@pytest.fixture(scope="module")
def example_data():
return da.array(
[
[1, 2, 3],
[-np.inf, np.inf, -np.inf],
[1, np.nan, np.nan],
]
)


def test_fillnan(example_data):
"""Tests dask FillNan operation class."""
op = values.FillNan(nan=123, posinf=456, neginf=-789)
result = op.apply_func(example_data)

assert (
(
result
== da.array(
[
[1, 2, 3],
[-789, 456, -789],
[1, 123, 123],
]
)
)
.all()
.compute()
)


def test_maskvalue(example_data):
"""Tests dask MaskValue operation class."""

# pass invalid operation
with pytest.raises(KeyError):
values.MaskValue(1, operation="*")

# test default op (==)
op = values.MaskValue(1)
result = op.apply_func(example_data)

assert da.allclose(
result,
da.array(
[
[np.nan, 2.0, 3.0],
[-np.inf, np.inf, -np.inf],
[np.nan, np.nan, np.nan],
],
),
equal_nan=True,
)

# test <= op
op = values.MaskValue(2, operation="<=")
result = op.apply_func(example_data)

assert da.allclose(
result,
da.array(
[
[np.nan, np.nan, 3.0],
[np.nan, np.inf, np.nan],
[np.nan, np.nan, np.nan],
],
),
equal_nan=True,
)

# test < op
op = values.MaskValue(2, operation="<")
result = op.apply_func(example_data)

assert da.allclose(
result,
da.array(
[
[np.nan, 2.0, 3.0],
[np.nan, np.inf, np.nan],
[np.nan, np.nan, np.nan],
],
),
equal_nan=True,
)

# test >= op
op = values.MaskValue(2, operation=">=")
result = op.apply_func(example_data)

assert np.allclose(
result,
da.array(
[
[1.0, np.nan, np.nan],
[-np.inf, np.nan, -np.inf],
[1.0, np.nan, np.nan],
],
dtype=result.dtype,
),
equal_nan=True,
)

# test > op
op = values.MaskValue(2, operation=">")
result = op.apply_func(example_data)

assert np.allclose(
result,
da.array(
[
[1.0, 2.0, np.nan],
[-np.inf, np.nan, -np.inf],
[1.0, np.nan, np.nan],
],
dtype=result.dtype,
),
equal_nan=True,
)


def test_clip(example_data):
"""Tests dask Clip operation class."""
op = values.Clip()
result = op.apply_func(example_data)

assert da.allclose(
result,
da.array(
[
[1.0, 1.0, 1.0],
[0.0, 1.0, 0.0],
[1.0, np.nan, np.nan],
],
dtype=result.dtype,
),
equal_nan=True,
)
Loading