diff --git a/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py b/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py index 8ee770d61691..495e888a5167 100644 --- a/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py +++ b/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py @@ -62,6 +62,7 @@ import dis from enum import Enum import functools +import hashlib import io import itertools import logging @@ -98,7 +99,7 @@ _DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary() _DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary() _DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS = weakref.WeakKeyDictionary() -_DYNAMIC_CLASS_TRACKER_LOCK = threading.Lock() +_DYNAMIC_CLASS_TRACKER_LOCK = threading.RLock() PYPY = platform.python_implementation() == "PyPy" @@ -168,6 +169,7 @@ class CloudPickleConfig: DEFAULT_CONFIG = CloudPickleConfig() +_GENERATING_SENTINEL = object() builtin_code_type = None if PYPY: # builtin-code objects only exist in pypy @@ -179,10 +181,21 @@ class CloudPickleConfig: def _get_or_create_tracker_id(class_def, id_generator): with _DYNAMIC_CLASS_TRACKER_LOCK: class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def) + if class_tracker_id is _GENERATING_SENTINEL and id_generator: + raise RuntimeError( + f"Recursive ID generation detected for {class_def}. " + f"The id_generator cannot recursively request an ID for the same class." + ) + if class_tracker_id is None and id_generator is not None: - class_tracker_id = id_generator(class_def) - _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id - _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def + _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = _GENERATING_SENTINEL + try: + class_tracker_id = id_generator(class_def) + _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id + _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def + except Exception: + _DYNAMIC_CLASS_TRACKER_BY_CLASS.pop(class_def, None) + raise return class_tracker_id @@ -1720,3 +1733,10 @@ def dumps( # Backward compat alias. CloudPickler = Pickler + + +def hash_dynamic_classdef(classdef): + """Generates a deterministic ID by hashing the pickled class definition.""" + hexdigest = hashlib.sha256( + dumps(classdef, config=CloudPickleConfig(id_generator=None))).hexdigest() + return hexdigest