-
Notifications
You must be signed in to change notification settings - Fork 24
Add pipeline value operations tests #234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
9157a6a
f22c67a
23f91f1
c1a6a9c
e4cceb6
578b9fd
5ef7a6e
eba6df0
f8f50c2
085abea
898786c
241c73b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,7 +53,6 @@ def __init__( | |
| Value to be used to fill negative infinity values, | ||
| If no value is passed then negative infinity values will be replaced with a very small (or negative) number. Defaults to None. | ||
| """ | ||
| raise NotImplementedError("Not implemented") | ||
|
|
||
| super().__init__( | ||
| operation="apply", | ||
|
|
@@ -69,7 +68,7 @@ def __init__( | |
| self.neginf = neginf | ||
|
|
||
| def apply_func(self, sample: da.Array): | ||
| return da.nan_to_num(da.array(sample), self.nan, self.posinf, self.neginf) | ||
| return da.nan_to_num(da.array(sample), True, self.nan, self.posinf, self.neginf) | ||
|
|
||
|
|
||
| class MaskValue(DaskOperation): | ||
|
|
@@ -118,7 +117,7 @@ def __init__( | |
| self.value = value | ||
| self.replacement_value = replacement_value | ||
|
|
||
| self._mask_transform = pyearthtools.data.transforms.mask.replace_value(value, operation, replacement_value) | ||
| self._mask_transform = pyearthtools.data.transforms.mask.Replace(value, operation, replacement_value) | ||
|
|
||
| def apply_func(self, sample: da.Array) -> da.Array: | ||
| """ | ||
|
|
@@ -135,9 +134,9 @@ def apply_func(self, sample: da.Array) -> da.Array: | |
| return self._mask_transform(sample) # type: ignore | ||
|
|
||
|
|
||
| class ForceNormalised(DaskOperation): | ||
| class Clip(DaskOperation): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm fine with this change, and there are no issues with dependencies on it at the moment. However, the docs will need to be updated (e.g. docs/api/pipeline/pipeline_api.md) and the notebooks (I think maybe only docs/notebooks/pipeline/Operations.ipynb). |
||
| """ | ||
| Operation to force data within a certain range, by default 0 & 1 | ||
| Operation to force data to be within a certain range, by default 0 & 1 | ||
| """ | ||
|
|
||
| _override_interface = ["Serial"] | ||
|
|
@@ -170,10 +169,8 @@ def __init__( | |
|
|
||
| self.record_initialisation() | ||
|
|
||
| self._force_min = MaskValue(min_value, "<", min_value) if min_value is not None else None | ||
| self._force_max = MaskValue(max_value, ">", max_value) if max_value is not None else None | ||
| self._min_value = min_value | ||
| self._max_value = max_value | ||
|
|
||
| def apply_func(self, sample): | ||
| for func in (func for func in [self._force_min, self._force_max] if func is not None): | ||
| sample = func.apply_func(sample) | ||
| return sample | ||
| return da.clip(sample, self._min_value, self._max_value) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,7 +67,24 @@ def __init__( | |
| self.neginf = neginf | ||
|
|
||
| def apply_func(self, sample: T) -> T: | ||
| return sample.fillna(self.nan) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just reading the docstring above (it won't let me comment on unchanged lines), it says "If no value is passed then positive infinity values will be replaced with a very large number. Defaults to None.". But None is not a very large number. So by default that comment doesn't parse for me.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| if not (isinstance(sample, xr.DataArray) or isinstance(sample, xr.Dataset)): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we should have a general xarray subclass of operation which handles xarray type checking rather than doing it down in apply_func... this is okay, but it might be better to put it up higher.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did have a similar review comment, but I didn't mention it because I wasn't familiar enough with the codebase to suggest where an ideal place would be, and how much impact putting it at a higher level would have on everything else. I guess (in hindsight) those would be the considerations, and also if it needs to be in a separate issue. But yes, it does seem like a very common "entrypoint" check that applies to many things. |
||
| raise TypeError("sample must be xr.DataArray or xr.Dataset.") | ||
|
|
||
| # create copy of input, with np.nan_to_num applied to underlying numpy arrays | ||
| if isinstance(sample, xr.DataArray): | ||
| return sample.copy( | ||
| deep=True, # since data is provided, deep copy only applies to coordinates | ||
| data=np.nan_to_num(sample.values, nan=self.nan, posinf=self.posinf, neginf=self.neginf), | ||
| ) | ||
| else: | ||
| return sample.copy( | ||
| deep=True, | ||
| data={ | ||
| k: np.nan_to_num(v.values, nan=self.nan, posinf=self.posinf, neginf=self.neginf) | ||
| for k, v in sample.items() | ||
| }, | ||
| ) | ||
|
|
||
|
|
||
| class MaskValue(Operation): | ||
|
|
@@ -132,9 +149,9 @@ def apply_func(self, sample: T) -> T: | |
| return self._mask_transform(sample) | ||
|
|
||
|
|
||
| class ForceNormalised(Operation): | ||
| class Clip(Operation): | ||
| """ | ||
| Operation to force data within a certain range, by default 0 & 1 | ||
| Operation to force data to be within a certain range, by default 0 & 1 | ||
| """ | ||
|
|
||
| _override_interface = "Serial" | ||
|
|
@@ -166,13 +183,11 @@ def __init__( | |
|
|
||
| self.record_initialisation() | ||
|
|
||
| self._force_min = MaskValue(min_value, "<", min_value) if min_value is not None else None | ||
| self._force_max = MaskValue(max_value, ">", max_value) if max_value is not None else None | ||
| self._min_value = min_value | ||
| self._max_value = max_value | ||
|
|
||
| def apply_func(self, sample): | ||
| for func in (func for func in [self._force_min, self._force_max] if func is not None): | ||
| sample = func.apply_func(sample) | ||
| return sample | ||
| return sample.clip(min=self._min_value, max=self._max_value) | ||
|
|
||
|
|
||
| class Derive(Operation): | ||
|
|
@@ -204,7 +219,7 @@ def __init__( | |
| **derivations (Union[str, tuple[str, dict[str, Any]]]): | ||
| Kwarg form of `derivation`. | ||
| """ | ||
| super().__init__(split_tuples=True, recursively_split_tuples=True, recognised_types=(xr.DataArray, xr.Dataset)) | ||
| super().__init__(split_tuples=True, recursively_split_tuples=True, recognised_types=(xr.Dataset,)) | ||
| self.record_initialisation() | ||
|
|
||
| derivation = derivation or {} | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| # Copyright Commonwealth of Australia, Bureau of Meteorology 2025. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from pyearthtools.pipeline.operations.dask import values | ||
|
|
||
| import dask.array as da | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def example_data(): | ||
| return da.array( | ||
| [ | ||
| [1, 2, 3], | ||
| [-np.inf, np.inf, -np.inf], | ||
| [1, np.nan, np.nan], | ||
| ] | ||
| ) | ||
|
|
||
|
|
||
| def test_fillnan(example_data): | ||
| """Tests dask FillNan operation class.""" | ||
| op = values.FillNan(nan=123, posinf=456, neginf=-789) | ||
| result = op.apply_func(example_data) | ||
|
|
||
| assert ( | ||
| ( | ||
| result | ||
| == da.array( | ||
| [ | ||
| [1, 2, 3], | ||
| [-789, 456, -789], | ||
| [1, 123, 123], | ||
| ] | ||
| ) | ||
| ) | ||
| .all() | ||
| .compute() | ||
| ) | ||
|
|
||
|
|
||
| def test_maskvalue(example_data): | ||
| """Tests dask MaskValue operation class.""" | ||
|
|
||
| # pass invalid operation | ||
| with pytest.raises(KeyError): | ||
| values.MaskValue(1, operation="*") | ||
|
|
||
| # test default op (==) | ||
| op = values.MaskValue(1) | ||
| result = op.apply_func(example_data) | ||
|
|
||
| assert da.allclose( | ||
| result, | ||
| da.array( | ||
| [ | ||
| [np.nan, 2.0, 3.0], | ||
| [-np.inf, np.inf, -np.inf], | ||
| [np.nan, np.nan, np.nan], | ||
| ], | ||
| ), | ||
| equal_nan=True, | ||
| ) | ||
|
|
||
| # test <= op | ||
| op = values.MaskValue(2, operation="<=") | ||
| result = op.apply_func(example_data) | ||
|
|
||
| assert da.allclose( | ||
| result, | ||
| da.array( | ||
| [ | ||
| [np.nan, np.nan, 3.0], | ||
| [np.nan, np.inf, np.nan], | ||
| [np.nan, np.nan, np.nan], | ||
| ], | ||
| ), | ||
| equal_nan=True, | ||
| ) | ||
|
|
||
| # test < op | ||
| op = values.MaskValue(2, operation="<") | ||
| result = op.apply_func(example_data) | ||
|
|
||
| assert da.allclose( | ||
| result, | ||
| da.array( | ||
| [ | ||
| [np.nan, 2.0, 3.0], | ||
| [np.nan, np.inf, np.nan], | ||
| [np.nan, np.nan, np.nan], | ||
| ], | ||
| ), | ||
| equal_nan=True, | ||
| ) | ||
|
|
||
| # test >= op | ||
| op = values.MaskValue(2, operation=">=") | ||
| result = op.apply_func(example_data) | ||
|
|
||
| assert np.allclose( | ||
| result, | ||
| da.array( | ||
| [ | ||
| [1.0, np.nan, np.nan], | ||
| [-np.inf, np.nan, -np.inf], | ||
| [1.0, np.nan, np.nan], | ||
| ], | ||
| dtype=result.dtype, | ||
| ), | ||
| equal_nan=True, | ||
| ) | ||
|
|
||
| # test > op | ||
| op = values.MaskValue(2, operation=">") | ||
| result = op.apply_func(example_data) | ||
|
|
||
| assert np.allclose( | ||
| result, | ||
| da.array( | ||
| [ | ||
| [1.0, 2.0, np.nan], | ||
| [-np.inf, np.nan, -np.inf], | ||
| [1.0, np.nan, np.nan], | ||
| ], | ||
| dtype=result.dtype, | ||
| ), | ||
| equal_nan=True, | ||
| ) | ||
|
|
||
|
|
||
| def test_clip(example_data): | ||
| """Tests dask Clip operation class.""" | ||
| op = values.Clip() | ||
| result = op.apply_func(example_data) | ||
|
|
||
| assert da.allclose( | ||
| result, | ||
| da.array( | ||
| [ | ||
| [1.0, 1.0, 1.0], | ||
| [0.0, 1.0, 0.0], | ||
| [1.0, np.nan, np.nan], | ||
| ], | ||
| dtype=result.dtype, | ||
| ), | ||
| equal_nan=True, | ||
| ) |

Uh oh!
There was an error while loading. Please reload this page.