diff --git a/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py b/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py index ea6a1af3..98502894 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py +++ b/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py @@ -62,6 +62,13 @@ def expand_pipeline(original: PipelineBranchPoint, length: int) -> list[Pipeline return new_pipeline +def _check_if_index(pipe: Pipeline): + """Check if pipeline is an index pipeline/branch""" + if isinstance(pipe.steps[0], PipelineBranchPoint): + return all(map(_check_if_index, pipe.steps[0].sub_pipelines)) + return isinstance(pipe.steps[0], Index) + + class PipelineBranchPoint(_Pipeline, Operation): """ Branch Point in a `Pipeline`. @@ -154,7 +161,9 @@ def apply(self, sample): raise PipelineRuntimeError( f"Cannot map sample to branches as length differ. {len(sample)} != {len(self.sub_pipelines)}." ) - elif self._map_copy: + elif ( + self._map_copy + ): # pragma: no cover # cannot be fully tested - if this line is reached, self._map_copy is True self.sub_pipelines = expand_pipeline(self, len(sample)) for s, pipe in zip(sample, self.sub_pipelines): @@ -213,13 +222,7 @@ def undo(self, sample): # ) result = tuple(self.parallel_interface.collect(sub_samples)) - def check_if_index(pipe: Pipeline): - """Check if pipeline is an index pipeline/branch""" - if isinstance(pipe.steps[0], PipelineBranchPoint): - return all(map(check_if_index, pipe.steps[0].sub_pipelines)) - return isinstance(pipe.steps[0], Index) - - if all(len(pipe.steps) == 1 or check_if_index(pipe) for pipe in self.sub_pipelines): + if all(len(pipe.steps) == 1 or _check_if_index(pipe) for pipe in self.sub_pipelines): if all(map(lambda x: result[0] == x, result[1:])): return result[0] return result diff --git a/packages/pipeline/tests/fake_pipeline_steps.py b/packages/pipeline/tests/fake_pipeline_steps.py index 604885cb..b1e74c53 100644 --- a/packages/pipeline/tests/fake_pipeline_steps.py +++ b/packages/pipeline/tests/fake_pipeline_steps.py @@ -33,6 +33,9 @@ def __init__(self, override: int | None = None): def get(self, idx): return self._overrideValue or idx + def __eq__(self, other): + return type(other) is FakeIndex and other._overrideValue == self._overrideValue + class MultiplicationOperation(Operation): def __init__(self, factor): diff --git a/packages/pipeline/tests/test_controller/test_branchpoints.py b/packages/pipeline/tests/test_controller/test_branchpoints.py index 8898fc61..c99cbd54 100644 --- a/packages/pipeline/tests/test_controller/test_branchpoints.py +++ b/packages/pipeline/tests/test_controller/test_branchpoints.py @@ -19,6 +19,7 @@ import pyearthtools.utils from pyearthtools.pipeline import Pipeline, exceptions, branching +from pyearthtools.pipeline.warnings import PipelineWarning from tests.fake_pipeline_steps import FakeIndex, MultiplicationOperation, MultiplicationOperationUnunifiedable @@ -28,6 +29,7 @@ def test_branchingpoint_basic(): pipe = Pipeline((FakeIndex(), FakeIndex())) assert pipe[1] == (1, 1) + assert pipe.complete_steps == (((FakeIndex(),), (FakeIndex(),)),) def test_branch_differing_operations(): @@ -87,6 +89,8 @@ def test_branch_differing_operations_nested_larger(): def test_branch_differing_operations_undo(): pipe = Pipeline(FakeIndex(), (MultiplicationOperation(10), MultiplicationOperation(2))) assert pipe.undo(pipe[1]) == 1 + with pytest.warns(PipelineWarning): + pipe._steps[1].undo((30, 20, 10)) # def test_branch_differing_operations_undo_unify(): @@ -168,6 +172,8 @@ def test_branch_with_mapping_copy(): (MultiplicationOperation(1), "map_copy"), ) assert pipe[1] == (1, 2) + # test round robin application of undo + assert pipe._steps[1].undo((30, 20, 10)) == (30, 20, 10) def test_branch_with_mapping_not_tuple(): @@ -199,3 +205,48 @@ def test_branch_with_source(): (MultiplicationOperation(2), FakeIndex()), ) assert pipe[1] == (2, 1) + + # test error when trying to map to a datasource + pipe = Pipeline((FakeIndex(1),), (FakeIndex(2), "map")) + with pytest.raises(ValueError): + pipe[0] + + # test branching with None index + pipe = Pipeline((FakeIndex(1),), (FakeIndex(2),)) + with pytest.raises(ValueError): + pipe[None] + + +def test_check_index(): + pipe = Pipeline( + (MultiplicationOperation(2), MultiplicationOperation(3)), + ) + assert not branching.branching._check_if_index(pipe) + + +def test_nested_branching_undo_nosource(): + pipe = Pipeline( + ( + ( + MultiplicationOperation(2), + MultiplicationOperation(3), + ), + MultiplicationOperation(4), + ) + ) + assert pipe.undo((18, 8)) == (3, 2) + + +def test_expand_pipeline_skip_non_pipeline(): + """Tests that expand_pipeline function ignore non-pipeline objects.""" + + # instantiate pipline with branchpoint at step 0. + pipe = Pipeline( + (MultiplicationOperation(3), MultiplicationOperation(4)), + ) + + # add the non-pipeline object to branchpoint sub pipelines + pipe._steps[0].sub_pipelines.append((1,)) + + # test that the non pipeline object isn't in the resultant expanded pipeline. + assert (1,) not in branching.branching.expand_pipeline(pipe._steps[0], 3)