diff --git a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto index 9b32048b4995..7704d4bdd3a5 100644 --- a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto +++ b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto @@ -157,6 +157,7 @@ message InstructionRequest { MonitoringInfosMetadataRequest monitoring_infos = 1005; HarnessMonitoringInfosRequest harness_monitoring_infos = 1006; SampleDataRequest sample_data = 1007; + AiWorkerPoolMetadata ai_worker_pool_metadata = 1008; // DEPRECATED RegisterRequest register = 1000; @@ -529,6 +530,13 @@ message MonitoringInfosMetadataResponse { map monitoring_info = 1; } +message AiWorkerPoolMetadata { + // The external IP address of the AI worker pool. + string external_ip = 1; + // The external port of the AI worker pool. + int32 external_port = 2; +} + // Represents a request to the SDK to split a currently active bundle. message ProcessBundleSplitRequest { // (Required) A reference to an active process bundle request with the given diff --git a/sdks/python/apache_beam/examples/ratelimit/__init__.py b/sdks/python/apache_beam/examples/ratelimit/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdks/python/apache_beam/examples/ratelimit/beam_example.py b/sdks/python/apache_beam/examples/ratelimit/beam_example.py new file mode 100644 index 000000000000..ce44d83e75e0 --- /dev/null +++ b/sdks/python/apache_beam/examples/ratelimit/beam_example.py @@ -0,0 +1,101 @@ +import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions, StandardOptions +from apache_beam.runners.worker.sdk_worker import get_ai_worker_pool_metadata + +import grpc +import logging +import os +import sys + + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'generated_proto'))) + +from envoy.service.ratelimit.v3 import rls_pb2 +from envoy.service.ratelimit.v3 import rls_pb2_grpc +from envoy.extensions.common.ratelimit.v3 import ratelimit_pb2 + +# Set up logging +logging.basicConfig(level=logging.INFO) +_LOGGER = logging.getLogger(__name__) + +class GRPCRateLimitClient(beam.DoFn): + """ + A DoFn that makes gRPC calls to an Envoy Rate Limit Service. + """ + def __init__(self): + self._envoy_address = None + self._channel = None + self._stub = None + + def setup(self): + """ + Initializes the gRPC channel and stub. + """ + ai_worker_pool_metadata = get_ai_worker_pool_metadata() + self._envoy_address = f"{ai_worker_pool_metadata.external_ip}:{ai_worker_pool_metadata.external_port}" + _LOGGER.info(f"Setting up gRPC client for Envoy at {self._envoy_address}") + self._channel = grpc.insecure_channel(self._envoy_address) + self._stub = rls_pb2_grpc.RateLimitServiceStub(self._channel) + + def process(self, element): + client_id = element.get('client_id', 'unknown_client') + request_id = element.get('request_id', 'unknown_request') + + _LOGGER.info(f"Processing element: client_id={client_id}, request_id={request_id}") + + # Create a RateLimitDescriptor + descriptor = ratelimit_pb2.RateLimitDescriptor() + descriptor.entries.add(key="client_id", value=client_id) + descriptor.entries.add(key="request_id", value=request_id) + + # Create a RateLimitRequest + request = rls_pb2.RateLimitRequest( + domain="my_service", + descriptors=[descriptor], + hits_addend=1 + ) + + try: + response = self._stub.ShouldRateLimit(request) + _LOGGER.info(f"RateLimitResponse for client_id={client_id}, request_id={request_id}: {response.overall_code}") + yield { + 'client_id': client_id, + 'request_id': request_id, + 'rate_limit_status': rls_pb2.RateLimitResponse.Code.Name(response.overall_code), + 'response_details': str(response) + } + except grpc.RpcError as e: + _LOGGER.error(f"gRPC call failed for client_id={client_id}, request_id={request_id}: {e.details()}") + yield { + 'client_id': client_id, + 'request_id': request_id, + 'rate_limit_status': 'ERROR', + 'error_details': e.details() + } + + def teardown(self): + if self._channel: + _LOGGER.info("Tearing down gRPC client.") + self._channel.close() + +def run(): + options = PipelineOptions() + options.view_as(StandardOptions).runner = 'DirectRunner' # Use DirectRunner for local testing + + with beam.Pipeline(options=options) as p: + # Sample input data + requests = p | 'CreateRequests' >> beam.Create([ + {'client_id': 'user_1', 'request_id': 'req_a'}, + {'client_id': 'user_2', 'request_id': 'req_b'}, + {'client_id': 'user_1', 'request_id': 'req_c'}, + {'client_id': 'user_3', 'request_id': 'req_d'}, + ]) + + # Apply the gRPC client DoFn + rate_limit_results = requests | 'CheckRateLimit' >> beam.ParDo(GRPCRateLimitClient()) + + # Log the results + rate_limit_results | 'LogResults' >> beam.Map(lambda x: _LOGGER.info(f"Result: {x}")) + +if __name__ == '__main__': + run() diff --git a/sdks/python/apache_beam/examples/ratelimit/beam_example2.py b/sdks/python/apache_beam/examples/ratelimit/beam_example2.py new file mode 100644 index 000000000000..c561e43bea11 --- /dev/null +++ b/sdks/python/apache_beam/examples/ratelimit/beam_example2.py @@ -0,0 +1,31 @@ +import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions +import logging +from apache_beam.runners.worker.sdk_worker import get_ai_worker_pool_metadata + +logging.basicConfig(level=logging.INFO) + +class PrintFn(beam.DoFn): + def process(self, element): + logging.info(f"Processing element: {element} and worker metadata {get_ai_worker_pool_metadata()}") + yield element + +pipeline_options = PipelineOptions() +pipeline = beam.Pipeline(options=pipeline_options) + +# Create a PCollection from a list of elements for this batch job. +data = pipeline | 'Create' >> beam.Create([ + 'Hello', + 'World', + 'This', + 'is', + 'a', + 'batch', + 'example', +]) + +# Apply the custom DoFn with resource hints. +data | 'PrintWithDoFn' >> beam.ParDo(PrintFn()) + +result = pipeline.run() +result.wait_until_finish() diff --git a/sdks/python/apache_beam/examples/ratelimit/buf.yaml b/sdks/python/apache_beam/examples/ratelimit/buf.yaml new file mode 100644 index 000000000000..113c50eb4d16 --- /dev/null +++ b/sdks/python/apache_beam/examples/ratelimit/buf.yaml @@ -0,0 +1,3 @@ +version: v1 +deps: + - buf.build/envoyproxy/envoy diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 6060ff8d54a8..63c38089d11f 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -69,7 +69,7 @@ from apache_beam.utils import thread_pool_executor from apache_beam.utils.sentinel import Sentinel from apache_beam.version import __version__ as beam_version - +import dataclasses if TYPE_CHECKING: from apache_beam.portability.api import endpoints_pb2 from apache_beam.utils.profiler import Profile @@ -104,6 +104,44 @@ }] }) +@dataclasses.dataclass +class AiWorkerPoolMetadata: + """Runtime metadata about AI worker pool resources, such as external IP and + port. + + Attributes: + external_ip (str): The external IP address of the AI worker pool. + external_port (int): The external port of the AI worker pool. + """ + external_ip: Optional[str] = None + external_port: Optional[int] = None + + @classmethod + def from_proto(cls, proto): + # type: (beam_fn_api_pb2.AiWorkerPoolMetadata) -> AiWorkerPoolMetadata + """Creates an instance from an AiWorkerPoolMetadata proto.""" + return cls( + external_ip=proto.external_ip if proto.external_ip else None, + external_port=proto.external_port if proto.external_port else None) + + +class _AiMetadataHolder: + """Singleton holder for AiWorkerPoolMetadata.""" + _metadata: Optional[AiWorkerPoolMetadata] = None + _lock = threading.Lock() + + @classmethod + def set_metadata(cls, proto): + # type: (beam_fn_api_pb2.AiWorkerPoolMetadata) -> None + with cls._lock: + cls._metadata = AiWorkerPoolMetadata.from_proto(proto) + + @classmethod + def get_metadata(cls) -> Optional[AiWorkerPoolMetadata]: + return cls._metadata + +def get_ai_worker_pool_metadata() -> Optional[AiWorkerPoolMetadata]: + return _AiMetadataHolder.get_metadata() class ShortIdCache(object): """ Cache for MonitoringInfo "short ids" @@ -393,6 +431,13 @@ def task(): _LOGGER.debug( "Currently using %s threads." % len(self._worker_thread_pool._workers)) + def _request_ai_worker_pool_metadata(self, request): + # type: (beam_fn_api_pb2.InstructionRequest) -> None + _AiMetadataHolder.set_metadata(request.ai_worker_pool_metadata) + _LOGGER.info("received metadata for AI worker pool: %s", request.ai_worker_pool_metadata) + self._responses.put( + beam_fn_api_pb2.InstructionResponse(instruction_id=request.instruction_id)) + def _request_sample_data(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None