diff --git a/.github/workflows/miniflare.yml b/.github/workflows/miniflare.yml index 5b64c581..8bf2744a 100644 --- a/.github/workflows/miniflare.yml +++ b/.github/workflows/miniflare.yml @@ -33,9 +33,6 @@ jobs: docker pull localstack/localstack-pro & pip install localstack localstack-ext - # TODO remove - mkdir ~/.localstack; echo '{"token":"test"}' > ~/.localstack/auth.json - branchName=${GITHUB_HEAD_REF##*/} if [ "$branchName" = "" ]; then branchName=main; fi echo "Installing from branch name $branchName" diff --git a/.github/workflows/utils.yml b/.github/workflows/utils.yml new file mode 100644 index 00000000..f6741174 --- /dev/null +++ b/.github/workflows/utils.yml @@ -0,0 +1,66 @@ +name: LocalStack Extensions Utils Tests + +on: + push: + paths: + - utils/** + branches: + - main + pull_request: + paths: + - .github/workflows/utils.yml + - utils/** + workflow_dispatch: + +jobs: + unit-tests: + name: Run Unit Tests + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + cd utils + pip install -e .[dev,test] + + - name: Lint + run: | + cd utils + make lint + + - name: Run unit tests + run: | + cd utils + make test-unit + + integration-tests: + name: Run Integration Tests + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + cd utils + pip install -e .[dev,test] + + - name: Run integration tests + run: | + docker pull moul/grpcbin & + cd utils + make test-integration diff --git a/.github/workflows/wiremock.yml b/.github/workflows/wiremock.yml index 845f28cc..955a25be 100644 --- a/.github/workflows/wiremock.yml +++ b/.github/workflows/wiremock.yml @@ -39,6 +39,7 @@ jobs: pip install localstack terraform-local awscli-local[ver1] make install + make lint make dist localstack extensions -v install file://$(ls ./dist/localstack_wiremock-*.tar.gz) diff --git a/paradedb/Makefile b/paradedb/Makefile index dea93833..9c4db1ec 100644 --- a/paradedb/Makefile +++ b/paradedb/Makefile @@ -34,7 +34,7 @@ entrypoints: venv ## Generate plugin entrypoints for Python package $(VENV_RUN); python -m plux entrypoints format: ## Run ruff to format the codebase - $(VENV_RUN); python -m ruff format .; make lint + $(VENV_RUN); python -m ruff format .; python -m ruff check --fix . lint: ## Run ruff to lint the codebase $(VENV_RUN); python -m ruff check --output-format=full . diff --git a/paradedb/localstack_paradedb/extension.py b/paradedb/localstack_paradedb/extension.py index 845dbbcd..8adb23ff 100644 --- a/paradedb/localstack_paradedb/extension.py +++ b/paradedb/localstack_paradedb/extension.py @@ -1,9 +1,8 @@ import os -import logging +import socket -from localstack_paradedb.utils.docker import DatabaseDockerContainerExtension - -LOG = logging.getLogger(__name__) +from localstack_extensions.utils.docker import ProxiedDockerContainerExtension +from localstack import config # Environment variables for configuration ENV_POSTGRES_USER = "PARADEDB_POSTGRES_USER" @@ -18,7 +17,7 @@ DEFAULT_POSTGRES_PORT = 5432 -class ParadeDbExtension(DatabaseDockerContainerExtension): +class ParadeDbExtension(ProxiedDockerContainerExtension): name = "paradedb" # Name of the Docker image to spin up @@ -33,6 +32,12 @@ def __init__(self): postgres_db = os.environ.get(ENV_POSTGRES_DB, DEFAULT_POSTGRES_DB) postgres_port = int(os.environ.get(ENV_POSTGRES_PORT, DEFAULT_POSTGRES_PORT)) + # Store configuration for connection info + self.postgres_user = postgres_user + self.postgres_password = postgres_password + self.postgres_db = postgres_db + self.postgres_port = postgres_port + # Environment variables to pass to the container env_vars = { "POSTGRES_USER": postgres_user, @@ -40,31 +45,70 @@ def __init__(self): "POSTGRES_DB": postgres_db, } + def _tcp_health_check(): + """Check if ParadeDB port is accepting connections.""" + self._check_tcp_port(self.container_host, self.postgres_port) + super().__init__( image_name=self.DOCKER_IMAGE, container_ports=[postgres_port], env_vars=env_vars, + health_check_fn=_tcp_health_check, + tcp_ports=[postgres_port], # Enable TCP proxying through gateway ) - # Store configuration for connection info - self.postgres_user = postgres_user - self.postgres_password = postgres_password - self.postgres_db = postgres_db - self.postgres_port = postgres_port + def tcp_connection_matcher(self, data: bytes) -> bool: + """ + Identify PostgreSQL/ParadeDB connections by protocol handshake. + + PostgreSQL can start with either: + 1. SSL request: protocol code 80877103 (0x04D2162F) + 2. Startup message: protocol version 3.0 (0x00030000) + + Both use the same format: + - 4 bytes: message length + - 4 bytes: protocol version/code + """ + if len(data) < 8: + return False + + # Check for SSL request (80877103 = 0x04D2162F) + if data[4:8] == b"\x04\xd2\x16\x2f": + return True + + # Check for protocol version 3.0 (0x00030000) + if data[4:8] == b"\x00\x03\x00\x00": + return True + + return False + + def _check_tcp_port(self, host: str, port: int, timeout: float = 2.0) -> None: + """Check if a TCP port is accepting connections.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(timeout) + try: + sock.connect((host, port)) + sock.close() + except (socket.timeout, socket.error) as e: + raise AssertionError(f"Port {port} not ready: {e}") def get_connection_info(self) -> dict: """Return connection information for ParadeDB.""" - info = super().get_connection_info() - info.update( - { - "database": self.postgres_db, - "user": self.postgres_user, - "password": self.postgres_password, - "port": self.postgres_port, - "connection_string": ( - f"postgresql://{self.postgres_user}:{self.postgres_password}" - f"@{self.container_host}:{self.postgres_port}/{self.postgres_db}" - ), - } - ) - return info + # Clients should connect through the LocalStack gateway + gateway_host = "paradedb.localhost.localstack.cloud" + gateway_port = config.LOCALSTACK_HOST.port + + return { + "host": gateway_host, + "database": self.postgres_db, + "user": self.postgres_user, + "password": self.postgres_password, + "port": gateway_port, + "connection_string": ( + f"postgresql://{self.postgres_user}:{self.postgres_password}" + f"@{gateway_host}:{gateway_port}/{self.postgres_db}" + ), + # Also include container connection details for debugging + "container_host": self.container_host, + "container_port": self.postgres_port, + } diff --git a/paradedb/localstack_paradedb/utils/docker.py b/paradedb/localstack_paradedb/utils/docker.py deleted file mode 100644 index 6643e6d9..00000000 --- a/paradedb/localstack_paradedb/utils/docker.py +++ /dev/null @@ -1,144 +0,0 @@ -import re -import socket -import logging -from functools import cache -from typing import Callable - -from localstack import config -from localstack.utils.docker_utils import DOCKER_CLIENT -from localstack.extensions.api import Extension -from localstack.utils.container_utils.container_client import PortMappings -from localstack.utils.net import get_addressable_container_host -from localstack.utils.sync import retry - -LOG = logging.getLogger(__name__) -logging.getLogger("localstack_paradedb").setLevel( - logging.DEBUG if config.DEBUG else logging.INFO -) -logging.basicConfig() - - -class DatabaseDockerContainerExtension(Extension): - """ - Utility class to create a LocalStack Extension which runs a Docker container - for a database service that uses a native protocol (e.g., PostgreSQL). - - Unlike HTTP-based services, database connections are made directly to the - exposed container port rather than through the LocalStack gateway. - """ - - name: str - """Name of this extension, which must be overridden in a subclass.""" - image_name: str - """Docker image name""" - container_ports: list[int] - """List of network ports of the Docker container spun up by the extension""" - command: list[str] | None - """Optional command (and flags) to execute in the container.""" - env_vars: dict[str, str] | None - """Optional environment variables to pass to the container.""" - health_check_port: int | None - """Port to use for health check (defaults to first port in container_ports).""" - health_check_fn: Callable[[], bool] | None - """Optional custom health check function.""" - - def __init__( - self, - image_name: str, - container_ports: list[int], - command: list[str] | None = None, - env_vars: dict[str, str] | None = None, - health_check_port: int | None = None, - health_check_fn: Callable[[], bool] | None = None, - ): - self.image_name = image_name - if not container_ports: - raise ValueError("container_ports is required") - self.container_ports = container_ports - self.container_name = re.sub(r"\W", "-", f"ls-ext-{self.name}") - self.command = command - self.env_vars = env_vars - self.health_check_port = health_check_port or container_ports[0] - self.health_check_fn = health_check_fn - self.container_host = get_addressable_container_host() - - def on_extension_load(self): - LOG.info("Loading ParadeDB extension") - - def on_platform_start(self): - LOG.info("Starting ParadeDB extension - launching container") - self.start_container() - - def on_platform_shutdown(self): - self._remove_container() - - @cache - def start_container(self) -> None: - LOG.debug("Starting extension container %s", self.container_name) - - port_mapping = PortMappings() - for port in self.container_ports: - port_mapping.add(port) - - kwargs = {} - if self.command: - kwargs["command"] = self.command - if self.env_vars: - kwargs["env_vars"] = self.env_vars - - try: - DOCKER_CLIENT.run_container( - self.image_name, - detach=True, - remove=True, - name=self.container_name, - ports=port_mapping, - **kwargs, - ) - except Exception as e: - LOG.debug("Failed to start container %s: %s", self.container_name, e) - raise - - def _check_health(): - if self.health_check_fn: - assert self.health_check_fn() - else: - # Default: TCP socket check - self._check_tcp_port(self.container_host, self.health_check_port) - - try: - retry(_check_health, retries=60, sleep=1) - except Exception as e: - LOG.info("Failed to connect to container %s: %s", self.container_name, e) - self._remove_container() - raise - - LOG.info( - "Successfully started extension container %s on %s:%s", - self.container_name, - self.container_host, - self.health_check_port, - ) - - def _check_tcp_port(self, host: str, port: int, timeout: float = 2.0) -> None: - """Check if a TCP port is accepting connections.""" - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(timeout) - try: - sock.connect((host, port)) - sock.close() - except (socket.timeout, socket.error) as e: - raise AssertionError(f"Port {port} not ready: {e}") - - def _remove_container(self): - LOG.debug("Stopping extension container %s", self.container_name) - DOCKER_CLIENT.remove_container( - self.container_name, force=True, check_existence=False - ) - - def get_connection_info(self) -> dict: - """Return connection information for the database.""" - return { - "host": self.container_host, - "ports": {port: port for port in self.container_ports}, - } diff --git a/paradedb/pyproject.toml b/paradedb/pyproject.toml index 291d8576..76a469e2 100644 --- a/paradedb/pyproject.toml +++ b/paradedb/pyproject.toml @@ -13,7 +13,11 @@ authors = [ ] keywords = ["LocalStack", "ParadeDB", "PostgreSQL", "Search", "Analytics"] classifiers = [] -dependencies = [] +dependencies = [ + # TODO remove / replace prior to merge! +# "localstack-extensions-utils", + "localstack-extensions-utils @ git+https://github.com/localstack/localstack-extensions.git@extract-utils-package#subdirectory=utils" +] [project.urls] Homepage = "https://github.com/localstack/localstack-extensions" diff --git a/paradedb/tests/test_extension.py b/paradedb/tests/test_extension.py index bd1277e6..b816682c 100644 --- a/paradedb/tests/test_extension.py +++ b/paradedb/tests/test_extension.py @@ -3,8 +3,9 @@ # Connection details for ParadeDB -HOST = "localhost" -PORT = 5432 +# Connect through LocalStack gateway with TCP proxying +HOST = "paradedb.localhost.localstack.cloud" +PORT = 4566 USER = "myuser" PASSWORD = "mypassword" DATABASE = "mydatabase" diff --git a/typedb/localstack_typedb/extension.py b/typedb/localstack_typedb/extension.py index 21d8c815..83321e59 100644 --- a/typedb/localstack_typedb/extension.py +++ b/typedb/localstack_typedb/extension.py @@ -2,7 +2,7 @@ import shlex from localstack.config import is_env_not_false -from localstack_typedb.utils.docker import ProxiedDockerContainerExtension +from localstack_extensions.utils import ProxiedDockerContainerExtension from rolo import Request from werkzeug.datastructures import Headers @@ -37,7 +37,7 @@ def __init__(self): http2_ports=http2_ports, ) - def should_proxy_request(self, headers: Headers) -> bool: + def http2_request_matcher(self, headers: Headers) -> bool: # determine if this is a gRPC request targeting TypeDB content_type = headers.get("content-type") or "" req_path = headers.get(":path") or "" diff --git a/typedb/localstack_typedb/utils/__init__.py b/typedb/localstack_typedb/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/typedb/pyproject.toml b/typedb/pyproject.toml index 6e8f703d..66d3f55e 100644 --- a/typedb/pyproject.toml +++ b/typedb/pyproject.toml @@ -14,9 +14,9 @@ authors = [ keywords = ["LocalStack", "TypeDB"] classifiers = [] dependencies = [ - "httpx", - "h2", - "priority", + # TODO remove / replace prior to merge! +# "localstack-extensions-utils", + "localstack-extensions-utils @ git+https://github.com/localstack/localstack-extensions.git@extract-utils-package#subdirectory=utils" ] [project.urls] diff --git a/typedb/tests/test_extension.py b/typedb/tests/test_extension.py index efe3fc4c..f385f5a7 100644 --- a/typedb/tests/test_extension.py +++ b/typedb/tests/test_extension.py @@ -1,10 +1,6 @@ import requests import httpx from localstack.utils.strings import short_uid -from localstack_typedb.utils.h2_proxy import ( - get_frames_from_http2_stream, - get_headers_from_frames, -) from typedb.driver import TypeDB, Credentials, DriverOptions, TransactionType @@ -98,17 +94,3 @@ def test_connect_to_h2_endpoint_non_typedb(): assert response.status_code == 200 assert response.http_version == "HTTP/2" assert "myext.` + Can be either a static hostname, or a pattern like `myext.` """ path: str | None """Optional path on which to expose the container endpoints.""" command: list[str] | None """Optional command (and flags) to execute in the container.""" + env_vars: dict[str, str] | None + """Optional environment variables to pass to the container.""" + volumes: list[SimpleVolumeBind] | None + """Optional volumes to mount into the container.""" + health_check_fn: Callable[[], None] | None + """ + Optional custom health check function. If not provided, defaults to HTTP GET on main_port. + The function should raise an exception if the health check fails. + """ + health_check_retries: int + """Number of times to retry the health check before giving up.""" + health_check_sleep: float + """Time in seconds to sleep between health check retries.""" request_to_port_router: Callable[[Request], int] | None """Callable that returns the target port for a given request, for routing purposes""" http2_ports: list[int] | None """List of ports for which HTTP2 proxy forwarding into the container should be enabled.""" + tcp_ports: list[int] | None + """ + List of container ports for raw TCP proxying through the gateway. + Enables transparent TCP forwarding for protocols that don't use HTTP (e.g., native DB protocols). + + When tcp_ports is set, the extension must implement tcp_connection_matcher() to identify + its traffic by inspecting initial connection bytes. + """ + + tcp_connection_matcher: Callable[[bytes], bool] | None + """ + Optional function to identify TCP connections belonging to this extension. + + Called with initial connection bytes (up to 512 bytes) to determine if this extension + should handle the connection. Return True to claim the connection, False otherwise. + """ def __init__( self, @@ -65,8 +94,14 @@ def __init__( host: str | None = None, path: str | None = None, command: list[str] | None = None, + env_vars: dict[str, str] | None = None, + volumes: list[SimpleVolumeBind] | None = None, + health_check_fn: Callable[[], None] | None = None, + health_check_retries: int = 60, + health_check_sleep: float = 1.0, request_to_port_router: Callable[[Request], int] | None = None, http2_ports: list[int] | None = None, + tcp_ports: list[int] | None = None, ): self.image_name = image_name if not container_ports: @@ -76,8 +111,14 @@ def __init__( self.path = path self.container_name = re.sub(r"\W", "-", f"ls-ext-{self.name}") self.command = command + self.env_vars = env_vars + self.volumes = volumes + self.health_check_fn = health_check_fn + self.health_check_retries = health_check_retries + self.health_check_sleep = health_check_sleep self.request_to_port_router = request_to_port_router self.http2_ports = http2_ports + self.tcp_ports = tcp_ports self.main_port = self.container_ports[0] self.container_host = get_addressable_container_host() @@ -97,12 +138,63 @@ def update_gateway_routes(self, router: http.Router[http.RouteHandler]): # apply patches to serve HTTP/2 requests for port in self.http2_ports or []: apply_http2_patches_for_grpc_support( - self.container_host, port, self.should_proxy_request + self.container_host, port, self.http2_request_matcher ) - @abstractmethod - def should_proxy_request(self, headers: Headers) -> bool: - """Define whether a request should be proxied, based on request headers.""" + # set up raw TCP proxies with protocol detection + if self.tcp_ports: + self._setup_tcp_protocol_routing() + + def _setup_tcp_protocol_routing(self): + """ + Set up TCP routing on the LocalStack gateway for this extension. + + This method patches the gateway's HTTP protocol handler to intercept TCP + connections and allow this extension to claim them via tcp_connection_matcher(). + This enables multiple TCP protocols to share the main gateway port (4566). + + Uses monkeypatching to intercept dataReceived() before HTTP processing. + """ + from localstack_extensions.utils.tcp_protocol_router import ( + patch_gateway_for_tcp_routing, + register_tcp_extension, + ) + + # Get the connection matcher from the extension + matcher = getattr(self, "tcp_connection_matcher", None) + if not matcher: + LOG.warning( + f"Extension {self.name} has tcp_ports but no tcp_connection_matcher(). " + "TCP routing will not work without a matcher." + ) + return + + # Apply gateway patches (only happens once globally) + patch_gateway_for_tcp_routing() + + # Register this extension for TCP routing + # Use first port as the default target port + target_port = self.tcp_ports[0] if self.tcp_ports else self.main_port + + register_tcp_extension( + extension_name=self.name, + matcher=matcher, + backend_host=self.container_host, + backend_port=target_port, + ) + + LOG.info( + f"Registered TCP extension {self.name} -> {self.container_host}:{target_port} on gateway" + ) + + def http2_request_matcher(self, headers: Headers) -> bool: + """ + Define whether an HTTP2 request should be proxied, based on request headers. + + Default implementation returns False (no HTTP2 proxying). + Override this method in subclasses that need HTTP2 proxying. + """ + return False def on_platform_shutdown(self): self._remove_container() @@ -118,6 +210,10 @@ def start_container(self) -> None: kwargs = {} if self.command: kwargs["command"] = self.command + if self.env_vars: + kwargs["env_vars"] = self.env_vars + if self.volumes: + kwargs["volumes"] = self.volumes try: DOCKER_CLIENT.run_container( @@ -130,23 +226,28 @@ def start_container(self) -> None: ) except Exception as e: LOG.debug("Failed to start container %s: %s", self.container_name, e) - # allow running TypeDB in a local server in dev mode, if TYPEDB_DEV_MODE is enabled - if not is_env_true("TYPEDB_DEV_MODE"): + # allow running the container in a local server in dev mode + if not is_env_true(f"{self.name.upper().replace('-', '_')}_DEV_MODE"): raise - def _ping_endpoint(): - # TODO: allow defining a custom healthcheck endpoint ... - response = requests.get(f"http://{self.container_host}:{self.main_port}/") - assert response.ok + # Use custom health check if provided, otherwise default to HTTP GET + health_check = self.health_check_fn or self._default_health_check try: - retry(_ping_endpoint, retries=40, sleep=1) + retry( + health_check, + retries=self.health_check_retries, + sleep=self.health_check_sleep, + ) except Exception as e: LOG.info("Failed to connect to container %s: %s", self.container_name, e) self._remove_container() raise - LOG.debug("Successfully started extension container %s", self.container_name) + def _default_health_check(self) -> None: + """Default health check: HTTP GET request to the main port.""" + response = requests.get(f"http://{self.container_host}:{self.main_port}/") + assert response.ok def _remove_container(self): LOG.debug("Stopping extension container %s", self.container_name) diff --git a/typedb/localstack_typedb/utils/h2_proxy.py b/utils/localstack_extensions/utils/h2_proxy.py similarity index 91% rename from typedb/localstack_typedb/utils/h2_proxy.py rename to utils/localstack_extensions/utils/h2_proxy.py index 5231541d..84ed0cb0 100644 --- a/typedb/localstack_typedb/utils/h2_proxy.py +++ b/utils/localstack_extensions/utils/h2_proxy.py @@ -29,16 +29,29 @@ def __init__(self, port: int, host: str = "localhost"): self.host = host self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket.connect((self.host, self.port)) + self._closed = False def receive_loop(self, callback): - while data := self._socket.recv(self.buffer_size): + while data := self.recv(self.buffer_size): callback(data) + def recv(self, length): + try: + return self._socket.recv(length) + except OSError as e: + if self._closed: + return None + else: + raise e + def send(self, data): self._socket.sendall(data) def close(self): + if self._closed: + return LOG.debug(f"Closing connection to upstream HTTP2 server on port {self.port}") + self._closed = True try: self._socket.shutdown(socket.SHUT_RDWR) self._socket.close() @@ -51,7 +64,7 @@ def close(self): def apply_http2_patches_for_grpc_support( - target_host: str, target_port: int, should_proxy_request: ProxyRequestMatcher + target_host: str, target_port: int, http2_request_matcher: ProxyRequestMatcher ): """ Apply some patches to proxy incoming gRPC requests and forward them to a target port. @@ -93,7 +106,6 @@ def __init__(self, http_response_stream): ) def received_from_backend(self, data): - LOG.debug(f"Received {len(data)} bytes from backend") self.http_response_stream.write(data) def received_from_http2_client(self, data, default_handler: Callable): @@ -111,11 +123,8 @@ def received_from_http2_client(self, data, default_handler: Callable): buffered_data = b"".join(self.buffer) self.buffer = [] - if should_proxy_request(headers): + if http2_request_matcher(headers): self.state = ForwardingState.FORWARDING - LOG.debug( - f"Forwarding {len(buffered_data)} bytes to backend" - ) self.backend.send(buffered_data) else: self.state = ForwardingState.PASSTHROUGH diff --git a/utils/localstack_extensions/utils/tcp_protocol_router.py b/utils/localstack_extensions/utils/tcp_protocol_router.py new file mode 100644 index 00000000..bb3045b1 --- /dev/null +++ b/utils/localstack_extensions/utils/tcp_protocol_router.py @@ -0,0 +1,179 @@ +""" +Protocol-detecting TCP router for LocalStack Gateway. + +This module provides a Twisted protocol that detects the protocol from initial +connection bytes and routes to the appropriate backend, enabling multiple TCP +protocols to share a single gateway port. +""" + +import logging +from twisted.internet import reactor +from twisted.protocols.portforward import ProxyClient, ProxyClientFactory +from twisted.web.http import HTTPChannel + +from localstack.utils.patch import patch +from localstack import config + +LOG = logging.getLogger(__name__) +LOG.setLevel(logging.DEBUG if config.DEBUG else logging.INFO) + +# Global registry of extensions with TCP matchers +# List of tuples: (extension_name, matcher_func, backend_host, backend_port) +_tcp_extensions = [] +_gateway_patched = False + + +class TcpProxyClient(ProxyClient): + """Backend TCP connection for protocol-detected connections.""" + + def connectionMade(self): + """Called when backend connection is established.""" + server = self.factory.server + + # Set up peer relationship + server.set_tcp_peer(self) + + # Unregister any existing producer on server transport (HTTPChannel may have one) + try: + server.transport.unregisterProducer() + except Exception: + pass # No producer was registered, which is fine + + # Enable flow control + self.transport.registerProducer(server.transport, True) + server.transport.registerProducer(self.transport, True) + + # Send buffered data from detection phase + if hasattr(self.factory, "initial_data"): + initial_data = self.factory.initial_data + self.transport.write(initial_data) + del self.factory.initial_data + + def dataReceived(self, data): + """Forward data from backend to client.""" + self.factory.server.transport.write(data) + + def connectionLost(self, reason): + """Backend connection closed.""" + self.factory.server.transport.loseConnection() + + +def patch_gateway_for_tcp_routing(): + """ + Patch the LocalStack gateway to enable protocol detection and TCP routing. + + This monkeypatches the HTTPChannel class used by the gateway to intercept + connections and detect TCP protocols before HTTP processing. + """ + global _gateway_patched + + if _gateway_patched: + return + + # Patch HTTPChannel to use our protocol-detecting version + @patch(HTTPChannel.__init__) + def _patched_init(fn, self, *args, **kwargs): + # Call original init + fn(self, *args, **kwargs) + # Add our detection attributes + self._detection_buffer = [] + self._detecting = True + self._tcp_peer = None + + @patch(HTTPChannel.dataReceived) + def _patched_dataReceived(fn, self, data): + """Intercept data to allow extensions to claim TCP connections.""" + if not getattr(self, "_detecting", False): + # Already decided - either proxying TCP or processing HTTP + if getattr(self, "_tcp_peer", None): + # TCP proxying mode + self._tcp_peer.transport.write(data) + else: + # HTTP mode - pass to original + fn(self, data) + return + + # Still detecting - buffer data + if not hasattr(self, "_detection_buffer"): + self._detection_buffer = [] + self._detection_buffer.append(data) + buffered_data = b"".join(self._detection_buffer) + + # Try each registered extension's matcher + if len(buffered_data) >= 8: + for ext_name, matcher, backend_host, backend_port in _tcp_extensions: + try: + if matcher(buffered_data): + # Switch to TCP proxy mode + self._detecting = False + self.transport.pauseProducing() + + # Create backend connection + client_factory = ProxyClientFactory() + client_factory.protocol = TcpProxyClient + client_factory.server = self + client_factory.initial_data = buffered_data + + reactor.connectTCP(backend_host, backend_port, client_factory) + return + except Exception as e: + LOG.debug(f"Error in matcher for {ext_name}: {e}") + continue + + # No extension claimed the connection + self._detecting = False + # Feed buffered data to HTTP handler + for chunk in self._detection_buffer: + fn(self, chunk) + self._detection_buffer = [] + + @patch(HTTPChannel.connectionLost) + def _patched_connectionLost(fn, self, reason): + """Handle connection close.""" + tcp_peer = getattr(self, "_tcp_peer", None) + if tcp_peer: + tcp_peer.transport.loseConnection() + self._tcp_peer = None + fn(self, reason) + + # Monkey-patch the set_tcp_peer method onto HTTPChannel + def set_tcp_peer(self, peer): + """Called when backend TCP connection is established.""" + self._tcp_peer = peer + self.transport.resumeProducing() + + HTTPChannel.set_tcp_peer = set_tcp_peer + + _gateway_patched = True + + +def register_tcp_extension( + extension_name: str, + matcher: callable, + backend_host: str, + backend_port: int, +): + """ + Register an extension for TCP connection routing. + + Args: + extension_name: Name of the extension + matcher: Function that takes bytes and returns bool to claim connection + backend_host: Backend host to route to + backend_port: Backend port to route to + """ + _tcp_extensions.append((extension_name, matcher, backend_host, backend_port)) + LOG.info( + f"Registered TCP extension {extension_name} -> {backend_host}:{backend_port}" + ) + + +def unregister_tcp_extension(extension_name: str): + """Unregister an extension from TCP routing.""" + global _tcp_extensions + _tcp_extensions = [ + (name, matcher, host, port) + for name, matcher, host, port in _tcp_extensions + if name != extension_name + ] + LOG.info(f"Unregistered TCP extension {extension_name}") diff --git a/utils/pyproject.toml b/utils/pyproject.toml new file mode 100644 index 00000000..257d6d32 --- /dev/null +++ b/utils/pyproject.toml @@ -0,0 +1,55 @@ +[build-system] +requires = ["setuptools", "wheel", "plux>=1.3.1"] +build-backend = "setuptools.build_meta" + +[project] +name = "localstack-extensions-utils" +version = "0.1.0" +description = "Utility library for LocalStack Extensions" +readme = {file = "README.md", content-type = "text/markdown; charset=UTF-8"} +requires-python = ">=3.10" +authors = [ + { name = "LocalStack Team" } +] +keywords = ["LocalStack", "Extensions", "Utils"] +classifiers = [] +dependencies = [ + "httpx", + "h2", + "hpack", + "hyperframe", + "priority", + "requests", + "rolo", + "twisted", +] + +[project.urls] +Homepage = "https://github.com/localstack/localstack-extensions" + +[project.optional-dependencies] +dev = [ + "boto3", + "build", + "localstack", + "jsonpatch", + "pytest", + "ruff", +] +test = [ + "pytest>=7.0", + "pytest-timeout>=2.0", + "localstack", + "jsonpatch", + "grpcio>=1.60.0", + "grpcio-tools>=1.60.0", +] + +[tool.setuptools.packages.find] +include = ["localstack_extensions*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +filterwarnings = [ + "ignore::DeprecationWarning", +] diff --git a/utils/tests/__init__.py b/utils/tests/__init__.py new file mode 100644 index 00000000..9e3a0996 --- /dev/null +++ b/utils/tests/__init__.py @@ -0,0 +1 @@ +# Utils package tests diff --git a/utils/tests/integration/__init__.py b/utils/tests/integration/__init__.py new file mode 100644 index 00000000..9f0458b1 --- /dev/null +++ b/utils/tests/integration/__init__.py @@ -0,0 +1 @@ +# Integration tests - Docker required, no LocalStack diff --git a/utils/tests/integration/conftest.py b/utils/tests/integration/conftest.py new file mode 100644 index 00000000..6702c5c9 --- /dev/null +++ b/utils/tests/integration/conftest.py @@ -0,0 +1,151 @@ +""" +Integration test fixtures for utils package. + +Provides fixtures for running tests against the grpcbin Docker container. +grpcbin is a neutral gRPC test service that supports various RPC types. + +Uses ProxiedDockerContainerExtension to manage the grpcbin container, +providing realistic test coverage of the Docker container management infrastructure. +""" + +import socket +import threading +import time + +import pytest +from hyperframe.frame import Frame +from localstack.utils.net import get_free_tcp_port +from rolo import Router +from rolo.gateway import Gateway +from twisted.internet import reactor +from twisted.web import server as twisted_server + +from localstack_extensions.utils.docker import ProxiedDockerContainerExtension + +GRPCBIN_IMAGE = "moul/grpcbin" +GRPCBIN_INSECURE_PORT = 9000 # HTTP/2 without TLS +GRPCBIN_SECURE_PORT = 9001 # HTTP/2 with TLS + +# HTTP/2 protocol constants +HTTP2_PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" +SETTINGS_FRAME = b"\x00\x00\x00\x04\x00\x00\x00\x00\x00" # Empty SETTINGS frame + + +class GrpcbinExtension(ProxiedDockerContainerExtension): + """ + Test extension for grpcbin that uses ProxiedDockerContainerExtension. + + This extension demonstrates using ProxiedDockerContainerExtension for + a gRPC/HTTP2 service. While grpcbin doesn't use the HTTP gateway routing + (it's accessed via direct TCP), this tests the Docker container management + capabilities of ProxiedDockerContainerExtension. + """ + + name = "grpcbin-test" + + def __init__(self): + def _tcp_health_check(): + """Check if grpcbin insecure port is accepting TCP connections.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2.0) + try: + # Use container_host from the parent class + sock.connect((self.container_host, GRPCBIN_INSECURE_PORT)) + sock.close() + except (socket.timeout, socket.error) as e: + raise AssertionError(f"Port {GRPCBIN_INSECURE_PORT} not ready: {e}") + + super().__init__( + image_name=GRPCBIN_IMAGE, + container_ports=[GRPCBIN_INSECURE_PORT, GRPCBIN_SECURE_PORT], + health_check_fn=_tcp_health_check, + tcp_ports=[GRPCBIN_INSECURE_PORT], # Enable raw TCP proxying for gRPC/HTTP2 + ) + + def tcp_connection_matcher(self, data: bytes) -> bool: + """Detect HTTP/2 connection preface to route gRPC/HTTP2 traffic.""" + # HTTP/2 connections start with the connection preface + if len(data) >= len(HTTP2_PREFACE): + return data.startswith(HTTP2_PREFACE) + # Also match if we have partial preface data (for early detection) + return len(data) > 0 and HTTP2_PREFACE.startswith(data) + + +@pytest.fixture(scope="session") +def grpcbin_extension_server(): + """ + Start grpcbin using ProxiedDockerContainerExtension with a test gateway server. + + This tests the Docker container management and proxy capabilities by: + 1. Starting the grpcbin container via the extension + 2. Setting up a Gateway with the extension's routes and TCP patches + 3. Serving the Gateway on a test port via Twisted + 4. Returning server info for end-to-end testing + """ + extension = GrpcbinExtension() + + # Create router and update with extension routes + # This will start the grpcbin container and apply TCP protocol patches + router = Router() + extension.update_gateway_routes(router) + + # Create a Gateway with proper TCP support + # The TCP patches are applied by update_gateway_routes above + gateway = Gateway(router) + + # Start gateway on a test port using Twisted + test_port = get_free_tcp_port() + site = twisted_server.Site(gateway) + listener = reactor.listenTCP(test_port, site) + + # Run reactor in background thread + def run_reactor(): + reactor.run(installSignalHandlers=False) + + reactor_thread = threading.Thread(target=run_reactor, daemon=True) + reactor_thread.start() + + # Wait for reactor to start - not ideal, but should work as a simple solution + time.sleep(0.5) + + # Return server information for tests + server_info = { + "port": test_port, + "url": f"http://localhost:{test_port}", + "extension": extension, + "listener": listener, + } + + yield server_info + + # Cleanup + reactor.callFromThread(reactor.stop) + time.sleep(0.5) + extension.on_platform_shutdown() + + +@pytest.fixture(scope="session") +def grpcbin_extension(grpcbin_extension_server): + """Return the extension instance from the server fixture.""" + return grpcbin_extension_server["extension"] + + +def parse_server_frames(data: bytes) -> list: + """Parse HTTP/2 frames from server response data (no preface expected). + + Server responses don't include the HTTP/2 preface - they start with frames directly. + This function parses raw frame data using hyperframe directly. + """ + frames = [] + pos = 0 + while pos + 9 <= len(data): # Frame header is 9 bytes + try: + frame, length = Frame.parse_frame_header(memoryview(data[pos : pos + 9])) + if pos + 9 + length > len(data): + break # Incomplete frame + frame.parse_body(memoryview(data[pos + 9 : pos + 9 + length])) + frames.append(frame) + pos += 9 + length + except Exception: + break + return frames diff --git a/utils/tests/integration/test_extension_integration.py b/utils/tests/integration/test_extension_integration.py new file mode 100644 index 00000000..56d8c060 --- /dev/null +++ b/utils/tests/integration/test_extension_integration.py @@ -0,0 +1,53 @@ +""" +Integration tests for ProxiedDockerContainerExtension with grpcbin. + +These tests verify that ProxiedDockerContainerExtension properly manages +Docker containers in a realistic scenario, using grpcbin as a test service. +""" + +import socket + +from werkzeug.datastructures import Headers + + +class TestProxiedDockerContainerExtension: + """Tests for ProxiedDockerContainerExtension using the GrpcbinExtension.""" + + def test_extension_starts_container(self, grpcbin_extension): + """Test that the extension successfully starts the Docker container.""" + assert grpcbin_extension.container_name == "ls-ext-grpcbin-test" + assert grpcbin_extension.image_name == "moul/grpcbin" + assert len(grpcbin_extension.container_ports) == 2 + + def test_extension_container_host_is_accessible(self, grpcbin_extension): + """Test that the container_host is set and accessible.""" + assert grpcbin_extension.container_host is not None + # container_host should be localhost, localhost.localstack.cloud, or a docker bridge IP + assert grpcbin_extension.container_host in ( + "localhost", + "127.0.0.1", + "localhost.localstack.cloud", + ) or grpcbin_extension.container_host.startswith("172.") + + def test_extension_ports_are_reachable(self, grpcbin_extension_server): + """Test that the gateway port is reachable via TCP.""" + gateway_port = grpcbin_extension_server["port"] + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2.0) + try: + sock.connect(("localhost", gateway_port)) + sock.close() + # Connection successful + except (socket.timeout, socket.error) as e: + raise AssertionError(f"Could not connect to gateway port: {e}") + + def test_extension_implements_required_methods(self, grpcbin_extension): + """Test that the extension properly implements the required abstract methods.""" + # http2_request_matcher should be callable + result = grpcbin_extension.http2_request_matcher(Headers()) + assert result is False, "gRPC services should not proxy through HTTP gateway" + + def test_multiple_ports_configured(self, grpcbin_extension): + """Test that the extension properly handles multiple ports.""" + assert 9000 in grpcbin_extension.container_ports # insecure port + assert 9001 in grpcbin_extension.container_ports # secure port diff --git a/utils/tests/integration/test_grpc_e2e.py b/utils/tests/integration/test_grpc_e2e.py new file mode 100644 index 00000000..eab6db15 --- /dev/null +++ b/utils/tests/integration/test_grpc_e2e.py @@ -0,0 +1,132 @@ +""" +End-to-end gRPC tests using grpcbin services. + +These tests make actual gRPC calls to grpcbin to verify that the full +HTTP/2 stack works correctly, including proper request/response handling. + +grpcbin provides services like: Empty, Index, HeadersUnary, etc. +We use the Empty service which returns an empty response. +""" + +import grpc + + +class TestGrpcEndToEnd: + """End-to-end tests making actual gRPC calls to grpcbin.""" + + def test_grpc_empty_call(self, grpcbin_extension_server): + """Test making a gRPC call to grpcbin's Empty service via the gateway.""" + # Create a channel to grpcbin through the gateway + gateway_port = grpcbin_extension_server["port"] + channel = grpc.insecure_channel(f"localhost:{gateway_port}") + + try: + # Use grpc.channel_ready_future to verify connection + grpc.channel_ready_future(channel).result(timeout=5) + + # grpcbin provides /grpcbin.GRPCBin/Empty which returns empty response + method = "/grpcbin.GRPCBin/Empty" + + # Empty message is just empty bytes in protobuf + request = b"" + + # Make the unary-unary call + response = channel.unary_unary( + method, + request_serializer=lambda x: x, + response_deserializer=lambda x: x, + )(request, timeout=5) + + # Empty service returns empty response + assert response is not None + assert response == b"" or len(response) == 0 + + finally: + channel.close() + + def test_grpc_index_call(self, grpcbin_extension_server): + """Test calling grpcbin's Index service which returns server info.""" + gateway_port = grpcbin_extension_server["port"] + channel = grpc.insecure_channel(f"localhost:{gateway_port}") + + try: + # Verify channel is ready + grpc.channel_ready_future(channel).result(timeout=5) + + # grpcbin's Index service returns information about the server + method = "/grpcbin.GRPCBin/Index" + request = b"" + + response = channel.unary_unary( + method, + request_serializer=lambda x: x, + response_deserializer=lambda x: x, + )(request, timeout=5) + + # Index returns a non-empty protobuf message with server info + assert response is not None + assert len(response) > 0, "Index service should return server information" + + finally: + channel.close() + + def test_grpc_concurrent_calls(self, grpcbin_extension_server): + """Test making multiple concurrent gRPC calls.""" + gateway_port = grpcbin_extension_server["port"] + channel = grpc.insecure_channel(f"localhost:{gateway_port}") + + try: + # Verify channel is ready + grpc.channel_ready_future(channel).result(timeout=5) + + method = "/grpcbin.GRPCBin/Empty" + request = b"" + + # Make multiple concurrent calls + responses = [] + for i in range(3): + response = channel.unary_unary( + method, + request_serializer=lambda x: x, + response_deserializer=lambda x: x, + )(request, timeout=5) + responses.append(response) + + # Verify all calls completed + assert len(responses) == 3, "All concurrent calls should complete" + for i, response in enumerate(responses): + assert response is not None, f"Call {i} should return a response" + + finally: + channel.close() + + def test_grpc_connection_reuse(self, grpcbin_extension_server): + """Test that a single gRPC channel can handle multiple sequential calls.""" + gateway_port = grpcbin_extension_server["port"] + channel = grpc.insecure_channel(f"localhost:{gateway_port}") + + try: + # Verify channel is ready + grpc.channel_ready_future(channel).result(timeout=5) + + # Alternate between Empty and Index calls + methods = ["/grpcbin.GRPCBin/Empty", "/grpcbin.GRPCBin/Index"] + request = b"" + + # Make multiple sequential calls on the same channel + for i in range(6): + method = methods[i % 2] + response = channel.unary_unary( + method, + request_serializer=lambda x: x, + response_deserializer=lambda x: x, + )(request, timeout=5) + + assert response is not None, f"Call {i} to {method} should succeed" + + # Index should return data, Empty should return empty + if "Index" in method: + assert len(response) > 0, "Index should return server info" + + finally: + channel.close() diff --git a/utils/tests/integration/test_http2_proxy.py b/utils/tests/integration/test_http2_proxy.py new file mode 100644 index 00000000..a6ba6003 --- /dev/null +++ b/utils/tests/integration/test_http2_proxy.py @@ -0,0 +1,355 @@ +""" +Integration tests for HTTP/2 proxy utilities against a live server. + +These tests verify that the TcpForwarder utility and HTTP/2 frame parsing functions +work correctly with real HTTP/2 traffic. We use grpcbin as a neutral HTTP/2 test +server to validate the utility functionality. +""" + +import threading +import pytest + +from hyperframe.frame import SettingsFrame + +from localstack_extensions.utils.h2_proxy import ( + get_headers_from_frames, + TcpForwarder, +) + +# Import from conftest - pytest automatically loads conftest.py +from .conftest import HTTP2_PREFACE, SETTINGS_FRAME, parse_server_frames + + +class TestTcpForwarderConnection: + """Tests for TcpForwarder connection management.""" + + def test_connect_to_grpcbin(self, grpcbin_extension_server): + """Test that TcpForwarder can connect to grpcbin.""" + gateway_port = grpcbin_extension_server["port"] + forwarder = TcpForwarder(port=gateway_port, host="localhost") + try: + # Connection is made in __init__, so if we get here, it worked + assert forwarder.port == gateway_port + assert forwarder.host == "localhost" + finally: + forwarder.close() + + def test_connect_and_close(self, grpcbin_extension_server): + """Test connect and close cycle.""" + gateway_port = grpcbin_extension_server["port"] + forwarder = TcpForwarder(port=gateway_port, host="localhost") + assert forwarder.port == gateway_port + forwarder.close() + # Verify close succeeded without raising an exception + + def test_multiple_connect_close_cycles(self, grpcbin_extension_server): + """Test multiple connect/close cycles.""" + gateway_port = grpcbin_extension_server["port"] + for _ in range(3): + forwarder = TcpForwarder(port=gateway_port, host="localhost") + forwarder.close() + + +class TestTcpForwarderSendReceive: + """Tests for TcpForwarder send/receive operations.""" + + def test_send_and_receive(self, grpcbin_extension_server): + """Test sending data and receiving response.""" + gateway_port = grpcbin_extension_server["port"] + forwarder = TcpForwarder(port=gateway_port, host="localhost") + received_data = [] + receive_complete = threading.Event() + + def callback(data): + received_data.append(data) + receive_complete.set() + + try: + # Start receive loop in background thread + receive_thread = threading.Thread( + target=forwarder.receive_loop, args=(callback,), daemon=True + ) + receive_thread.start() + + # Send HTTP/2 preface + forwarder.send(HTTP2_PREFACE) + + # Wait for response (with timeout) + if not receive_complete.wait(timeout=5.0): + pytest.fail("Did not receive response within timeout") + + # Should have received at least one chunk + assert len(received_data) > 0 + # Response should contain data (at least a SETTINGS frame) + total_bytes = sum(len(d) for d in received_data) + assert total_bytes >= 9, ( + "Should receive at least one frame header (9 bytes)" + ) + + finally: + forwarder.close() + + def test_bidirectional_http2_exchange(self, grpcbin_extension_server): + """Test bidirectional HTTP/2 settings exchange.""" + gateway_port = grpcbin_extension_server["port"] + forwarder = TcpForwarder(port=gateway_port, host="localhost") + received_data = [] + first_response = threading.Event() + + def callback(data): + received_data.append(data) + first_response.set() + + try: + # Start receive loop + receive_thread = threading.Thread( + target=forwarder.receive_loop, args=(callback,), daemon=True + ) + receive_thread.start() + + # Send HTTP/2 preface + forwarder.send(HTTP2_PREFACE) + + # Wait for initial response + first_response.wait(timeout=5.0) + assert len(received_data) > 0 + + # Send SETTINGS frame + forwarder.send(SETTINGS_FRAME) + + finally: + forwarder.close() + + +class TestTcpForwarderErrorHandling: + """Tests for error handling in TcpForwarder.""" + + def test_connection_to_invalid_port(self, grpcbin_extension_server): + """Test connecting to a port that's not listening.""" + with pytest.raises((ConnectionRefusedError, OSError)): + TcpForwarder(port=59999, host="localhost") + + def test_close_after_failed_connection(self, grpcbin_extension_server): + """Test that close works even after error conditions.""" + gateway_port = grpcbin_extension_server["port"] + forwarder = TcpForwarder(port=gateway_port, host="localhost") + forwarder.close() + # Close again should not raise + forwarder.close() + + def test_send_after_close(self, grpcbin_extension_server): + """Test sending after close raises appropriate error.""" + gateway_port = grpcbin_extension_server["port"] + forwarder = TcpForwarder(port=gateway_port, host="localhost") + forwarder.close() + + with pytest.raises(OSError): + forwarder.send(b"data") + + +class TestTcpForwarderConcurrency: + """Tests for concurrent operations in TcpForwarder.""" + + def test_multiple_sends(self, grpcbin_extension_server): + """Test multiple sequential sends (no exception = success).""" + gateway_port = grpcbin_extension_server["port"] + forwarder = TcpForwarder(port=gateway_port, host="localhost") + try: + forwarder.send(HTTP2_PREFACE) + forwarder.send(SETTINGS_FRAME) + forwarder.send(b"\x00\x00\x00\x04\x01\x00\x00\x00\x00") # SETTINGS ACK + finally: + forwarder.close() + + def test_concurrent_connections(self, grpcbin_extension_server): + """Test multiple concurrent TcpForwarder connections.""" + gateway_port = grpcbin_extension_server["port"] + forwarders = [] + try: + for _ in range(3): + forwarder = TcpForwarder(port=gateway_port, host="localhost") + forwarders.append(forwarder) + + # All connections should be established + assert len(forwarders) == 3 + + # Send preface to all + for forwarder in forwarders: + forwarder.send(HTTP2_PREFACE) + + finally: + for forwarder in forwarders: + forwarder.close() + + +class TestHttp2FrameParsing: + """Tests for HTTP/2 frame parsing with live server traffic.""" + + def test_capture_settings_frame(self, grpcbin_extension_server): + """Test capturing a SETTINGS frame from grpcbin.""" + gateway_port = grpcbin_extension_server["port"] + forwarder = TcpForwarder(port=gateway_port, host="localhost") + received_data = [] + done = threading.Event() + thread_started = threading.Event() + + def callback(data): + received_data.append(data) + done.set() + + def receive_with_signal(): + thread_started.set() + forwarder.receive_loop(callback) + + try: + receive_thread = threading.Thread(target=receive_with_signal, daemon=True) + receive_thread.start() + thread_started.wait(timeout=1.0) # Wait for receive thread to be ready + + forwarder.send(HTTP2_PREFACE + SETTINGS_FRAME) + done.wait(timeout=5.0) + + # Parse the server response (no preface expected in server data) + server_data = b"".join(received_data) + frames = parse_server_frames(server_data) + + # Check that we got frames + assert len(frames) > 0 + + # First frame should be SETTINGS + settings_frames = [f for f in frames if isinstance(f, SettingsFrame)] + assert len(settings_frames) > 0, ( + "Should receive at least one SETTINGS frame" + ) + finally: + forwarder.close() + + def test_parse_server_settings(self, grpcbin_extension_server): + """Test parsing the server's SETTINGS values.""" + gateway_port = grpcbin_extension_server["port"] + forwarder = TcpForwarder(port=gateway_port, host="localhost") + received_data = [] + done = threading.Event() + thread_started = threading.Event() + + def callback(data): + received_data.append(data) + done.set() + + def receive_with_signal(): + thread_started.set() + forwarder.receive_loop(callback) + + try: + receive_thread = threading.Thread(target=receive_with_signal, daemon=True) + receive_thread.start() + thread_started.wait(timeout=1.0) # Wait for receive thread to be ready + + forwarder.send(HTTP2_PREFACE + SETTINGS_FRAME) + done.wait(timeout=5.0) + + server_data = b"".join(received_data) + frames = parse_server_frames(server_data) + + settings_frames = [f for f in frames if isinstance(f, SettingsFrame)] + assert len(settings_frames) > 0 + + # SETTINGS frame should have settings attribute + settings_frame = settings_frames[0] + assert hasattr(settings_frame, "settings") + finally: + forwarder.close() + + def test_http2_handshake_completes(self, grpcbin_extension_server): + """Test that we can complete an HTTP/2 handshake with settings exchange.""" + gateway_port = grpcbin_extension_server["port"] + forwarder = TcpForwarder(port=gateway_port, host="localhost") + received_data = [] + first_response = threading.Event() + + def callback(data): + received_data.append(data) + first_response.set() + + try: + receive_thread = threading.Thread( + target=forwarder.receive_loop, args=(callback,), daemon=True + ) + receive_thread.start() + + # Send HTTP/2 preface and SETTINGS + forwarder.send(HTTP2_PREFACE + SETTINGS_FRAME) + + # Wait for server's initial frames + first_response.wait(timeout=5.0) + assert len(received_data) > 0, "Should receive server SETTINGS" + + # Send SETTINGS ACK to complete handshake + forwarder.send(b"\x00\x00\x00\x04\x01\x00\x00\x00\x00") # SETTINGS ACK + finally: + forwarder.close() + + def test_full_connection_sequence(self, grpcbin_extension_server): + """Test a full HTTP/2 connection sequence with grpcbin.""" + gateway_port = grpcbin_extension_server["port"] + forwarder = TcpForwarder(port=gateway_port, host="localhost") + received_data = [] + first_response = threading.Event() + + def callback(data): + received_data.append(data) + first_response.set() + + try: + receive_thread = threading.Thread( + target=forwarder.receive_loop, args=(callback,), daemon=True + ) + receive_thread.start() + + # Send preface and SETTINGS frame together + forwarder.send(HTTP2_PREFACE + SETTINGS_FRAME) + first_response.wait(timeout=5.0) + + # Parse server response frames + server_data = b"".join(received_data) + frames = parse_server_frames(server_data) + + assert len(frames) >= 1, "Should receive at least one frame from server" + + # Verify frame types + frame_types = [type(f).__name__ for f in frames] + assert "SettingsFrame" in frame_types, ( + f"Expected SettingsFrame, got: {frame_types}" + ) + + finally: + forwarder.close() + + def test_headers_extraction_from_raw_traffic(self, grpcbin_extension_server): + """Test that get_headers_from_frames works with live traffic.""" + gateway_port = grpcbin_extension_server["port"] + forwarder = TcpForwarder(port=gateway_port, host="localhost") + received_data = [] + done = threading.Event() + + def callback(data): + received_data.append(data) + done.set() + + try: + receive_thread = threading.Thread( + target=forwarder.receive_loop, args=(callback,), daemon=True + ) + receive_thread.start() + + forwarder.send(HTTP2_PREFACE + SETTINGS_FRAME) + done.wait(timeout=5.0) + + server_data = b"".join(received_data) + frames = parse_server_frames(server_data) + headers = get_headers_from_frames(frames) + + # Server response has SETTINGS, not HEADERS, so headers should be empty + assert len(headers) == 0, "SETTINGS frames should not produce headers" + finally: + forwarder.close() diff --git a/utils/tests/unit/__init__.py b/utils/tests/unit/__init__.py new file mode 100644 index 00000000..8ce7db9e --- /dev/null +++ b/utils/tests/unit/__init__.py @@ -0,0 +1 @@ +# Unit tests - no Docker or LocalStack required diff --git a/utils/tests/unit/test_h2_frame_parsing.py b/utils/tests/unit/test_h2_frame_parsing.py new file mode 100644 index 00000000..6d7e3b14 --- /dev/null +++ b/utils/tests/unit/test_h2_frame_parsing.py @@ -0,0 +1,192 @@ +""" +Unit tests for HTTP/2 frame parsing utilities. + +These tests verify the parsing of HTTP/2 frames from raw byte streams, +including the HTTP/2 preface, settings frames, and headers frames. +No Docker or network access required. +""" + +from hyperframe.frame import SettingsFrame, HeadersFrame, WindowUpdateFrame + +from localstack_extensions.utils.h2_proxy import ( + get_frames_from_http2_stream, + get_headers_from_frames, + get_headers_from_data_stream, +) + + +# HTTP/2 connection preface (24 bytes) +HTTP2_PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + + +class TestParseHttp2PrefaceAndFrames: + """Tests for parsing HTTP/2 frames from captured data.""" + + # This data is a dump taken from a browser request - includes preface, settings, and headers + SAMPLE_HTTP2_DATA = ( + b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n\x00\x00\x18\x04\x00\x00\x00\x00\x00\x00\x01\x00\x01" + b"\x00\x00\x00\x02\x00\x00\x00\x00\x00\x04\x00\x02\x00\x00\x00\x05\x00\x00@\x00\x00\x00" + b"\x04\x08\x00\x00\x00\x00\x00\x00\xbf\x00\x01\x00\x01V\x01%\x00\x00\x00\x03\x00\x00\x00" + b"\x00\x15C\x87\xd5\xaf~MZw\x7f\x05\x8eb*\x0eA\xd0\x84\x8c\x9dX\x9c\xa3\xa13\xffA\x96" + b"\xa0\xe4\x1d\x13\x9d\t^\x83\x90t!#'U\xc9A\xed\x92\xe3M\xb8\xe7\x87z\xbe\xd0\x7ff\xa2" + b"\x81\xb0\xda\xe0S\xfa\xd02\x1a\xa4\x9d\x13\xfd\xa9\x92\xa4\x96\x854\x0c\x8aj\xdc\xa7" + b"\xe2\x81\x02\xe1o\xedK;\xdc\x0bM.\x0f\xedLE'S\xb0 \x04\x00\x08\x02\xa6\x13XYO\xe5\x80" + b"\xb4\xd2\xe0S\x83\xf9c\xe7Q\x8b-Kp\xdd\xf4Z\xbe\xfb@\x05\xdbP\x92\x9b\xd9\xab\xfaRB" + b"\xcb@\xd2_\xa5#\xb3\xe9OhL\x9f@\x94\x19\x08T!b\x1e\xa4\xd8z\x16\xb0\xbd\xad*\x12\xb5" + b"%L\xe7\x93\x83\xc5\x83\x7f@\x95\x19\x08T!b\x1e\xa4\xd8z\x16\xb0\xbd\xad*\x12\xb4\xe5" + b"\x1c\x85\xb1\x1f\x89\x1d\xa9\x9c\xf6\x1b\xd8\xd2c\xd5s\x95\x9d)\xad\x17\x18`u\xd6\xbd" + b"\x07 \xe8BFN\xab\x92\x83\xdb#\x1f@\x85=\x86\x98\xd5\x7f\x94\x9d)\xad\x17\x18`u\xd6\xbd" + b"\x07 \xe8BFN\xab\x92\x83\xdb'@\x8aAH\xb4\xa5I'ZB\xa1?\x84-5\xa7\xd7@\x8aAH\xb4\xa5I'" + b"Z\x93\xc8_\x83!\xecG@\x8aAH\xb4\xa5I'Y\x06I\x7f\x86@\xe9*\xc82K@\x86\xae\xc3\x1e\xc3'" + b"\xd7\x83\xb6\x06\xbf@\x82I\x7f\x86M\x835\x05\xb1\x1f\x00\x00\x04\x08\x00\x00\x00\x00" + b"\x03\x00\xbe\x00\x00" + ) + + def test_parse_http2_frames_from_captured_data(self): + """Test parsing HTTP/2 frames from a real captured browser request.""" + frames = list(get_frames_from_http2_stream(self.SAMPLE_HTTP2_DATA)) + + assert len(frames) > 0, "Should parse at least one frame" + + # First frame after preface should be a SETTINGS frame + frame_types = [type(f) for f in frames] + assert SettingsFrame in frame_types, "Should contain SETTINGS frame" + + def test_frames_contain_headers_frame(self): + """Test that parsed frames include a HEADERS frame.""" + frames = list(get_frames_from_http2_stream(self.SAMPLE_HTTP2_DATA)) + frame_types = [type(f) for f in frames] + assert HeadersFrame in frame_types, "Should contain HEADERS frame" + + def test_parse_preface_only(self): + """Test parsing just the HTTP/2 preface (no frames expected).""" + frames = list(get_frames_from_http2_stream(HTTP2_PREFACE)) + # The preface alone doesn't produce frames (it's consumed as preface) + assert frames == [], "HTTP/2 preface alone should not produce frames" + + def test_parse_preface_with_settings(self): + """Test parsing preface followed by a SETTINGS frame.""" + # SETTINGS frame: type=0x04, flags=0x00, stream=0, length=0 (empty settings) + settings_frame = b"\x00\x00\x00\x04\x00\x00\x00\x00\x00" + data = HTTP2_PREFACE + settings_frame + + frames = list(get_frames_from_http2_stream(data)) + assert len(frames) == 1 + assert isinstance(frames[0], SettingsFrame) + + +class TestExtractHeaders: + """Tests for extracting headers from HTTP/2 frames.""" + + SAMPLE_HTTP2_DATA = TestParseHttp2PrefaceAndFrames.SAMPLE_HTTP2_DATA + + def test_extract_headers_from_frames(self): + """Test extracting headers from parsed frames.""" + frames = list(get_frames_from_http2_stream(self.SAMPLE_HTTP2_DATA)) + headers = get_headers_from_frames(frames) + + assert len(headers) > 0, "Should extract at least one header" + + def test_extract_pseudo_headers(self): + """Test that HTTP/2 pseudo-headers are correctly extracted.""" + frames = list(get_frames_from_http2_stream(self.SAMPLE_HTTP2_DATA)) + headers = get_headers_from_frames(frames) + + # HTTP/2 pseudo-headers start with ':' + assert headers.get(":scheme") == "https" + assert headers.get(":method") == "OPTIONS" + assert headers.get(":path") == "/_localstack/health" + + def test_get_headers_from_data_stream(self): + """Test the convenience function that combines frame parsing and header extraction.""" + # Use the same data but as a list of chunks + data_chunks = [self.SAMPLE_HTTP2_DATA[:100], self.SAMPLE_HTTP2_DATA[100:]] + headers = get_headers_from_data_stream(data_chunks) + + assert headers.get(":scheme") == "https" + assert headers.get(":method") == "OPTIONS" + + def test_headers_case_insensitive(self): + """Test that headers object is case-insensitive for non-pseudo headers.""" + frames = list(get_frames_from_http2_stream(self.SAMPLE_HTTP2_DATA)) + headers = get_headers_from_frames(frames) + + # werkzeug.Headers is case-insensitive + origin = headers.get("origin") + if origin: + assert headers.get("Origin") == origin + assert headers.get("ORIGIN") == origin + + +class TestEmptyAndInvalidData: + """Tests for edge cases with empty or invalid data.""" + + def test_empty_data(self): + """Test parsing empty data returns empty list.""" + frames = list(get_frames_from_http2_stream(b"")) + assert frames == [] + + def test_invalid_data(self): + """Test parsing invalid/random data returns empty list (no crash).""" + frames = list(get_frames_from_http2_stream(b"not http2 data at all")) + assert frames == [] + + def test_truncated_frame(self): + """Test parsing truncated frame data returns empty list.""" + # Start of a valid HTTP/2 preface but truncated + truncated = b"PRI * HTTP/2.0\r\n" + frames = list(get_frames_from_http2_stream(truncated)) + assert frames == [] + + def test_headers_from_empty_frames(self): + """Test extracting headers from empty frame list.""" + headers = get_headers_from_frames([]) + assert len(headers) == 0 + + def test_headers_from_non_header_frames(self): + """Test extracting headers when no HEADERS frames present.""" + # SETTINGS frame only + settings_frame = b"\x00\x00\x00\x04\x00\x00\x00\x00\x00" + data = HTTP2_PREFACE + settings_frame + + frames = list(get_frames_from_http2_stream(data)) + headers = get_headers_from_frames(frames) + + assert len(headers) == 0, "SETTINGS frame should not produce headers" + + def test_get_headers_from_empty_data_stream(self): + """Test get_headers_from_data_stream with empty input.""" + headers = get_headers_from_data_stream([]) + assert len(headers) == 0 + + def test_get_headers_from_data_stream_with_empty_chunks(self): + """Test get_headers_from_data_stream with list of empty chunks.""" + headers = get_headers_from_data_stream([b"", b"", b""]) + assert len(headers) == 0 + + +class TestHttp2FrameTypes: + """Tests for identifying different HTTP/2 frame types.""" + + def test_window_update_frame(self): + """Test parsing WINDOW_UPDATE frames.""" + # WINDOW_UPDATE frame: type=0x08, flags=0x00, stream=0, length=4 + # Window size increment: 0x00010000 (65536) + window_update = b"\x00\x00\x04\x08\x00\x00\x00\x00\x00\x00\x01\x00\x00" + data = HTTP2_PREFACE + window_update + + frames = list(get_frames_from_http2_stream(data)) + assert len(frames) == 1 + assert isinstance(frames[0], WindowUpdateFrame) + + def test_multiple_frame_types(self): + """Test parsing multiple different frame types.""" + # SETTINGS frame followed by WINDOW_UPDATE frame + settings_frame = b"\x00\x00\x00\x04\x00\x00\x00\x00\x00" + window_update = b"\x00\x00\x04\x08\x00\x00\x00\x00\x00\x00\x01\x00\x00" + data = HTTP2_PREFACE + settings_frame + window_update + + frames = list(get_frames_from_http2_stream(data)) + assert len(frames) == 2 + assert isinstance(frames[0], SettingsFrame) + assert isinstance(frames[1], WindowUpdateFrame) diff --git a/wiremock/Makefile b/wiremock/Makefile index d9b9e68b..30c4203c 100644 --- a/wiremock/Makefile +++ b/wiremock/Makefile @@ -34,6 +34,10 @@ entrypoints: venv # Generate plugin entrypoints for Python package format: ## Run ruff to format the whole codebase $(VENV_RUN); python -m ruff format .; python -m ruff check --output-format=full --fix . + +lint: ## Run ruff to lint the codebase + $(VENV_RUN); python -m ruff check --output-format=full . + sample-oss: ## Deploy sample app (OSS mode) echo "Creating stubs in WireMock ..." bin/create-stubs.sh diff --git a/wiremock/localstack_wiremock/extension.py b/wiremock/localstack_wiremock/extension.py index b58c9be6..f69de8a4 100644 --- a/wiremock/localstack_wiremock/extension.py +++ b/wiremock/localstack_wiremock/extension.py @@ -1,9 +1,11 @@ import logging import os from pathlib import Path +import requests from localstack import config, constants -from localstack_wiremock.utils.docker import ProxiedDockerContainerExtension +from localstack.utils.net import get_addressable_container_host +from localstack_extensions.utils.docker import ProxiedDockerContainerExtension LOG = logging.getLogger(__name__) @@ -71,15 +73,23 @@ def __init__(self): health_check_port = ADMIN_PORT if api_token else SERVICE_PORT self._is_runner_mode = bool(api_token) + def _health_check(): + """Custom health check for WireMock.""" + container_host = get_addressable_container_host() + health_url = ( + f"http://{container_host}:{health_check_port}{health_check_path}" + ) + LOG.debug("Health check: %s", health_url) + response = requests.get(health_url, timeout=5) + assert response.ok + super().__init__( image_name=image_name, container_ports=container_ports, - container_name=self.CONTAINER_NAME, host=self.HOST, env_vars=env_vars if env_vars else None, volumes=volumes, - health_check_path=health_check_path, - health_check_port=health_check_port, + health_check_fn=_health_check, health_check_retries=health_check_retries, health_check_sleep=health_check_sleep, ) diff --git a/wiremock/localstack_wiremock/utils/__init__.py b/wiremock/localstack_wiremock/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/wiremock/localstack_wiremock/utils/docker.py b/wiremock/localstack_wiremock/utils/docker.py deleted file mode 100644 index c102e831..00000000 --- a/wiremock/localstack_wiremock/utils/docker.py +++ /dev/null @@ -1,220 +0,0 @@ -import re -import logging -from functools import cache -from typing import Callable -import requests - -from localstack.utils.docker_utils import DOCKER_CLIENT -from localstack.extensions.api import Extension, http -from localstack.http import Request -from localstack.utils.container_utils.container_client import ( - PortMappings, - SimpleVolumeBind, -) -from localstack.utils.net import get_addressable_container_host -from localstack.utils.sync import retry -from rolo import route -from rolo.proxy import Proxy -from rolo.routing import RuleAdapter, WithHost - -LOG = logging.getLogger(__name__) -logging.basicConfig() - -# TODO: merge utils with code in TypeDB extension over time ... - - -class ProxiedDockerContainerExtension(Extension): - name: str - """Name of this extension""" - image_name: str - """Docker image name""" - container_name: str | None - """Name of the Docker container spun up by the extension""" - container_ports: list[int] - """List of network ports of the Docker container spun up by the extension""" - host: str | None - """ - Optional host on which to expose the container endpoints. - Can be either a static hostname, or a pattern like `myext.` - """ - path: str | None - """Optional path on which to expose the container endpoints.""" - command: list[str] | None - """Optional command (and flags) to execute in the container.""" - - request_to_port_router: Callable[[Request], int] | None - """Callable that returns the target port for a given request, for routing purposes""" - http2_ports: list[int] | None - """List of ports for which HTTP2 proxy forwarding into the container should be enabled.""" - - volumes: list[SimpleVolumeBind] | None = None - """Optional volumes to mount into the container host.""" - - env_vars: dict[str, str] | None = None - """Optional environment variables to pass to the container.""" - - health_check_path: str = "/__admin/health" - """Health check endpoint path to verify container is ready.""" - - health_check_port: int | None = None - """Port to use for health check. If None, uses the first container port.""" - - health_check_retries: int = 40 - """Number of retries for health check.""" - - health_check_sleep: float = 1 - """Sleep time between health check retries in seconds.""" - - def __init__( - self, - image_name: str, - container_ports: list[int], - host: str | None = None, - path: str | None = None, - container_name: str | None = None, - command: list[str] | None = None, - request_to_port_router: Callable[[Request], int] | None = None, - http2_ports: list[int] | None = None, - volumes: list[SimpleVolumeBind] | None = None, - env_vars: dict[str, str] | None = None, - health_check_path: str = "/__admin/health", - health_check_port: int | None = None, - health_check_retries: int = 40, - health_check_sleep: float = 1, - ): - self.image_name = image_name - self.container_ports = container_ports - self.host = host - self.path = path - self.container_name = container_name - self.command = command - self.request_to_port_router = request_to_port_router - self.http2_ports = http2_ports - self.volumes = volumes - self.env_vars = env_vars - self.health_check_path = health_check_path - self.health_check_port = health_check_port - self.health_check_retries = health_check_retries - self.health_check_sleep = health_check_sleep - - def update_gateway_routes(self, router: http.Router[http.RouteHandler]): - if self.path: - raise NotImplementedError( - "Path-based routing not yet implemented for this extension" - ) - self.start_container() - # add resource for HTTP/1.1 requests - resource = RuleAdapter(ProxyResource(self)) - if self.host: - resource = WithHost(self.host, [resource]) - router.add(resource) - - def on_platform_shutdown(self): - self._remove_container() - - def _get_container_name(self) -> str: - if self.container_name: - return self.container_name - name = f"ls-ext-{self.name}" - name = re.sub(r"\W", "-", name) - return name - - @cache - def start_container(self) -> None: - container_name = self._get_container_name() - LOG.debug("Starting extension container %s", container_name) - - ports = PortMappings() - for port in self.container_ports: - ports.add(port) - - kwargs = {} - if self.command: - kwargs["command"] = self.command - if self.env_vars: - kwargs["env_vars"] = self.env_vars - - try: - DOCKER_CLIENT.run_container( - self.image_name, - detach=True, - remove=True, - name=container_name, - ports=ports, - volumes=self.volumes, - **kwargs, - ) - except Exception as e: - LOG.debug("Failed to start container %s: %s", container_name, e) - raise - - health_port = self.health_check_port or self.container_ports[0] - container_host = get_addressable_container_host() - health_url = f"http://{container_host}:{health_port}{self.health_check_path}" - - def _ping_endpoint(): - LOG.debug("Health check: %s", health_url) - response = requests.get(health_url, timeout=5) - assert response.ok - - try: - retry( - _ping_endpoint, - retries=self.health_check_retries, - sleep=self.health_check_sleep, - ) - except Exception as e: - LOG.info("Failed to connect to container %s: %s", container_name, e) - # Log container output for debugging - try: - logs = DOCKER_CLIENT.get_container_logs(container_name) - LOG.info("Container logs for %s:\n%s", container_name, logs) - except Exception: - pass - self._remove_container() - raise - - LOG.debug("Successfully started extension container %s", container_name) - - def _remove_container(self): - container_name = self._get_container_name() - LOG.debug("Stopping extension container %s", container_name) - DOCKER_CLIENT.remove_container( - container_name, force=True, check_existence=False - ) - - -class ProxyResource: - """ - Simple proxy resource that forwards incoming requests from the - LocalStack Gateway to the target Docker container. - """ - - extension: ProxiedDockerContainerExtension - - def __init__(self, extension: ProxiedDockerContainerExtension): - self.extension = extension - - @route("/") - def index(self, request: Request, path: str, *args, **kwargs): - return self._proxy_request(request, forward_path=f"/{path}") - - def _proxy_request(self, request: Request, forward_path: str, *args, **kwargs): - self.extension.start_container() - - port = self.extension.container_ports[0] - container_host = get_addressable_container_host() - base_url = f"http://{container_host}:{port}" - proxy = Proxy(forward_base_url=base_url) - - # update content length (may have changed due to content compression) - if request.method not in ("GET", "OPTIONS"): - request.headers["Content-Length"] = str(len(request.data)) - - # make sure we're forwarding the correct Host header - request.headers["Host"] = f"localhost:{port}" - - # forward the request to the target - result = proxy.forward(request, forward_path=forward_path) - - return result diff --git a/wiremock/pyproject.toml b/wiremock/pyproject.toml index 2ee19c86..196f196d 100644 --- a/wiremock/pyproject.toml +++ b/wiremock/pyproject.toml @@ -14,7 +14,10 @@ authors = [ keywords = ["LocalStack", "WireMock"] classifiers = [] dependencies = [ - "priority" + "priority", + # TODO remove / replace prior to merge! +# "localstack-extensions-utils", + "localstack-extensions-utils @ git+https://github.com/localstack/localstack-extensions.git@extract-utils-package#subdirectory=utils" ] [project.urls]