diff --git a/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py b/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py index ab066b954b66..8ee770d61691 100644 --- a/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py +++ b/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py @@ -168,7 +168,6 @@ class CloudPickleConfig: DEFAULT_CONFIG = CloudPickleConfig() - builtin_code_type = None if PYPY: # builtin-code objects only exist in pypy diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler.py b/sdks/python/apache_beam/internal/cloudpickle_pickler.py index eebba178e7c3..53cd7aace868 100644 --- a/sdks/python/apache_beam/internal/cloudpickle_pickler.py +++ b/sdks/python/apache_beam/internal/cloudpickle_pickler.py @@ -95,6 +95,27 @@ def _get_proto_enum_descriptor_class(): _LOGGER = logging.getLogger(__name__) +# Helper to return an object directly during unpickling. +def _return_obj(obj): + return obj + + +# Optional import for Python 3.12 TypeAliasType +try: # pragma: no cover - dependent on Python version + from typing import TypeAliasType as _TypeAliasType # type: ignore[attr-defined] +except Exception: + _TypeAliasType = None + + +def _typealias_reduce(obj): + # Unwrap typing.TypeAliasType to its underlying value for robust pickling. + underlying = getattr(obj, '__value__', None) + if underlying is None: + # Fallback: return the object itself; lets default behavior handle it. + return _return_obj, (obj, ) + return _return_obj, (underlying, ) + + def _reconstruct_enum_descriptor(full_name): for _, module in list(sys.modules.items()): if not hasattr(module, 'DESCRIPTOR'): @@ -171,6 +192,9 @@ def _dumps( pickler.dispatch_table[type(flags.FLAGS)] = _pickle_absl_flags except NameError: pass + # Register Python 3.12 `type` alias reducer to unwrap to underlying value. + if _TypeAliasType is not None: + pickler.dispatch_table[_TypeAliasType] = _typealias_reduce try: pickler.dispatch_table[RLOCK_TYPE] = _pickle_rlock except NameError: diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index ea736dceddb1..e70fd3db0b88 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -25,6 +25,7 @@ import pickle import random import re +import sys import typing import unittest from functools import reduce @@ -2910,6 +2911,37 @@ def test_threshold(self): use_subprocess=self.use_subprocess)) +class PTransformTypeAliasTest(unittest.TestCase): + @unittest.skipIf(sys.version_info < (3, 12), "Python 3.12 required") + def test_type_alias_statement_supported_in_with_output_types(self): + ns = {} + exec("type InputType = tuple[int, ...]", ns) # pylint: disable=exec-used + InputType = ns["InputType"] + + def print_element(element: InputType) -> InputType: + return element + + with beam.Pipeline() as p: + _ = ( + p + | beam.Create([(1, 2)]) + | beam.Map(lambda x: x) + | beam.Map(print_element)) + + @unittest.skipIf(sys.version_info < (3, 12), "Python 3.12 required") + def test_type_alias_supported_in_ptransform_with_output_types(self): + ns = {} + exec("type OutputType = tuple[int, int]", ns) # pylint: disable=exec-used + OutputType = ns["OutputType"] + + with beam.Pipeline() as p: + _ = ( + p + | beam.Create([(1, 2)]) + | beam.Map(lambda x: x) + | beam.Map(lambda x: x).with_output_types(OutputType)) + + class TestPTransformFn(TypeHintTestCase): def test_type_checking_fail(self): @beam.ptransform_fn diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index b6bf6d37fe02..2360df142167 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -35,6 +35,14 @@ except ImportError: from typing_extensions import is_typeddict +# Python 3.12 adds TypeAliasType for `type` statements; keep optional import. +# pylint: disable=ungrouped-imports +# isort: off +try: + from typing import TypeAliasType # type: ignore[attr-defined] +except Exception: # pragma: no cover - pre-3.12 + TypeAliasType = None # type: ignore[assignment] + T = TypeVar('T') _LOGGER = logging.getLogger(__name__) @@ -332,6 +340,14 @@ def convert_to_beam_type(typ): sys.version_info.minor >= 10) and (isinstance(typ, types.UnionType)): typ = typing.Union[typ] + # Unwrap Python 3.12 `type` aliases (TypeAliasType) to their underlying value. + # This ensures Beam sees the actual aliased type (e.g., tuple[int, ...]). + if sys.version_info >= (3, 12) and TypeAliasType is not None: + if isinstance(typ, TypeAliasType): # pylint: disable=isinstance-second-argument-not-valid-type + underlying = getattr(typ, '__value__', None) + if underlying is not None: + typ = underlying + if getattr(typ, '__module__', None) == 'typing': typ = convert_typing_to_builtin(typ) diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py index f6a13d7795a0..0e933b0d4925 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py @@ -491,6 +491,24 @@ def test_convert_typing_to_builtin(self): builtin_type = convert_typing_to_builtin(typing_type) self.assertEqual(builtin_type, expected_builtin_type, description) + def test_type_alias_type_unwrapped(self): + # Only applicable on Python 3.12+, where typing.TypeAliasType exists + # and the `type` statement is available. + TypeAliasType = getattr(typing, 'TypeAliasType', None) + if TypeAliasType is None: + self.skipTest('TypeAliasType not available') + + ns = {} + try: + exec('type AliasTuple = tuple[int, ...]', {}, ns) # pylint: disable=exec-used + except SyntaxError: + self.skipTest('type statement not supported') + + AliasTuple = ns['AliasTuple'] + self.assertTrue(isinstance(AliasTuple, TypeAliasType)) # pylint: disable=isinstance-second-argument-not-valid-type + self.assertEqual( + typehints.Tuple[int, ...], convert_to_beam_type(AliasTuple)) + if __name__ == '__main__': unittest.main()