diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 26795b8a9833..6ef06abb7436 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -866,6 +866,100 @@ def _infer_result_type( transform: ptransform.PTransform, inputs: Sequence[Union[pvalue.PBegin, pvalue.PCollection]], result_pcollection: Union[pvalue.PValue, pvalue.DoOutputsTuple]) -> None: + """Infer and set the output element type for a PCollection. + + This function determines the output types of transforms by combining: + 1. Concrete input types from previous transforms + 2. Type hints declared on the current transform + 3. Type variable binding and substitution + + TYPE VARIABLE BINDING + --------------------- + Type variables (K, V, T, etc.) act as placeholders that get bound to + concrete types through pattern matching. This requires both an input + pattern and an output template: + + Input Pattern (from .with_input_types()): + Defines where in the input to find each type variable + Example: Tuple[K, V] means "K is the first element, V is the second" + + Output Template (from .with_output_types()): + Defines how to use the bound variables in the output + Example: Tuple[V, K] means "swap the positions" + + CONCRETE TYPES VS TYPE VARIABLES + --------------------------------- + The system handles these differently: + + Concrete Types (e.g., str, int, Tuple[str, int]): + - Used as-is without any binding + - Do not fall back to Any + - Example: .with_output_types(Tuple[str, int]) → Tuple[str, int] + + Type Variables (e.g., K, V, T): + - Must be bound through pattern matching + - Require .with_input_types() to provide the pattern + - Fall back to Any if not bound + - Example without pattern: Tuple[K, V] → Tuple[Any, Any] + - Example with pattern: Tuple[K, V] → Tuple[str, int] + + BINDING ALGORITHM + ----------------- + 1. Match: Compare input pattern to concrete input + Pattern: Tuple[K, V] + Concrete: Tuple[str, int] + Result: {K: str, V: int} ← Bindings created + + 2. Substitute: Apply bindings to output template + Template: Tuple[V, K] ← Note: swapped! + Bindings: {K: str, V: int} + Result: Tuple[int, str] ← Swapped concrete types + + Each transform operates in its own type inference scope. Type variables + declared in a parent composite transform do NOT automatically propagate + to child transforms. + + Parent scope (composite): + @with_input_types(Tuple[K, V]) ← K, V defined here + class MyComposite(PTransform): + def expand(self, pcoll): + # Child scope - parent's K, V are NOT available + return pcoll | ChildTransform() + + Type variables that remain unbound after inference fall back to Any: + + EXAMPLES + -------- + Example 1: Concrete types (no variables) + Input: Tuple[str, int] + Transform: .with_output_types(Tuple[str, int]) + Output: Tuple[str, int] ← Used as-is + + Example 2: Type variables with pattern (correct) + Input: Tuple[str, int] + Transform: .with_input_types(Tuple[K, V]) + .with_output_types(Tuple[V, K]) + Binding: {K: str, V: int} + Output: Tuple[int, str] ← Swapped! + + Example 3: Type variables without pattern (falls back to Any) + Input: Tuple[str, int] + Transform: .with_output_types(Tuple[K, V]) ← No input pattern! + Binding: None (can't match) + Output: Tuple[Any, Any] ← Fallback + + Example 4: Mixed concrete and variables + Input: Tuple[str, int] + Transform: .with_input_types(Tuple[str, V]) + .with_output_types(Tuple[str, V]) + Binding: {V: int} ← Only V needs binding + Output: Tuple[str, int] ← str passed through, V bound to int + + Args: + transform: The PTransform being applied + inputs: Input PCollections (provides concrete types) + result_pcollection: Output PCollection to set type on + """ # TODO(robertwb): Multi-input inference. type_options = self._options.view_as(TypeOptions) if type_options is None or not type_options.pipeline_type_check: @@ -881,6 +975,7 @@ def _infer_result_type( else typehints.Union[input_element_types_tuple]) type_hints = transform.get_type_hints() declared_output_type = type_hints.simple_output_type(transform.label) + if declared_output_type: input_types = type_hints.input_types if input_types and input_types[0]: @@ -893,6 +988,7 @@ def _infer_result_type( result_element_type = declared_output_type else: result_element_type = transform.infer_output_type(input_element_type) + # Any remaining type variables have no bindings higher than this scope. result_pcollection.element_type = typehints.bind_type_variables( result_element_type, {'*': typehints.Any}) diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index e70fd3db0b88..9a9bf6ff0a74 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -1402,6 +1402,105 @@ def process(self, element, five): assert_that(d, equal_to([6, 7, 8])) self.p.run() + def test_child_with_both_input_and_output_hints_binds_typevars_correctly( + self): + """ + When a child transform has both input and output type hints with type + variables, those variables bind correctly from the actual input data. + + Example: Child with .with_input_types(Tuple[K, V]) + .with_output_types(Tuple[K, V]) receiving Tuple['a', 'hello'] will bind + K=str, V=str correctly. + """ + K = typehints.TypeVariable('K') + V = typehints.TypeVariable('V') + + @typehints.with_input_types(typehints.Tuple[K, V]) + @typehints.with_output_types(typehints.Tuple[K, V]) + class TransformWithoutChildHints(beam.PTransform): + class MyDoFn(beam.DoFn): + def process(self, element): + k, v = element + yield (k, v.upper()) + + def expand(self, pcoll): + return ( + pcoll + | beam.ParDo(self.MyDoFn()).with_input_types( + tuple[K, V]).with_output_types(tuple[K, V])) + + with TestPipeline() as p: + result = ( + p + | beam.Create([('a', 'hello'), ('b', 'world')]) + | TransformWithoutChildHints()) + + self.assertEqual(result.element_type, typehints.Tuple[str, str]) + + def test_child_without_input_hints_fails_to_bind_typevars(self): + """ + When a child transform lacks input type hints, type variables in its output + hints cannot bind and default to Any, even when parent composite has + decorated type hints. + + This test demonstrates the current limitation: without explicit input hints + on the child, the type variable K in .with_output_types(Tuple[K, str]) + remains unbound, resulting in Tuple[Any, str] instead of the expected + Tuple[str, str]. + """ + K = typehints.TypeVariable('K') + + @typehints.with_input_types(typehints.Tuple[K, str]) + @typehints.with_output_types(typehints.Tuple[K, str]) + class TransformWithoutChildHints(beam.PTransform): + class MyDoFn(beam.DoFn): + def process(self, element): + k, v = element + yield (k, v.upper()) + + def expand(self, pcoll): + return ( + pcoll + | beam.ParDo(self.MyDoFn()).with_output_types(tuple[K, str])) + + with TestPipeline() as p: + result = ( + p + | beam.Create([('a', 'hello'), ('b', 'world')]) + | TransformWithoutChildHints()) + + self.assertEqual(result.element_type, typehints.Tuple[typehints.Any, str]) + + def test_child_without_output_hints_infers_partial_types_from_dofn(self): + """ + When a child transform has input hints but no output hints, type inference + from the DoFn's process method produces partially inferred types. + + Type inference is able to infer the first element of the tuple as str, but + not the v.upper() and falls back to any. + """ + K = typehints.TypeVariable('K') + V = typehints.TypeVariable('V') + + @typehints.with_input_types(typehints.Tuple[K, V]) + @typehints.with_output_types(typehints.Tuple[K, V]) + class TransformWithoutChildHints(beam.PTransform): + class MyDoFn(beam.DoFn): + def process(self, element): + k, v = element + yield (k, v.upper()) + + def expand(self, pcoll): + return (pcoll | beam.ParDo(self.MyDoFn()).with_input_types(tuple[K, V])) + + with TestPipeline() as p: + result = ( + p + | beam.Create([('a', 'hello'), ('b', 'world')]) + | TransformWithoutChildHints()) + + self.assertEqual(result.element_type, typehints.Tuple[str, typing.Any]) + def test_do_fn_pipeline_pipeline_type_check_violated(self): @with_input_types(str, str) @with_output_types(str)