From 04a2e370b1297f1f8da658943f9e5c86fba7a07c Mon Sep 17 00:00:00 2001 From: dustin12 Date: Mon, 10 Nov 2025 13:45:44 -0800 Subject: [PATCH 1/5] Allow for a custom id function other than the default hashing funciton. --- .../apache_beam/transforms/async_dofn.py | 42 ++++++++++++------- .../apache_beam/transforms/async_dofn_test.py | 35 ++++++++++++++++ 2 files changed, 63 insertions(+), 14 deletions(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index d2fa90c85085..a7e608f46521 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -77,6 +77,7 @@ def __init__( max_items_to_buffer=None, timeout=1, max_wait_time=0.5, + id_fn=None, ): """Wraps the sync_fn to create an asynchronous version. @@ -101,6 +102,8 @@ def __init__( locally before it goes in the queue of waiting work. max_wait_time: The maximum amount of sleep time while attempting to schedule an item. Used in testing to ensure timeouts are met. + id_fn: A function that returns a hashable object from an element. This + will be used to track items instead of the element's default hash. """ self._sync_fn = sync_fn self._uuid = uuid.uuid4().hex @@ -108,6 +111,7 @@ def __init__( self._timeout = timeout self._max_wait_time = max_wait_time self._timer_frequency = callback_frequency + self._id_fn = id_fn or (lambda x: x) if max_items_to_buffer is None: self._max_items_to_buffer = max(parallelism * 2, 10) else: @@ -205,7 +209,8 @@ def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs): True if the item was scheduled False otherwise. """ with AsyncWrapper._lock: - if element in AsyncWrapper._processing_elements[self._uuid]: + element_id = self._id_fn(element[1]) + if element_id in AsyncWrapper._processing_elements[self._uuid]: logging.info('item %s already in processing elements', element) return True if self.accepting_items() or ignore_buffer: @@ -214,7 +219,8 @@ def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs): lambda: self.sync_fn_process(element, *args, **kwargs), ) result.add_done_callback(self.decrement_items_in_buffer) - AsyncWrapper._processing_elements[self._uuid][element] = result + AsyncWrapper._processing_elements[self._uuid][element_id] = ( + element, result) AsyncWrapper._items_in_buffer[self._uuid] += 1 return True else: @@ -362,27 +368,34 @@ def commit_finished_items( # given key. Skip items in processing_elements which are for a different # key. with AsyncWrapper._lock: - for x in AsyncWrapper._processing_elements[self._uuid]: - if x[0] == key and x not in to_process_local: + processing_elements = AsyncWrapper._processing_elements[self._uuid] + to_process_local_ids = {self._id_fn(e[1]) for e in to_process_local} + to_remove_ids = [] + for element_id, ( + element, + future) in processing_elements.items(): + if element[0] == key and element_id not in to_process_local_ids: items_cancelled += 1 - AsyncWrapper._processing_elements[self._uuid][x].cancel() - to_remove.append(x) + future.cancel() + to_remove_ids.append(element_id) logging.info( - 'cancelling item %s which is no longer in processing state', x) - for x in to_remove: - AsyncWrapper._processing_elements[self._uuid].pop(x) + 'cancelling item %s which is no longer in processing state', + element) + for element_id in to_remove_ids: + processing_elements.pop(element_id) # For all elements which have finished processing output their result. to_return = [] finished_items = [] for x in to_process_local: items_in_se_state += 1 - if x in AsyncWrapper._processing_elements[self._uuid]: - if AsyncWrapper._processing_elements[self._uuid][x].done(): - to_return.append( - AsyncWrapper._processing_elements[self._uuid][x].result()) + x_id = self._id_fn(x[1]) + if x_id in processing_elements: + _element, future = processing_elements[x_id] + if future.done(): + to_return.append(future.result()) finished_items.append(x) - AsyncWrapper._processing_elements[self._uuid].pop(x) + processing_elements.pop(x_id) items_finished += 1 else: items_not_yet_finished += 1 @@ -444,3 +457,4 @@ def timer_callback( A generator of elements that have finished processing for this key. """ return self.commit_finished_items(to_process, timer) + \ No newline at end of file diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index 7577e215d1c7..ef3a1330eaaf 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -119,6 +119,41 @@ def check_items_in_buffer(self, async_dofn, expected_count): expected_count, ) + def test_custom_id_fn(self): + class CustomIdObject: + def __init__(self, element_id, value): + self.element_id = element_id + self.value = value + + def __hash__(self): + return hash(self.element_id) + + def __eq__(self, other): + return self.element_id == other.element_id + + dofn = BasicDofn() + async_dofn = async_lib.AsyncWrapper( + dofn, id_fn=lambda x: x.element_id) + async_dofn.setup() + fake_bag_state = FakeBagState([]) + fake_timer = FakeTimer(0) + msg1 = ('key1', CustomIdObject(1, 'a')) + msg2 = ('key1', CustomIdObject(1, 'b')) + + result = async_dofn.process( + msg1, to_process=fake_bag_state, timer=fake_timer) + self.assertEqual(result, []) + + # The second message should be a no-op as it has the same id. + result = async_dofn.process( + msg2, to_process=fake_bag_state, timer=fake_timer) + self.assertEqual(result, []) + + self.wait_for_empty(async_dofn) + result = async_dofn.commit_finished_items(fake_bag_state, fake_timer) + self.check_output(result, [('key1', msg1[1])]) + self.assertEqual(fake_bag_state.items, []) + def test_basic(self): # Setup an async dofn and send a message in to process. dofn = BasicDofn() From cad73a67d8c8251ede7fe1079dc8e8a6146e8169 Mon Sep 17 00:00:00 2001 From: dustin12 Date: Tue, 11 Nov 2025 10:59:15 -0800 Subject: [PATCH 2/5] fix formatting errors --- sdks/python/apache_beam/transforms/async_dofn.py | 3 +-- sdks/python/apache_beam/transforms/async_dofn_test.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index a7e608f46521..2830cd3251ef 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -456,5 +456,4 @@ def timer_callback( Returns: A generator of elements that have finished processing for this key. """ - return self.commit_finished_items(to_process, timer) - \ No newline at end of file + return self.commit_finished_items(to_process, timer) \ No newline at end of file diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index ef3a1330eaaf..fe75de05ccd5 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -132,8 +132,7 @@ def __eq__(self, other): return self.element_id == other.element_id dofn = BasicDofn() - async_dofn = async_lib.AsyncWrapper( - dofn, id_fn=lambda x: x.element_id) + async_dofn = async_lib.AsyncWrapper(dofn, id_fn=lambda x: x.element_id) async_dofn.setup() fake_bag_state = FakeBagState([]) fake_timer = FakeTimer(0) From 45239f8eac5710ad35d4c2bf88174cf4d2d2cc6d Mon Sep 17 00:00:00 2001 From: dustin12 Date: Tue, 11 Nov 2025 11:11:53 -0800 Subject: [PATCH 3/5] Formatting Fix 2 --- sdks/python/apache_beam/transforms/async_dofn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index 2830cd3251ef..3b6489bf3550 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -371,9 +371,7 @@ def commit_finished_items( processing_elements = AsyncWrapper._processing_elements[self._uuid] to_process_local_ids = {self._id_fn(e[1]) for e in to_process_local} to_remove_ids = [] - for element_id, ( - element, - future) in processing_elements.items(): + for element_id, (element, future) in processing_elements.items(): if element[0] == key and element_id not in to_process_local_ids: items_cancelled += 1 future.cancel() From df9a578d72120a776b6a7f2e1bd3dbb32063d685 Mon Sep 17 00:00:00 2001 From: dustin12 Date: Tue, 11 Nov 2025 15:57:47 -0800 Subject: [PATCH 4/5] fix linter errors --- sdks/python/apache_beam/transforms/async_dofn.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index 3b6489bf3550..1c8d2f3fe90e 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -351,9 +351,6 @@ def commit_finished_items( to_process_local = list(to_process.read()) - # For all elements that in local state but not processing state delete them - # from local state and cancel their futures. - to_remove = [] key = None to_reschedule = [] if to_process_local: @@ -454,4 +451,4 @@ def timer_callback( Returns: A generator of elements that have finished processing for this key. """ - return self.commit_finished_items(to_process, timer) \ No newline at end of file + return self.commit_finished_items(to_process, timer) From 5b031167fcfa38e90b855972d013c4c713bd4ad1 Mon Sep 17 00:00:00 2001 From: dustin12 Date: Wed, 12 Nov 2025 21:07:32 +0000 Subject: [PATCH 5/5] change element_ to _ --- sdks/python/apache_beam/transforms/async_dofn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index 1c8d2f3fe90e..5e1c6d219f4b 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -386,7 +386,7 @@ def commit_finished_items( items_in_se_state += 1 x_id = self._id_fn(x[1]) if x_id in processing_elements: - _element, future = processing_elements[x_id] + _, future = processing_elements[x_id] if future.done(): to_return.append(future.result()) finished_items.append(x)