From c6660c28fe36556e738a7870d94e7fcb195f67b2 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 12 Nov 2025 09:29:34 +1100 Subject: [PATCH 1/4] add TemporalRetrieval test and move TemporalWindow test --- .../pipeline/tests/pipeline/test_idx_mod.py | 104 ++++++++++++++---- 1 file changed, 84 insertions(+), 20 deletions(-) diff --git a/packages/pipeline/tests/pipeline/test_idx_mod.py b/packages/pipeline/tests/pipeline/test_idx_mod.py index ebbfc0d1..d22c54e7 100644 --- a/packages/pipeline/tests/pipeline/test_idx_mod.py +++ b/packages/pipeline/tests/pipeline/test_idx_mod.py @@ -19,9 +19,16 @@ import pyearthtools.utils -from pyearthtools.pipeline import Pipeline, modifications -from pyearthtools.pipeline import Operation -from pyearthtools.data import Index +from pyearthtools.pipeline import Pipeline, Operation +from pyearthtools.data import Index, Petdt +from pyearthtools.pipeline.modifications import ( + IdxModifier, + IdxOverride, + TimeIdxModifier, + TemporalWindow, + TemporalRetrieval, +) +from pyearthtools.data.time import TimeDelta class FakeIndex(Index): @@ -58,32 +65,32 @@ def test_multiplication_undo(): def test_IdxModifier_basic(): - pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0,))) + pipe = Pipeline(FakeIndex(), IdxModifier((0,))) assert pipe[0] == (0,) def test_IdxModifier_basic_no_tuple(): - pipe = Pipeline(FakeIndex(), modifications.IdxModifier(0)) + pipe = Pipeline(FakeIndex(), IdxModifier(0)) assert pipe[0] == 0 def test_IdxModifier_two_samples(): - pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0, 1))) + pipe = Pipeline(FakeIndex(), IdxModifier((0, 1))) assert pipe[0] == (0, 1) def test_IdxModifier_nested(): - pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0, (1, 2)))) + pipe = Pipeline(FakeIndex(), IdxModifier((0, (1, 2)))) assert pipe[0] == (0, (1, 2)) def test_IdxModifier_nested_double(): - pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0, (1, (2, 3))))) + pipe = Pipeline(FakeIndex(), IdxModifier((0, (1, (2, 3))))) assert pipe[0] == (0, (1, (2, 3))) def test_IdxModifier_nested_merge(): - pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0, (1, 2)), merge=True, merge_function=sum)) + pipe = Pipeline(FakeIndex(), IdxModifier((0, (1, 2)), merge=True, merge_function=sum)) assert pipe[0] == (0, 3) @@ -99,7 +106,7 @@ def test_IdxModifier_nested_merge(): def test_IdxModifier_merge_depth(depth, result): pipe = Pipeline( FakeIndex(), - modifications.IdxModifier((1, (2, (3, 4))), merge=depth, merge_function=sum), + IdxModifier((1, (2, (3, 4))), merge=depth, merge_function=sum), ) assert pipe[0] == result @@ -107,7 +114,7 @@ def test_IdxModifier_merge_depth(depth, result): def test_IdxModifier_unmergeable(): pipe = Pipeline( FakeIndex("test"), # type: ignore - modifications.IdxModifier(("t", "a"), merge=True, merge_function=sum), + IdxModifier(("t", "a"), merge=True, merge_function=sum), ) with pytest.raises(TypeError): assert pipe[1] == (1, 5) @@ -116,8 +123,8 @@ def test_IdxModifier_unmergeable(): def test_IdxMod_stacked(): pipe = Pipeline( FakeIndex(), - modifications.IdxModifier((0, 1)), - modifications.IdxModifier((0, 1)), + IdxModifier((0, 1)), + IdxModifier((0, 1)), ) assert pipe[1] == ((1, 2), (2, 3)) @@ -125,8 +132,8 @@ def test_IdxMod_stacked(): def test_IdxMod_stacked_with_mult(): pipe = Pipeline( FakeIndex(), - modifications.IdxModifier((0, 1)), - modifications.IdxModifier((0, 1)), + IdxModifier((0, 1)), + IdxModifier((0, 1)), MultiplicationOperation(2), ) assert pipe[1] == ((2, 4), (4, 6)) @@ -135,7 +142,7 @@ def test_IdxMod_stacked_with_mult(): def test_IdxMod_with_branch(): pipe = Pipeline( FakeIndex(), - modifications.IdxModifier((0, 1)), + IdxModifier((0, 1)), ( (MultiplicationOperation(1),), (MultiplicationOperation(2),), @@ -147,7 +154,7 @@ def test_IdxMod_with_branch(): def test_IdxMod_with_branch_mapping(): pipe = Pipeline( FakeIndex(), - modifications.IdxModifier((0, 1)), + IdxModifier((0, 1)), ((MultiplicationOperation(1),), (MultiplicationOperation(2),), "map"), ) assert pipe[1] == (1, 4) @@ -157,7 +164,7 @@ def test_IdxMod_with_branch_mapping(): def test_IdxOverride_basic(): - pipe = Pipeline(FakeIndex(), modifications.IdxOverride(0)) + pipe = Pipeline(FakeIndex(), IdxOverride(0)) assert pipe[1] == 0 @@ -167,7 +174,7 @@ def test_IdxOverride_basic(): def test_TimeIdxModifier_basic(): import pyearthtools.data - pipe = Pipeline(FakeIndex(), modifications.TimeIdxModifier("6 hours")) + pipe = Pipeline(FakeIndex(), TimeIdxModifier("6 hours")) assert pipe[pyearthtools.data.Petdt("2000-01-01T00")] == pyearthtools.data.Petdt("2000-01-01T06") @@ -180,8 +187,65 @@ def test_TimeIdxModifier_basic(): def test_TimeIdxModifier_nested(): import pyearthtools.data - pipe = Pipeline(FakeIndex(), modifications.TimeIdxModifier(("6 hours", "12 hours"))) + pipe = Pipeline(FakeIndex(), TimeIdxModifier(("6 hours", "12 hours"))) assert pipe[pyearthtools.data.Petdt("2000-01-01T00")] == ( pyearthtools.data.Petdt("2000-01-01T06"), pyearthtools.data.Petdt("2000-01-01T12"), ) + + +class test_data_accessor(Index): + _data = { + "2025-01-01": 1, + "2025-01-02": 2, + "2025-01-03": 3, + } + + def get(self, time): + return self._data[time] + + +@pytest.mark.parametrize("merge_method", (None, sum)) +def test_temporal_retrieval(merge_method): + """Test temporal retrieval.""" + + # instantiate temporal retrieval with merge_method - retrieving from two steps behind. + temporal_retrieval_step = TemporalRetrieval(-2, merge_function=merge_method, delta_unit="day") + + pipeline = Pipeline(test_data_accessor(), temporal_retrieval_step) + + result = pipeline[Petdt("2025-01-03")] + if merge_method is sum: + assert result == test_data_accessor._data["2025-01-03"] + test_data_accessor._data["2025-01-01"] + elif merge_method is None: + assert result == (test_data_accessor._data["2025-01-01"], test_data_accessor._data["2025-01-03"]) + + +def test_temporal_retrieval_invalid(): + """Tests errors when using/instantiating TemporalRetrieval.""" + with pytest.raises(ValueError): + TemporalRetrieval(None) # index ought to be int or iterable of ints + tr = TemporalRetrieval(-1) + with pytest.raises(TypeError): + tr["a"] # not convertable to Petdt + + +@pytest.mark.parametrize("merge_method", (None, sum)) +def test_temporal_window(merge_method): + """Test temporal window.""" + + # instantiate temporal window with merge method + temporal_window_step = TemporalWindow( + prior_indexes=[-2, -1], + posterior_indexes=[0], + timedelta=TimeDelta((1, "day")), + merge_method=merge_method, + ) + + # Instantiate pipeline with test data and temporal window + pipeline = Pipeline(test_data_accessor(), temporal_window_step) + result = pipeline["2025-01-03"] + if merge_method is None: + assert result == ([1, 2], [3]) + elif merge_method is sum: + assert result == (3, 3) From 6d5505516ff3b4c6750a71585eba93ae19f71252 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 12 Nov 2025 09:38:44 +1100 Subject: [PATCH 2/4] cover extra_mods case of TimeIdxModifier --- .../modifications/idx_modification.py | 1 + .../pipeline/tests/pipeline/test_idx_mod.py | 44 +++++++++---------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py b/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py index b8f1ce3f..209635b6 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py +++ b/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py @@ -255,6 +255,7 @@ def map_to_time(mod): return tuple(map(map_to_time, mod)) return pyearthtools.data.TimeDelta(mod) + # packs modifications and extra_mods into single tuple if extra_mods: modification = ( *(modification if isinstance(modification, tuple) else (modification,)), diff --git a/packages/pipeline/tests/pipeline/test_idx_mod.py b/packages/pipeline/tests/pipeline/test_idx_mod.py index d22c54e7..ac06ee4e 100644 --- a/packages/pipeline/tests/pipeline/test_idx_mod.py +++ b/packages/pipeline/tests/pipeline/test_idx_mod.py @@ -172,25 +172,33 @@ def test_IdxOverride_basic(): def test_TimeIdxModifier_basic(): - import pyearthtools.data pipe = Pipeline(FakeIndex(), TimeIdxModifier("6 hours")) - assert pipe[pyearthtools.data.Petdt("2000-01-01T00")] == pyearthtools.data.Petdt("2000-01-01T06") + assert pipe[Petdt("2000-01-01T00")] == Petdt("2000-01-01T06") # def test_TimeIdxModifier_basic_tuple(): # import pyearthtools.data # pipe = Pipeline(FakeIndex(), pipelines.TimeIdxModifier((6, 'hours'))) -# assert pipe[pyearthtools.data.Petdt('2000-01-01T00')] == pyearthtools.data.Petdt('2000-01-01T06') +# assert pipe[Petdt('2000-01-01T00')] == Petdt('2000-01-01T06') def test_TimeIdxModifier_nested(): - import pyearthtools.data pipe = Pipeline(FakeIndex(), TimeIdxModifier(("6 hours", "12 hours"))) - assert pipe[pyearthtools.data.Petdt("2000-01-01T00")] == ( - pyearthtools.data.Petdt("2000-01-01T06"), - pyearthtools.data.Petdt("2000-01-01T12"), + assert pipe[Petdt("2000-01-01T00")] == ( + Petdt("2000-01-01T06"), + Petdt("2000-01-01T12"), + ) + + +def test_TimeIdxModifier_extramods(): + """Tests TimeIdxModifier with modifications passed as variable args (extra_mods)""" + # first arg to TimeIdxModifier goes to "modifications" and second goes to extra_args + pipe = Pipeline(FakeIndex(), TimeIdxModifier("6 hours", "12 hours")) + assert pipe[Petdt("2000-01-01T00")] == ( + Petdt("2000-01-01T06"), + Petdt("2000-01-01T12"), ) @@ -205,8 +213,8 @@ def get(self, time): return self._data[time] -@pytest.mark.parametrize("merge_method", (None, sum)) -def test_temporal_retrieval(merge_method): +@pytest.mark.parametrize("merge_method, expected", ((None, (1, 3)), (sum, 4))) +def test_TemporalRetrieval(merge_method, expected): """Test temporal retrieval.""" # instantiate temporal retrieval with merge_method - retrieving from two steps behind. @@ -214,14 +222,10 @@ def test_temporal_retrieval(merge_method): pipeline = Pipeline(test_data_accessor(), temporal_retrieval_step) - result = pipeline[Petdt("2025-01-03")] - if merge_method is sum: - assert result == test_data_accessor._data["2025-01-03"] + test_data_accessor._data["2025-01-01"] - elif merge_method is None: - assert result == (test_data_accessor._data["2025-01-01"], test_data_accessor._data["2025-01-03"]) + assert expected == pipeline[Petdt("2025-01-03")] -def test_temporal_retrieval_invalid(): +def test_TemporalRetrieval_invalid(): """Tests errors when using/instantiating TemporalRetrieval.""" with pytest.raises(ValueError): TemporalRetrieval(None) # index ought to be int or iterable of ints @@ -230,8 +234,8 @@ def test_temporal_retrieval_invalid(): tr["a"] # not convertable to Petdt -@pytest.mark.parametrize("merge_method", (None, sum)) -def test_temporal_window(merge_method): +@pytest.mark.parametrize("merge_method, expected", ((None, ([1, 2], [3])), (sum, (3, 3)))) +def test_TemporalWindow(merge_method, expected): """Test temporal window.""" # instantiate temporal window with merge method @@ -244,8 +248,4 @@ def test_temporal_window(merge_method): # Instantiate pipeline with test data and temporal window pipeline = Pipeline(test_data_accessor(), temporal_window_step) - result = pipeline["2025-01-03"] - if merge_method is None: - assert result == ([1, 2], [3]) - elif merge_method is sum: - assert result == (3, 3) + assert expected == pipeline["2025-01-03"] From f0a0dc3fa14eb02ea4886bed6a7700aad315df6e Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 12 Nov 2025 14:34:52 +1100 Subject: [PATCH 3/4] flag DASK_IMPORTED as not covered --- .../pipeline/modifications/idx_modification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py b/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py index 209635b6..d242ab36 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py +++ b/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py @@ -33,8 +33,8 @@ try: import dask.array as da from dask.delayed import Delayed, delayed -except (ImportError, ModuleNotFoundError) as _: - DASK_IMPORTED = False +except (ImportError, ModuleNotFoundError) as _: # pragma: no cover # manually tested + DASK_IMPORTED = False # pragma: no cover MERGE_FUNCTIONS = { xr.Dataset: xr.combine_by_coords, @@ -169,8 +169,8 @@ def trim(s): if self._merge_function is not None: return self._merge_function(sample, **self._merge_kwargs) - if DASK_IMPORTED and types[0] == Delayed: - return delayed(self._run_merge)(sample) + if DASK_IMPORTED and types[0] == Delayed: # pragma: no cover + return delayed(self._run_merge)(sample) # pragma: no cover if types[0] not in MERGE_FUNCTIONS: warnings.warn(f"Cannot merge samples of type {types[0]}.", PipelineWarning) From c53ff60d863ad69d6222ee1dd9fe341157d059e5 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 12 Nov 2025 14:36:18 +1100 Subject: [PATCH 4/4] test rest of idx_modification --- .../pipeline/tests/pipeline/test_idx_mod.py | 80 ++++++++++++------- 1 file changed, 52 insertions(+), 28 deletions(-) diff --git a/packages/pipeline/tests/pipeline/test_idx_mod.py b/packages/pipeline/tests/pipeline/test_idx_mod.py index ac06ee4e..5e9a07c3 100644 --- a/packages/pipeline/tests/pipeline/test_idx_mod.py +++ b/packages/pipeline/tests/pipeline/test_idx_mod.py @@ -29,6 +29,8 @@ TemporalRetrieval, ) from pyearthtools.data.time import TimeDelta +import xarray as xr +import numpy as np class FakeIndex(Index): @@ -64,29 +66,14 @@ def test_multiplication_undo(): assert orig == 2 -def test_IdxModifier_basic(): - pipe = Pipeline(FakeIndex(), IdxModifier((0,))) - assert pipe[0] == (0,) - - -def test_IdxModifier_basic_no_tuple(): - pipe = Pipeline(FakeIndex(), IdxModifier(0)) - assert pipe[0] == 0 - - -def test_IdxModifier_two_samples(): - pipe = Pipeline(FakeIndex(), IdxModifier((0, 1))) - assert pipe[0] == (0, 1) - - -def test_IdxModifier_nested(): - pipe = Pipeline(FakeIndex(), IdxModifier((0, (1, 2)))) - assert pipe[0] == (0, (1, 2)) - - -def test_IdxModifier_nested_double(): - pipe = Pipeline(FakeIndex(), IdxModifier((0, (1, (2, 3))))) - assert pipe[0] == (0, (1, (2, 3))) +@pytest.mark.parametrize("mod", ((0,), 0, (0, 1), (0, (1, 2)), (0, (1, (2, 3))))) +def test_IdxModifier_basic(mod): + pipe = Pipeline(FakeIndex(), IdxModifier(mod)) + assert pipe[0] == mod + # check that extra_mods gets passed through + if type(mod) is tuple and len(mod) > 1: + pipe = Pipeline(FakeIndex(), IdxModifier(*mod)) + assert pipe[0] == mod def test_IdxModifier_nested_merge(): @@ -203,11 +190,13 @@ def test_TimeIdxModifier_extramods(): class test_data_accessor(Index): - _data = { - "2025-01-01": 1, - "2025-01-02": 2, - "2025-01-03": 3, - } + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._data = { + "2025-01-01": 1, + "2025-01-02": 2, + "2025-01-03": 3, + } def get(self, time): return self._data[time] @@ -225,6 +214,34 @@ def test_TemporalRetrieval(merge_method, expected): assert expected == pipeline[Petdt("2025-01-03")] +def test_TemporalRetrieval_xarrayaccessor(): + """Tests temporal retrieval default merger for xarray data.""" + + # insantiate TemporalRetrieval without merge method + temporal_retrieval_step = TemporalRetrieval(-2) + + # create a data accessor with fake xarray data. + data_accessor = test_data_accessor() + data_accessor._data = {date: xr.DataArray([val] * 2, name=f"arr{val}") for date, val in data_accessor._data.items()} + pipeline = Pipeline(data_accessor, temporal_retrieval_step) + + assert xr.merge((data_accessor["2025-01-01"], data_accessor["2025-01-03"])) == pipeline["2025-01-03"] + + +def test_TemporalRetrieval_npconcat(): + """Tests temporal retrieval default merger for numpy data.""" + + # insantiate TemporalRetrieval with concat + temporal_retrieval_step = TemporalRetrieval(-2, concat=True) + + # create a data accessor with fake numpy data. + data_accessor = test_data_accessor() + data_accessor._data = {date: np.array((val, val + 1)) for date, val in data_accessor._data.items()} + pipeline = Pipeline(data_accessor, temporal_retrieval_step) + + assert np.array_equal(np.array((1, 2, 3, 4)), pipeline["2025-01-03"]) + + def test_TemporalRetrieval_invalid(): """Tests errors when using/instantiating TemporalRetrieval.""" with pytest.raises(ValueError): @@ -233,6 +250,13 @@ def test_TemporalRetrieval_invalid(): with pytest.raises(TypeError): tr["a"] # not convertable to Petdt + # this is actually covering a type error in IdxModifier._run_merge + invalid_accessor = test_data_accessor() + invalid_accessor._data["2025-01-01"] = "a" + pipeline = Pipeline(invalid_accessor, TemporalRetrieval(-2)) + with pytest.raises(TypeError): + pipeline["2025-01-03"] + @pytest.mark.parametrize("merge_method, expected", ((None, ([1, 2], [3])), (sum, (3, 3)))) def test_TemporalWindow(merge_method, expected):