diff --git a/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py b/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py index b8f1ce3f..d242ab36 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py +++ b/packages/pipeline/src/pyearthtools/pipeline/modifications/idx_modification.py @@ -33,8 +33,8 @@ try: import dask.array as da from dask.delayed import Delayed, delayed -except (ImportError, ModuleNotFoundError) as _: - DASK_IMPORTED = False +except (ImportError, ModuleNotFoundError) as _: # pragma: no cover # manually tested + DASK_IMPORTED = False # pragma: no cover MERGE_FUNCTIONS = { xr.Dataset: xr.combine_by_coords, @@ -169,8 +169,8 @@ def trim(s): if self._merge_function is not None: return self._merge_function(sample, **self._merge_kwargs) - if DASK_IMPORTED and types[0] == Delayed: - return delayed(self._run_merge)(sample) + if DASK_IMPORTED and types[0] == Delayed: # pragma: no cover + return delayed(self._run_merge)(sample) # pragma: no cover if types[0] not in MERGE_FUNCTIONS: warnings.warn(f"Cannot merge samples of type {types[0]}.", PipelineWarning) @@ -255,6 +255,7 @@ def map_to_time(mod): return tuple(map(map_to_time, mod)) return pyearthtools.data.TimeDelta(mod) + # packs modifications and extra_mods into single tuple if extra_mods: modification = ( *(modification if isinstance(modification, tuple) else (modification,)), diff --git a/packages/pipeline/tests/pipeline/test_idx_mod.py b/packages/pipeline/tests/pipeline/test_idx_mod.py index ebbfc0d1..5e9a07c3 100644 --- a/packages/pipeline/tests/pipeline/test_idx_mod.py +++ b/packages/pipeline/tests/pipeline/test_idx_mod.py @@ -19,9 +19,18 @@ import pyearthtools.utils -from pyearthtools.pipeline import Pipeline, modifications -from pyearthtools.pipeline import Operation -from pyearthtools.data import Index +from pyearthtools.pipeline import Pipeline, Operation +from pyearthtools.data import Index, Petdt +from pyearthtools.pipeline.modifications import ( + IdxModifier, + IdxOverride, + TimeIdxModifier, + TemporalWindow, + TemporalRetrieval, +) +from pyearthtools.data.time import TimeDelta +import xarray as xr +import numpy as np class FakeIndex(Index): @@ -57,33 +66,18 @@ def test_multiplication_undo(): assert orig == 2 -def test_IdxModifier_basic(): - pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0,))) - assert pipe[0] == (0,) - - -def test_IdxModifier_basic_no_tuple(): - pipe = Pipeline(FakeIndex(), modifications.IdxModifier(0)) - assert pipe[0] == 0 - - -def test_IdxModifier_two_samples(): - pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0, 1))) - assert pipe[0] == (0, 1) - - -def test_IdxModifier_nested(): - pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0, (1, 2)))) - assert pipe[0] == (0, (1, 2)) - - -def test_IdxModifier_nested_double(): - pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0, (1, (2, 3))))) - assert pipe[0] == (0, (1, (2, 3))) +@pytest.mark.parametrize("mod", ((0,), 0, (0, 1), (0, (1, 2)), (0, (1, (2, 3))))) +def test_IdxModifier_basic(mod): + pipe = Pipeline(FakeIndex(), IdxModifier(mod)) + assert pipe[0] == mod + # check that extra_mods gets passed through + if type(mod) is tuple and len(mod) > 1: + pipe = Pipeline(FakeIndex(), IdxModifier(*mod)) + assert pipe[0] == mod def test_IdxModifier_nested_merge(): - pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0, (1, 2)), merge=True, merge_function=sum)) + pipe = Pipeline(FakeIndex(), IdxModifier((0, (1, 2)), merge=True, merge_function=sum)) assert pipe[0] == (0, 3) @@ -99,7 +93,7 @@ def test_IdxModifier_nested_merge(): def test_IdxModifier_merge_depth(depth, result): pipe = Pipeline( FakeIndex(), - modifications.IdxModifier((1, (2, (3, 4))), merge=depth, merge_function=sum), + IdxModifier((1, (2, (3, 4))), merge=depth, merge_function=sum), ) assert pipe[0] == result @@ -107,7 +101,7 @@ def test_IdxModifier_merge_depth(depth, result): def test_IdxModifier_unmergeable(): pipe = Pipeline( FakeIndex("test"), # type: ignore - modifications.IdxModifier(("t", "a"), merge=True, merge_function=sum), + IdxModifier(("t", "a"), merge=True, merge_function=sum), ) with pytest.raises(TypeError): assert pipe[1] == (1, 5) @@ -116,8 +110,8 @@ def test_IdxModifier_unmergeable(): def test_IdxMod_stacked(): pipe = Pipeline( FakeIndex(), - modifications.IdxModifier((0, 1)), - modifications.IdxModifier((0, 1)), + IdxModifier((0, 1)), + IdxModifier((0, 1)), ) assert pipe[1] == ((1, 2), (2, 3)) @@ -125,8 +119,8 @@ def test_IdxMod_stacked(): def test_IdxMod_stacked_with_mult(): pipe = Pipeline( FakeIndex(), - modifications.IdxModifier((0, 1)), - modifications.IdxModifier((0, 1)), + IdxModifier((0, 1)), + IdxModifier((0, 1)), MultiplicationOperation(2), ) assert pipe[1] == ((2, 4), (4, 6)) @@ -135,7 +129,7 @@ def test_IdxMod_stacked_with_mult(): def test_IdxMod_with_branch(): pipe = Pipeline( FakeIndex(), - modifications.IdxModifier((0, 1)), + IdxModifier((0, 1)), ( (MultiplicationOperation(1),), (MultiplicationOperation(2),), @@ -147,7 +141,7 @@ def test_IdxMod_with_branch(): def test_IdxMod_with_branch_mapping(): pipe = Pipeline( FakeIndex(), - modifications.IdxModifier((0, 1)), + IdxModifier((0, 1)), ((MultiplicationOperation(1),), (MultiplicationOperation(2),), "map"), ) assert pipe[1] == (1, 4) @@ -157,7 +151,7 @@ def test_IdxMod_with_branch_mapping(): def test_IdxOverride_basic(): - pipe = Pipeline(FakeIndex(), modifications.IdxOverride(0)) + pipe = Pipeline(FakeIndex(), IdxOverride(0)) assert pipe[1] == 0 @@ -165,23 +159,117 @@ def test_IdxOverride_basic(): def test_TimeIdxModifier_basic(): - import pyearthtools.data - pipe = Pipeline(FakeIndex(), modifications.TimeIdxModifier("6 hours")) - assert pipe[pyearthtools.data.Petdt("2000-01-01T00")] == pyearthtools.data.Petdt("2000-01-01T06") + pipe = Pipeline(FakeIndex(), TimeIdxModifier("6 hours")) + assert pipe[Petdt("2000-01-01T00")] == Petdt("2000-01-01T06") # def test_TimeIdxModifier_basic_tuple(): # import pyearthtools.data # pipe = Pipeline(FakeIndex(), pipelines.TimeIdxModifier((6, 'hours'))) -# assert pipe[pyearthtools.data.Petdt('2000-01-01T00')] == pyearthtools.data.Petdt('2000-01-01T06') +# assert pipe[Petdt('2000-01-01T00')] == Petdt('2000-01-01T06') def test_TimeIdxModifier_nested(): - import pyearthtools.data - pipe = Pipeline(FakeIndex(), modifications.TimeIdxModifier(("6 hours", "12 hours"))) - assert pipe[pyearthtools.data.Petdt("2000-01-01T00")] == ( - pyearthtools.data.Petdt("2000-01-01T06"), - pyearthtools.data.Petdt("2000-01-01T12"), + pipe = Pipeline(FakeIndex(), TimeIdxModifier(("6 hours", "12 hours"))) + assert pipe[Petdt("2000-01-01T00")] == ( + Petdt("2000-01-01T06"), + Petdt("2000-01-01T12"), + ) + + +def test_TimeIdxModifier_extramods(): + """Tests TimeIdxModifier with modifications passed as variable args (extra_mods)""" + # first arg to TimeIdxModifier goes to "modifications" and second goes to extra_args + pipe = Pipeline(FakeIndex(), TimeIdxModifier("6 hours", "12 hours")) + assert pipe[Petdt("2000-01-01T00")] == ( + Petdt("2000-01-01T06"), + Petdt("2000-01-01T12"), + ) + + +class test_data_accessor(Index): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._data = { + "2025-01-01": 1, + "2025-01-02": 2, + "2025-01-03": 3, + } + + def get(self, time): + return self._data[time] + + +@pytest.mark.parametrize("merge_method, expected", ((None, (1, 3)), (sum, 4))) +def test_TemporalRetrieval(merge_method, expected): + """Test temporal retrieval.""" + + # instantiate temporal retrieval with merge_method - retrieving from two steps behind. + temporal_retrieval_step = TemporalRetrieval(-2, merge_function=merge_method, delta_unit="day") + + pipeline = Pipeline(test_data_accessor(), temporal_retrieval_step) + + assert expected == pipeline[Petdt("2025-01-03")] + + +def test_TemporalRetrieval_xarrayaccessor(): + """Tests temporal retrieval default merger for xarray data.""" + + # insantiate TemporalRetrieval without merge method + temporal_retrieval_step = TemporalRetrieval(-2) + + # create a data accessor with fake xarray data. + data_accessor = test_data_accessor() + data_accessor._data = {date: xr.DataArray([val] * 2, name=f"arr{val}") for date, val in data_accessor._data.items()} + pipeline = Pipeline(data_accessor, temporal_retrieval_step) + + assert xr.merge((data_accessor["2025-01-01"], data_accessor["2025-01-03"])) == pipeline["2025-01-03"] + + +def test_TemporalRetrieval_npconcat(): + """Tests temporal retrieval default merger for numpy data.""" + + # insantiate TemporalRetrieval with concat + temporal_retrieval_step = TemporalRetrieval(-2, concat=True) + + # create a data accessor with fake numpy data. + data_accessor = test_data_accessor() + data_accessor._data = {date: np.array((val, val + 1)) for date, val in data_accessor._data.items()} + pipeline = Pipeline(data_accessor, temporal_retrieval_step) + + assert np.array_equal(np.array((1, 2, 3, 4)), pipeline["2025-01-03"]) + + +def test_TemporalRetrieval_invalid(): + """Tests errors when using/instantiating TemporalRetrieval.""" + with pytest.raises(ValueError): + TemporalRetrieval(None) # index ought to be int or iterable of ints + tr = TemporalRetrieval(-1) + with pytest.raises(TypeError): + tr["a"] # not convertable to Petdt + + # this is actually covering a type error in IdxModifier._run_merge + invalid_accessor = test_data_accessor() + invalid_accessor._data["2025-01-01"] = "a" + pipeline = Pipeline(invalid_accessor, TemporalRetrieval(-2)) + with pytest.raises(TypeError): + pipeline["2025-01-03"] + + +@pytest.mark.parametrize("merge_method, expected", ((None, ([1, 2], [3])), (sum, (3, 3)))) +def test_TemporalWindow(merge_method, expected): + """Test temporal window.""" + + # instantiate temporal window with merge method + temporal_window_step = TemporalWindow( + prior_indexes=[-2, -1], + posterior_indexes=[0], + timedelta=TimeDelta((1, "day")), + merge_method=merge_method, ) + + # Instantiate pipeline with test data and temporal window + pipeline = Pipeline(test_data_accessor(), temporal_window_step) + assert expected == pipeline["2025-01-03"]