diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd index f24b75a720e0..3179385d740f 100644 --- a/sdks/python/apache_beam/runners/worker/operations.pxd +++ b/sdks/python/apache_beam/runners/worker/operations.pxd @@ -117,6 +117,7 @@ cdef class DoOperation(Operation): cdef dict timer_specs cdef public object input_info cdef object fn + cdef readonly object scoped_timer_processing_state cdef class SdfProcessSizedElements(DoOperation): diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index 2b20bebe0940..432e5052ccb3 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -49,6 +49,7 @@ from apache_beam.runners.worker import opcounters from apache_beam.runners.worker import operation_specs from apache_beam.runners.worker import sideinputs +from apache_beam.runners.worker import statesampler from apache_beam.runners.worker.data_sampler import DataSampler from apache_beam.transforms import sideinputs as apache_sideinputs from apache_beam.transforms import combiners @@ -808,8 +809,14 @@ def __init__( self.user_state_context = user_state_context self.tagged_receivers = None # type: Optional[_TaggedReceivers] # A mapping of timer tags to the input "PCollections" they come in on. + # Force clean rebuild self.input_info = None # type: Optional[OpInputInfo] - + self.scoped_timer_processing_state = statesampler.NOOP_SCOPED_STATE + if self.state_sampler: + self.scoped_timer_processing_state = self.state_sampler.scoped_state( + self.name_context, + 'process-timers', + metrics_container=self.metrics_container) # See fn_data in dataflow_runner.py # TODO: Store all the items from spec? self.fn, _, _, _, _ = (pickler.loads(self.spec.serialized_fn)) @@ -971,14 +978,21 @@ def add_timer_info(self, timer_family_id, timer_info): self.user_state_context.add_timer_info(timer_family_id, timer_info) def process_timer(self, tag, timer_data): - timer_spec = self.timer_specs[tag] - self.dofn_runner.process_user_timer( - timer_spec, - timer_data.user_key, - timer_data.windows[0], - timer_data.fire_timestamp, - timer_data.paneinfo, - timer_data.dynamic_timer_tag) + def process_timer_logic(): + timer_spec = self.timer_specs[tag] + self.dofn_runner.process_user_timer( + timer_spec, + timer_data.user_key, + timer_data.windows[0], + timer_data.fire_timestamp, + timer_data.paneinfo, + timer_data.dynamic_timer_tag) + + if self.scoped_timer_processing_state: + with self.scoped_timer_processing_state: + process_timer_logic() + else: + process_timer_logic() def finish(self): # type: () -> None diff --git a/sdks/python/apache_beam/runners/worker/statesampler.py b/sdks/python/apache_beam/runners/worker/statesampler.py index b9c75f4de93d..53c3d8055101 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler.py +++ b/sdks/python/apache_beam/runners/worker/statesampler.py @@ -134,8 +134,8 @@ def scoped_state( name_context: Union[str, 'common.NameContext'], state_name: str, io_target=None, - metrics_container: Optional['MetricsContainer'] = None - ) -> statesampler_impl.ScopedState: + metrics_container: Optional['MetricsContainer'] = None, + suffix: str = '-msecs') -> statesampler_impl.ScopedState: """Returns a ScopedState object associated to a Step and a State. Args: @@ -152,7 +152,7 @@ def scoped_state( name_context = common.NameContext(name_context) counter_name = CounterName( - state_name + '-msecs', + state_name + suffix, stage_name=self._prefix, step_name=name_context.metrics_name(), io_target=io_target) @@ -170,3 +170,17 @@ def commit_counters(self) -> None: for state in self._states_by_name.values(): state_msecs = int(1e-6 * state.nsecs) state.counter.update(state_msecs - state.counter.value()) + + +class NoOpScopedState: + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def sampled_msecs_int(self): + return 0 + + +NOOP_SCOPED_STATE = NoOpScopedState() diff --git a/sdks/python/apache_beam/runners/worker/statesampler_test.py b/sdks/python/apache_beam/runners/worker/statesampler_test.py index c9ea7e8eef97..8e3a2e6f1202 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_test.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_test.py @@ -28,6 +28,10 @@ from apache_beam.runners.worker import statesampler from apache_beam.utils.counters import CounterFactory from apache_beam.utils.counters import CounterName +from apache_beam.runners.worker import operation_specs +from apache_beam.runners.worker import operations +from apache_beam.internal import pickler +from apache_beam.transforms import core _LOGGER = logging.getLogger(__name__) @@ -127,6 +131,118 @@ def test_sampler_transition_overhead(self): # debug mode). self.assertLess(overhead_us, 20.0) + @retry(reraise=True, stop=stop_after_attempt(3)) + def test_timer_sampler(self): + # Set up state sampler. + counter_factory = CounterFactory() + sampler = statesampler.StateSampler( + 'timer', counter_factory, sampling_period_ms=1) + + # Duration of the timer processing. + state_duration_ms = 100 + margin_of_error = 0.25 + + sampler.start() + with sampler.scoped_state('step1', 'process-timers'): + time.sleep(state_duration_ms / 1000) + sampler.stop() + sampler.commit_counters() + + if not statesampler.FAST_SAMPLER: + # The slow sampler does not implement sampling, so we won't test it. + return + + # Test that sampled state timings are close to their expected values. + c = CounterName( + 'process-timers-msecs', step_name='step1', stage_name='timer') + expected_counter_values = { + c: state_duration_ms, + } + for counter in counter_factory.get_counters(): + self.assertIn(counter.name, expected_counter_values) + expected_value = expected_counter_values[counter.name] + actual_value = counter.value() + deviation = float(abs(actual_value - expected_value)) / expected_value + _LOGGER.info('Sampling deviation from expectation: %f', deviation) + self.assertGreater(actual_value, expected_value * (1.0 - margin_of_error)) + self.assertLess(actual_value, expected_value * (1.0 + margin_of_error)) + + @retry(reraise=True, stop=stop_after_attempt(3)) + def test_process_timers_metric_is_recorded(self): + """ + Tests that the 'process-timers-msecs' metric is correctly recorded + when a state sampler is active. + """ + # Set up a real state sampler and counter factory. + counter_factory = CounterFactory() + sampler = statesampler.StateSampler( + 'test_stage', counter_factory, sampling_period_ms=1) + + state_duration_ms = 100 + margin_of_error = 0.25 + + # Run a workload inside the 'process-timers' scoped state. + sampler.start() + with sampler.scoped_state('test_step', 'process-timers'): + time.sleep(state_duration_ms / 1000.0) + sampler.stop() + sampler.commit_counters() + + if not statesampler.FAST_SAMPLER: + return + + # Verify that the counter was created with the correct name and value. + expected_counter_name = CounterName( + 'process-timers-msecs', step_name='test_step', stage_name='test_stage') + + # Find the specific counter we are looking for. + found_counter = None + for counter in counter_factory.get_counters(): + if counter.name == expected_counter_name: + found_counter = counter + break + + self.assertIsNotNone( + found_counter, + f"The expected counter '{expected_counter_name}' was not created.") + + # Check that its value is approximately correct. + actual_value = found_counter.value() + expected_value = state_duration_ms + self.assertGreater( + actual_value, + expected_value * (1.0 - margin_of_error), + "The timer metric was lower than expected.") + self.assertLess( + actual_value, + expected_value * (1.0 + margin_of_error), + "The timer metric was higher than expected.") + + def test_do_operation_with_sampler(self): + """ + Tests that a DoOperation with an active state_sampler correctly + creates a real ScopedState object for timer processing. + """ + mock_spec = operation_specs.WorkerDoFn( + serialized_fn=pickler.dumps((core.DoFn(), None, None, None, None)), + output_tags=[], + input=None, + side_inputs=[], + output_coders=[]) + + sampler = statesampler.StateSampler( + 'test_stage', CounterFactory(), sampling_period_ms=1) + + # 1. Create the operation WITHOUT the unexpected keyword argument + op = operations.create_operation( + name_context='test_op', + spec=mock_spec, + counter_factory=CounterFactory(), + state_sampler=sampler) + + self.assertIsNot( + op.scoped_timer_processing_state, statesampler.NOOP_SCOPED_STATE) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)