From ebcd8dba4a0bb847cca7cb780977435df41d7e72 Mon Sep 17 00:00:00 2001 From: Michael Pegios Date: Thu, 13 Nov 2025 11:17:13 +1100 Subject: [PATCH 1/3] Add test for supersampler __add__ dunder --- packages/pipeline/tests/test_sampling.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/packages/pipeline/tests/test_sampling.py b/packages/pipeline/tests/test_sampling.py index 30c39ff8..4057fa0f 100644 --- a/packages/pipeline/tests/test_sampling.py +++ b/packages/pipeline/tests/test_sampling.py @@ -18,6 +18,13 @@ from tests.fake_pipeline_steps import FakeIndex +def _get_elements(obj): + if isinstance(obj, tuple): + return obj + else: + return (obj,) + + @pytest.mark.parametrize( "sampler,length", [ @@ -39,6 +46,13 @@ def test_samplers(sampler, length): if length is not None: assert len(list(pipe)) == length, "Length differs from expected" + super_sampler = samplers.SuperSampler(sampler) + super_sampler + super_sampler + + assert isinstance(super_sampler, samplers.Sampler) + for sample in _get_elements(sampler): + assert isinstance(sample + sample, samplers.Sampler) + iteration_1 = list(pipe) iteration_2 = list(pipe) From aabb630dce087682f2c1b928a990c37f5142306f Mon Sep 17 00:00:00 2001 From: Michael Pegios Date: Thu, 13 Nov 2025 11:17:57 +1100 Subject: [PATCH 2/3] create test for add_coordinate operation transform --- .../transform/test_add_coordinates.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 packages/pipeline/tests/operations/transform/test_add_coordinates.py diff --git a/packages/pipeline/tests/operations/transform/test_add_coordinates.py b/packages/pipeline/tests/operations/transform/test_add_coordinates.py new file mode 100644 index 00000000..5e913606 --- /dev/null +++ b/packages/pipeline/tests/operations/transform/test_add_coordinates.py @@ -0,0 +1,44 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import numpy as np +import xarray as xr + +from pyearthtools.pipeline.operations.transform import add_coordinates + + +def test_time(): + times = [datetime.datetime(2020, 1, 1) + datetime.timedelta(days=day) for day in range(365)] + + da = xr.DataArray(coords={"time": times, "level": [1, 2]}, data=np.ones((365, 2))) + + ds = xr.Dataset(coords={"time": times, "level": [1, 2]}, data_vars={"temperature": da}) + + orig_ds = ds.copy() + + # Test using a coordinate that doesn't exist returns same dataset + toy = add_coordinates.AddCoordinates("doesnt_exist") + result = toy.apply(ds) + + assert result.equals(orig_ds) + assert toy._info_ == {"coordinates": ["doesnt_exist"]} + + # Test coordinate adds to variable + toy = add_coordinates.AddCoordinates("time") + result = toy.apply(ds) + + assert "var_time" in result.data_vars + assert toy._info_ == {"coordinates": ["time"]} From 5301583cb5d6efbe55f0c6a25898f608e5a6a368 Mon Sep 17 00:00:00 2001 From: Michael Pegios Date: Thu, 13 Nov 2025 11:18:26 +1100 Subject: [PATCH 3/3] Create tests for data transform variables * added test for Drop * added test for Select * added test for Trim --- .../data/tests/transform/test_variables.py | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 packages/data/tests/transform/test_variables.py diff --git a/packages/data/tests/transform/test_variables.py b/packages/data/tests/transform/test_variables.py new file mode 100644 index 00000000..1f5060b9 --- /dev/null +++ b/packages/data/tests/transform/test_variables.py @@ -0,0 +1,112 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import xarray as xr + +from pyearthtools.data.transforms.variables import Drop, Select, Trim + + +def test_trim(): + data = xr.Dataset( + { + "var_to_keep": ("time", [0, 0, 7]), + "var_to_ignore1": ("time", [0, 0, 7]), + "var_to_ignore2": ("time", [0, 0, 7]), + } + ) + + # Test variable drops var as intended + orig_data = data.copy() + transform = Trim("var_to_keep") + + transformed_data = transform.apply(data) + assert "var_to_ignore1" not in transformed_data.data_vars + assert "var_to_ignore2" not in transformed_data.data_vars + + # Check transformed dataset hasn't been modified + assert data.equals(orig_data) + + # Test a var that is not in the list returns the same dataset + transform = Trim("var_not_found") + transformed_data = transform.apply(data) + assert transformed_data.equals(orig_data) + + # Test None returns the same dataset + transform = Trim(None) + transformed_data = transform.apply(data) + assert transformed_data.equals(orig_data) + + +def test_drop(): + data = xr.Dataset( + { + "var_to_keep": ("time", [0, 0, 7]), + "var_to_drop": ("time", [0, 0, 7]), + } + ) + + # Test variable drops var as intended + orig_data = data.copy() + transform = Drop("var_to_drop") + transformed_data = transform.apply(data) + assert "var_to_drop" not in transformed_data.data_vars + # We don't want to modify the original data, so testing it's not in place + assert data.equals(orig_data) + + # Test empty list returns the same dataset + transform = Drop([]) + transformed_data = transform.apply(data) + assert data.equals(transformed_data) + + # Test a var that is not in the list + with pytest.raises(ValueError): + transform = Drop("var_not_found") + _ = transform.apply(data) + + # Test None raise valueError + with pytest.raises(ValueError): + transform = Drop(None) + _ = transform.apply(data) + + +def test_select(): + data = xr.Dataset( + { + "var_to_keep": ("time", [0, 0, 7]), + "var_to_ignore1": ("time", [0, 0, 7]), + "var_to_ignore2": ("time", [0, 0, 7]), + } + ) + orig_data = data.copy() + + transform = Select("var_to_keep") + transformed_data = transform.apply(data) + assert "var_to_ignore1" not in transformed_data.data_vars + assert "var_to_ignore2" not in transformed_data.data_vars + + # Test empty list returns the same dataset + transform = Select([]) + transformed_data = transform.apply(data) + assert orig_data.equals(transformed_data) + + # Test a var that is not in the list returns the same dataset + transform = Select("var_not_found") + transformed_data = transform.apply(data) + assert transformed_data.equals(orig_data) + + # Test None returns the same dataset + transform = Select(None) + transformed_data = transform.apply(data) + assert transformed_data.equals(orig_data)