From f7ea2547661b29473967955b080235c8dfa21cb8 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 16 Jun 2025 20:37:08 +0000 Subject: [PATCH 1/5] Refactor PipelineBase._validate_input() method I've refactored the `PipelineBase._validate_input()` method to improve clarity and maintainability. This involved breaking down the method into smaller helper functions and enhancing error messages for better specificity. I also added comprehensive unit tests to cover various input validation scenarios, ensuring the refactored method behaves as expected. --- haystack/core/pipeline/base.py | 66 +++++++++++++++++------- test/core/pipeline/test_pipeline_base.py | 58 +++++++++++++++++++++ 2 files changed, 104 insertions(+), 20 deletions(-) diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index a44cb5f724..7e854e464e 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -899,6 +899,39 @@ def _create_component_span( parent_span=parent_span, ) + def _validate_component_input(self, component_name: str, component_inputs: Dict[str, Any], data: Dict[str, Any]) -> None: + """ + Validates input data for a specific component. + + :param component_name: Name of the component. + :param component_inputs: Inputs provided for the component. + :param data: All pipeline input data. + :raises ValueError: If inputs are invalid. + """ + if component_name not in self.graph.nodes: + raise ValueError(f"Component '{component_name}' not found in the pipeline. Available components: {list(self.graph.nodes.keys())}") + instance = self.graph.nodes[component_name]["instance"] + + # Validate that all mandatory inputs are provided either directly or by senders + for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): + if socket.is_mandatory and not socket.senders and socket_name not in component_inputs: + raise ValueError(f"Missing mandatory input '{socket_name}' for component '{component_name}'.") + + # Validate that provided inputs exist in the component's input sockets + for input_name in component_inputs.keys(): + if input_name not in instance.__haystack_input__._sockets_dict: + raise ValueError(f"Unexpected input '{input_name}' for component '{component_name}'. Available inputs: {list(instance.__haystack_input__._sockets_dict.keys())}") + + # Validate that inputs are not multiply defined (already sent by another component and also provided directly) + # unless the socket is variadic + for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): + if socket.senders and socket_name in component_inputs and not socket.is_variadic: + raise ValueError( + f"Input '{socket_name}' for component '{component_name}' is already provided by component " + f"'{socket.senders[0]}'. Do not provide it directly." + ) + + def _validate_input(self, data: Dict[str, Any]) -> None: """ Validates pipeline input data. @@ -916,26 +949,19 @@ def _validate_input(self, data: Dict[str, Any]) -> None: If inputs are invalid according to the above. """ for component_name, component_inputs in data.items(): - if component_name not in self.graph.nodes: - raise ValueError(f"Component named {component_name} not found in the pipeline.") - instance = self.graph.nodes[component_name]["instance"] - for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): - if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs: - raise ValueError(f"Missing input for component {component_name}: {socket_name}") - for input_name in component_inputs.keys(): - if input_name not in instance.__haystack_input__._sockets_dict: - raise ValueError(f"Input {input_name} not found in component {component_name}.") - - for component_name in self.graph.nodes: - instance = self.graph.nodes[component_name]["instance"] - for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): - component_inputs = data.get(component_name, {}) - if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs: - raise ValueError(f"Missing input for component {component_name}: {socket_name}") - if socket.senders and socket_name in component_inputs and not socket.is_variadic: - raise ValueError( - f"Input {socket_name} for component {component_name} is already sent by {socket.senders}." - ) + self._validate_component_input(component_name, component_inputs, data) + + # Additionally, check for components that might be missing inputs, + # even if they were not explicitly mentioned in the `data` dictionary. + # This covers cases where a component has mandatory inputs but receives no data. + for component_name_in_graph in self.graph.nodes: + if component_name_in_graph not in data: + # This component was not in the input data dictionary, check if it has mandatory inputs without senders + instance = self.graph.nodes[component_name_in_graph]["instance"] + for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): + if socket.is_mandatory and not socket.senders: + raise ValueError(f"Missing mandatory input '{socket_name}' for component '{component_name_in_graph}' which was not provided in the input data.") + def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: """ diff --git a/test/core/pipeline/test_pipeline_base.py b/test/core/pipeline/test_pipeline_base.py index 7c2d62e8cf..067a6f13e2 100644 --- a/test/core/pipeline/test_pipeline_base.py +++ b/test/core/pipeline/test_pipeline_base.py @@ -1866,3 +1866,61 @@ def test_pipeline_show_called_with_keyword_args_triggers_no_warning(self, mock_i with warnings.catch_warnings(record=True) as w: pipeline.show(server_url="http://localhost:3000") assert len(w) == 0, "No warning should be triggered when using keyword arguments" + + +class TestValidateInput: + def test_validate_input_valid_data(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + pipe.add_component("comp1", comp1) + pipe._validate_input(data={"comp1": {"x": 1}}) + # No exception should be raised + + def test_validate_input_missing_mandatory_input(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + pipe.add_component("comp1", comp1) + with pytest.raises(ValueError, match="Missing mandatory input 'x' for component 'comp1'"): + pipe._validate_input(data={"comp1": {}}) + + def test_validate_input_missing_mandatory_input_for_component_not_in_data(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + comp2 = component_class("Comp2", input_types={"a": str}, output_types={"b": str})() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) # comp2 requires 'a' but is not in data + with pytest.raises(ValueError, match="Missing mandatory input 'a' for component 'comp2' which was not provided in the input data."): + pipe._validate_input(data={"comp1": {"x": 1}}) + + + def test_validate_input_to_already_connected_socket(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + comp2 = component_class("Comp2", input_types={"a": int}, output_types={"b": int})() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.connect("comp1.y", "comp2.a") + with pytest.raises(ValueError, match="Input 'a' for component 'comp2' is already provided by component 'comp1'. Do not provide it directly."): + pipe._validate_input(data={"comp2": {"a": 1}}) + + def test_validate_input_for_non_existent_component(self): + pipe = PipelineBase() + with pytest.raises(ValueError, match="Component 'non_existent' not found in the pipeline. Available components: \\[\\]"): + pipe._validate_input(data={"non_existent": {"x": 1}}) + + def test_validate_input_with_unexpected_input_name(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + pipe.add_component("comp1", comp1) + with pytest.raises(ValueError, match="Unexpected input 'z' for component 'comp1'. Available inputs: \\['x'\\]"): + pipe._validate_input(data={"comp1": {"z": 1}}) + + def test_validate_input_variadic_socket_can_receive_multiple_inputs(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", output_types={"y": int})() + comp2 = component_class("Comp2", input_types={"a": Variadic[int]}, output_types={"b": int})() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.connect("comp1.y", "comp2.a") + # Should not raise an error, as variadic sockets can accept multiple inputs + pipe._validate_input(data={"comp2": {"a": 1}}) From 65b455f3d7c892147834100aa09943e997d2f169 Mon Sep 17 00:00:00 2001 From: carlosrinc Date: Wed, 18 Jun 2025 09:42:25 -0500 Subject: [PATCH 2/5] =?UTF-8?q?restaurar=20el=20nombre=20p=C3=BAblico=20de?= =?UTF-8?q?l=20m=C3=A9todo=20a=20validate=5Finput=20en=20PipelineBase?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- haystack/core/pipeline/base.py | 4 +- test/core/pipeline/test_pipeline_base.py | 3258 ++++++++++++---------- 2 files changed, 1836 insertions(+), 1426 deletions(-) diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index d7df26c1be..3c8e817bc0 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -71,7 +71,7 @@ class ComponentPriority(IntEnum): BLOCKED = 5 -class PipelineBase: # noqa: PLW1641 +class PipelineBase: """ Components orchestration engine. @@ -1513,4 +1513,4 @@ def _write_to_standard_socket( # Only overwrite if there's no existing value, or we have a new value to provide if current_value is None or value is not _NO_OUTPUT_PRODUCED: - inputs[receiver_name][receiver_socket_name] = [{"sender": component_name, "value": value}] + inputs[receiver_name][receiver_socket_name] = [{"sender": component_name, "value": value}] \ No newline at end of file diff --git a/test/core/pipeline/test_pipeline_base.py b/test/core/pipeline/test_pipeline_base.py index 7e854e464e..2d60fb44c1 100644 --- a/test/core/pipeline/test_pipeline_base.py +++ b/test/core/pipeline/test_pipeline_base.py @@ -2,1515 +2,1925 @@ # # SPDX-License-Identifier: Apache-2.0 -import itertools -from collections import defaultdict -from datetime import datetime -from enum import IntEnum -from pathlib import Path -from typing import Any, ContextManager, Dict, Iterator, List, Optional, Set, TextIO, Tuple, Type, TypeVar, Union +import logging +from typing import List, Optional +from unittest.mock import patch -import networkx # type:ignore +import pytest -from haystack import logging, tracing -from haystack.core.component import Component, InputSocket, OutputSocket, component +from pandas import DataFrame + +from haystack import Document +from haystack.core.component import component +from haystack.core.component.types import InputSocket, OutputSocket, Variadic, GreedyVariadic, _empty from haystack.core.errors import ( DeserializationError, - PipelineComponentsBlockedError, PipelineConnectError, PipelineDrawingError, PipelineError, PipelineMaxComponentRuns, - PipelineUnmarshalError, - PipelineValidationError, -) -from haystack.core.pipeline.component_checks import ( - _NO_OUTPUT_PRODUCED, - all_predecessors_executed, - are_all_lazy_variadic_sockets_resolved, - are_all_sockets_ready, - can_component_run, - is_any_greedy_socket_ready, - is_socket_lazy_variadic, -) -from haystack.core.pipeline.utils import ( - FIFOPriorityQueue, - _deepcopy_with_exceptions, - args_deprecated, - parse_connect_string, ) -from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict -from haystack.core.type_utils import _type_name, _types_are_compatible -from haystack.marshal import Marshaller, YamlMarshaller -from haystack.utils import is_in_jupyter, type_serialization +from haystack.core.pipeline import PredefinedPipeline +from haystack.core.pipeline.base import PipelineBase +from haystack.core.pipeline.base import ComponentPriority, _NO_OUTPUT_PRODUCED +from haystack.core.pipeline.utils import FIFOPriorityQueue -from .descriptions import find_pipeline_inputs, find_pipeline_outputs -from .draw import _to_mermaid_image -from .template import PipelineTemplate, PredefinedPipeline +from haystack.core.serialization import DeserializationCallbacks +from haystack.testing.factory import component_class +from haystack.testing.sample_components import AddFixedValue, Double, Greet -DEFAULT_MARSHALLER = YamlMarshaller() +logging.basicConfig(level=logging.DEBUG) -# We use a generic type to annotate the return value of class methods, -# so that static analyzers won't be confused when derived classes -# use those methods. -T = TypeVar("T", bound="PipelineBase") -logger = logging.getLogger(__name__) +@component +class FakeComponent: + def __init__(self, an_init_param: Optional[str] = None): + pass + @component.output_types(value=str) + def run(self, input_: str): + return {"value": input_} -# Constants for tracing tags -_COMPONENT_INPUT = "haystack.component.input" -_COMPONENT_OUTPUT = "haystack.component.output" -_COMPONENT_VISITS = "haystack.component.visits" +@component +class FakeComponentSquared: + def __init__(self, an_init_param: Optional[str] = None): + self.an_init_param = an_init_param + self.inner = FakeComponent() -class ComponentPriority(IntEnum): - HIGHEST = 1 - READY = 2 - DEFER = 3 - DEFER_LAST = 4 - BLOCKED = 5 + @component.output_types(value=str) + def run(self, input_: str): + return {"value": input_} -class PipelineBase: - """ - Components orchestration engine. +@pytest.fixture +def regular_output_socket(): + """Output socket for a regular (non-variadic) connection with receivers""" + return OutputSocket("output1", int, receivers=["receiver1", "receiver2"]) - Builds a graph of components and orchestrates their execution according to the execution graph. - """ - def __init__( - self, - metadata: Optional[Dict[str, Any]] = None, - max_runs_per_component: int = 100, - connection_type_validation: bool = True, - ): - """ - Creates the Pipeline. - - :param metadata: - Arbitrary dictionary to store metadata about this `Pipeline`. Make sure all the values contained in - this dictionary can be serialized and deserialized if you wish to save this `Pipeline` to file. - :param max_runs_per_component: - How many times the `Pipeline` can run the same Component. - If this limit is reached a `PipelineMaxComponentRuns` exception is raised. - If not set defaults to 100 runs per Component. - :param connection_type_validation: Whether the pipeline will validate the types of the connections. - Defaults to True. - """ - self._telemetry_runs = 0 - self._last_telemetry_sent: Optional[datetime] = None - self.metadata = metadata or {} - self.graph = networkx.MultiDiGraph() - self._max_runs_per_component = max_runs_per_component - self._connection_type_validation = connection_type_validation - - def __eq__(self, other: object) -> bool: - """ - Pipeline equality is defined by their type and the equality of their serialized form. +@pytest.fixture +def regular_input_socket(): + """Regular (non-variadic) input socket with a single sender""" + return InputSocket("input1", int, senders=["sender1"]) - Pipelines of the same type share every metadata, node and edge, but they're not required to use - the same node instances: this allows pipeline saved and then loaded back to be equal to themselves. - """ - if not isinstance(self, type(other)): - return False - assert isinstance(other, PipelineBase) - return self.to_dict() == other.to_dict() - def __repr__(self) -> str: - """ - Returns a text representation of the Pipeline. - """ - res = f"{object.__repr__(self)}\n" - if self.metadata: - res += "🧱 Metadata\n" - for k, v in self.metadata.items(): - res += f" - {k}: {v}\n" +@pytest.fixture +def lazy_variadic_input_socket(): + """Lazy variadic input socket with multiple senders""" + return InputSocket("variadic_input", Variadic[int], senders=["sender1", "sender2"]) - res += "🚅 Components\n" - for name, instance in self.graph.nodes(data="instance"): # type: ignore # type wrongly defined in networkx - res += f" - {name}: {instance.__class__.__name__}\n" - res += "🛤️ Connections\n" - for sender, receiver, edge_data in self.graph.edges(data=True): - sender_socket = edge_data["from_socket"].name - receiver_socket = edge_data["to_socket"].name - res += f" - {sender}.{sender_socket} -> {receiver}.{receiver_socket} ({edge_data['conn_type']})\n" +class TestPipelineBase: + """ + This class contains only unit tests for the PipelineBase class. + It doesn't test Pipeline.run(), that is done separately in a different way. + """ - return res + def test_pipeline_dumps(self, test_files_path): + pipeline = PipelineBase(max_runs_per_component=99) + pipeline.add_component("Comp1", FakeComponent("Foo")) + pipeline.add_component("Comp2", FakeComponent()) + pipeline.connect("Comp1.value", "Comp2.input_") + result = pipeline.dumps() + with open(f"{test_files_path}/yaml/test_pipeline.yaml", "r") as f: + assert f.read() == result + + def test_pipeline_loads_invalid_data(self): + invalid_yaml = """components: + Comp1: + init_parameters: + an_init_param: null + type: test.core.pipeline.test_pipeline_base.FakeComponent + Comp2* + init_parameters: + an_init_param: null + type: test.core.pipeline.test_pipeline_base.FakeComponent + connections: + * receiver: Comp2.input_ + sender: Comp1.value + metadata: + """ + + with pytest.raises(DeserializationError, match="unmarshalling serialized"): + pipeline = PipelineBase.loads(invalid_yaml) + + invalid_init_parameter_yaml = """components: + Comp1: + init_parameters: + unknown: null + type: test.core.pipeline.test_pipeline_base.FakeComponent + Comp2: + init_parameters: + an_init_param: null + type: test.core.pipeline.test_pipeline_base.FakeComponent + connections: + - receiver: Comp2.input_ + sender: Comp1.value + metadata: {} + """ + + with pytest.raises(DeserializationError, match=".*Comp1.*unknown.*"): + pipeline = PipelineBase.loads(invalid_init_parameter_yaml) + + def test_pipeline_dump(self, test_files_path, tmp_path): + pipeline = PipelineBase(max_runs_per_component=99) + pipeline.add_component("Comp1", FakeComponent("Foo")) + pipeline.add_component("Comp2", FakeComponent()) + pipeline.connect("Comp1.value", "Comp2.input_") + with open(tmp_path / "out.yaml", "w") as f: + pipeline.dump(f) + # re-open and ensure it's the same data as the test file + with open(f"{test_files_path}/yaml/test_pipeline.yaml", "r") as test_f, open(tmp_path / "out.yaml", "r") as f: + assert f.read() == test_f.read() + + def test_pipeline_load(self, test_files_path): + with open(f"{test_files_path}/yaml/test_pipeline.yaml", "r") as f: + pipeline = PipelineBase.load(f) + assert pipeline._max_runs_per_component == 99 + assert isinstance(pipeline.get_component("Comp1"), FakeComponent) + assert isinstance(pipeline.get_component("Comp2"), FakeComponent) + + @patch("haystack.core.pipeline.base._to_mermaid_image") + @patch("haystack.core.pipeline.base.is_in_jupyter") + @patch("IPython.display.Image") + @patch("IPython.display.display") + def test_show_in_notebook( + self, mock_ipython_display, mock_ipython_image, mock_is_in_jupyter, mock_to_mermaid_image + ): + pipe = PipelineBase() + + mock_to_mermaid_image.return_value = b"some_image_data" + mock_is_in_jupyter.return_value = True + + pipe.show() + mock_ipython_image.assert_called_once_with(b"some_image_data") + mock_ipython_display.assert_called_once() + + @patch("haystack.core.pipeline.base.is_in_jupyter") + def test_show_not_in_notebook(self, mock_is_in_jupyter): + pipe = PipelineBase() + + mock_is_in_jupyter.return_value = False + + with pytest.raises(PipelineDrawingError): + pipe.show() + + @patch("haystack.core.pipeline.base._to_mermaid_image") + def test_draw(self, mock_to_mermaid_image, tmp_path): + pipe = PipelineBase() + mock_to_mermaid_image.return_value = b"some_image_data" + + image_path = tmp_path / "test.png" + pipe.draw(path=image_path) + assert image_path.read_bytes() == mock_to_mermaid_image.return_value + + def test_find_super_components(self): + """ + Test that the pipeline can find super components in it's pipeline. + """ + from haystack import Pipeline + from haystack.components.converters import MultiFileConverter + from haystack.components.preprocessors import DocumentPreprocessor + from haystack.components.writers import DocumentWriter + from haystack.document_stores.in_memory import InMemoryDocumentStore + + multi_file_converter = MultiFileConverter() + doc_processor = DocumentPreprocessor() + + pipeline = Pipeline() + pipeline.add_component("converter", multi_file_converter) + pipeline.add_component("preprocessor", doc_processor) + pipeline.add_component("writer", DocumentWriter(document_store=InMemoryDocumentStore())) + pipeline.connect("converter", "preprocessor") + pipeline.connect("preprocessor", "writer") + + result = pipeline._find_super_components() + + assert len(result) == 2 + assert [("converter", multi_file_converter), ("preprocessor", doc_processor)] == result + + def test_merge_super_component_pipelines(self): + from haystack import Pipeline + from haystack.components.converters import MultiFileConverter + from haystack.components.preprocessors import DocumentPreprocessor + from haystack.components.writers import DocumentWriter + from haystack.document_stores.in_memory import InMemoryDocumentStore + + multi_file_converter = MultiFileConverter() + doc_processor = DocumentPreprocessor() + + pipeline = Pipeline() + pipeline.add_component("converter", multi_file_converter) + pipeline.add_component("preprocessor", doc_processor) + pipeline.add_component("writer", DocumentWriter(document_store=InMemoryDocumentStore())) + pipeline.connect("converter", "preprocessor") + pipeline.connect("preprocessor", "writer") + + merged_graph, super_component_components = pipeline._merge_super_component_pipelines() + + assert super_component_components == { + "router": "converter", + "docx": "converter", + "html": "converter", + "json": "converter", + "md": "converter", + "text": "converter", + "pdf": "converter", + "pptx": "converter", + "xlsx": "converter", + "joiner": "converter", + "csv": "converter", + "splitter": "preprocessor", + "cleaner": "preprocessor", + } - def to_dict(self) -> Dict[str, Any]: - """ - Serializes the pipeline to a dictionary. + expected_nodes = [ + "cleaner", + "csv", + "docx", + "html", + "joiner", + "json", + "md", + "pdf", + "pptx", + "router", + "splitter", + "text", + "writer", + "xlsx", + ] + assert sorted(merged_graph.nodes) == expected_nodes + + expected_edges = [ + ("cleaner", "writer"), + ("csv", "joiner"), + ("docx", "joiner"), + ("html", "joiner"), + ("joiner", "splitter"), + ("json", "joiner"), + ("md", "joiner"), + ("pdf", "joiner"), + ("pptx", "joiner"), + ("router", "csv"), + ("router", "docx"), + ("router", "html"), + ("router", "json"), + ("router", "md"), + ("router", "pdf"), + ("router", "pptx"), + ("router", "text"), + ("router", "xlsx"), + ("splitter", "cleaner"), + ("text", "joiner"), + ("xlsx", "joiner"), + ] + actual_edges = [(u, v) for u, v, _ in merged_graph.edges] + assert sorted(actual_edges) == expected_edges + + # UNIT + def test_add_invalid_component_name(self): + pipe = PipelineBase() + with pytest.raises(ValueError): + pipe.add_component("this.is.not.a.valida.name", FakeComponent) + with pytest.raises(ValueError): + pipe.add_component("_debug", FakeComponent) + + def test_add_component_to_different_pipelines(self): + first_pipe = PipelineBase() + second_pipe = PipelineBase() + some_component = component_class("Some")() + + assert some_component.__haystack_added_to_pipeline__ is None + first_pipe.add_component("some", some_component) + assert some_component.__haystack_added_to_pipeline__ is first_pipe + + with pytest.raises(PipelineError): + second_pipe.add_component("some", some_component) + + def test_remove_component_raises_if_invalid_component_name(self): + pipe = PipelineBase() + component = component_class("Some")() + + pipe.add_component("1", component) + + with pytest.raises(ValueError): + pipe.remove_component("2") + + def test_remove_component_removes_component_and_its_edges(self): + pipe = PipelineBase() + component_1 = component_class("Type1")() + component_2 = component_class("Type2")() + component_3 = component_class("Type3")() + component_4 = component_class("Type4")() + + pipe.add_component("1", component_1) + pipe.add_component("2", component_2) + pipe.add_component("3", component_3) + pipe.add_component("4", component_4) + + pipe.connect("1", "2") + pipe.connect("2", "3") + pipe.connect("3", "4") + + pipe.remove_component("2") + + assert ["1", "3", "4"] == sorted(pipe.graph.nodes) + assert [("3", "4")] == sorted([(u, v) for (u, v) in pipe.graph.edges()]) + + def test_remove_component_allows_you_to_reuse_the_component(self): + pipe = PipelineBase() + Some = component_class("Some", input_types={"in": int}, output_types={"out": int}) + + pipe.add_component("component_1", Some()) + pipe.add_component("component_2", Some()) + pipe.add_component("component_3", Some()) + pipe.connect("component_1", "component_2") + pipe.connect("component_2", "component_3") + component_2 = pipe.remove_component("component_2") + + assert component_2.__haystack_added_to_pipeline__ is None + assert component_2.__haystack_input__._sockets_dict == {"in": InputSocket(name="in", type=int, senders=[])} + assert component_2.__haystack_output__._sockets_dict == { + "out": OutputSocket(name="out", type=int, receivers=[]) + } - This is meant to be an intermediate representation but it can be also used to save a pipeline to file. + pipe2 = PipelineBase() + pipe2.add_component("component_4", Some()) + pipe2.add_component("component_2", component_2) + pipe2.add_component("component_5", Some()) - :returns: - Dictionary with serialized data. - """ - components = {} - for name, instance in self.graph.nodes(data="instance"): # type:ignore - components[name] = component_to_dict(instance, name) - - connections = [] - for sender, receiver, edge_data in self.graph.edges.data(): - sender_socket = edge_data["from_socket"].name - receiver_socket = edge_data["to_socket"].name - connections.append({"sender": f"{sender}.{sender_socket}", "receiver": f"{receiver}.{receiver_socket}"}) - return { - "metadata": self.metadata, - "max_runs_per_component": self._max_runs_per_component, - "components": components, - "connections": connections, - "connection_type_validation": self._connection_type_validation, + pipe2.connect("component_4", "component_2") + pipe2.connect("component_2", "component_5") + assert component_2.__haystack_added_to_pipeline__ is pipe2 + assert component_2.__haystack_input__._sockets_dict == { + "in": InputSocket(name="in", type=int, senders=["component_4"]) + } + assert component_2.__haystack_output__._sockets_dict == { + "out": OutputSocket(name="out", type=int, receivers=["component_5"]) } - @classmethod - def from_dict( - cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs: Any - ) -> T: - """ - Deserializes the pipeline from a dictionary. - - :param data: - Dictionary to deserialize from. - :param callbacks: - Callbacks to invoke during deserialization. - :param kwargs: - `components`: a dictionary of {name: instance} to reuse instances of components instead of creating new - ones. - :returns: - Deserialized component. - """ - data_copy = _deepcopy_with_exceptions(data) # to prevent modification of original data - metadata = data_copy.get("metadata", {}) - max_runs_per_component = data_copy.get("max_runs_per_component", 100) - connection_type_validation = data_copy.get("connection_type_validation", True) - pipe = cls( - metadata=metadata, - max_runs_per_component=max_runs_per_component, - connection_type_validation=connection_type_validation, + # instance = pipe2.get_component("some") + # assert instance == component + + # UNIT + def test_get_component_name(self): + pipe = PipelineBase() + some_component = component_class("Some")() + pipe.add_component("some", some_component) + + assert pipe.get_component_name(some_component) == "some" + + # UNIT + def test_get_component_name_not_added_to_pipeline(self): + pipe = PipelineBase() + some_component = component_class("Some")() + + assert pipe.get_component_name(some_component) == "" + + # UNIT + def test_repr(self): + pipe = PipelineBase(metadata={"test": "test"}) + pipe.add_component("add_two", AddFixedValue(add=2)) + pipe.add_component("add_default", AddFixedValue()) + pipe.add_component("double", Double()) + pipe.connect("add_two", "double") + pipe.connect("double", "add_default") + + expected_repr = ( + f"{object.__repr__(pipe)}\n" + "🧱 Metadata\n" + " - test: test\n" + "🚅 Components\n" + " - add_two: AddFixedValue\n" + " - add_default: AddFixedValue\n" + " - double: Double\n" + "🛤️ Connections\n" + " - add_two.result -> double.value (int)\n" + " - double.value -> add_default.value (int)\n" ) - components_to_reuse = kwargs.get("components", {}) - for name, component_data in data_copy.get("components", {}).items(): - if name in components_to_reuse: - # Reuse an instance - instance = components_to_reuse[name] - else: - if "type" not in component_data: - raise PipelineError(f"Missing 'type' in component '{name}'") - - if component_data["type"] not in component.registry: - try: - # Import the module first... - module, _ = component_data["type"].rsplit(".", 1) - logger.debug("Trying to import module {module_name}", module_name=module) - type_serialization.thread_safe_import(module) - # ...then try again - if component_data["type"] not in component.registry: - raise PipelineError( - f"Successfully imported module '{module}' but couldn't find " - f"'{component_data['type']}' in the component registry.\n" - f"The component might be registered under a different path. " - f"Here are the registered components:\n {list(component.registry.keys())}\n" - ) - except (ImportError, PipelineError, ValueError) as e: - raise PipelineError( - f"Component '{component_data['type']}' (name: '{name}') not imported. Please " - f"check that the package is installed and the component path is correct." - ) from e - - # Create a new one - component_class = component.registry[component_data["type"]] - - try: - instance = component_from_dict(component_class, component_data, name, callbacks) - except Exception as e: - msg = ( - f"Couldn't deserialize component '{name}' of class '{component_class.__name__}' " - f"with the following data: {str(component_data)}. Possible reasons include " - "malformed serialized data, mismatch between the serialized component and the " - "loaded one (due to a breaking change, see " - "https://github.com/deepset-ai/haystack/releases), etc." - ) - raise DeserializationError(msg) from e - pipe.add_component(name=name, instance=instance) - - for connection in data.get("connections", []): - if "sender" not in connection: - raise PipelineError(f"Missing sender in connection: {connection}") - if "receiver" not in connection: - raise PipelineError(f"Missing receiver in connection: {connection}") - pipe.connect(sender=connection["sender"], receiver=connection["receiver"]) - - return pipe - - def dumps(self, marshaller: Marshaller = DEFAULT_MARSHALLER) -> str: - """ - Returns the string representation of this pipeline according to the format dictated by the `Marshaller` in use. - - :param marshaller: - The Marshaller used to create the string representation. Defaults to `YamlMarshaller`. - :returns: - A string representing the pipeline. - """ - return marshaller.marshal(self.to_dict()) - - def dump(self, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER) -> None: - """ - Writes the string representation of this pipeline to the file-like object passed in the `fp` argument. - - :param fp: - A file-like object ready to be written to. - :param marshaller: - The Marshaller used to create the string representation. Defaults to `YamlMarshaller`. - """ - fp.write(marshaller.marshal(self.to_dict())) - - @classmethod - def loads( - cls: Type[T], - data: Union[str, bytes, bytearray], - marshaller: Marshaller = DEFAULT_MARSHALLER, - callbacks: Optional[DeserializationCallbacks] = None, - ) -> T: - """ - Creates a `Pipeline` object from the string representation passed in the `data` argument. - - :param data: - The string representation of the pipeline, can be `str`, `bytes` or `bytearray`. - :param marshaller: - The Marshaller used to create the string representation. Defaults to `YamlMarshaller`. - :param callbacks: - Callbacks to invoke during deserialization. - :raises DeserializationError: - If an error occurs during deserialization. - :returns: - A `Pipeline` object. - """ - try: - deserialized_data = marshaller.unmarshal(data) - except Exception as e: - raise DeserializationError( - "Error while unmarshalling serialized pipeline data. This is usually " - "caused by malformed or invalid syntax in the serialized representation." - ) from e - - return cls.from_dict(deserialized_data, callbacks) - - @classmethod - def load( - cls: Type[T], - fp: TextIO, - marshaller: Marshaller = DEFAULT_MARSHALLER, - callbacks: Optional[DeserializationCallbacks] = None, - ) -> T: - """ - Creates a `Pipeline` object a string representation. - The string representation is read from the file-like object passed in the `fp` argument. - - - :param fp: - A file-like object ready to be read from. - :param marshaller: - The Marshaller used to create the string representation. Defaults to `YamlMarshaller`. - :param callbacks: - Callbacks to invoke during deserialization. - :raises DeserializationError: - If an error occurs during deserialization. - :returns: - A `Pipeline` object. - """ - return cls.loads(fp.read(), marshaller, callbacks) - - def add_component(self, name: str, instance: Component) -> None: - """ - Add the given component to the pipeline. - - Components are not connected to anything by default: use `Pipeline.connect()` to connect components together. - Component names must be unique, but component instances can be reused if needed. - - :param name: - The name of the component to add. - :param instance: - The component instance to add. - - :raises ValueError: - If a component with the same name already exists. - :raises PipelineValidationError: - If the given instance is not a component. - """ - # Component names are unique - if name in self.graph.nodes: - raise ValueError(f"A component named '{name}' already exists in this pipeline: choose another name.") - - # Components can't be named `_debug` - if name == "_debug": - raise ValueError("'_debug' is a reserved name for debug output. Choose another name.") - - # Component names can't have "." - if "." in name: - raise ValueError(f"{name} is an invalid component name, cannot contain '.' (dot) characters.") - - # Component instances must be components - if not isinstance(instance, Component): - raise PipelineValidationError( - f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?" - ) - - if getattr(instance, "__haystack_added_to_pipeline__", None): - msg = ( - "Component has already been added in another Pipeline. Components can't be shared between Pipelines. " - "Create a new instance instead." - ) - raise PipelineError(msg) - - setattr(instance, "__haystack_added_to_pipeline__", self) - setattr(instance, "__component_name__", name) - - # Add component to the graph, disconnected - logger.debug("Adding component '{component_name}' ({component})", component_name=name, component=instance) - # We're completely sure the fields exist so we ignore the type error - self.graph.add_node( - name, - instance=instance, - input_sockets=instance.__haystack_input__._sockets_dict, # type: ignore[attr-defined] - output_sockets=instance.__haystack_output__._sockets_dict, # type: ignore[attr-defined] - visits=0, + assert repr(pipe) == expected_repr + + # UNIT + def test_to_dict(self): + add_two = AddFixedValue(add=2) + add_default = AddFixedValue() + double = Double() + pipe = PipelineBase(metadata={"test": "test"}, max_runs_per_component=42) + pipe.add_component("add_two", add_two) + pipe.add_component("add_default", add_default) + pipe.add_component("double", double) + pipe.connect("add_two", "double") + pipe.connect("double", "add_default") + + res = pipe.to_dict() + expected = { + "metadata": {"test": "test"}, + "max_runs_per_component": 42, + "connection_type_validation": True, + "components": { + "add_two": { + "type": "haystack.testing.sample_components.add_value.AddFixedValue", + "init_parameters": {"add": 2}, + }, + "add_default": { + "type": "haystack.testing.sample_components.add_value.AddFixedValue", + "init_parameters": {"add": 1}, + }, + "double": {"type": "haystack.testing.sample_components.double.Double", "init_parameters": {}}, + }, + "connections": [ + {"sender": "add_two.result", "receiver": "double.value"}, + {"sender": "double.value", "receiver": "add_default.value"}, + ], + } + assert res == expected + + def test_from_dict(self): + data = { + "metadata": {"test": "test"}, + "max_runs_per_component": 101, + "components": { + "add_two": { + "type": "haystack.testing.sample_components.add_value.AddFixedValue", + "init_parameters": {"add": 2}, + }, + "add_default": { + "type": "haystack.testing.sample_components.add_value.AddFixedValue", + "init_parameters": {"add": 1}, + }, + "double": {"type": "haystack.testing.sample_components.double.Double", "init_parameters": {}}, + }, + "connections": [ + {"sender": "add_two.result", "receiver": "double.value"}, + {"sender": "double.value", "receiver": "add_default.value"}, + ], + } + pipe = PipelineBase.from_dict(data) + + assert pipe.metadata == {"test": "test"} + assert pipe._max_runs_per_component == 101 + + # Components + assert len(pipe.graph.nodes) == 3 + ## add_two + add_two = pipe.graph.nodes["add_two"] + assert add_two["instance"].add == 2 + assert add_two["input_sockets"] == { + "value": InputSocket(name="value", type=int), + "add": InputSocket(name="add", type=Optional[int], default_value=None), + } + assert add_two["output_sockets"] == {"result": OutputSocket(name="result", type=int, receivers=["double"])} + assert add_two["visits"] == 0 + + ## add_default + add_default = pipe.graph.nodes["add_default"] + assert add_default["instance"].add == 1 + assert add_default["input_sockets"] == { + "value": InputSocket(name="value", type=int, senders=["double"]), + "add": InputSocket(name="add", type=Optional[int], default_value=None), + } + assert add_default["output_sockets"] == {"result": OutputSocket(name="result", type=int)} + assert add_default["visits"] == 0 + + ## double + double = pipe.graph.nodes["double"] + assert double["instance"] + assert double["input_sockets"] == {"value": InputSocket(name="value", type=int, senders=["add_two"])} + assert double["output_sockets"] == {"value": OutputSocket(name="value", type=int, receivers=["add_default"])} + assert double["visits"] == 0 + + # Connections + connections = list(pipe.graph.edges(data=True)) + assert len(connections) == 2 + assert connections[0] == ( + "add_two", + "double", + { + "conn_type": "int", + "from_socket": OutputSocket(name="result", type=int, receivers=["double"]), + "to_socket": InputSocket(name="value", type=int, senders=["add_two"]), + "mandatory": True, + }, + ) + assert connections[1] == ( + "double", + "add_default", + { + "conn_type": "int", + "from_socket": OutputSocket(name="value", type=int, receivers=["add_default"]), + "to_socket": InputSocket(name="value", type=int, senders=["double"]), + "mandatory": True, + }, ) - def remove_component(self, name: str) -> Component: - """ - Remove and returns component from the pipeline. - - Remove an existing component from the pipeline by providing its name. - All edges that connect to the component will also be deleted. - - :param name: - The name of the component to remove. - :returns: - The removed Component instance. - - :raises ValueError: - If there is no component with that name already in the Pipeline. - """ - - # Check that a component with that name is in the Pipeline - try: - instance = self.get_component(name) - except ValueError as exc: - raise ValueError( - f"There is no component named '{name}' in the pipeline. The valid component names are: ", - ", ".join(n for n in self.graph.nodes), - ) from exc - - # Delete component from the graph, deleting all its connections - self.graph.remove_node(name) - - # Reset the Component sockets' senders and receivers - input_sockets = instance.__haystack_input__._sockets_dict # type: ignore[attr-defined] - for socket in input_sockets.values(): - socket.senders = [] - - output_sockets = instance.__haystack_output__._sockets_dict # type: ignore[attr-defined] - for socket in output_sockets.values(): - socket.receivers = [] - - # Reset the Component's pipeline reference - setattr(instance, "__haystack_added_to_pipeline__", None) - - return instance + # TODO: Remove this, this should be a component test. + # The pipeline can't handle this in any case nor way. + def test_from_dict_with_callbacks(self): + data = { + "metadata": {"test": "test"}, + "components": { + "add_two": { + "type": "haystack.testing.sample_components.add_value.AddFixedValue", + "init_parameters": {"add": 2}, + }, + "add_default": { + "type": "haystack.testing.sample_components.add_value.AddFixedValue", + "init_parameters": {"add": 1}, + }, + "double": {"type": "haystack.testing.sample_components.double.Double", "init_parameters": {}}, + "greet": { + "type": "haystack.testing.sample_components.greet.Greet", + "init_parameters": {"message": "test"}, + }, + }, + "connections": [ + {"sender": "add_two.result", "receiver": "double.value"}, + {"sender": "double.value", "receiver": "add_default.value"}, + ], + } - def connect(self, sender: str, receiver: str) -> "PipelineBase": # noqa: PLR0915 PLR0912 - """ - Connects two components together. + components_seen_in_callback = [] - All components to connect must exist in the pipeline. - If connecting to a component that has several output connections, specify the inputs and output names as - 'component_name.connections_name'. + def component_pre_init_callback(name, component_cls, init_params): + assert name in ["add_two", "add_default", "double", "greet"] + assert component_cls in [AddFixedValue, Double, Greet] - :param sender: - The component that delivers the value. This can be either just a component name or can be - in the format `component_name.connection_name` if the component has multiple outputs. - :param receiver: - The component that receives the value. This can be either just a component name or can be - in the format `component_name.connection_name` if the component has multiple inputs. + if name == "add_two": + assert init_params == {"add": 2} + elif name == "add_default": + assert init_params == {"add": 1} + elif name == "greet": + assert init_params == {"message": "test"} - :returns: - The Pipeline instance. + components_seen_in_callback.append(name) - :raises PipelineConnectError: - If the two components cannot be connected (for example if one of the components is - not present in the pipeline, or the connections don't match by type, and so on). - """ - # Edges may be named explicitly by passing 'node_name.edge_name' to connect(). - sender_component_name, sender_socket_name = parse_connect_string(sender) - receiver_component_name, receiver_socket_name = parse_connect_string(receiver) - - if sender_component_name == receiver_component_name: - raise PipelineConnectError("Connecting a Component to itself is not supported.") - - # Get the nodes data. - try: - sender_sockets = self.graph.nodes[sender_component_name]["output_sockets"] - except KeyError as exc: - raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc - try: - receiver_sockets = self.graph.nodes[receiver_component_name]["input_sockets"] - except KeyError as exc: - raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc - - # If the name of either socket is given, get the socket - sender_socket: Optional[OutputSocket] = None - if sender_socket_name: - sender_socket = sender_sockets.get(sender_socket_name) - if not sender_socket: - raise PipelineConnectError( - f"'{sender} does not exist. " - f"Output connections of {sender_component_name} are: " - + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in sender_sockets.items()]) - ) - - receiver_socket: Optional[InputSocket] = None - if receiver_socket_name: - receiver_socket = receiver_sockets.get(receiver_socket_name) - if not receiver_socket: - raise PipelineConnectError( - f"'{receiver} does not exist. " - f"Input connections of {receiver_component_name} are: " - + ", ".join( - [f"{name} (type {_type_name(socket.type)})" for name, socket in receiver_sockets.items()] - ) - ) - - # Look for a matching connection among the possible ones. - # Note that if there is more than one possible connection but two sockets match by name, they're paired. - sender_socket_candidates: List[OutputSocket] = ( - [sender_socket] if sender_socket else list(sender_sockets.values()) + pipe = PipelineBase.from_dict( + data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback) ) - receiver_socket_candidates: List[InputSocket] = ( - [receiver_socket] if receiver_socket else list(receiver_sockets.values()) + assert components_seen_in_callback == ["add_two", "add_default", "double", "greet"] + add_two = pipe.graph.nodes["add_two"]["instance"] + assert add_two.add == 2 + add_default = pipe.graph.nodes["add_default"]["instance"] + assert add_default.add == 1 + greet = pipe.graph.nodes["greet"]["instance"] + assert greet.message == "test" + assert greet.log_level == "INFO" + + def component_pre_init_callback_modify(name, component_cls, init_params): + assert name in ["add_two", "add_default", "double", "greet"] + assert component_cls in [AddFixedValue, Double, Greet] + + if name == "add_two": + init_params["add"] = 3 + elif name == "add_default": + init_params["add"] = 0 + elif name == "greet": + init_params["message"] = "modified test" + init_params["log_level"] = "DEBUG" + + pipe = PipelineBase.from_dict( + data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_modify) ) - - # Find all possible connections between these two components - possible_connections = [] - for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates): - if _types_are_compatible(sender_sock.type, receiver_sock.type, self._connection_type_validation): - possible_connections.append((sender_sock, receiver_sock)) - - # We need this status for error messages, since we might need it in multiple places we calculate it here - status = _connections_status( - sender_node=sender_component_name, - sender_sockets=sender_socket_candidates, - receiver_node=receiver_component_name, - receiver_sockets=receiver_socket_candidates, + add_two = pipe.graph.nodes["add_two"]["instance"] + assert add_two.add == 3 + add_default = pipe.graph.nodes["add_default"]["instance"] + assert add_default.add == 0 + greet = pipe.graph.nodes["greet"]["instance"] + assert greet.message == "modified test" + assert greet.log_level == "DEBUG" + + # Test with a component that internally instantiates another component + def component_pre_init_callback_check_class(name, component_cls, init_params): + assert name == "fake_component_squared" + assert component_cls == FakeComponentSquared + + pipe = PipelineBase() + pipe.add_component("fake_component_squared", FakeComponentSquared()) + pipe = PipelineBase.from_dict( + pipe.to_dict(), + callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_check_class), ) - - if not possible_connections: - # There's no possible connection between these two components - if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1: - msg = ( - f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with " - f"'{receiver_component_name}.{receiver_socket_candidates[0].name}': " - f"their declared input and output types do not match.\n{status}" - ) - else: - msg = ( - f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': " - f"no matching connections available.\n{status}" - ) - raise PipelineConnectError(msg) - - if len(possible_connections) == 1: - # There's only one possible connection, use it - sender_socket = possible_connections[0][0] - receiver_socket = possible_connections[0][1] - - if len(possible_connections) > 1: - # There are multiple possible connection, let's try to match them by name - name_matches = [ - (out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name - ] - if len(name_matches) != 1: - # There's are either no matches or more than one, we can't pick one reliably - msg = ( - f"Cannot connect '{sender_component_name}' with " - f"'{receiver_component_name}': more than one connection is possible " - "between these components. Please specify the connection name, like: " - f"pipeline.connect('{sender_component_name}.{possible_connections[0][0].name}', " - f"'{receiver_component_name}.{possible_connections[0][1].name}').\n{status}" - ) - raise PipelineConnectError(msg) - - # Get the only possible match - sender_socket = name_matches[0][0] - receiver_socket = name_matches[0][1] - - # Connection must be valid on both sender/receiver sides - if not sender_socket or not receiver_socket or not sender_component_name or not receiver_component_name: - if sender_component_name and sender_socket: - sender_repr = f"{sender_component_name}.{sender_socket.name} ({_type_name(sender_socket.type)})" - else: - sender_repr = "input needed" - - if receiver_component_name and receiver_socket: - receiver_repr = f"({_type_name(receiver_socket.type)}) {receiver_component_name}.{receiver_socket.name}" - else: - receiver_repr = "output" - msg = f"Connection must have both sender and receiver: {sender_repr} -> {receiver_repr}" - raise PipelineConnectError(msg) - - logger.debug( - "Connecting '{sender_component}.{sender_socket_name}' to '{receiver_component}.{receiver_socket_name}'", - sender_component=sender_component_name, - sender_socket_name=sender_socket.name, - receiver_component=receiver_component_name, - receiver_socket_name=receiver_socket.name, + assert type(pipe.graph.nodes["fake_component_squared"]["instance"].inner) == FakeComponent + + # UNIT + def test_from_dict_with_empty_dict(self): + assert PipelineBase() == PipelineBase.from_dict({}) + + # TODO: UNIT, consider deprecating this argument + def test_from_dict_with_components_instances(self): + add_two = AddFixedValue(add=2) + add_default = AddFixedValue() + components = {"add_two": add_two, "add_default": add_default} + data = { + "metadata": {"test": "test"}, + "components": { + "add_two": {}, + "add_default": {}, + "double": {"type": "haystack.testing.sample_components.double.Double", "init_parameters": {}}, + }, + "connections": [ + {"sender": "add_two.result", "receiver": "double.value"}, + {"sender": "double.value", "receiver": "add_default.value"}, + ], + } + pipe = PipelineBase.from_dict(data, components=components) + assert pipe.metadata == {"test": "test"} + + # Components + assert len(pipe.graph.nodes) == 3 + ## add_two + add_two_data = pipe.graph.nodes["add_two"] + assert add_two_data["instance"] is add_two + assert add_two_data["instance"].add == 2 + assert add_two_data["input_sockets"] == { + "value": InputSocket(name="value", type=int), + "add": InputSocket(name="add", type=Optional[int], default_value=None), + } + assert add_two_data["output_sockets"] == {"result": OutputSocket(name="result", type=int, receivers=["double"])} + assert add_two_data["visits"] == 0 + + ## add_default + add_default_data = pipe.graph.nodes["add_default"] + assert add_default_data["instance"] is add_default + assert add_default_data["instance"].add == 1 + assert add_default_data["input_sockets"] == { + "value": InputSocket(name="value", type=int, senders=["double"]), + "add": InputSocket(name="add", type=Optional[int], default_value=None), + } + assert add_default_data["output_sockets"] == {"result": OutputSocket(name="result", type=int, receivers=[])} + assert add_default_data["visits"] == 0 + + ## double + double = pipe.graph.nodes["double"] + assert double["instance"] + assert double["input_sockets"] == {"value": InputSocket(name="value", type=int, senders=["add_two"])} + assert double["output_sockets"] == {"value": OutputSocket(name="value", type=int, receivers=["add_default"])} + assert double["visits"] == 0 + + # Connections + connections = list(pipe.graph.edges(data=True)) + assert len(connections) == 2 + assert connections[0] == ( + "add_two", + "double", + { + "conn_type": "int", + "from_socket": OutputSocket(name="result", type=int, receivers=["double"]), + "to_socket": InputSocket(name="value", type=int, senders=["add_two"]), + "mandatory": True, + }, ) - - if receiver_component_name in sender_socket.receivers and sender_component_name in receiver_socket.senders: - # This is already connected, nothing to do - return self - - if receiver_socket.senders and not receiver_socket.is_variadic: - # Only variadic input sockets can receive from multiple senders - msg = ( - f"Cannot connect '{sender_component_name}.{sender_socket.name}' with " - f"'{receiver_component_name}.{receiver_socket.name}': " - f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n" - ) - raise PipelineConnectError(msg) - - # Update the sockets with the new connection - sender_socket.receivers.append(receiver_component_name) - receiver_socket.senders.append(sender_component_name) - - # Create the new connection - self.graph.add_edge( - sender_component_name, - receiver_component_name, - key=f"{sender_socket.name}/{receiver_socket.name}", - conn_type=_type_name(sender_socket.type), - from_socket=sender_socket, - to_socket=receiver_socket, - mandatory=receiver_socket.is_mandatory, + assert connections[1] == ( + "double", + "add_default", + { + "conn_type": "int", + "from_socket": OutputSocket(name="value", type=int, receivers=["add_default"]), + "to_socket": InputSocket(name="value", type=int, senders=["double"]), + "mandatory": True, + }, ) - return self - def get_component(self, name: str) -> Component: - """ - Get the component with the specified name from the pipeline. + # UNIT + def test_from_dict_without_component_type(self): + data = { + "metadata": {"test": "test"}, + "components": {"add_two": {"init_parameters": {"add": 2}}}, + "connections": [], + } + with pytest.raises(PipelineError) as err: + PipelineBase.from_dict(data) - :param name: - The name of the component. - :returns: - The instance of that component. + err.match("Missing 'type' in component 'add_two'") - :raises ValueError: - If a component with that name is not present in the pipeline. - """ - try: - return self.graph.nodes[name]["instance"] - except KeyError as exc: - raise ValueError(f"Component named {name} not found in the pipeline.") from exc + # UNIT + def test_from_dict_without_registered_component_type(self): + data = { + "metadata": {"test": "test"}, + "components": {"add_two": {"type": "foo.bar.baz", "init_parameters": {"add": 2}}}, + "connections": [], + } + with pytest.raises(PipelineError) as err: + PipelineBase.from_dict(data) - def get_component_name(self, instance: Component) -> str: - """ - Returns the name of the Component instance if it has been added to this Pipeline or an empty string otherwise. + err.match(r"Component .+ not imported.") - :param instance: - The Component instance to look for. - :returns: - The name of the Component instance. - """ - for name, inst in self.graph.nodes(data="instance"): # type: ignore # type wrongly defined in networkx - if inst == instance: - return name - return "" + def test_from_dict_with_invalid_type(self): + data = { + "metadata": {"test": "test"}, + "components": {"add_two": {"type": "", "init_parameters": {"add": 2}}}, + "connections": [], + } + with pytest.raises(PipelineError) as err: + PipelineBase.from_dict(data) - def inputs(self, include_components_with_connected_inputs: bool = False) -> Dict[str, Dict[str, Any]]: - """ - Returns a dictionary containing the inputs of a pipeline. + err.match( + r"Component '' \(name: 'add_two'\) not imported. Please check that the package is installed and the component path is correct." + ) - Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes - the input sockets of that component, including their types and whether they are optional. + def test_from_dict_with_correct_import_but_invalid_type(self): + # Test case: Module imports but component not found in registry. + data_registry_error = { + "metadata": {"test": "test"}, + "components": {"add_two": {"type": "haystack.testing.NonExistentComponent", "init_parameters": {"add": 2}}}, + "connections": [], + } - :param include_components_with_connected_inputs: - If `False`, only components that have disconnected input edges are - included in the output. - :returns: - A dictionary where each key is a pipeline component name and each value is a dictionary of - inputs sockets of that component. - """ - inputs: Dict[str, Dict[str, Any]] = {} - for component_name, data in find_pipeline_inputs(self.graph, include_components_with_connected_inputs).items(): - sockets_description = {} - for socket in data: - sockets_description[socket.name] = {"type": socket.type, "is_mandatory": socket.is_mandatory} - if not socket.is_mandatory: - sockets_description[socket.name]["default_value"] = socket.default_value - - if sockets_description: - inputs[component_name] = sockets_description - return inputs - - def outputs(self, include_components_with_connected_outputs: bool = False) -> Dict[str, Dict[str, Any]]: - """ - Returns a dictionary containing the outputs of a pipeline. + # Patch thread_safe_import so it doesn't raise an ImportError. + with patch("haystack.utils.type_serialization.thread_safe_import") as mock_import: + mock_import.return_value = None + with pytest.raises(PipelineError) as err_info: + PipelineBase.from_dict(data_registry_error) + outer_message = str(err_info.value) + inner_message = str(err_info.value.__cause__) + + assert "Component 'haystack.testing.NonExistentComponent' (name: 'add_two') not imported." in outer_message + assert "Successfully imported module 'haystack.testing' but couldn't find" in inner_message + assert "in the component registry." in inner_message + assert "registered under a different path." in inner_message + + # UNIT + def test_from_dict_without_connection_sender(self): + data = {"metadata": {"test": "test"}, "components": {}, "connections": [{"receiver": "some.receiver"}]} + with pytest.raises(PipelineError) as err: + PipelineBase.from_dict(data) + + err.match("Missing sender in connection: {'receiver': 'some.receiver'}") + + # UNIT + def test_from_dict_without_connection_receiver(self): + data = {"metadata": {"test": "test"}, "components": {}, "connections": [{"sender": "some.sender"}]} + with pytest.raises(PipelineError) as err: + PipelineBase.from_dict(data) + + err.match("Missing receiver in connection: {'sender': 'some.sender'}") + + def test_describe_input_only_no_inputs_components(self): + A = component_class("A", input_types={}, output={"x": 0}) + B = component_class("B", input_types={}, output={"y": 0}) + C = component_class("C", input_types={"x": int, "y": int}, output={"z": 0}) + p = PipelineBase() + p.add_component("a", A()) + p.add_component("b", B()) + p.add_component("c", C()) + p.connect("a.x", "c.x") + p.connect("b.y", "c.y") + assert p.inputs() == {} + assert p.inputs(include_components_with_connected_inputs=True) == { + "c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}} + } - Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes - the output sockets of that component. + def test_describe_input_some_components_with_no_inputs(self): + A = component_class("A", input_types={}, output={"x": 0}) + B = component_class("B", input_types={"y": int}, output={"y": 0}) + C = component_class("C", input_types={"x": int, "y": int}, output={"z": 0}) + p = PipelineBase() + p.add_component("a", A()) + p.add_component("b", B()) + p.add_component("c", C()) + p.connect("a.x", "c.x") + p.connect("b.y", "c.y") + assert p.inputs() == {"b": {"y": {"type": int, "is_mandatory": True}}} + assert p.inputs(include_components_with_connected_inputs=True) == { + "b": {"y": {"type": int, "is_mandatory": True}}, + "c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}, + } - :param include_components_with_connected_outputs: - If `False`, only components that have disconnected output edges are - included in the output. - :returns: - A dictionary where each key is a pipeline component name and each value is a dictionary of - output sockets of that component. - """ - outputs = { - comp: {socket.name: {"type": socket.type} for socket in data} - for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items() - if data + def test_describe_input_all_components_have_inputs(self): + A = component_class("A", input_types={"x": Optional[int]}, output={"x": 0}) + B = component_class("B", input_types={"y": int}, output={"y": 0}) + C = component_class("C", input_types={"x": int, "y": int}, output={"z": 0}) + p = PipelineBase() + p.add_component("a", A()) + p.add_component("b", B()) + p.add_component("c", C()) + p.connect("a.x", "c.x") + p.connect("b.y", "c.y") + assert p.inputs() == { + "a": {"x": {"type": Optional[int], "is_mandatory": True}}, + "b": {"y": {"type": int, "is_mandatory": True}}, + } + assert p.inputs(include_components_with_connected_inputs=True) == { + "a": {"x": {"type": Optional[int], "is_mandatory": True}}, + "b": {"y": {"type": int, "is_mandatory": True}}, + "c": {"x": {"type": int, "is_mandatory": True}, "y": {"type": int, "is_mandatory": True}}, } - return outputs - @args_deprecated - def show( - self, - server_url: str = "https://mermaid.ink", - params: Optional[dict] = None, - timeout: int = 30, - super_component_expansion: bool = False, - ) -> None: + def test_describe_output_multiple_possible(self): """ - Display an image representing this `Pipeline` in a Jupyter notebook. - - This function generates a diagram of the `Pipeline` using a Mermaid server and displays it directly in - the notebook. - - :param server_url: - The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink'). - See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more - info on how to set up your own Mermaid server. - - :param params: - Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details - Supported keys: - - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'. - - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'. - - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'. - - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white'). - - width: Width of the output image (integer). - - height: Height of the output image (integer). - - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified. - - fit: Whether to fit the diagram size to the page (PDF only, boolean). - - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true. - - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true. - - :param timeout: - Timeout in seconds for the request to the Mermaid server. - - :param super_component_expansion: - If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of - super-components as if they were components part of the pipeline instead of a "black-box". - Otherwise, only the super-component itself will be displayed. - - :raises PipelineDrawingError: - If the function is called outside of a Jupyter notebook or if there is an issue with rendering. + This pipeline has two outputs: + {"b": {"output_b": {"type": str}}, "a": {"output_a": {"type": str}}} """ + A = component_class("A", input_types={"input_a": str}, output={"output_a": "str", "output_b": "str"}) + B = component_class("B", input_types={"input_b": str}, output={"output_b": "str"}) - # Call the internal implementation with keyword arguments - self._show_internal( - server_url=server_url, params=params, timeout=timeout, super_component_expansion=super_component_expansion - ) - - def _show_internal( - self, - *, - server_url: str = "https://mermaid.ink", - params: Optional[dict] = None, - timeout: int = 30, - super_component_expansion: bool = False, - ) -> None: - """ - Internal implementation of show() that uses keyword-only arguments. + pipe = PipelineBase() + pipe.add_component("a", A()) + pipe.add_component("b", B()) + pipe.connect("a.output_b", "b.input_b") - ToDo: after 2.14.0 release make this the main function and remove the old one. - """ - if is_in_jupyter(): - from IPython.display import Image, display # type: ignore - - if super_component_expansion: - graph, super_component_mapping = self._merge_super_component_pipelines() - else: - graph = self.graph - super_component_mapping = None - - image_data = _to_mermaid_image( - graph, - server_url=server_url, - params=params, - timeout=timeout, - super_component_mapping=super_component_mapping, - ) - display(Image(image_data)) - else: - msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally." - raise PipelineDrawingError(msg) + assert pipe.outputs() == {"b": {"output_b": {"type": str}}, "a": {"output_a": {"type": str}}} + assert pipe.outputs(include_components_with_connected_outputs=True) == { + "a": {"output_a": {"type": str}, "output_b": {"type": str}}, + "b": {"output_b": {"type": str}}, + } - @args_deprecated - def draw( # pylint: disable=too-many-positional-arguments - self, - path: Path, - server_url: str = "https://mermaid.ink", - params: Optional[dict] = None, - timeout: int = 30, - super_component_expansion: bool = False, - ) -> None: - """ - Save an image representing this `Pipeline` to the specified file path. - - This function generates a diagram of the `Pipeline` using the Mermaid server and saves it to the provided path. - - :param path: - The file path where the generated image will be saved. - - :param server_url: - The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink'). - See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more - info on how to set up your own Mermaid server. - - :param params: - Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details - Supported keys: - - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'. - - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'. - - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'. - - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white'). - - width: Width of the output image (integer). - - height: Height of the output image (integer). - - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified. - - fit: Whether to fit the diagram size to the page (PDF only, boolean). - - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true. - - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true. - - :param timeout: - Timeout in seconds for the request to the Mermaid server. - - :param super_component_expansion: - If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of - super-components as if they were components part of the pipeline instead of a "black-box". - Otherwise, only the super-component itself will be displayed. - - :raises PipelineDrawingError: - If there is an issue with rendering or saving the image. - """ + def test_describe_output_single(self): + """ + This pipeline has one output: + {"c": {"z": {"type": int}}} + """ + A = component_class("A", input_types={"x": Optional[int]}, output={"x": 0}) + B = component_class("B", input_types={"y": int}, output={"y": 0}) + C = component_class("C", input_types={"x": int, "y": int}, output={"z": 0}) + p = PipelineBase() + p.add_component("a", A()) + p.add_component("b", B()) + p.add_component("c", C()) + p.connect("a.x", "c.x") + p.connect("b.y", "c.y") + + assert p.outputs() == {"c": {"z": {"type": int}}} + assert p.outputs(include_components_with_connected_outputs=True) == { + "a": {"x": {"type": int}}, + "b": {"y": {"type": int}}, + "c": {"z": {"type": int}}, + } - # Call the internal implementation with keyword arguments - self._draw_internal( - path=path, - server_url=server_url, - params=params, - timeout=timeout, - super_component_expansion=super_component_expansion, - ) + def test_describe_no_outputs(self): + """ + This pipeline sets up elaborate connections between three components but in fact it has no outputs: + Check that p.outputs() == {} + """ + A = component_class("A", input_types={"x": Optional[int]}, output={"x": 0}) + B = component_class("B", input_types={"y": int}, output={"y": 0}) + C = component_class("C", input_types={"x": int, "y": int}, output={}) + p = PipelineBase() + p.add_component("a", A()) + p.add_component("b", B()) + p.add_component("c", C()) + p.connect("a.x", "c.x") + p.connect("b.y", "c.y") + assert p.outputs() == {} + assert p.outputs(include_components_with_connected_outputs=True) == { + "a": {"x": {"type": int}}, + "b": {"y": {"type": int}}, + } - def _draw_internal( - self, - *, - path: Path, - server_url: str = "https://mermaid.ink", - params: Optional[dict] = None, - timeout: int = 30, - super_component_expansion: bool = False, - ) -> None: - """ - Internal implementation of draw() that uses keyword-only arguments. + def test_from_template(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake_key") + pipe = PipelineBase.from_template(PredefinedPipeline.INDEXING) + assert pipe.get_component("cleaner") + + def test_walk_pipeline_with_no_cycles(self): + """ + This pipeline has two source nodes, source1 and source2, one hello3 node in between, and one sink node, joiner. + pipeline.walk() should return each component exactly once. The order is not guaranteed. + """ + + @component + class Hello: + @component.output_types(output=str) + def run(self, word: str): + """ + Takes a string in input and returns "Hello, !" in output. + """ + return {"output": f"Hello, {word}!"} + + @component + class Joiner: + @component.output_types(output=str) + def run(self, word1: str, word2: str): + """ + Takes two strings in input and returns "Hello, and !" in output. + """ + return {"output": f"Hello, {word1} and {word2}!"} + + pipeline = PipelineBase() + source1 = Hello() + source2 = Hello() + hello3 = Hello() + joiner = Joiner() + pipeline.add_component("source1", source1) + pipeline.add_component("source2", source2) + pipeline.add_component("hello3", hello3) + pipeline.add_component("joiner", joiner) + + pipeline.connect("source1", "joiner.word1") + pipeline.connect("source2", "hello3") + pipeline.connect("hello3", "joiner.word2") + + expected_components = [("source1", source1), ("source2", source2), ("joiner", joiner), ("hello3", hello3)] + assert sorted(expected_components) == sorted(pipeline.walk()) + + def test_walk_pipeline_with_cycles(self): + """ + This pipeline consists of two components, which would run three times in a loop. + pipeline.walk() should return these components exactly once. The order is not guaranteed. + """ + + @component + class Hello: + def __init__(self): + self.iteration_counter = 0 + + @component.output_types(intermediate=str, final=str) + def run(self, word: str, intermediate: Optional[str] = None): + """ + Takes a string in input and returns "Hello, !" in output. + """ + if self.iteration_counter < 3: + self.iteration_counter += 1 + return {"intermediate": f"Hello, {intermediate or word}!"} + return {"final": f"Hello, {intermediate or word}!"} + + pipeline = PipelineBase() + hello = Hello() + hello_again = Hello() + pipeline.add_component("hello", hello) + pipeline.add_component("hello_again", hello_again) + pipeline.connect("hello.intermediate", "hello_again.intermediate") + pipeline.connect("hello_again.intermediate", "hello.intermediate") + assert {("hello", hello), ("hello_again", hello_again)} == set(pipeline.walk()) + + def test__prepare_component_input_data(self): + MockComponent = component_class("MockComponent", input_types={"x": List[str], "y": str}) + pipe = PipelineBase() + pipe.add_component("first_mock", MockComponent()) + pipe.add_component("second_mock", MockComponent()) + + res = pipe._prepare_component_input_data({"x": ["some data"], "y": "some other data"}) + assert res == { + "first_mock": {"x": ["some data"], "y": "some other data"}, + "second_mock": {"x": ["some data"], "y": "some other data"}, + } + assert id(res["first_mock"]["x"]) != id(res["second_mock"]["x"]) - ToDo: after 2.14.0 release make this the main function and remove the old one. - """ - # Before drawing we edit a bit the graph, to avoid modifying the original that is - # used for running the pipeline we copy it. - if super_component_expansion: - graph, super_component_mapping = self._merge_super_component_pipelines() - else: - graph = self.graph - super_component_mapping = None - - image_data = _to_mermaid_image( - graph, - server_url=server_url, - params=params, - timeout=timeout, - super_component_mapping=super_component_mapping, + def test__prepare_component_input_data_with_connected_inputs(self): + MockComponent = component_class( + "MockComponent", input_types={"x": List[str], "y": str}, output_types={"z": str} + ) + pipe = PipelineBase() + pipe.add_component("first_mock", MockComponent()) + pipe.add_component("second_mock", MockComponent()) + pipe.connect("first_mock.z", "second_mock.y") + + res = pipe._prepare_component_input_data({"x": ["some data"], "y": "some other data"}) + assert res == {"first_mock": {"x": ["some data"], "y": "some other data"}, "second_mock": {"x": ["some data"]}} + assert id(res["first_mock"]["x"]) != id(res["second_mock"]["x"]) + + def test__prepare_component_input_data_with_non_existing_input(self, caplog): + pipe = PipelineBase() + res = pipe._prepare_component_input_data({"input_name": 1}) + assert res == {} + assert ( + "Inputs ['input_name'] were not matched to any component inputs, " + "please check your run parameters." in caplog.text ) - Path(path).write_bytes(image_data) - - def walk(self) -> Iterator[Tuple[str, Component]]: - """ - Visits each component in the pipeline exactly once and yields its name and instance. - - No guarantees are provided on the visiting order. - - :returns: - An iterator of tuples of component name and component instance. - """ - for component_name, instance in self.graph.nodes(data="instance"): # type: ignore # type is wrong in networkx - yield component_name, instance - - def warm_up(self) -> None: - """ - Make sure all nodes are warm. - It's the node's responsibility to make sure this method can be called at every `Pipeline.run()` - without re-initializing everything. - """ - for node in self.graph.nodes: - if hasattr(self.graph.nodes[node]["instance"], "warm_up"): - logger.info("Warming up component {node}...", node=node) - self.graph.nodes[node]["instance"].warm_up() - - @staticmethod - def _create_component_span( - component_name: str, instance: Component, inputs: Dict[str, Any], parent_span: Optional[tracing.Span] = None - ) -> ContextManager[tracing.Span]: - return tracing.tracer.trace( - "haystack.component.run", - tags={ - "haystack.component.name": component_name, - "haystack.component.type": instance.__class__.__name__, - "haystack.component.input_types": {k: type(v).__name__ for k, v in inputs.items()}, - "haystack.component.input_spec": { - key: { - "type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)), - "senders": value.senders, - } - for key, value in instance.__haystack_input__._sockets_dict.items() # type: ignore + def test_connect(self): + comp1 = component_class("Comp1", output_types={"value": int})() + comp2 = component_class("Comp2", input_types={"value": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + assert pipe.connect("comp1.value", "comp2.value") is pipe + + assert comp1.__haystack_output__.value.receivers == ["comp2"] + assert comp2.__haystack_input__.value.senders == ["comp1"] + assert list(pipe.graph.edges) == [("comp1", "comp2", "value/value")] + + def test_connect_already_connected(self): + comp1 = component_class("Comp1", output_types={"value": int})() + comp2 = component_class("Comp2", input_types={"value": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.connect("comp1.value", "comp2.value") + pipe.connect("comp1.value", "comp2.value") + + assert comp1.__haystack_output__.value.receivers == ["comp2"] + assert comp2.__haystack_input__.value.senders == ["comp1"] + assert list(pipe.graph.edges) == [("comp1", "comp2", "value/value")] + + def test_connect_with_sender_component_name(self): + comp1 = component_class("Comp1", output_types={"value": int})() + comp2 = component_class("Comp2", input_types={"value": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.connect("comp1", "comp2.value") + + assert comp1.__haystack_output__.value.receivers == ["comp2"] + assert comp2.__haystack_input__.value.senders == ["comp1"] + assert list(pipe.graph.edges) == [("comp1", "comp2", "value/value")] + + def test_connect_with_receiver_component_name(self): + comp1 = component_class("Comp1", output_types={"value": int})() + comp2 = component_class("Comp2", input_types={"value": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.connect("comp1.value", "comp2") + + assert comp1.__haystack_output__.value.receivers == ["comp2"] + assert comp2.__haystack_input__.value.senders == ["comp1"] + assert list(pipe.graph.edges) == [("comp1", "comp2", "value/value")] + + def test_connect_with_sender_and_receiver_component_name(self): + comp1 = component_class("Comp1", output_types={"value": int})() + comp2 = component_class("Comp2", input_types={"value": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.connect("comp1", "comp2") + + assert comp1.__haystack_output__.value.receivers == ["comp2"] + assert comp2.__haystack_input__.value.senders == ["comp1"] + assert list(pipe.graph.edges) == [("comp1", "comp2", "value/value")] + + def test_connect_with_sender_not_in_pipeline(self): + comp2 = component_class("Comp2", input_types={"value": int})() + pipe = PipelineBase() + pipe.add_component("comp2", comp2) + with pytest.raises(ValueError): + pipe.connect("comp1.value", "comp2.value") + + def test_connect_with_receiver_not_in_pipeline(self): + comp1 = component_class("Comp1", output_types={"value": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + with pytest.raises(ValueError): + pipe.connect("comp1.value", "comp2.value") + + def test_connect_with_sender_socket_name_not_in_pipeline(self): + comp1 = component_class("Comp1", output_types={"value": int})() + comp2 = component_class("Comp2", input_types={"value": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + with pytest.raises(PipelineConnectError): + pipe.connect("comp1.non_existing", "comp2.value") + + def test_connect_with_receiver_socket_name_not_in_pipeline(self): + comp1 = component_class("Comp1", output_types={"value": int})() + comp2 = component_class("Comp2", input_types={"value": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + with pytest.raises(PipelineConnectError): + pipe.connect("comp1.value", "comp2.non_existing") + + def test_connect_with_no_matching_types_and_same_names(self): + comp1 = component_class("Comp1", output_types={"value": int})() + comp2 = component_class("Comp2", input_types={"value": str})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + with pytest.raises(PipelineConnectError): + pipe.connect("comp1", "comp2") + + def test_connect_with_multiple_sender_connections_with_same_type_and_differing_name(self): + comp1 = component_class("Comp1", output_types={"val1": int, "val2": int})() + comp2 = component_class("Comp2", input_types={"value": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + with pytest.raises(PipelineConnectError): + pipe.connect("comp1", "comp2") + + def test_connect_with_multiple_receiver_connections_with_same_type_and_differing_name(self): + comp1 = component_class("Comp1", output_types={"value": int})() + comp2 = component_class("Comp2", input_types={"val1": int, "val2": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + with pytest.raises(PipelineConnectError): + pipe.connect("comp1", "comp2") + + def test_connect_with_multiple_sender_connections_with_same_type_and_same_name(self): + comp1 = component_class("Comp1", output_types={"value": int, "other": int})() + comp2 = component_class("Comp2", input_types={"value": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.connect("comp1", "comp2") + + assert comp1.__haystack_output__.value.receivers == ["comp2"] + assert comp2.__haystack_input__.value.senders == ["comp1"] + assert list(pipe.graph.edges) == [("comp1", "comp2", "value/value")] + + def test_connect_with_multiple_receiver_connections_with_same_type_and_same_name(self): + comp1 = component_class("Comp1", output_types={"value": int})() + comp2 = component_class("Comp2", input_types={"value": int, "other": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.connect("comp1", "comp2") + + assert comp1.__haystack_output__.value.receivers == ["comp2"] + assert comp2.__haystack_input__.value.senders == ["comp1"] + assert list(pipe.graph.edges) == [("comp1", "comp2", "value/value")] + + def test_connect_multiple_outputs_to_non_variadic_input(self): + comp1 = component_class("Comp1", output_types={"value": int})() + comp2 = component_class("Comp2", output_types={"value": int})() + comp3 = component_class("Comp3", input_types={"value": int})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.add_component("comp3", comp3) + pipe.connect("comp1.value", "comp3.value") + with pytest.raises(PipelineConnectError): + pipe.connect("comp2.value", "comp3.value") + + def test_connect_multiple_outputs_to_variadic_input(self): + comp1 = component_class("Comp1", output_types={"value": int})() + comp2 = component_class("Comp2", output_types={"value": int})() + comp3 = component_class("Comp3", input_types={"value": Variadic[int]})() + pipe = PipelineBase() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.add_component("comp3", comp3) + pipe.connect("comp1.value", "comp3.value") + pipe.connect("comp2.value", "comp3.value") + + assert comp1.__haystack_output__.value.receivers == ["comp3"] + assert comp2.__haystack_output__.value.receivers == ["comp3"] + assert comp3.__haystack_input__.value.senders == ["comp1", "comp2"] + assert list(pipe.graph.edges) == [("comp1", "comp3", "value/value"), ("comp2", "comp3", "value/value")] + + def test_connect_same_component_as_sender_and_receiver(self): + """ + This pipeline consists of one component, which would be connected to itself. + Connecting a component to itself is raises PipelineConnectError. + """ + pipe = PipelineBase() + single_component = FakeComponent() + pipe.add_component("single_component", single_component) + with pytest.raises(PipelineConnectError): + pipe.connect("single_component.out", "single_component.in") + + @pytest.mark.parametrize( + "component_inputs,sockets,expected_inputs", + [ + ({"mandatory": 1}, {"mandatory": InputSocket("mandatory", int)}, {"mandatory": 1}), + ({}, {"optional": InputSocket("optional", str, default_value="test")}, {"optional": "test"}), + ( + {"mandatory": 1}, + { + "mandatory": InputSocket("mandatory", int), + "optional": InputSocket("optional", str, default_value="test"), + }, + {"mandatory": 1, "optional": "test"}, + ), + ( + {}, + {"optional_variadic": InputSocket("optional_variadic", Variadic[str], default_value="test")}, + {"optional_variadic": ["test"]}, + ), + ( + {}, + { + "optional_1": InputSocket("optional_1", int, default_value=1), + "optional_2": InputSocket("optional_2", int, default_value=2), + }, + {"optional_1": 1, "optional_2": 2}, + ), + ], + ids=["no-defaults", "only-default", "mixed-default", "variadic-default", "multiple_defaults"], + ) + def test__add_missing_defaults(self, component_inputs, sockets, expected_inputs): + filled_inputs = PipelineBase._add_missing_input_defaults(component_inputs, sockets) + + assert filled_inputs == expected_inputs + + def test__find_receivers_from(self): + sentence_builder = component_class( + "SentenceBuilder", input_types={"words": List[str]}, output_types={"text": str} + )() + document_builder = component_class( + "DocumentBuilder", input_types={"text": str}, output_types={"doc": Document} + )() + conditional_document_builder = component_class( + "ConditionalDocumentBuilder", output_types={"doc": Document, "noop": None} + )() + + document_joiner = component_class("DocumentJoiner", input_types={"docs": Variadic[Document]})() + + pipe = PipelineBase() + pipe.add_component("sentence_builder", sentence_builder) + pipe.add_component("document_builder", document_builder) + pipe.add_component("document_joiner", document_joiner) + pipe.add_component("conditional_document_builder", conditional_document_builder) + pipe.connect("sentence_builder.text", "document_builder.text") + pipe.connect("document_builder.doc", "document_joiner.docs") + pipe.connect("conditional_document_builder.doc", "document_joiner.docs") + + res = pipe._find_receivers_from("sentence_builder") + assert res == [ + ( + "document_builder", + OutputSocket(name="text", type=str, receivers=["document_builder"]), + InputSocket(name="text", type=str, default_value=_empty, senders=["sentence_builder"]), + ) + ] + + res = pipe._find_receivers_from("document_builder") + assert res == [ + ( + "document_joiner", + OutputSocket(name="doc", type=Document, receivers=["document_joiner"]), + InputSocket( + name="docs", + type=Variadic[Document], + default_value=_empty, + senders=["document_builder", "conditional_document_builder"], + ), + ) + ] + + res = pipe._find_receivers_from("document_joiner") + assert res == [] + + res = pipe._find_receivers_from("conditional_document_builder") + assert res == [ + ( + "document_joiner", + OutputSocket(name="doc", type=Document, receivers=["document_joiner"]), + InputSocket( + name="docs", + type=Variadic[Document], + default_value=_empty, + senders=["document_builder", "conditional_document_builder"], + ), + ) + ] + + @pytest.mark.parametrize( + "component, inputs, expected_priority, test_description", + [ + # Test case 1: BLOCKED - Missing mandatory input + ( + { + "instance": "mock_instance", + "visits": 0, + "input_sockets": { + "mandatory_input": InputSocket("mandatory_input", int), + "optional_input": InputSocket( + "optional_input", str, default_value="default", senders=["previous_component"] + ), + }, + }, + {"optional_input": [{"sender": "previous_component", "value": "test"}]}, + ComponentPriority.BLOCKED, + "Component should be BLOCKED when mandatory input is missing", + ), + # Test case 2: BLOCKED - No trigger after first visit + ( + { + "instance": "mock_instance", + "visits": 1, # Already visited + "input_sockets": { + "mandatory_input": InputSocket("mandatory_input", int), + "optional_input": InputSocket("optional_input", str, default_value="default"), + }, + }, + {"mandatory_input": [{"sender": None, "value": 42}]}, + ComponentPriority.BLOCKED, + "Component should be BLOCKED when there's no new trigger after first visit", + ), + # Test case 3: HIGHEST - Greedy socket ready + ( + { + "instance": "mock_instance", + "visits": 0, + "input_sockets": { + "greedy_input": InputSocket("greedy_input", GreedyVariadic[int], senders=["component1"]), + "normal_input": InputSocket("normal_input", str, senders=["component2"]), + }, + }, + { + "greedy_input": [{"sender": "component1", "value": 42}], + "normal_input": [{"sender": "component2", "value": "test"}], + }, + ComponentPriority.HIGHEST, + "Component should have HIGHEST priority when greedy socket has valid input", + ), + # Test case 4: DEFER - Greedy socket ready but optional missing + ( + { + "instance": "mock_instance", + "visits": 0, + "input_sockets": { + "greedy_input": InputSocket("greedy_input", GreedyVariadic[int], senders=["component1"]), + "optional_input": InputSocket( + "optional_input", str, senders=["component2"], default_value="test" + ), + }, + }, + {"greedy_input": [{"sender": "component1", "value": 42}]}, + ComponentPriority.DEFER, + "Component should DEFER when greedy socket has valid input but expected optional input is missing", + ), + # Test case 4: READY - All predecessors executed + ( + { + "instance": "mock_instance", + "visits": 0, + "input_sockets": { + "mandatory_input": InputSocket("mandatory_input", int, senders=["previous_component"]), + "optional_input": InputSocket( + "optional_input", str, senders=["another_component"], default_value="default" + ), + }, + }, + { + "mandatory_input": [{"sender": "previous_component", "value": 42}], + "optional_input": [{"sender": "another_component", "value": "test"}], + }, + ComponentPriority.READY, + "Component should be READY when all predecessors have executed", + ), + # Test case 5: DEFER - Lazy variadic sockets resolved and optional missing. + ( + { + "instance": "mock_instance", + "visits": 0, + "input_sockets": { + "variadic_input": InputSocket( + "variadic_input", Variadic[int], senders=["component1", "component2"] + ), + "normal_input": InputSocket("normal_input", str, senders=["component3"]), + "optional_input": InputSocket( + "optional_input", str, default_value="default", senders=["component4"] + ), + }, + }, + { + "variadic_input": [ + {"sender": "component1", "value": "test"}, + {"sender": "component2", "value": _NO_OUTPUT_PRODUCED}, + ], + "normal_input": [{"sender": "component3", "value": "test"}], + }, + ComponentPriority.DEFER, + "Component should DEFER when all lazy variadic sockets are resolved", + ), + # Test case 6: DEFER_LAST - Incomplete variadic inputs + ( + { + "instance": "mock_instance", + "visits": 0, + "input_sockets": { + "variadic_input": InputSocket( + "variadic_input", Variadic[int], senders=["component1", "component2"] + ), + "normal_input": InputSocket("normal_input", str), + }, + }, + { + "variadic_input": [{"sender": "component1", "value": 42}], # Missing component2 + "normal_input": [{"sender": "component3", "value": "test"}], }, - "haystack.component.output_spec": { - key: { - "type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)), - "receivers": value.receivers, + ComponentPriority.DEFER_LAST, + "Component should be DEFER_LAST when not all variadic senders have produced output", + ), + # Test case 7: READY - No input sockets, first visit + ( + { + "instance": "mock_instance", + "visits": 0, + "input_sockets": {"optional_input": InputSocket("optional_input", str, default_value="default")}, + }, + {}, # no inputs + ComponentPriority.READY, + "Component should be READY on first visit when it has no input sockets", + ), + # Test case 8: BLOCKED - No connected input sockets, subsequent visit + ( + { + "instance": "mock_instance", + "visits": 1, + "input_sockets": {"optional_input": InputSocket("optional_input", str, default_value="default")}, + }, + {}, # no inputs + ComponentPriority.BLOCKED, + "Component should be BLOCKED on subsequent visits when it has no input sockets", + ), + ], + ids=lambda p: p.name if isinstance(p, ComponentPriority) else str(p), + ) + def test__calculate_priority(self, component, inputs, expected_priority, test_description): + """Test priority calculation for various component and input combinations.""" + # For variadic inputs, set up senders if needed + for socket in component["input_sockets"].values(): + if socket.is_variadic and not hasattr(socket, "senders"): + socket.senders = ["component1", "component2"] + + assert PipelineBase._calculate_priority(component, inputs) == expected_priority + + @pytest.mark.parametrize( + "pipeline_inputs,expected_output", + [ + # Test case 1: Empty input + ({}, {}), + # Test case 2: Single component, multiple inputs + ( + {"component1": {"input1": 42, "input2": "test", "input3": True}}, + { + "component1": { + "input1": [{"sender": None, "value": 42}], + "input2": [{"sender": None, "value": "test"}], + "input3": [{"sender": None, "value": True}], } - for key, value in instance.__haystack_output__._sockets_dict.items() # type: ignore }, - }, - parent_span=parent_span, + ), + # Test case 3: Multiple components + ( + { + "component1": {"input1": 42, "input2": "test"}, + "component2": {"input3": [1, 2, 3], "input4": {"key": "value"}}, + }, + { + "component1": { + "input1": [{"sender": None, "value": 42}], + "input2": [{"sender": None, "value": "test"}], + }, + "component2": { + "input3": [{"sender": None, "value": [1, 2, 3]}], + "input4": [{"sender": None, "value": {"key": "value"}}], + }, + }, + ), + ], + ids=["empty_input", "single_component_multiple_inputs", "multiple_components"], + ) + def test__convert_to_internal_format(self, pipeline_inputs, expected_output): + """Test conversion of legacy pipeline inputs to internal format.""" + result = PipelineBase._convert_to_internal_format(pipeline_inputs) + assert result == expected_output + + @pytest.mark.parametrize( + "socket_type,existing_inputs,expected_count", + [ + ("regular", None, 1), # Regular socket should overwrite + ("regular", [{"sender": "other", "value": 24}], 1), # Should still overwrite + ("lazy_variadic", None, 1), # First input to lazy variadic + ("lazy_variadic", [{"sender": "other", "value": 24}], 2), # Should append + ], + ids=["regular-new", "regular-existing", "variadic-new", "variadic-existing"], + ) + def test__write_component_outputs_different_sockets( + self, + socket_type, + existing_inputs, + expected_count, + regular_output_socket, + regular_input_socket, + lazy_variadic_input_socket, + ): + """Test writing to different socket types with various existing input states""" + receiver_socket = lazy_variadic_input_socket if socket_type == "lazy_variadic" else regular_input_socket + socket_name = receiver_socket.name + receivers = [("receiver1", regular_output_socket, receiver_socket)] + + inputs = {} + if existing_inputs: + inputs = {"receiver1": {socket_name: existing_inputs}} + + component_outputs = {"output1": 42} + + PipelineBase._write_component_outputs( + component_name="sender1", + component_outputs=component_outputs, + inputs=inputs, + receivers=receivers, + include_outputs_from=[], ) - def _validate_component_input(self, component_name: str, component_inputs: Dict[str, Any], data: Dict[str, Any]) -> None: - """ - Validates input data for a specific component. - - :param component_name: Name of the component. - :param component_inputs: Inputs provided for the component. - :param data: All pipeline input data. - :raises ValueError: If inputs are invalid. - """ - if component_name not in self.graph.nodes: - raise ValueError(f"Component '{component_name}' not found in the pipeline. Available components: {list(self.graph.nodes.keys())}") - instance = self.graph.nodes[component_name]["instance"] - - # Validate that all mandatory inputs are provided either directly or by senders - for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): - if socket.is_mandatory and not socket.senders and socket_name not in component_inputs: - raise ValueError(f"Missing mandatory input '{socket_name}' for component '{component_name}'.") - - # Validate that provided inputs exist in the component's input sockets - for input_name in component_inputs.keys(): - if input_name not in instance.__haystack_input__._sockets_dict: - raise ValueError(f"Unexpected input '{input_name}' for component '{component_name}'. Available inputs: {list(instance.__haystack_input__._sockets_dict.keys())}") - - # Validate that inputs are not multiply defined (already sent by another component and also provided directly) - # unless the socket is variadic - for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): - if socket.senders and socket_name in component_inputs and not socket.is_variadic: - raise ValueError( - f"Input '{socket_name}' for component '{component_name}' is already provided by component " - f"'{socket.senders[0]}'. Do not provide it directly." - ) - - - def _validate_input(self, data: Dict[str, Any]) -> None: - """ - Validates pipeline input data. - - Validates that data: - * Each Component name actually exists in the Pipeline - * Each Component is not missing any input - * Each Component has only one input per input socket, if not variadic - * Each Component doesn't receive inputs that are already sent by another Component - - :param data: - A dictionary of inputs for the pipeline's components. Each key is a component name. - - :raises ValueError: - If inputs are invalid according to the above. - """ - for component_name, component_inputs in data.items(): - self._validate_component_input(component_name, component_inputs, data) - - # Additionally, check for components that might be missing inputs, - # even if they were not explicitly mentioned in the `data` dictionary. - # This covers cases where a component has mandatory inputs but receives no data. - for component_name_in_graph in self.graph.nodes: - if component_name_in_graph not in data: - # This component was not in the input data dictionary, check if it has mandatory inputs without senders - instance = self.graph.nodes[component_name_in_graph]["instance"] - for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): - if socket.is_mandatory and not socket.senders: - raise ValueError(f"Missing mandatory input '{socket_name}' for component '{component_name_in_graph}' which was not provided in the input data.") - - - def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: - """ - Prepares input data for pipeline components. - - Organizes input data for pipeline components and identifies any inputs that are not matched to any - component's input slots. Deep-copies data items to avoid sharing mutables across multiple components. - - This method processes a flat dictionary of input data, where each key-value pair represents an input name - and its corresponding value. It distributes these inputs to the appropriate pipeline components based on - their input requirements. Inputs that don't match any component's input slots are classified as unresolved. - - :param data: - A dictionary potentially having input names as keys and input values as values. - - :returns: - A dictionary mapping component names to their respective matched inputs. - """ - # check whether the data is a nested dictionary of component inputs where each key is a component name - # and each value is a dictionary of input parameters for that component - is_nested_component_input = all(isinstance(value, dict) for value in data.values()) - if not is_nested_component_input: - # flat input, a dict where keys are input names and values are the corresponding values - # we need to convert it to a nested dictionary of component inputs and then run the pipeline - # just like in the previous case - pipeline_input_data: Dict[str, Dict[str, Any]] = defaultdict(dict) - unresolved_kwargs = {} - - # Retrieve the input slots for each component in the pipeline - available_inputs: Dict[str, Dict[str, Any]] = self.inputs() - - # Go through all provided to distribute them to the appropriate component inputs - for input_name, input_value in data.items(): - resolved_at_least_once = False - - # Check each component to see if it has a slot for the current kwarg - for component_name, component_inputs in available_inputs.items(): - if input_name in component_inputs: - # If a match is found, add the kwarg to the component's input data - pipeline_input_data[component_name][input_name] = input_value - resolved_at_least_once = True - - if not resolved_at_least_once: - unresolved_kwargs[input_name] = input_value - - if unresolved_kwargs: - logger.warning( - "Inputs {input_keys} were not matched to any component inputs, please check your run parameters.", - input_keys=list(unresolved_kwargs.keys()), - ) - - data = dict(pipeline_input_data) - - # deepcopying the inputs prevents the Pipeline run logic from being altered unexpectedly - # when the same input reference is passed to multiple components. - for component_name, component_inputs in data.items(): - data[component_name] = {k: _deepcopy_with_exceptions(v) for k, v in component_inputs.items()} - - return data - - @classmethod - def from_template( - cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None - ) -> "PipelineBase": - """ - Create a Pipeline from a predefined template. See `PredefinedPipeline` for available options. - - :param predefined_pipeline: - The predefined pipeline to use. - :param template_params: - An optional dictionary of parameters to use when rendering the pipeline template. - :returns: - An instance of `Pipeline`. - """ - tpl = PipelineTemplate.from_predefined(predefined_pipeline) - # If tpl.render() fails, we let bubble up the original error - rendered = tpl.render(template_params) - - # If there was a problem with the rendered version of the - # template, we add it to the error stack for debugging - try: - return cls.loads(rendered) - except Exception as e: - msg = f"Error unmarshalling pipeline: {e}\n" - msg += f"Source:\n{rendered}" - raise PipelineUnmarshalError(msg) - - def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]: - """ - Utility function to find all Components that receive input from `component_name`. - - :param component_name: - Name of the sender Component - - :returns: - List of tuples containing name of the receiver Component and sender OutputSocket - and receiver InputSocket instances - """ - res = [] - for _, receiver_name, connection in self.graph.edges(nbunch=component_name, data=True): - sender_socket: OutputSocket = connection["from_socket"] - receiver_socket: InputSocket = connection["to_socket"] - res.append((receiver_name, sender_socket, receiver_socket)) - return res - - @staticmethod - def _convert_to_internal_format(pipeline_inputs: Dict[str, Any]) -> Dict[str, Dict[str, List]]: - """ - Converts the inputs to the pipeline to the format that is needed for the internal `Pipeline.run` logic. - - Example Input: - {'prompt_builder': {'question': 'Who lives in Paris?'}, 'retriever': {'query': 'Who lives in Paris?'}} - Example Output: - {'prompt_builder': {'question': [{'sender': None, 'value': 'Who lives in Paris?'}]}, - 'retriever': {'query': [{'sender': None, 'value': 'Who lives in Paris?'}]}} - - :param pipeline_inputs: Inputs to the pipeline. - :returns: Converted inputs that can be used by the internal `Pipeline.run` logic. - """ - inputs: Dict[str, Dict[str, List[Dict[str, Any]]]] = {} - for component_name, socket_dict in pipeline_inputs.items(): - inputs[component_name] = {} - for socket_name, value in socket_dict.items(): - inputs[component_name][socket_name] = [{"sender": None, "value": value}] - - return inputs - - @staticmethod - def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict) -> Dict[str, Any]: - """ - Extracts the inputs needed to run for the component and removes them from the global inputs state. - - :param component_name: The name of a component. - :param component: Component with component metadata. - :param inputs: Global inputs state. - :returns: The inputs for the component. - """ - component_inputs = inputs.get(component_name, {}) - consumed_inputs = {} - greedy_inputs_to_remove = set() - for socket_name, socket in component["input_sockets"].items(): - socket_inputs = component_inputs.get(socket_name, []) - socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED] - if socket_inputs: - if not socket.is_variadic: - # We only care about the first input provided to the socket. - consumed_inputs[socket_name] = socket_inputs[0] - elif socket.is_greedy: - # We need to keep track of greedy inputs because we always remove them, even if they come from - # outside the pipeline. Otherwise, a greedy input from the user would trigger a pipeline to run - # indefinitely. - greedy_inputs_to_remove.add(socket_name) - consumed_inputs[socket_name] = [socket_inputs[0]] - elif is_socket_lazy_variadic(socket): - # We use all inputs provided to the socket on a lazy variadic socket. - consumed_inputs[socket_name] = socket_inputs - - # We prune all inputs except for those that were provided from outside the pipeline (e.g. user inputs). - pruned_inputs = { - socket_name: [ - sock for sock in socket if sock["sender"] is None and not socket_name in greedy_inputs_to_remove - ] - for socket_name, socket in component_inputs.items() - } - pruned_inputs = {socket_name: socket for socket_name, socket in pruned_inputs.items() if len(socket) > 0} - - inputs[component_name] = pruned_inputs - - return consumed_inputs - - def _fill_queue( - self, component_names: List[str], inputs: Dict[str, Any], component_visits: Dict[str, int] - ) -> FIFOPriorityQueue: - """ - Calculates the execution priority for each component and inserts it into the priority queue. - - :param component_names: Names of the components to put into the queue. - :param inputs: Inputs to the components. - :param component_visits: Current state of component visits. - :returns: A prioritized queue of component names. - """ - priority_queue = FIFOPriorityQueue() - for component_name in component_names: - component = self._get_component_with_graph_metadata_and_visits( - component_name, component_visits[component_name] - ) - priority = self._calculate_priority(component, inputs.get(component_name, {})) - priority_queue.push(component_name, priority) - - return priority_queue + assert len(inputs["receiver1"][socket_name]) == expected_count + assert {"sender": "sender1", "value": 42} in inputs["receiver1"][socket_name] + + @pytest.mark.parametrize( + "component_outputs,include_outputs,expected_pruned", + [ + ({"output1": 42, "output2": 24}, [], {"output2": 24}), # Prune consumed outputs only + ({"output1": 42, "output2": 24}, ["sender1"], {"output1": 42, "output2": 24}), # Keep all outputs + ({}, [], {}), # No outputs case + ], + ids=["prune-consumed", "keep-all", "no-outputs"], + ) + def test__write_component_outputs_output_pruning( + self, component_outputs, include_outputs, expected_pruned, regular_output_socket, regular_input_socket + ): + """Test output pruning behavior under different scenarios""" + receivers = [("receiver1", regular_output_socket, regular_input_socket)] + + pruned_outputs = PipelineBase._write_component_outputs( + component_name="sender1", + component_outputs=component_outputs, + inputs={}, + receivers=receivers, + include_outputs_from=include_outputs, + ) - @staticmethod - def _calculate_priority(component: Dict, inputs: Dict) -> ComponentPriority: - """ - Calculates the execution priority for a component depending on the component's inputs. + assert pruned_outputs == expected_pruned - :param component: Component metadata and component instance. - :param inputs: Inputs to the component. - :returns: Priority value for the component. - """ - if not can_component_run(component, inputs): - return ComponentPriority.BLOCKED - elif is_any_greedy_socket_ready(component, inputs) and are_all_sockets_ready(component, inputs): - return ComponentPriority.HIGHEST - elif all_predecessors_executed(component, inputs): - return ComponentPriority.READY - elif are_all_lazy_variadic_sockets_resolved(component, inputs): - return ComponentPriority.DEFER - else: - return ComponentPriority.DEFER_LAST - - def _get_component_with_graph_metadata_and_visits(self, component_name: str, visits: int) -> Dict[str, Any]: - """ - Returns the component instance alongside input/output-socket metadata from the graph and adds current visits. + @pytest.mark.parametrize( + "output_value", + [42, None, _NO_OUTPUT_PRODUCED, "string_value", 3.14], + ids=["int", "none", "no-output", "string", "float"], + ) + def test__write_component_outputs_different_output_values( + self, output_value, regular_output_socket, regular_input_socket + ): + """Test handling of different output values""" + receivers = [("receiver1", regular_output_socket, regular_input_socket)] + component_outputs = {"output1": output_value} + inputs = {} + PipelineBase._write_component_outputs( + component_name="sender1", + component_outputs=component_outputs, + inputs=inputs, + receivers=receivers, + include_outputs_from=[], + ) - We can't store visits in the pipeline graph because this would prevent reentrance / thread-safe execution. + assert inputs["receiver1"]["input1"] == [{"sender": "sender1", "value": output_value}] + + def test__write_component_outputs_dont_overwrite_with_no_output(self, regular_output_socket, regular_input_socket): + """Test that existing inputs are not overwritten with _NO_OUTPUT_PRODUCED""" + receivers = [("receiver1", regular_output_socket, regular_input_socket)] + component_outputs = {"output1": _NO_OUTPUT_PRODUCED} + inputs = {"receiver1": {"input1": [{"sender": "sender1", "value": "keep"}]}} + PipelineBase._write_component_outputs( + component_name="sender1", + component_outputs=component_outputs, + inputs=inputs, + receivers=receivers, + include_outputs_from=[], + ) - :param component_name: The name of the component. - :param visits: Number of visits for the component. - :returns: Dict including component instance, input/output-sockets and visits. - """ - comp_dict = self.graph.nodes[component_name] - comp_dict = {**comp_dict, "visits": visits} - return comp_dict + assert inputs["receiver1"]["input1"] == [{"sender": "sender1", "value": "keep"}] - def _get_next_runnable_component( - self, priority_queue: FIFOPriorityQueue, component_visits: Dict[str, int] - ) -> Union[Tuple[ComponentPriority, str, Dict[str, Any]], None]: - """ - Returns the next runnable component alongside its metadata from the priority queue. + @pytest.mark.parametrize("receivers_count", [1, 2, 3], ids=["single-receiver", "two-receivers", "three-receivers"]) + def test__write_component_outputs_multiple_receivers( + self, receivers_count, regular_output_socket, regular_input_socket + ): + """Test writing to multiple receivers""" + receivers = [(f"receiver{i}", regular_output_socket, regular_input_socket) for i in range(receivers_count)] + component_outputs = {"output1": 42} + + inputs = {} + PipelineBase._write_component_outputs( + component_name="sender1", + component_outputs=component_outputs, + inputs=inputs, + receivers=receivers, + include_outputs_from=[], + ) - :param priority_queue: Priority queue of component names. - :param component_visits: Current state of component visits. - :returns: The next runnable component, the component name, and its priority - or None if no component in the queue can run. - :raises: PipelineMaxComponentRuns if the next runnable component has exceeded the maximum number of runs. - """ - priority_and_component_name: Union[Tuple[ComponentPriority, str], None] = ( - None if (item := priority_queue.get()) is None else (ComponentPriority(item[0]), str(item[1])) + for i in range(receivers_count): + receiver_name = f"receiver{i}" + assert receiver_name in inputs + assert inputs[receiver_name]["input1"] == [{"sender": "sender1", "value": 42}] + + def test__get_next_runnable_component_empty(self): + """Test with empty queue returns None""" + queue = FIFOPriorityQueue() + pipeline = PipelineBase() + result = pipeline._get_next_runnable_component(queue, component_visits={}) + assert result is None + + def test__get_next_runnable_component_blocked(self): + """Test component with BLOCKED priority returns None""" + pipeline = PipelineBase() + queue = FIFOPriorityQueue() + queue.push("blocked_component", ComponentPriority.BLOCKED) + result = pipeline._get_next_runnable_component(queue, component_visits={"blocked_component": 0}) + assert result is None + + @patch("haystack.core.pipeline.base.PipelineBase._get_component_with_graph_metadata_and_visits") + def test__get_next_runnable_component_max_visits(self, mock_get_component_with_graph_metadata_and_visits): + """Test component exceeding max visits raises exception""" + pipeline = PipelineBase(max_runs_per_component=2) + queue = FIFOPriorityQueue() + queue.push("ready_component", ComponentPriority.READY) + mock_get_component_with_graph_metadata_and_visits.return_value = {"instance": "test", "visits": 3} + + with pytest.raises(PipelineMaxComponentRuns) as exc_info: + pipeline._get_next_runnable_component(queue, component_visits={"ready_component": 3}) + + assert "Maximum run count 2 reached for component 'ready_component'" in str(exc_info.value) + + @patch("haystack.core.pipeline.base.PipelineBase._get_component_with_graph_metadata_and_visits") + def test__get_next_runnable_component_ready(self, mock_get_component_with_graph_metadata_and_visits): + """Test component that is READY""" + pipeline = PipelineBase() + queue = FIFOPriorityQueue() + queue.push("ready_component", ComponentPriority.READY) + mock_get_component_with_graph_metadata_and_visits.return_value = {"instance": "test", "visits": 1} + + priority, component_name, component = pipeline._get_next_runnable_component( + queue, component_visits={"ready_component": 1} ) - if priority_and_component_name is not None and priority_and_component_name[0] != ComponentPriority.BLOCKED: - priority, component_name = priority_and_component_name - component = self._get_component_with_graph_metadata_and_visits( - component_name, component_visits[component_name] + assert priority == ComponentPriority.READY + assert component_name == "ready_component" + assert component == {"instance": "test", "visits": 1} + + @pytest.mark.parametrize( + "queue_setup,expected_stale", + [ + # Empty queue case + (None, True), + # READY priority case + ((ComponentPriority.READY, "component1"), False), + # DEFER priority case + ((ComponentPriority.DEFER, "component1"), True), + ], + ids=["empty-queue", "ready-component", "deferred-component"], + ) + def test__is_queue_stale(self, queue_setup, expected_stale): + queue = FIFOPriorityQueue() + if queue_setup: + priority, component_name = queue_setup + queue.push(component_name, priority) + + result = PipelineBase._is_queue_stale(queue) + assert result == expected_stale + + @patch("haystack.core.pipeline.base.PipelineBase._calculate_priority") + @patch("haystack.core.pipeline.base.PipelineBase._get_component_with_graph_metadata_and_visits") + def test_fill_queue(self, mock_get_metadata, mock_calc_priority): + pipeline = PipelineBase() + component_names = ["comp1", "comp2"] + inputs = {"comp1": {"input1": "value1"}, "comp2": {"input2": "value2"}} + + mock_get_metadata.side_effect = lambda name, _: {"component": f"mock_{name}"} + mock_calc_priority.side_effect = [1, 2] # Different priorities for testing + + queue = pipeline._fill_queue(component_names, inputs, component_visits={"comp1": 1, "comp2": 1}) + + assert mock_get_metadata.call_count == 2 + assert mock_calc_priority.call_count == 2 + + # Verify correct calls for first component + mock_get_metadata.assert_any_call("comp1", 1) + mock_calc_priority.assert_any_call({"component": "mock_comp1"}, {"input1": "value1"}) + + # Verify correct calls for second component + mock_get_metadata.assert_any_call("comp2", 1) + mock_calc_priority.assert_any_call({"component": "mock_comp2"}, {"input2": "value2"}) + + assert queue.pop() == (1, "comp1") + assert queue.pop() == (2, "comp2") + + @pytest.mark.parametrize( + "input_sockets,component_inputs,expected_consumed,expected_remaining", + [ + # Regular socket test + ( + {"input1": InputSocket("input1", int)}, + {"input1": [{"sender": "comp1", "value": 42}, {"sender": "comp2", "value": 24}]}, + {"input1": 42}, # Should take first valid input + {}, # All pipeline inputs should be removed + ), + # Regular socket with user input + ( + {"input1": InputSocket("input1", int)}, + { + "input1": [ + {"sender": "comp1", "value": 42}, + {"sender": None, "value": 24}, # User input + ] + }, + {"input1": 42}, + {"input1": [{"sender": None, "value": 24}]}, # User input should remain + ), + # Greedy variadic socket + ( + {"greedy": InputSocket("greedy", GreedyVariadic[int])}, + { + "greedy": [ + {"sender": "comp1", "value": 42}, + {"sender": None, "value": 24}, # User input + {"sender": "comp2", "value": 33}, + ] + }, + {"greedy": [42]}, # Takes first valid input + {}, # All inputs removed for greedy sockets + ), + # Lazy variadic socket + ( + {"lazy": InputSocket("lazy", Variadic[int])}, + { + "lazy": [ + {"sender": "comp1", "value": 42}, + {"sender": "comp2", "value": 24}, + {"sender": None, "value": 33}, # User input + ] + }, + {"lazy": [42, 24, 33]}, # Takes all valid inputs + {"lazy": [{"sender": None, "value": 33}]}, # User input remains + ), + # Mixed socket types + ( + { + "regular": InputSocket("regular", int), + "greedy": InputSocket("greedy", GreedyVariadic[int]), + "lazy": InputSocket("lazy", Variadic[int]), + }, + { + "regular": [{"sender": "comp1", "value": 42}, {"sender": None, "value": 24}], + "greedy": [{"sender": "comp2", "value": 33}, {"sender": None, "value": 15}], + "lazy": [{"sender": "comp3", "value": 55}, {"sender": "comp4", "value": 66}], + }, + {"regular": 42, "greedy": [33], "lazy": [55, 66]}, + {"regular": [{"sender": None, "value": 24}]}, # Only non-greedy user input remains + ), + # Filtering _NO_OUTPUT_PRODUCED + ( + {"input1": InputSocket("input1", int)}, + { + "input1": [ + {"sender": "comp1", "value": _NO_OUTPUT_PRODUCED}, + {"sender": "comp2", "value": 42}, + {"sender": "comp2", "value": _NO_OUTPUT_PRODUCED}, + ] + }, + {"input1": 42}, # Should skip _NO_OUTPUT_PRODUCED values + {}, # All inputs consumed + ), + ], + ids=[ + "regular-socket", + "regular-with-user-input", + "greedy-variadic", + "lazy-variadic", + "mixed-sockets", + "no-output-filtering", + ], + ) + def test__consume_component_inputs(self, input_sockets, component_inputs, expected_consumed, expected_remaining): + # Setup + component = {"input_sockets": input_sockets} + inputs = {"test_component": component_inputs} + + # Run + consumed = PipelineBase._consume_component_inputs("test_component", component, inputs) + + # Verify + assert consumed == expected_consumed + assert inputs["test_component"] == expected_remaining + + def test__consume_component_inputs_with_df(self, regular_input_socket): + component = {"input_sockets": {"input1": regular_input_socket}} + inputs = {"test_component": {"input1": [{"sender": "sender1", "value": DataFrame({"a": [1, 2], "b": [1, 2]})}]}} + + consumed = PipelineBase._consume_component_inputs("test_component", component, inputs) + + assert consumed["input1"].equals(DataFrame({"a": [1, 2], "b": [1, 2]})) + + @patch("haystack.core.pipeline.draw.requests") + def test_pipeline_draw_called_with_positional_args_triggers_a_warning(self, mock_requests): + """ + Test that calling the pipeline draw method with positional arguments raises a warning. + """ + from pathlib import Path + import warnings + + pipeline = PipelineBase() + mock_response = mock_requests.get.return_value + mock_response.status_code = 200 + mock_response.content = b"image_data" + out_file = Path("original_pipeline.png") + with warnings.catch_warnings(record=True) as w: + pipeline.draw(out_file, server_url="http://localhost:3000") + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert ( + "Warning: In an upcoming release, this method will require keyword arguments for all parameters" + in str(w[0].message) ) - if component["visits"] > self._max_runs_per_component: - msg = f"Maximum run count {self._max_runs_per_component} reached for component '{component_name}'" - raise PipelineMaxComponentRuns(msg) - return priority, component_name, component - - return None - - @staticmethod - def _add_missing_input_defaults( - component_inputs: Dict[str, Any], component_input_sockets: Dict[str, InputSocket] - ) -> Dict[str, Any]: - """ - Updates the inputs with the default values for the inputs that are missing - - :param component_inputs: Inputs for the component. - :param component_input_sockets: Input sockets of the component. + @patch("haystack.core.pipeline.draw.requests") + @patch("haystack.core.pipeline.base.is_in_jupyter") + def test_pipeline_show_called_with_positional_args_triggers_a_warning(self, mock_is_in_jupyter, mock_requests): """ - for name, socket in component_input_sockets.items(): - if not socket.is_mandatory and name not in component_inputs: - if socket.is_variadic: - component_inputs[name] = [socket.default_value] - else: - component_inputs[name] = socket.default_value - - return component_inputs - - def _tiebreak_waiting_components( - self, - component_name: str, - priority: ComponentPriority, - priority_queue: FIFOPriorityQueue, - topological_sort: Union[Dict[str, int], None], - ) -> Tuple[str, Union[Dict[str, int], None]]: + Test that calling the pipeline show method with positional arguments raises a warning. """ - Decides which component to run when multiple components are waiting for inputs with the same priority. + import warnings - :param component_name: The name of the component. - :param priority: Priority of the component. - :param priority_queue: Priority queue of component names. - :param topological_sort: Cached topological sort of all components in the pipeline. - """ - components_with_same_priority = [component_name] - - while len(priority_queue) > 0: - next_priority, next_component_name = priority_queue.peek() - if next_priority == priority: - priority_queue.pop() # actually remove the component - components_with_same_priority.append(next_component_name) - else: - break - - if len(components_with_same_priority) > 1: - if topological_sort is None: - if networkx.is_directed_acyclic_graph(self.graph): - topological_sort = networkx.lexicographical_topological_sort(self.graph) - topological_sort = {node: idx for idx, node in enumerate(topological_sort)} - else: - condensed = networkx.condensation(self.graph) - condensed_sorted = {node: idx for idx, node in enumerate(networkx.topological_sort(condensed))} - topological_sort = { - component_name: condensed_sorted[node] - for component_name, node in condensed.graph["mapping"].items() - } + pipeline = PipelineBase() + mock_response = mock_requests.get.return_value + mock_response.status_code = 200 + mock_response.content = b"image_data" + mock_is_in_jupyter.return_value = True - components_with_same_priority = sorted( - components_with_same_priority, key=lambda comp_name: (topological_sort[comp_name], comp_name.lower()) + with warnings.catch_warnings(record=True) as w: + pipeline.show("http://localhost:3000") + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert ( + "Warning: In an upcoming release, this method will require keyword arguments for all parameters" + in str(w[0].message) ) - component_name = components_with_same_priority[0] - - return component_name, topological_sort - - @staticmethod - def _write_component_outputs( - component_name: str, - component_outputs: Dict[str, Any], - inputs: Dict[str, Any], - receivers: List[Tuple], - include_outputs_from: Set[str], - ) -> Dict[str, Any]: - """ - Distributes the outputs of a component to the input sockets that it is connected to. - - :param component_name: The name of the component. - :param component_outputs: The outputs of the component. - :param inputs: The current global input state. - :param receivers: List of components that receive inputs from the component. - :param include_outputs_from: List of component names that should always return an output from the pipeline. - """ - for receiver_name, sender_socket, receiver_socket in receivers: - # We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate - # that the sender did not produce an output for this socket. - # This allows us to track if a predecessor already ran but did not produce an output. - value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED) - - if receiver_name not in inputs: - inputs[receiver_name] = {} - - if is_socket_lazy_variadic(receiver_socket): - # If the receiver socket is lazy variadic, we append the new input. - # Lazy variadic sockets can collect multiple inputs. - _write_to_lazy_variadic_socket( - inputs=inputs, - receiver_name=receiver_name, - receiver_socket_name=receiver_socket.name, - component_name=component_name, - value=value, - ) - else: - # If the receiver socket is not lazy variadic, it is greedy variadic or non-variadic. - # We overwrite with the new input if it's not _NO_OUTPUT_PRODUCED or if the current value is None. - _write_to_standard_socket( - inputs=inputs, - receiver_name=receiver_name, - receiver_socket_name=receiver_socket.name, - component_name=component_name, - value=value, - ) - - # If we want to include all outputs from this actor in the final outputs, we don't need to prune any consumed - # outputs - if component_name in include_outputs_from: - return component_outputs - - # We prune outputs that were consumed by any receiving sockets. - # All remaining outputs will be added to the final outputs of the pipeline. - consumed_outputs = {sender_socket.name for _, sender_socket, __ in receivers} - pruned_outputs = {key: value for key, value in component_outputs.items() if key not in consumed_outputs} - - return pruned_outputs - - @staticmethod - def _is_queue_stale(priority_queue: FIFOPriorityQueue) -> bool: - """ - Checks if the priority queue needs to be recomputed because the priorities might have changed. - - :param priority_queue: Priority queue of component names. - """ - return len(priority_queue) == 0 or priority_queue.peek()[0] > ComponentPriority.READY - - @staticmethod - def validate_pipeline(priority_queue: FIFOPriorityQueue) -> None: - """ - Validate the pipeline to check if it is blocked or has no valid entry point. - - :param priority_queue: Priority queue of component names. - :raises PipelineRuntimeError: - If the pipeline is blocked or has no valid entry point. - """ - if len(priority_queue) == 0: - return - - candidate = priority_queue.peek() - if candidate is not None and candidate[0] == ComponentPriority.BLOCKED: - raise PipelineComponentsBlockedError() - - def _find_super_components(self) -> list[tuple[str, Component]]: - """ - Find all SuperComponents in the pipeline. - - :returns: - List of tuples containing (component_name, component_instance) representing a SuperComponent. - """ - - super_components = [] - for comp_name, comp in self.walk(): - # a SuperComponent has a "pipeline" attribute which itself a Pipeline instance - # we don't test against SuperComponent because doing so always lead to circular imports - if hasattr(comp, "pipeline") and isinstance(comp.pipeline, self.__class__): - super_components.append((comp_name, comp)) - return super_components - - def _merge_super_component_pipelines(self) -> Tuple["networkx.MultiDiGraph", Dict[str, str]]: - """ - Merge the internal pipelines of SuperComponents into the main pipeline graph structure. - - This creates a new networkx.MultiDiGraph containing all the components from both the main pipeline - and all the internal SuperComponents' pipelines. The SuperComponents are removed and their internal - components are connected to corresponding input and output sockets of the main pipeline. - - :returns: - A tuple containing: - - A networkx.MultiDiGraph with the expanded structure of the main pipeline and all it's SuperComponents - - A dictionary mapping component names to boolean indicating that this component was part of a - SuperComponent - - A dictionary mapping component names to their SuperComponent name - """ - merged_graph = self.graph.copy() - super_component_mapping: Dict[str, str] = {} - - for super_name, super_component in self._find_super_components(): - internal_pipeline = super_component.pipeline # type: ignore - internal_graph = internal_pipeline.graph.copy() - - # Mark all components in the internal pipeline as being part of a SuperComponent - for node in internal_graph.nodes(): - super_component_mapping[node] = super_name - - # edges connected to the super component - incoming_edges = list(merged_graph.in_edges(super_name, data=True)) - outgoing_edges = list(merged_graph.out_edges(super_name, data=True)) - - # merge the SuperComponent graph into the main graph and remove the super component node - # since its components are now part of the main graph - merged_graph = networkx.compose(merged_graph, internal_graph) - merged_graph.remove_node(super_name) - - # get the entry and exit points of the SuperComponent internal pipeline - entry_points = [n for n in internal_graph.nodes() if internal_graph.in_degree(n) == 0] - exit_points = [n for n in internal_graph.nodes() if internal_graph.out_degree(n) == 0] - - # connect the incoming edges to entry points - for sender, _, edge_data in incoming_edges: - sender_socket = edge_data["from_socket"] - for entry_point in entry_points: - # find a matching input socket in the entry point - entry_point_sockets = internal_graph.nodes[entry_point]["input_sockets"] - for socket_name, socket in entry_point_sockets.items(): - if _types_are_compatible(sender_socket.type, socket.type, self._connection_type_validation): - merged_graph.add_edge( - sender, - entry_point, - key=f"{sender_socket.name}/{socket_name}", - conn_type=_type_name(sender_socket.type), - from_socket=sender_socket, - to_socket=socket, - mandatory=socket.is_mandatory, - ) - - # connect outgoing edges from exit points - for _, receiver, edge_data in outgoing_edges: - receiver_socket = edge_data["to_socket"] - for exit_point in exit_points: - # find a matching output socket in the exit point - exit_point_sockets = internal_graph.nodes[exit_point]["output_sockets"] - for socket_name, socket in exit_point_sockets.items(): - if _types_are_compatible(socket.type, receiver_socket.type, self._connection_type_validation): - merged_graph.add_edge( - exit_point, - receiver, - key=f"{socket_name}/{receiver_socket.name}", - conn_type=_type_name(socket.type), - from_socket=socket, - to_socket=receiver_socket, - mandatory=receiver_socket.is_mandatory, - ) - - return merged_graph, super_component_mapping - - -def _connections_status( - sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket] -) -> str: - """ - Lists the status of the sockets, for error messages. - """ - sender_sockets_entries = [] - for sender_socket in sender_sockets: - sender_sockets_entries.append(f" - {sender_socket.name}: {_type_name(sender_socket.type)}") - sender_sockets_list = "\n".join(sender_sockets_entries) - - receiver_sockets_entries = [] - for receiver_socket in receiver_sockets: - if receiver_socket.senders: - sender_status = f"sent by {','.join(receiver_socket.senders)}" - else: - sender_status = "available" - receiver_sockets_entries.append( - f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})" - ) - receiver_sockets_list = "\n".join(receiver_sockets_entries) - - return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}" - - -# Utility functions for writing to sockets - - -def _write_to_lazy_variadic_socket( - inputs: Dict[str, Any], receiver_name: str, receiver_socket_name: str, component_name: str, value: Any -) -> None: - """ - Write to a lazy variadic socket. - - Mutates inputs in place. - """ - if not inputs[receiver_name].get(receiver_socket_name): - inputs[receiver_name][receiver_socket_name] = [] - - inputs[receiver_name][receiver_socket_name].append({"sender": component_name, "value": value}) - - -def _write_to_standard_socket( - inputs: Dict[str, Any], receiver_name: str, receiver_socket_name: str, component_name: str, value: Any -) -> None: - """ - Write to a greedy variadic or non-variadic socket. - - Mutates inputs in place. - """ - current_value = inputs[receiver_name].get(receiver_socket_name) - - # Only overwrite if there's no existing value, or we have a new value to provide - if current_value is None or value is not _NO_OUTPUT_PRODUCED: - inputs[receiver_name][receiver_socket_name] = [{"sender": component_name, "value": value}] + @patch("haystack.core.pipeline.draw.requests") + def test_pipeline_draw_called_with_keyword_args_triggers_no_warning(self, mock_requests): + """ + Test that calling the pipeline draw method with keyword arguments does not raise a warning. + """ + from pathlib import Path + import warnings + + pipeline = PipelineBase() + mock_response = mock_requests.get.return_value + mock_response.status_code = 200 + mock_response.content = b"image_data" + out_file = Path("original_pipeline.png") + + with warnings.catch_warnings(record=True) as w: + pipeline.draw(path=out_file, server_url="http://localhost:3000") + assert len(w) == 0, "No warning should be triggered when using keyword arguments" + + @patch("haystack.core.pipeline.draw.requests") + @patch("haystack.core.pipeline.base.is_in_jupyter") + def test_pipeline_show_called_with_keyword_args_triggers_no_warning(self, mock_is_in_jupyter, mock_requests): + """ + Test that calling the pipeline show method with keyword arguments does not raise a warning. + """ + import warnings + + pipeline = PipelineBase() + mock_response = mock_requests.get.return_value + mock_response.status_code = 200 + mock_response.content = b"image_data" + mock_is_in_jupyter.return_value = True + + with warnings.catch_warnings(record=True) as w: + pipeline.show(server_url="http://localhost:3000") + assert len(w) == 0, "No warning should be triggered when using keyword arguments" + + +class TestValidateInput: + def test_validate_input_valid_data(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + pipe.add_component("comp1", comp1) + pipe._validate_input(data={"comp1": {"x": 1}}) + # No exception should be raised + + def test_validate_input_missing_mandatory_input(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + pipe.add_component("comp1", comp1) + with pytest.raises(ValueError, match="Missing mandatory input 'x' for component 'comp1'"): + pipe._validate_input(data={"comp1": {}}) + + def test_validate_input_missing_mandatory_input_for_component_not_in_data(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + comp2 = component_class("Comp2", input_types={"a": str}, output_types={"b": str})() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) # comp2 requires 'a' but is not in data + with pytest.raises(ValueError, match="Missing mandatory input 'a' for component 'comp2' which was not provided in the input data."): + pipe._validate_input(data={"comp1": {"x": 1}}) + + + def test_validate_input_to_already_connected_socket(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + comp2 = component_class("Comp2", input_types={"a": int}, output_types={"b": int})() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.connect("comp1.y", "comp2.a") + with pytest.raises(ValueError, match="Input 'a' for component 'comp2' is already provided by component 'comp1'. Do not provide it directly."): + pipe._validate_input(data={"comp2": {"a": 1}}) + + def test_validate_input_for_non_existent_component(self): + pipe = PipelineBase() + with pytest.raises(ValueError, match="Component 'non_existent' not found in the pipeline. Available components: \\[\\]"): + pipe._validate_input(data={"non_existent": {"x": 1}}) + + def test_validate_input_with_unexpected_input_name(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", input_types={"x": int}, output_types={"y": int})() + pipe.add_component("comp1", comp1) + with pytest.raises(ValueError, match="Unexpected input 'z' for component 'comp1'. Available inputs: \\['x'\\]"): + pipe._validate_input(data={"comp1": {"z": 1}}) + + def test_validate_input_variadic_socket_can_receive_multiple_inputs(self): + pipe = PipelineBase() + comp1 = component_class("Comp1", output_types={"y": int})() + comp2 = component_class("Comp2", input_types={"a": Variadic[int]}, output_types={"b": int})() + pipe.add_component("comp1", comp1) + pipe.add_component("comp2", comp2) + pipe.connect("comp1.y", "comp2.a") + # Should not raise an error, as variadic sockets can accept multiple inputs + pipe._validate_input(data={"comp2": {"a": 1}}) \ No newline at end of file From 91e7a0edd5bf2f2587c2589aca13af9c90317eab Mon Sep 17 00:00:00 2001 From: carlosrinc Date: Thu, 19 Jun 2025 10:38:13 -0500 Subject: [PATCH 3/5] Fix linting and style errors --- haystack/core/pipeline/base.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 3c8e817bc0..aac58f970b 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -72,6 +72,7 @@ class ComponentPriority(IntEnum): class PipelineBase: + __hash__ = None """ Components orchestration engine. @@ -899,7 +900,7 @@ def _create_component_span( parent_span=parent_span, ) - def _validate_component_input(self, component_name: str, component_inputs: Dict[str, Any], data: Dict[str, Any]) -> None: + def _validate_component_input(self, component_name: str, component_inputs: Dict[str, Any]) -> None: """ Validates input data for a specific component. @@ -909,7 +910,11 @@ def _validate_component_input(self, component_name: str, component_inputs: Dict[ :raises ValueError: If inputs are invalid. """ if component_name not in self.graph.nodes: - raise ValueError(f"Component '{component_name}' not found in the pipeline. Available components: {list(self.graph.nodes.keys())}") + available_nodes_message = f"Available components: {list(self.graph.nodes.keys())}" + raise ValueError( + f"Component '{component_name}' not found in the pipeline. " + f"{available_nodes_message}" + ) instance = self.graph.nodes[component_name]["instance"] # Validate that all mandatory inputs are provided either directly or by senders @@ -920,8 +925,12 @@ def _validate_component_input(self, component_name: str, component_inputs: Dict[ # Validate that provided inputs exist in the component's input sockets for input_name in component_inputs.keys(): if input_name not in instance.__haystack_input__._sockets_dict: - raise ValueError(f"Unexpected input '{input_name}' for component '{component_name}'. Available inputs: {list(instance.__haystack_input__._sockets_dict.keys())}") - + available_inputs_message = f"Available inputs: {list(instance.__haystack_input__._sockets_dict.keys())}" + raise ValueError( + f"Unexpected input '{input_name}' for component '{component_name}'. " + f"{available_inputs_message}" + ) + # Validate that inputs are not multiply defined (already sent by another component and also provided directly) # unless the socket is variadic for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): @@ -960,8 +969,12 @@ def _validate_input(self, data: Dict[str, Any]) -> None: instance = self.graph.nodes[component_name_in_graph]["instance"] for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): if socket.is_mandatory and not socket.senders: - raise ValueError(f"Missing mandatory input '{socket_name}' for component '{component_name_in_graph}' which was not provided in the input data.") - + error_message = ( + f"Missing mandatory input '{socket_name}' for component '{component_name_in_graph}' " + "(not found in input data)." + ) + raise ValueError(error_message) + def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: """ @@ -1513,4 +1526,4 @@ def _write_to_standard_socket( # Only overwrite if there's no existing value, or we have a new value to provide if current_value is None or value is not _NO_OUTPUT_PRODUCED: - inputs[receiver_name][receiver_socket_name] = [{"sender": component_name, "value": value}] \ No newline at end of file + inputs[receiver_name][receiver_socket_name] = [{"sender": component_name, "value": value}] From 2853b7bdebb4fa82420f948b17a76f65f76a80d8 Mon Sep 17 00:00:00 2001 From: carlosrinc Date: Fri, 20 Jun 2025 08:21:08 -0500 Subject: [PATCH 4/5] format --- haystack/core/pipeline/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index aac58f970b..a2289334b0 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -914,7 +914,7 @@ def _validate_component_input(self, component_name: str, component_inputs: Dict[ raise ValueError( f"Component '{component_name}' not found in the pipeline. " f"{available_nodes_message}" - ) + ) instance = self.graph.nodes[component_name]["instance"] # Validate that all mandatory inputs are provided either directly or by senders @@ -930,7 +930,7 @@ def _validate_component_input(self, component_name: str, component_inputs: Dict[ f"Unexpected input '{input_name}' for component '{component_name}'. " f"{available_inputs_message}" ) - + # Validate that inputs are not multiply defined (already sent by another component and also provided directly) # unless the socket is variadic for socket_name, socket in instance.__haystack_input__._sockets_dict.items(): @@ -974,7 +974,7 @@ def _validate_input(self, data: Dict[str, Any]) -> None: "(not found in input data)." ) raise ValueError(error_message) - + def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: """ From 7d943fd9cdcb6024a2be65822f4cc7087b6e5777 Mon Sep 17 00:00:00 2001 From: carlosrinc Date: Fri, 20 Jun 2025 08:30:19 -0500 Subject: [PATCH 5/5] release_notes --- releasenotes/notes/refactor-validate-input.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 releasenotes/notes/refactor-validate-input.yaml diff --git a/releasenotes/notes/refactor-validate-input.yaml b/releasenotes/notes/refactor-validate-input.yaml new file mode 100644 index 0000000000..335f5c3bcd --- /dev/null +++ b/releasenotes/notes/refactor-validate-input.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + - Refactore the PipelineBase._validate_input() method to improve clarity and maintainability. + - break down the method into smaller helper functions and enhance error messages for better specificity. \ No newline at end of file