Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
try:
import dask.array as da
from dask.delayed import Delayed, delayed
except (ImportError, ModuleNotFoundError) as _:
DASK_IMPORTED = False
except (ImportError, ModuleNotFoundError) as _: # pragma: no cover # manually tested
DASK_IMPORTED = False # pragma: no cover

MERGE_FUNCTIONS = {
xr.Dataset: xr.combine_by_coords,
Expand Down Expand Up @@ -169,8 +169,8 @@ def trim(s):
if self._merge_function is not None:
return self._merge_function(sample, **self._merge_kwargs)

if DASK_IMPORTED and types[0] == Delayed:
return delayed(self._run_merge)(sample)
if DASK_IMPORTED and types[0] == Delayed: # pragma: no cover
return delayed(self._run_merge)(sample) # pragma: no cover

if types[0] not in MERGE_FUNCTIONS:
warnings.warn(f"Cannot merge samples of type {types[0]}.", PipelineWarning)
Expand Down Expand Up @@ -255,6 +255,7 @@ def map_to_time(mod):
return tuple(map(map_to_time, mod))
return pyearthtools.data.TimeDelta(mod)

# packs modifications and extra_mods into single tuple
if extra_mods:
modification = (
*(modification if isinstance(modification, tuple) else (modification,)),
Expand Down
178 changes: 133 additions & 45 deletions packages/pipeline/tests/pipeline/test_idx_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,18 @@

import pyearthtools.utils

from pyearthtools.pipeline import Pipeline, modifications
from pyearthtools.pipeline import Operation
from pyearthtools.data import Index
from pyearthtools.pipeline import Pipeline, Operation
from pyearthtools.data import Index, Petdt
from pyearthtools.pipeline.modifications import (
IdxModifier,
IdxOverride,
TimeIdxModifier,
TemporalWindow,
TemporalRetrieval,
)
from pyearthtools.data.time import TimeDelta
import xarray as xr
import numpy as np


class FakeIndex(Index):
Expand Down Expand Up @@ -57,33 +66,18 @@ def test_multiplication_undo():
assert orig == 2


def test_IdxModifier_basic():
pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0,)))
assert pipe[0] == (0,)


def test_IdxModifier_basic_no_tuple():
pipe = Pipeline(FakeIndex(), modifications.IdxModifier(0))
assert pipe[0] == 0


def test_IdxModifier_two_samples():
pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0, 1)))
assert pipe[0] == (0, 1)


def test_IdxModifier_nested():
pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0, (1, 2))))
assert pipe[0] == (0, (1, 2))


def test_IdxModifier_nested_double():
pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0, (1, (2, 3)))))
assert pipe[0] == (0, (1, (2, 3)))
@pytest.mark.parametrize("mod", ((0,), 0, (0, 1), (0, (1, 2)), (0, (1, (2, 3)))))
def test_IdxModifier_basic(mod):
pipe = Pipeline(FakeIndex(), IdxModifier(mod))
assert pipe[0] == mod
# check that extra_mods gets passed through
if type(mod) is tuple and len(mod) > 1:
pipe = Pipeline(FakeIndex(), IdxModifier(*mod))
assert pipe[0] == mod


def test_IdxModifier_nested_merge():
pipe = Pipeline(FakeIndex(), modifications.IdxModifier((0, (1, 2)), merge=True, merge_function=sum))
pipe = Pipeline(FakeIndex(), IdxModifier((0, (1, 2)), merge=True, merge_function=sum))
assert pipe[0] == (0, 3)


Expand All @@ -99,15 +93,15 @@ def test_IdxModifier_nested_merge():
def test_IdxModifier_merge_depth(depth, result):
pipe = Pipeline(
FakeIndex(),
modifications.IdxModifier((1, (2, (3, 4))), merge=depth, merge_function=sum),
IdxModifier((1, (2, (3, 4))), merge=depth, merge_function=sum),
)
assert pipe[0] == result


def test_IdxModifier_unmergeable():
pipe = Pipeline(
FakeIndex("test"), # type: ignore
modifications.IdxModifier(("t", "a"), merge=True, merge_function=sum),
IdxModifier(("t", "a"), merge=True, merge_function=sum),
)
with pytest.raises(TypeError):
assert pipe[1] == (1, 5)
Expand All @@ -116,17 +110,17 @@ def test_IdxModifier_unmergeable():
def test_IdxMod_stacked():
pipe = Pipeline(
FakeIndex(),
modifications.IdxModifier((0, 1)),
modifications.IdxModifier((0, 1)),
IdxModifier((0, 1)),
IdxModifier((0, 1)),
)
assert pipe[1] == ((1, 2), (2, 3))


def test_IdxMod_stacked_with_mult():
pipe = Pipeline(
FakeIndex(),
modifications.IdxModifier((0, 1)),
modifications.IdxModifier((0, 1)),
IdxModifier((0, 1)),
IdxModifier((0, 1)),
MultiplicationOperation(2),
)
assert pipe[1] == ((2, 4), (4, 6))
Expand All @@ -135,7 +129,7 @@ def test_IdxMod_stacked_with_mult():
def test_IdxMod_with_branch():
pipe = Pipeline(
FakeIndex(),
modifications.IdxModifier((0, 1)),
IdxModifier((0, 1)),
(
(MultiplicationOperation(1),),
(MultiplicationOperation(2),),
Expand All @@ -147,7 +141,7 @@ def test_IdxMod_with_branch():
def test_IdxMod_with_branch_mapping():
pipe = Pipeline(
FakeIndex(),
modifications.IdxModifier((0, 1)),
IdxModifier((0, 1)),
((MultiplicationOperation(1),), (MultiplicationOperation(2),), "map"),
)
assert pipe[1] == (1, 4)
Expand All @@ -157,31 +151,125 @@ def test_IdxMod_with_branch_mapping():


def test_IdxOverride_basic():
pipe = Pipeline(FakeIndex(), modifications.IdxOverride(0))
pipe = Pipeline(FakeIndex(), IdxOverride(0))
assert pipe[1] == 0


#### TimeIdxModifier


def test_TimeIdxModifier_basic():
import pyearthtools.data

pipe = Pipeline(FakeIndex(), modifications.TimeIdxModifier("6 hours"))
assert pipe[pyearthtools.data.Petdt("2000-01-01T00")] == pyearthtools.data.Petdt("2000-01-01T06")
pipe = Pipeline(FakeIndex(), TimeIdxModifier("6 hours"))
assert pipe[Petdt("2000-01-01T00")] == Petdt("2000-01-01T06")


# def test_TimeIdxModifier_basic_tuple():
# import pyearthtools.data
# pipe = Pipeline(FakeIndex(), pipelines.TimeIdxModifier((6, 'hours')))
# assert pipe[pyearthtools.data.Petdt('2000-01-01T00')] == pyearthtools.data.Petdt('2000-01-01T06')
# assert pipe[Petdt('2000-01-01T00')] == Petdt('2000-01-01T06')


def test_TimeIdxModifier_nested():
import pyearthtools.data

pipe = Pipeline(FakeIndex(), modifications.TimeIdxModifier(("6 hours", "12 hours")))
assert pipe[pyearthtools.data.Petdt("2000-01-01T00")] == (
pyearthtools.data.Petdt("2000-01-01T06"),
pyearthtools.data.Petdt("2000-01-01T12"),
pipe = Pipeline(FakeIndex(), TimeIdxModifier(("6 hours", "12 hours")))
assert pipe[Petdt("2000-01-01T00")] == (
Petdt("2000-01-01T06"),
Petdt("2000-01-01T12"),
)


def test_TimeIdxModifier_extramods():
"""Tests TimeIdxModifier with modifications passed as variable args (extra_mods)"""
# first arg to TimeIdxModifier goes to "modifications" and second goes to extra_args
pipe = Pipeline(FakeIndex(), TimeIdxModifier("6 hours", "12 hours"))
assert pipe[Petdt("2000-01-01T00")] == (
Petdt("2000-01-01T06"),
Petdt("2000-01-01T12"),
)


class test_data_accessor(Index):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._data = {
"2025-01-01": 1,
"2025-01-02": 2,
"2025-01-03": 3,
}

def get(self, time):
return self._data[time]


@pytest.mark.parametrize("merge_method, expected", ((None, (1, 3)), (sum, 4)))
def test_TemporalRetrieval(merge_method, expected):
"""Test temporal retrieval."""

# instantiate temporal retrieval with merge_method - retrieving from two steps behind.
temporal_retrieval_step = TemporalRetrieval(-2, merge_function=merge_method, delta_unit="day")

pipeline = Pipeline(test_data_accessor(), temporal_retrieval_step)

assert expected == pipeline[Petdt("2025-01-03")]


def test_TemporalRetrieval_xarrayaccessor():
"""Tests temporal retrieval default merger for xarray data."""

# insantiate TemporalRetrieval without merge method
temporal_retrieval_step = TemporalRetrieval(-2)

# create a data accessor with fake xarray data.
data_accessor = test_data_accessor()
data_accessor._data = {date: xr.DataArray([val] * 2, name=f"arr{val}") for date, val in data_accessor._data.items()}
pipeline = Pipeline(data_accessor, temporal_retrieval_step)

assert xr.merge((data_accessor["2025-01-01"], data_accessor["2025-01-03"])) == pipeline["2025-01-03"]


def test_TemporalRetrieval_npconcat():
"""Tests temporal retrieval default merger for numpy data."""

# insantiate TemporalRetrieval with concat
temporal_retrieval_step = TemporalRetrieval(-2, concat=True)

# create a data accessor with fake numpy data.
data_accessor = test_data_accessor()
data_accessor._data = {date: np.array((val, val + 1)) for date, val in data_accessor._data.items()}
pipeline = Pipeline(data_accessor, temporal_retrieval_step)

assert np.array_equal(np.array((1, 2, 3, 4)), pipeline["2025-01-03"])


def test_TemporalRetrieval_invalid():
"""Tests errors when using/instantiating TemporalRetrieval."""
with pytest.raises(ValueError):
TemporalRetrieval(None) # index ought to be int or iterable of ints
tr = TemporalRetrieval(-1)
with pytest.raises(TypeError):
tr["a"] # not convertable to Petdt

# this is actually covering a type error in IdxModifier._run_merge
invalid_accessor = test_data_accessor()
invalid_accessor._data["2025-01-01"] = "a"
pipeline = Pipeline(invalid_accessor, TemporalRetrieval(-2))
with pytest.raises(TypeError):
pipeline["2025-01-03"]


@pytest.mark.parametrize("merge_method, expected", ((None, ([1, 2], [3])), (sum, (3, 3))))
def test_TemporalWindow(merge_method, expected):
"""Test temporal window."""

# instantiate temporal window with merge method
temporal_window_step = TemporalWindow(
prior_indexes=[-2, -1],
posterior_indexes=[0],
timedelta=TimeDelta((1, "day")),
merge_method=merge_method,
)

# Instantiate pipeline with test data and temporal window
pipeline = Pipeline(test_data_accessor(), temporal_window_step)
assert expected == pipeline["2025-01-03"]