Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions packages/pipeline/src/pyearthtools/pipeline/branching/branching.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def expand_pipeline(original: PipelineBranchPoint, length: int) -> list[Pipeline
return new_pipeline


def _check_if_index(pipe: Pipeline):
"""Check if pipeline is an index pipeline/branch"""
if isinstance(pipe.steps[0], PipelineBranchPoint):
return all(map(_check_if_index, pipe.steps[0].sub_pipelines))
return isinstance(pipe.steps[0], Index)


class PipelineBranchPoint(_Pipeline, Operation):
"""
Branch Point in a `Pipeline`.
Expand Down Expand Up @@ -154,7 +161,9 @@ def apply(self, sample):
raise PipelineRuntimeError(
f"Cannot map sample to branches as length differ. {len(sample)} != {len(self.sub_pipelines)}."
)
elif self._map_copy:
elif (
self._map_copy
): # pragma: no cover # cannot be fully tested - if this line is reached, self._map_copy is True
self.sub_pipelines = expand_pipeline(self, len(sample))

for s, pipe in zip(sample, self.sub_pipelines):
Expand Down Expand Up @@ -213,13 +222,7 @@ def undo(self, sample):
# )
result = tuple(self.parallel_interface.collect(sub_samples))

def check_if_index(pipe: Pipeline):
"""Check if pipeline is an index pipeline/branch"""
if isinstance(pipe.steps[0], PipelineBranchPoint):
return all(map(check_if_index, pipe.steps[0].sub_pipelines))
return isinstance(pipe.steps[0], Index)

if all(len(pipe.steps) == 1 or check_if_index(pipe) for pipe in self.sub_pipelines):
if all(len(pipe.steps) == 1 or _check_if_index(pipe) for pipe in self.sub_pipelines):
if all(map(lambda x: result[0] == x, result[1:])):
return result[0]
return result
Expand Down
3 changes: 3 additions & 0 deletions packages/pipeline/tests/fake_pipeline_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def __init__(self, override: int | None = None):
def get(self, idx):
return self._overrideValue or idx

def __eq__(self, other):
return type(other) is FakeIndex and other._overrideValue == self._overrideValue


class MultiplicationOperation(Operation):
def __init__(self, factor):
Expand Down
51 changes: 51 additions & 0 deletions packages/pipeline/tests/test_controller/test_branchpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pyearthtools.utils

from pyearthtools.pipeline import Pipeline, exceptions, branching
from pyearthtools.pipeline.warnings import PipelineWarning

from tests.fake_pipeline_steps import FakeIndex, MultiplicationOperation, MultiplicationOperationUnunifiedable

Expand All @@ -28,6 +29,7 @@
def test_branchingpoint_basic():
pipe = Pipeline((FakeIndex(), FakeIndex()))
assert pipe[1] == (1, 1)
assert pipe.complete_steps == (((FakeIndex(),), (FakeIndex(),)),)


def test_branch_differing_operations():
Expand Down Expand Up @@ -87,6 +89,8 @@ def test_branch_differing_operations_nested_larger():
def test_branch_differing_operations_undo():
pipe = Pipeline(FakeIndex(), (MultiplicationOperation(10), MultiplicationOperation(2)))
assert pipe.undo(pipe[1]) == 1
with pytest.warns(PipelineWarning):
pipe._steps[1].undo((30, 20, 10))


# def test_branch_differing_operations_undo_unify():
Expand Down Expand Up @@ -168,6 +172,8 @@ def test_branch_with_mapping_copy():
(MultiplicationOperation(1), "map_copy"),
)
assert pipe[1] == (1, 2)
# test round robin application of undo
assert pipe._steps[1].undo((30, 20, 10)) == (30, 20, 10)


def test_branch_with_mapping_not_tuple():
Expand Down Expand Up @@ -199,3 +205,48 @@ def test_branch_with_source():
(MultiplicationOperation(2), FakeIndex()),
)
assert pipe[1] == (2, 1)

# test error when trying to map to a datasource
pipe = Pipeline((FakeIndex(1),), (FakeIndex(2), "map"))
with pytest.raises(ValueError):
pipe[0]

# test branching with None index
pipe = Pipeline((FakeIndex(1),), (FakeIndex(2),))
with pytest.raises(ValueError):
pipe[None]


def test_check_index():
pipe = Pipeline(
(MultiplicationOperation(2), MultiplicationOperation(3)),
)
assert not branching.branching._check_if_index(pipe)


def test_nested_branching_undo_nosource():
pipe = Pipeline(
(
(
MultiplicationOperation(2),
MultiplicationOperation(3),
),
MultiplicationOperation(4),
)
)
assert pipe.undo((18, 8)) == (3, 2)


def test_expand_pipeline_skip_non_pipeline():
"""Tests that expand_pipeline function ignore non-pipeline objects."""

# instantiate pipline with branchpoint at step 0.
pipe = Pipeline(
(MultiplicationOperation(3), MultiplicationOperation(4)),
)

# add the non-pipeline object to branchpoint sub pipelines
pipe._steps[0].sub_pipelines.append((1,))

# test that the non pipeline object isn't in the resultant expanded pipeline.
assert (1,) not in branching.branching.expand_pipeline(pipe._steps[0], 3)
Loading