From 4840eca3e36af7a46a947ea312954897cf3386ca Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 13 Nov 2025 10:34:07 +1100 Subject: [PATCH 1/6] add test for branch with data source error --- packages/pipeline/tests/test_controller/test_branchpoints.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/packages/pipeline/tests/test_controller/test_branchpoints.py b/packages/pipeline/tests/test_controller/test_branchpoints.py index 8898fc61..9480b076 100644 --- a/packages/pipeline/tests/test_controller/test_branchpoints.py +++ b/packages/pipeline/tests/test_controller/test_branchpoints.py @@ -199,3 +199,8 @@ def test_branch_with_source(): (MultiplicationOperation(2), FakeIndex()), ) assert pipe[1] == (2, 1) + + # test error when trying to map to a datasource + pipe = Pipeline((FakeIndex(1),), (FakeIndex(2), "map")) + with pytest.raises(ValueError): + pipe[0] From 3540edfb771ec3026c937a80bb4eeb2800933706 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 13 Nov 2025 11:46:46 +1100 Subject: [PATCH 2/6] Add branch test when index is None --- packages/pipeline/tests/test_controller/test_branchpoints.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/packages/pipeline/tests/test_controller/test_branchpoints.py b/packages/pipeline/tests/test_controller/test_branchpoints.py index 9480b076..99533d6e 100644 --- a/packages/pipeline/tests/test_controller/test_branchpoints.py +++ b/packages/pipeline/tests/test_controller/test_branchpoints.py @@ -204,3 +204,8 @@ def test_branch_with_source(): pipe = Pipeline((FakeIndex(1),), (FakeIndex(2), "map")) with pytest.raises(ValueError): pipe[0] + + # test branching with None index + pipe = Pipeline((FakeIndex(1),), (FakeIndex(2),)) + with pytest.raises(ValueError): + pipe[None] From a815fcd66f886d7d4d969ddd52e912e0516fbe21 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 13 Nov 2025 12:58:26 +1100 Subject: [PATCH 3/6] complete cover for branchpoint undo --- packages/pipeline/tests/test_controller/test_branchpoints.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/packages/pipeline/tests/test_controller/test_branchpoints.py b/packages/pipeline/tests/test_controller/test_branchpoints.py index 99533d6e..eef20898 100644 --- a/packages/pipeline/tests/test_controller/test_branchpoints.py +++ b/packages/pipeline/tests/test_controller/test_branchpoints.py @@ -19,6 +19,7 @@ import pyearthtools.utils from pyearthtools.pipeline import Pipeline, exceptions, branching +from pyearthtools.pipeline.warnings import PipelineWarning from tests.fake_pipeline_steps import FakeIndex, MultiplicationOperation, MultiplicationOperationUnunifiedable @@ -87,6 +88,8 @@ def test_branch_differing_operations_nested_larger(): def test_branch_differing_operations_undo(): pipe = Pipeline(FakeIndex(), (MultiplicationOperation(10), MultiplicationOperation(2))) assert pipe.undo(pipe[1]) == 1 + with pytest.warns(PipelineWarning): + pipe._steps[1].undo((30, 20, 10)) # def test_branch_differing_operations_undo_unify(): @@ -168,6 +171,8 @@ def test_branch_with_mapping_copy(): (MultiplicationOperation(1), "map_copy"), ) assert pipe[1] == (1, 2) + # test round robin application of undo + assert pipe._steps[1].undo((30, 20, 10)) == (30, 20, 10) def test_branch_with_mapping_not_tuple(): From 613239a15a83e4dedf48414b48f756a7c8ad7874 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 13 Nov 2025 14:32:36 +1100 Subject: [PATCH 4/6] make check_if_index module function --- .../pyearthtools/pipeline/branching/branching.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py b/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py index ea6a1af3..e6501eeb 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py +++ b/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py @@ -62,6 +62,13 @@ def expand_pipeline(original: PipelineBranchPoint, length: int) -> list[Pipeline return new_pipeline +def _check_if_index(pipe: Pipeline): + """Check if pipeline is an index pipeline/branch""" + if isinstance(pipe.steps[0], PipelineBranchPoint): + return all(map(_check_if_index, pipe.steps[0].sub_pipelines)) + return isinstance(pipe.steps[0], Index) + + class PipelineBranchPoint(_Pipeline, Operation): """ Branch Point in a `Pipeline`. @@ -213,13 +220,7 @@ def undo(self, sample): # ) result = tuple(self.parallel_interface.collect(sub_samples)) - def check_if_index(pipe: Pipeline): - """Check if pipeline is an index pipeline/branch""" - if isinstance(pipe.steps[0], PipelineBranchPoint): - return all(map(check_if_index, pipe.steps[0].sub_pipelines)) - return isinstance(pipe.steps[0], Index) - - if all(len(pipe.steps) == 1 or check_if_index(pipe) for pipe in self.sub_pipelines): + if all(len(pipe.steps) == 1 or _check_if_index(pipe) for pipe in self.sub_pipelines): if all(map(lambda x: result[0] == x, result[1:])): return result[0] return result From 0d3d118120f2d6efd25e90c7c5c9f76425df1e92 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 13 Nov 2025 14:33:56 +1100 Subject: [PATCH 5/6] test cover untested lines in branching.py --- packages/pipeline/tests/fake_pipeline_steps.py | 3 +++ .../pipeline/tests/test_controller/test_branchpoints.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/packages/pipeline/tests/fake_pipeline_steps.py b/packages/pipeline/tests/fake_pipeline_steps.py index 604885cb..b1e74c53 100644 --- a/packages/pipeline/tests/fake_pipeline_steps.py +++ b/packages/pipeline/tests/fake_pipeline_steps.py @@ -33,6 +33,9 @@ def __init__(self, override: int | None = None): def get(self, idx): return self._overrideValue or idx + def __eq__(self, other): + return type(other) is FakeIndex and other._overrideValue == self._overrideValue + class MultiplicationOperation(Operation): def __init__(self, factor): diff --git a/packages/pipeline/tests/test_controller/test_branchpoints.py b/packages/pipeline/tests/test_controller/test_branchpoints.py index eef20898..0b39cf82 100644 --- a/packages/pipeline/tests/test_controller/test_branchpoints.py +++ b/packages/pipeline/tests/test_controller/test_branchpoints.py @@ -29,6 +29,7 @@ def test_branchingpoint_basic(): pipe = Pipeline((FakeIndex(), FakeIndex())) assert pipe[1] == (1, 1) + assert pipe.complete_steps == (((FakeIndex(),), (FakeIndex(),)),) def test_branch_differing_operations(): @@ -214,3 +215,10 @@ def test_branch_with_source(): pipe = Pipeline((FakeIndex(1),), (FakeIndex(2),)) with pytest.raises(ValueError): pipe[None] + + +def test_check_index(): + pipe = Pipeline( + (MultiplicationOperation(2), MultiplicationOperation(3)), + ) + assert not branching.branching._check_if_index(pipe) From c650742e57ac985d3eb1196428ec55086d746567 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Fri, 14 Nov 2025 09:05:40 +1100 Subject: [PATCH 6/6] bring branching.py to 100% test coverage --- .../pipeline/branching/branching.py | 4 ++- .../test_controller/test_branchpoints.py | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py b/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py index e6501eeb..98502894 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py +++ b/packages/pipeline/src/pyearthtools/pipeline/branching/branching.py @@ -161,7 +161,9 @@ def apply(self, sample): raise PipelineRuntimeError( f"Cannot map sample to branches as length differ. {len(sample)} != {len(self.sub_pipelines)}." ) - elif self._map_copy: + elif ( + self._map_copy + ): # pragma: no cover # cannot be fully tested - if this line is reached, self._map_copy is True self.sub_pipelines = expand_pipeline(self, len(sample)) for s, pipe in zip(sample, self.sub_pipelines): diff --git a/packages/pipeline/tests/test_controller/test_branchpoints.py b/packages/pipeline/tests/test_controller/test_branchpoints.py index 0b39cf82..c99cbd54 100644 --- a/packages/pipeline/tests/test_controller/test_branchpoints.py +++ b/packages/pipeline/tests/test_controller/test_branchpoints.py @@ -222,3 +222,31 @@ def test_check_index(): (MultiplicationOperation(2), MultiplicationOperation(3)), ) assert not branching.branching._check_if_index(pipe) + + +def test_nested_branching_undo_nosource(): + pipe = Pipeline( + ( + ( + MultiplicationOperation(2), + MultiplicationOperation(3), + ), + MultiplicationOperation(4), + ) + ) + assert pipe.undo((18, 8)) == (3, 2) + + +def test_expand_pipeline_skip_non_pipeline(): + """Tests that expand_pipeline function ignore non-pipeline objects.""" + + # instantiate pipline with branchpoint at step 0. + pipe = Pipeline( + (MultiplicationOperation(3), MultiplicationOperation(4)), + ) + + # add the non-pipeline object to branchpoint sub pipelines + pipe._steps[0].sub_pipelines.append((1,)) + + # test that the non pipeline object isn't in the resultant expanded pipeline. + assert (1,) not in branching.branching.expand_pipeline(pipe._steps[0], 3)