From b9ad291dfc607556a0500e00d64944898f2d3f38 Mon Sep 17 00:00:00 2001 From: Assaf Mauda Date: Wed, 3 Dec 2025 14:52:23 +0000 Subject: [PATCH 1/6] add support for kw_only dataclasses (#36978) --- sdks/python/apache_beam/coders/coder_impl.py | 19 +++++++++++++++---- .../apache_beam/coders/coders_test_common.py | 10 ++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 03514bb50db0..84d97dca812b 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -504,11 +504,14 @@ def encode_special_deterministic(self, value, stream): "for the input of '%s'" % (value, type(value), self.requires_deterministic_step_label)) self.encode_type(type(value), stream) - values = [ - getattr(value, field.name) for field in dataclasses.fields(value) + init_field_names = [ + field.name for field in dataclasses.fields(value) if field.init ] + stream.write_var_int64(len(init_field_names)) try: - self.iterable_coder_impl.encode_to_stream(values, stream, True) + for field_name in init_field_names: + stream.write(field_name.encode('utf-8'), True) + self.encode_to_stream(getattr(value, field_name), stream, True) except Exception as e: raise TypeError(self._deterministic_encoding_error_msg(value)) from e elif isinstance(value, tuple) and hasattr(type(value), '_fields'): @@ -616,7 +619,15 @@ def decode_from_stream(self, stream, nested): msg = cls() msg.ParseFromString(stream.read_all(True)) return msg - elif t == DATACLASS_TYPE or t == NAMED_TUPLE_TYPE: + elif t == DATACLASS_TYPE: + cls = self.decode_type(stream) + vlen = stream.read_var_int64() + fields = {} + for _ in range(vlen): + field_name = stream.read_all(True).decode('utf-8') + fields[field_name] = self.decode_from_stream(stream, True) + return cls(**fields) + elif t == NAMED_TUPLE_TYPE: cls = self.decode_type(stream) return cls(*self.iterable_coder_impl.decode_from_stream(stream, True)) elif t == ENUM_TYPE: diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 8a57d1e63e2c..8f89ab9602c1 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -113,6 +113,11 @@ class FrozenDataClass: a: Any b: int + @dataclasses.dataclass(frozen=True, kw_only=True) + class FrozenKwOnlyDataClass: + c: int + d: int + @dataclasses.dataclass class UnFrozenDataClass: x: int @@ -303,9 +308,11 @@ def test_deterministic_coder(self, compat_version): if dataclasses is not None: self.check_coder(deterministic_coder, FrozenDataClass(1, 2)) + self.check_coder(deterministic_coder, FrozenKwOnlyDataClass(c=1, d=2)) with self.assertRaises(TypeError): self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2)) + with self.assertRaises(TypeError): self.check_coder( deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3)) @@ -742,6 +749,7 @@ def test_cross_process_encoding_of_special_types_is_deterministic( from apache_beam.coders.coders_test_common import DefinesGetState from apache_beam.coders.coders_test_common import DefinesGetAndSetState from apache_beam.coders.coders_test_common import FrozenDataClass + from apache_beam.coders.coders_test_common import FrozenKwOnlyDataClass from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message @@ -777,6 +785,8 @@ def test_cross_process_encoding_of_special_types_is_deterministic( test_cases.extend([ ("frozen_dataclass", FrozenDataClass(1, 2)), ("frozen_dataclass_list", [FrozenDataClass(1, 2), FrozenDataClass(3, 4)]), + ("frozen_kwonly_dataclass", FrozenKwOnlyDataClass(c=1, d=2)), + ("frozen_kwonly_dataclass_list", [FrozenKwOnlyDataClass(c=1, d=2), FrozenKwOnlyDataClass(c=3, d=4)]), ]) compat_version = {'"'+ compat_version +'"' if compat_version else None} From 2f44c30919e57078fc4af9d4be01e33d5e2d42e3 Mon Sep 17 00:00:00 2001 From: Assaf Mauda Date: Thu, 18 Dec 2025 08:19:13 +0000 Subject: [PATCH 2/6] use a different type for kw_ony dataclasses --- sdks/python/apache_beam/coders/coder_impl.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 84d97dca812b..3fd87b179a50 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -350,10 +350,11 @@ def decode(self, value): ITERABLE_LIKE_TYPE = 10 PROTO_TYPE = 100 -DATACLASS_TYPE = 101 +DATACLASS_TYPE = 101 # TODO: Deprecate and use only DATACLASS_KW_ONLY_TYPE. NAMED_TUPLE_TYPE = 102 ENUM_TYPE = 103 NESTED_STATE_TYPE = 104 +DATACLASS_KW_ONLY_TYPE = 105 # Types that can be encoded as iterables, but are not literally # lists, etc. due to being lazy. The actual type is not preserved @@ -497,7 +498,7 @@ def encode_special_deterministic(self, value, stream): self.encode_type(type(value), stream) stream.write(value.SerializePartialToString(deterministic=True), True) elif dataclasses and dataclasses.is_dataclass(value): - stream.write_byte(DATACLASS_TYPE) + stream.write_byte(DATACLASS_KW_ONLY_TYPE) if not type(value).__dataclass_params__.frozen: raise TypeError( "Unable to deterministically encode non-frozen '%s' of type '%s' " @@ -619,7 +620,7 @@ def decode_from_stream(self, stream, nested): msg = cls() msg.ParseFromString(stream.read_all(True)) return msg - elif t == DATACLASS_TYPE: + elif t == DATACLASS_KW_ONLY_TYPE: cls = self.decode_type(stream) vlen = stream.read_var_int64() fields = {} @@ -627,7 +628,7 @@ def decode_from_stream(self, stream, nested): field_name = stream.read_all(True).decode('utf-8') fields[field_name] = self.decode_from_stream(stream, True) return cls(**fields) - elif t == NAMED_TUPLE_TYPE: + elif t == DATACLASS_TYPE or t == NAMED_TUPLE_TYPE: cls = self.decode_type(stream) return cls(*self.iterable_coder_impl.decode_from_stream(stream, True)) elif t == ENUM_TYPE: From 24df6a2963291b24188db342982326012983ddab Mon Sep 17 00:00:00 2001 From: Assaf Mauda Date: Thu, 18 Dec 2025 09:01:53 +0000 Subject: [PATCH 3/6] allow passing positional parameters to create dataclasses when possible --- sdks/python/apache_beam/coders/coder_impl.py | 48 +++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 3fd87b179a50..5c0c500d5d75 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -32,6 +32,7 @@ import decimal import enum +import functools import itertools import json import logging @@ -375,6 +376,17 @@ def _verify_dill_compat(): raise RuntimeError(base_error + f". Found dill version '{dill.__version__}") +if dataclasses: + # Cache the result to avoid multiple checks for the same dataclass type. + @functools.cache + def dataclass_uses_kw_only(cls): + return any( + field.init and field.kw_only for field in dataclasses.fields(cls)) + +else: + dataclass_uses_kw_only = lambda cls: False + + class FastPrimitivesCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" def __init__( @@ -498,23 +510,35 @@ def encode_special_deterministic(self, value, stream): self.encode_type(type(value), stream) stream.write(value.SerializePartialToString(deterministic=True), True) elif dataclasses and dataclasses.is_dataclass(value): - stream.write_byte(DATACLASS_KW_ONLY_TYPE) if not type(value).__dataclass_params__.frozen: raise TypeError( "Unable to deterministically encode non-frozen '%s' of type '%s' " "for the input of '%s'" % (value, type(value), self.requires_deterministic_step_label)) - self.encode_type(type(value), stream) - init_field_names = [ - field.name for field in dataclasses.fields(value) if field.init - ] - stream.write_var_int64(len(init_field_names)) - try: - for field_name in init_field_names: - stream.write(field_name.encode('utf-8'), True) - self.encode_to_stream(getattr(value, field_name), stream, True) - except Exception as e: - raise TypeError(self._deterministic_encoding_error_msg(value)) from e + if dataclass_uses_kw_only(type(value)): + stream.write_byte(DATACLASS_KW_ONLY_TYPE) + self.encode_type(type(value), stream) + init_field_names = [ + field.name for field in dataclasses.fields(value) if field.init + ] + stream.write_var_int64(len(init_field_names)) + try: + for field_name in init_field_names: + stream.write(field_name.encode("utf-8"), True) + self.encode_to_stream(getattr(value, field_name), stream, True) + except Exception as e: + raise TypeError(self._deterministic_encoding_error_msg(value)) from e + else: # Not using kw_only, we can pass parameters by position. + stream.write_byte(DATACLASS_TYPE) + self.encode_type(type(value), stream) + values = [ + getattr(value, field.name) for field in dataclasses.fields(value) + if field.init + ] + try: + self.iterable_coder_impl.encode_to_stream(values, stream, True) + except Exception as e: + raise TypeError(self._deterministic_encoding_error_msg(value)) from e elif isinstance(value, tuple) and hasattr(type(value), '_fields'): stream.write_byte(NAMED_TUPLE_TYPE) self.encode_type(type(value), stream) From 04cbd5d99e3763ecb78cc824729d936a5b4f2ad3 Mon Sep 17 00:00:00 2001 From: Assaf Mauda Date: Thu, 18 Dec 2025 09:07:14 +0000 Subject: [PATCH 4/6] remove wrong TODO --- sdks/python/apache_beam/coders/coder_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 5c0c500d5d75..261b6256d523 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -351,7 +351,7 @@ def decode(self, value): ITERABLE_LIKE_TYPE = 10 PROTO_TYPE = 100 -DATACLASS_TYPE = 101 # TODO: Deprecate and use only DATACLASS_KW_ONLY_TYPE. +DATACLASS_TYPE = 101 NAMED_TUPLE_TYPE = 102 ENUM_TYPE = 103 NESTED_STATE_TYPE = 104 From 1900a0adbf6fc14425fb1eaeb109e5b69e5dd7bd Mon Sep 17 00:00:00 2001 From: Assaf Mauda Date: Sun, 21 Dec 2025 11:53:57 +0000 Subject: [PATCH 5/6] add function typehint for pylint --- sdks/python/apache_beam/coders/coder_impl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 261b6256d523..978664df93bf 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -376,10 +376,11 @@ def _verify_dill_compat(): raise RuntimeError(base_error + f". Found dill version '{dill.__version__}") +dataclass_uses_kw_only: Callable[[Any], bool] if dataclasses: # Cache the result to avoid multiple checks for the same dataclass type. @functools.cache - def dataclass_uses_kw_only(cls): + def dataclass_uses_kw_only(cls) -> bool: return any( field.init and field.kw_only for field in dataclasses.fields(cls)) From ecfcf8536de42de872efeff59e2c51d6ca02aefd Mon Sep 17 00:00:00 2001 From: Assaf Mauda Date: Sun, 21 Dec 2025 12:20:18 +0000 Subject: [PATCH 6/6] minor refactoring --- sdks/python/apache_beam/coders/coder_impl.py | 38 ++++++++------------ 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 978664df93bf..1e3bb2ece92a 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -516,30 +516,22 @@ def encode_special_deterministic(self, value, stream): "Unable to deterministically encode non-frozen '%s' of type '%s' " "for the input of '%s'" % (value, type(value), self.requires_deterministic_step_label)) - if dataclass_uses_kw_only(type(value)): - stream.write_byte(DATACLASS_KW_ONLY_TYPE) - self.encode_type(type(value), stream) - init_field_names = [ - field.name for field in dataclasses.fields(value) if field.init - ] - stream.write_var_int64(len(init_field_names)) - try: - for field_name in init_field_names: - stream.write(field_name.encode("utf-8"), True) - self.encode_to_stream(getattr(value, field_name), stream, True) - except Exception as e: - raise TypeError(self._deterministic_encoding_error_msg(value)) from e - else: # Not using kw_only, we can pass parameters by position. - stream.write_byte(DATACLASS_TYPE) - self.encode_type(type(value), stream) - values = [ - getattr(value, field.name) for field in dataclasses.fields(value) - if field.init - ] - try: + init_fields = [field for field in dataclasses.fields(value) if field.init] + try: + if dataclass_uses_kw_only(type(value)): + stream.write_byte(DATACLASS_KW_ONLY_TYPE) + self.encode_type(type(value), stream) + stream.write_var_int64(len(init_fields)) + for field in init_fields: + stream.write(field.name.encode("utf-8"), True) + self.encode_to_stream(getattr(value, field.name), stream, True) + else: # Not using kw_only, we can pass parameters by position. + stream.write_byte(DATACLASS_TYPE) + self.encode_type(type(value), stream) + values = [getattr(value, field.name) for field in init_fields] self.iterable_coder_impl.encode_to_stream(values, stream, True) - except Exception as e: - raise TypeError(self._deterministic_encoding_error_msg(value)) from e + except Exception as e: + raise TypeError(self._deterministic_encoding_error_msg(value)) from e elif isinstance(value, tuple) and hasattr(type(value), '_fields'): stream.write_byte(NAMED_TUPLE_TYPE) self.encode_type(type(value), stream)