diff --git a/CHANGELOG.md b/CHANGELOG.md index dcb7bca..6ac977e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ - `write_dataframe()`: New method for writing pandas and polars DataFrames with explicit parameters (`measurement`, `timestamp_column`, `tags`, `timestamp_timezone`). - `query_dataframe()`: New method for querying data directly to a pandas or polars DataFrame via the `frame_type` parameter. - Updated README with clear examples for DataFrame operations. +1. [#179](https://github.com/InfluxCommunity/influxdb3-python/pull/179): Add option to disable gRPC response + compression for Flight queries: + - `disable_grpc_compression` parameter in `InfluxDBClient3` constructor + - `INFLUX_DISABLE_GRPC_COMPRESSION` environment variable support in `from_env()` ### Bug Fixes diff --git a/README.md b/README.md index 237aaf1..13de5af 100644 --- a/README.md +++ b/README.md @@ -204,9 +204,32 @@ print(table.to_pandas().to_markdown()) ``` ### gRPC compression -The Python client supports gRPC response compression. -If the server chooses to compress query responses (e.g., with gzip), the client -will automatically decompress them — no extra configuration is required. + +#### Request compression + +Request compression is not supported by InfluxDB 3 — the client sends uncompressed requests. + +#### Response compression + +Response compression is enabled by default. The client sends the `grpc-accept-encoding: identity, deflate, gzip` +header, and the server returns gzip-compressed responses (if supported). The client automatically +decompresses them — no configuration required. + +To **disable response compression**: + +```python +# Via constructor parameter +client = InfluxDBClient3( + host="your-host", + token="your-token", + database="your-database", + disable_grpc_compression=True +) + +# Or via environment variable +# INFLUX_DISABLE_GRPC_COMPRESSION=true +client = InfluxDBClient3.from_env() +``` ## Windows Users Currently, Windows users require an extra installation when querying via Flight natively. This is due to the fact gRPC cannot locate Windows root certificates. To work around this please follow these steps: diff --git a/influxdb_client_3/__init__.py b/influxdb_client_3/__init__.py index 3451485..3f763b5 100644 --- a/influxdb_client_3/__init__.py +++ b/influxdb_client_3/__init__.py @@ -31,6 +31,7 @@ INFLUX_WRITE_NO_SYNC = "INFLUX_WRITE_NO_SYNC" INFLUX_WRITE_TIMEOUT = "INFLUX_WRITE_TIMEOUT" INFLUX_QUERY_TIMEOUT = "INFLUX_QUERY_TIMEOUT" +INFLUX_DISABLE_GRPC_COMPRESSION = "INFLUX_DISABLE_GRPC_COMPRESSION" def write_client_options(**kwargs): @@ -190,6 +191,7 @@ def __init__( flight_client_options=None, write_port_overwrite=None, query_port_overwrite=None, + disable_grpc_compression=False, **kwargs): """ Initialize an InfluxDB client. @@ -206,6 +208,8 @@ def __init__( :type write_client_options: dict[str, any] :param flight_client_options: dictionary for providing additional arguments for the FlightClient. :type flight_client_options: dict[str, any] + :param disable_grpc_compression: Disable gRPC compression for Flight query responses. Default is False. + :type disable_grpc_compression: bool :key auth_scheme: token authentication scheme. Set to "Bearer" for Edge. :key bool verify_ssl: Set this to false to skip verifying SSL certificate when calling API from https server. :key str ssl_ca_cert: Set this to customize the certificate file to verify the peer. @@ -291,6 +295,8 @@ def __init__( connection_string = f"grpc+tcp://{hostname}:{port}" q_opts_builder = QueryApiOptionsBuilder() + if disable_grpc_compression: + q_opts_builder.disable_grpc_compression(True) if kw_keys.__contains__('ssl_ca_cert'): q_opts_builder.root_certs(kwargs.get('ssl_ca_cert', None)) if kw_keys.__contains__('verify_ssl'): @@ -361,6 +367,12 @@ def from_env(cls, **kwargs: Any) -> 'InfluxDBClient3': if os.getenv(INFLUX_AUTH_SCHEME) is not None: kwargs['auth_scheme'] = os.getenv(INFLUX_AUTH_SCHEME) + disable_grpc_compression = os.getenv(INFLUX_DISABLE_GRPC_COMPRESSION) + if disable_grpc_compression is not None: + disable_grpc_compression = disable_grpc_compression.strip().lower() in ['true', '1', 't', 'y', 'yes'] + else: + disable_grpc_compression = False + org = os.getenv(INFLUX_ORG, "default") return InfluxDBClient3( host=required_vars[INFLUX_HOST], @@ -368,6 +380,7 @@ def from_env(cls, **kwargs: Any) -> 'InfluxDBClient3': database=required_vars[INFLUX_DATABASE], write_client_options=write_client_option, org=org, + disable_grpc_compression=disable_grpc_compression, **kwargs ) diff --git a/influxdb_client_3/query/query_api.py b/influxdb_client_3/query/query_api.py index 80a3dca..0ba92de 100644 --- a/influxdb_client_3/query/query_api.py +++ b/influxdb_client_3/query/query_api.py @@ -19,6 +19,7 @@ class QueryApiOptions(object): proxy (str): URL to a proxy server flight_client_options (dict): base set of flight client options passed to internal pyarrow.flight.FlightClient timeout(float): timeout in seconds to wait for a response + disable_grpc_compression (bool): disable gRPC compression for query responses """ _DEFAULT_TIMEOUT = 300.0 tls_root_certs: bytes = None @@ -26,12 +27,14 @@ class QueryApiOptions(object): proxy: str = None flight_client_options: dict = None timeout: float = None + disable_grpc_compression: bool = False def __init__(self, root_certs_path: str, verify: bool, proxy: str, flight_client_options: dict, - timeout: float = _DEFAULT_TIMEOUT): + timeout: float = _DEFAULT_TIMEOUT, + disable_grpc_compression: bool = False): """ Initialize a set of QueryApiOptions @@ -41,6 +44,7 @@ def __init__(self, root_certs_path: str, :param flight_client_options: set of flight_client_options to be passed to internal pyarrow.flight.FlightClient. :param timeout: timeout in seconds to wait for a response. + :param disable_grpc_compression: disable gRPC compression for query responses. """ if root_certs_path: self.tls_root_certs = self._read_certs(root_certs_path) @@ -48,6 +52,7 @@ def __init__(self, root_certs_path: str, self.proxy = proxy self.flight_client_options = flight_client_options self.timeout = timeout + self.disable_grpc_compression = disable_grpc_compression def _read_certs(self, path: str) -> bytes: with open(path, "rb") as certs_file: @@ -75,6 +80,7 @@ class QueryApiOptionsBuilder(object): _proxy: str = None _flight_client_options: dict = None _timeout: float = None + _disable_grpc_compression: bool = False def root_certs(self, path: str): self._root_certs_path = path @@ -96,6 +102,11 @@ def timeout(self, timeout: float): self._timeout = timeout return self + def disable_grpc_compression(self, disable: bool): + """Disable gRPC compression for query responses.""" + self._disable_grpc_compression = disable + return self + def build(self) -> QueryApiOptions: """Build a QueryApiOptions object with previously set values""" return QueryApiOptions( @@ -104,6 +115,7 @@ def build(self) -> QueryApiOptions: proxy=self._proxy, flight_client_options=self._flight_client_options, timeout=self._timeout, + disable_grpc_compression=self._disable_grpc_compression, ) @@ -162,6 +174,13 @@ def __init__(self, self._flight_client_options["disable_server_verification"] = not options.tls_verify if options.timeout is not None: self._default_timeout = options.timeout + if options.disable_grpc_compression: + # Disable gRPC response compression by only enabling identity algorithm + # Bitset: bit 0 = identity, bit 1 = deflate, bit 2 = gzip + # Setting to 1 (0b001) enables only identity (no compression) + self._flight_client_options["generic_options"].append( + ("grpc.compression_enabled_algorithms_bitset", 1) + ) if self._proxy: self._flight_client_options["generic_options"].append(("grpc.http_proxy", self._proxy)) self._flight_client = FlightClient(connection_string, **self._flight_client_options) diff --git a/setup.py b/setup.py index 575af15..5f06d37 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,13 @@ def get_version(): 'pandas': ['pandas'], 'polars': ['polars'], 'dataframe': ['pandas', 'polars'], - 'test': ['pytest', 'pytest-cov', 'pytest-httpserver'] + 'test': [ + 'pytest', + 'pytest-cov', + 'pytest-httpserver', + 'h2>=4.0.0,<5.0.0', + 'cryptography>=3.4.0', + ] }, install_requires=requires, python_requires='>=3.8', diff --git a/tests/test_influxdb_client_3.py b/tests/test_influxdb_client_3.py index c683012..4b489d6 100644 --- a/tests/test_influxdb_client_3.py +++ b/tests/test_influxdb_client_3.py @@ -302,6 +302,74 @@ def test_parse_invalid_write_timeout_range(self): with self.assertRaisesRegex(ValueError, ".*Must be non-negative.*"): InfluxDBClient3.from_env() + def assertGrpcCompressionDisabled(self, client, disabled): + """Assert whether gRPC compression is disabled for the client.""" + self.assertIsInstance(client, InfluxDBClient3) + generic_options = dict(client._query_api._flight_client_options['generic_options']) + if disabled: + self.assertEqual(generic_options.get('grpc.compression_enabled_algorithms_bitset'), 1) + else: + self.assertIsNone(generic_options.get('grpc.compression_enabled_algorithms_bitset')) + + @patch.dict('os.environ', {'INFLUX_HOST': 'localhost', 'INFLUX_TOKEN': 'test_token', + 'INFLUX_DATABASE': 'test_db', 'INFLUX_DISABLE_GRPC_COMPRESSION': 'true'}) + def test_from_env_disable_grpc_compression_true(self): + client = InfluxDBClient3.from_env() + self.assertGrpcCompressionDisabled(client, True) + + @patch.dict('os.environ', {'INFLUX_HOST': 'localhost', 'INFLUX_TOKEN': 'test_token', + 'INFLUX_DATABASE': 'test_db', 'INFLUX_DISABLE_GRPC_COMPRESSION': 'TrUe'}) + def test_from_env_disable_grpc_compression_true_mixed_case(self): + client = InfluxDBClient3.from_env() + self.assertGrpcCompressionDisabled(client, True) + + @patch.dict('os.environ', {'INFLUX_HOST': 'localhost', 'INFLUX_TOKEN': 'test_token', + 'INFLUX_DATABASE': 'test_db', 'INFLUX_DISABLE_GRPC_COMPRESSION': '1'}) + def test_from_env_disable_grpc_compression_one(self): + client = InfluxDBClient3.from_env() + self.assertGrpcCompressionDisabled(client, True) + + @patch.dict('os.environ', {'INFLUX_HOST': 'localhost', 'INFLUX_TOKEN': 'test_token', + 'INFLUX_DATABASE': 'test_db', 'INFLUX_DISABLE_GRPC_COMPRESSION': 'false'}) + def test_from_env_disable_grpc_compression_false(self): + client = InfluxDBClient3.from_env() + self.assertGrpcCompressionDisabled(client, False) + + @patch.dict('os.environ', {'INFLUX_HOST': 'localhost', 'INFLUX_TOKEN': 'test_token', + 'INFLUX_DATABASE': 'test_db', 'INFLUX_DISABLE_GRPC_COMPRESSION': 'anything-else'}) + def test_from_env_disable_grpc_compression_anything_else_is_false(self): + client = InfluxDBClient3.from_env() + self.assertGrpcCompressionDisabled(client, False) + + def test_disable_grpc_compression_parameter_true(self): + client = InfluxDBClient3( + host="localhost", + org="my_org", + database="my_db", + token="my_token", + disable_grpc_compression=True + ) + self.assertGrpcCompressionDisabled(client, True) + + def test_disable_grpc_compression_parameter_false(self): + client = InfluxDBClient3( + host="localhost", + org="my_org", + database="my_db", + token="my_token", + disable_grpc_compression=False + ) + self.assertGrpcCompressionDisabled(client, False) + + def test_disable_grpc_compression_default_is_false(self): + client = InfluxDBClient3( + host="localhost", + org="my_org", + database="my_db", + token="my_token", + ) + self.assertGrpcCompressionDisabled(client, False) + def test_query_with_arrow_error(self): f = ErrorFlightServer() with InfluxDBClient3(f"http://localhost:{f.port}", "my_org", "my_db", "my_token") as c: diff --git a/tests/test_influxdb_client_3_integration.py b/tests/test_influxdb_client_3_integration.py index 33837b5..554a971 100644 --- a/tests/test_influxdb_client_3_integration.py +++ b/tests/test_influxdb_client_3_integration.py @@ -455,3 +455,130 @@ def retry_cb(args, data, excp): self.assertEqual(lp, ErrorResult["rd"].decode('utf-8')) self.assertIsNotNone(ErrorResult["rx"]) self.assertIsInstance(ErrorResult["rx"], Url3TimeoutError) + + def test_disable_grpc_compression(self): + """ + Test that disable_grpc_compression parameter controls query response compression. + + Uses H2HeaderProxy to intercept and verify gRPC headers over HTTP/2. + Supports both h2c (cleartext) and h2 (TLS) connections. + """ + from urllib.parse import urlparse + from tests.util.h2_proxy import H2HeaderProxy + + # Test cases + test_cases = [ + { + 'name': 'default', + 'disable_grpc_compression': None, + 'expected_req_encoding': 'identity, deflate, gzip', + 'expected_resp_encoding': 'gzip', + }, + { + 'name': 'disabled=False', + 'disable_grpc_compression': False, + 'expected_req_encoding': 'identity, deflate, gzip', + 'expected_resp_encoding': 'gzip', + }, + { + 'name': 'disabled=True', + 'disable_grpc_compression': True, + 'expected_req_encoding': 'identity', + 'expected_resp_encoding': None, + }, + ] + + # Parse upstream host/port from test URL + parsed = urlparse(self.host) + upstream_host = parsed.hostname or '127.0.0.1' + upstream_port = parsed.port or (443 if parsed.scheme == 'https' else 80) + use_tls = parsed.scheme == 'https' + + test_id = time.time_ns() + measurement = f'grpc_compression_test_{random_hex(6)}' + + # Write test data points + num_points = 10 + lines = [ + f'{measurement},type=test value={i}.0,counter={i}i,test_id={test_id}i {test_id + i * 1000000}' + for i in range(num_points) + ] + self.client.write('\n'.join(lines)) + + test_query = f"SELECT * FROM \"{measurement}\" WHERE test_id = {test_id} ORDER BY counter" + + # Wait for data to be available + result = None + start = time.time() + while time.time() - start < 10: + result = self.client.query(test_query, mode="all") + if len(result) >= num_points: + break + time.sleep(0.5) + self.assertEqual(len(result), num_points, "Data not available after write") + + for tc in test_cases: + name = tc['name'] + proxy = None + + try: + # Start proxy - supports both h2c (cleartext) and h2 (TLS) + proxy = H2HeaderProxy( + upstream_host=upstream_host, + upstream_port=upstream_port, + tls=use_tls, + upstream_tls=use_tls + ) + proxy.start() + + # Build client kwargs + client_kwargs = { + 'host': proxy.url, + 'database': self.database, + 'token': self.token, + 'verify_ssl': False, # Accept proxy's self-signed cert + } + if tc['disable_grpc_compression'] is not None: + client_kwargs['disable_grpc_compression'] = tc['disable_grpc_compression'] + + client = InfluxDBClient3(**client_kwargs) + try: + result = client.query(test_query, mode="all") + self.assertEqual(len(result), num_points, f"[{name}] Should return {num_points} rows") + finally: + client.close() + + # Verify headers + req_encoding = proxy.get_last_request_header('grpc-accept-encoding') + resp_encoding = proxy.get_last_response_header('grpc-encoding') + + print(f"\n[{name}] Request grpc-accept-encoding: {req_encoding}") + expected_resp = tc['expected_resp_encoding'] + if expected_resp and resp_encoding != expected_resp: + print(f"[{name}] Response grpc-encoding: {resp_encoding} " + f"(expected: {expected_resp})") + else: + print(f"[{name}] Response grpc-encoding: {resp_encoding}") + + self.assertEqual(req_encoding, tc['expected_req_encoding'], + f"[{name}] Unexpected request encoding") + + if tc['expected_resp_encoding']: + # Note: InfluxDB 3 Core may not compress responses even when client + # advertises gzip support. Per gRPC spec, servers may choose not to + # compress regardless of client settings. InfluxDB Cloud typically + # compresses, but Core may not. We warn instead of failing. + # See: https://grpc.io/docs/guides/compression/ + if resp_encoding != tc['expected_resp_encoding']: + import warnings + warnings.warn( + f"[{name}] Server returned '{resp_encoding}' instead of " + f"'{tc['expected_resp_encoding']}'. This is normal for " + f"InfluxDB 3 Core which may not compress responses." + ) + else: + self.assertTrue(resp_encoding is None or resp_encoding == 'identity', + f"[{name}] Expected no compression, got: {resp_encoding}") + finally: + if proxy: + proxy.stop() diff --git a/tests/util/h2_proxy.py b/tests/util/h2_proxy.py new file mode 100644 index 0000000..a7c3d8d --- /dev/null +++ b/tests/util/h2_proxy.py @@ -0,0 +1,442 @@ +""" +HTTP/2 proxy for capturing gRPC headers in tests. + +This module provides a lightweight HTTP/2 proxy that supports both: +- h2c (HTTP/2 cleartext with prior knowledge) - which mitmproxy does not support +- h2 (HTTP/2 over TLS) - with runtime-generated self-signed certificates + +It uses the hyper-h2 library to parse HTTP/2 frames and capture request/response headers. + +Usage (h2c - cleartext): + with H2HeaderProxy(upstream_host='127.0.0.1', upstream_port=8181) as proxy: + client = InfluxDBClient3( + host=proxy.url, + token='...', + database='...' + ) + client.query("SELECT 1") + assert proxy.get_request_header('grpc-accept-encoding') == 'identity, deflate, gzip' + +Usage (h2 - TLS): + with H2HeaderProxy(upstream_host='cloud.influxdata.com', upstream_port=443, + tls=True, upstream_tls=True) as proxy: + client = InfluxDBClient3( + host=proxy.url, + token='...', + database='...', + verify_ssl=False # Accept proxy's self-signed cert + ) + client.query("SELECT 1") + assert proxy.get_request_header('grpc-accept-encoding') == 'identity, deflate, gzip' +""" + +import datetime +import ipaddress +import socket +import ssl +import threading +import select +import tempfile +import os + +import h2.connection +import h2.config +import h2.events +import h2.exceptions + +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa + + +def _find_free_port(): + """Find a free port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', 0)) + s.listen(1) + return s.getsockname()[1] + + +def _generate_self_signed_cert(): + """Generate a self-signed certificate and private key in memory.""" + # Generate RSA key + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + # Generate certificate + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), + ]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)) + .add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ]), + critical=False, + ) + .sign(key, hashes.SHA256()) + ) + + # Serialize to PEM format + cert_pem = cert.public_bytes(serialization.Encoding.PEM) + key_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + return cert_pem, key_pem + + +class H2HeaderProxy: + """ + HTTP/2 proxy that captures request and response headers. + + This proxy supports both: + - h2c (HTTP/2 cleartext with prior knowledge) - for HTTP endpoints + - h2 (HTTP/2 over TLS) - for HTTPS endpoints + + For TLS mode, generates a self-signed certificate at runtime. + Use verify_ssl=False on the client to accept the self-signed cert. + + Attributes: + port: The port the proxy is listening on + captured: Dict with 'request' and 'response' lists of captured headers + tls: Whether the proxy accepts TLS connections from clients + upstream_tls: Whether the proxy uses TLS to connect to upstream + """ + + def __init__(self, upstream_host='127.0.0.1', upstream_port=8181, listen_port=None, + tls=False, upstream_tls=False): + """ + Initialize the HTTP/2 proxy. + + Args: + upstream_host: The upstream server hostname + upstream_port: The upstream server port + listen_port: Port to listen on (auto-assigned if None) + tls: Accept TLS connections from clients (generates self-signed cert) + upstream_tls: Use TLS when connecting to upstream server + """ + self.upstream = (upstream_host, upstream_port) + self.upstream_host = upstream_host + self.port = listen_port or _find_free_port() + self.captured = {'request': [], 'response': []} + self.tls = tls + self.upstream_tls = upstream_tls + self._server_sock = None + self._thread = None + self._running = False + self._ssl_context = None + self._cert_file = None + self._key_file = None + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args): + self.stop() + + def start(self): + """Start the proxy server.""" + # Set up TLS if enabled + if self.tls: + self._setup_tls() + + self._server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._server_sock.bind(('127.0.0.1', self.port)) + self._server_sock.listen(5) + self._server_sock.settimeout(1.0) + self._running = True + + self._thread = threading.Thread(target=self._accept_loop, daemon=True) + self._thread.start() + + def _setup_tls(self): + """Set up TLS context with a self-signed certificate.""" + # Generate self-signed certificate + cert_pem, key_pem = _generate_self_signed_cert() + + # Write cert and key to temporary files (ssl.SSLContext needs files) + self._cert_file = tempfile.NamedTemporaryFile(mode='wb', suffix='.pem', delete=False) + self._cert_file.write(cert_pem) + self._cert_file.close() + + self._key_file = tempfile.NamedTemporaryFile(mode='wb', suffix='.pem', delete=False) + self._key_file.write(key_pem) + self._key_file.close() + + # Create server SSL context + self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + self._ssl_context.set_alpn_protocols(['h2']) + self._ssl_context.load_cert_chain(self._cert_file.name, self._key_file.name) + + def stop(self): + """Stop the proxy server.""" + self._running = False + if self._server_sock: + self._server_sock.close() + self._server_sock = None + + # Clean up temp certificate files + if self._cert_file: + try: + os.unlink(self._cert_file.name) + except OSError: + pass + self._cert_file = None + if self._key_file: + try: + os.unlink(self._key_file.name) + except OSError: + pass + self._key_file = None + + def clear(self): + """Clear captured headers.""" + self.captured = {'request': [], 'response': []} + + def _accept_loop(self): + """Accept and handle incoming connections.""" + while self._running: + try: + client_sock, _ = self._server_sock.accept() + + # Wrap with TLS if enabled + if self.tls and self._ssl_context: + try: + client_sock = self._ssl_context.wrap_socket( + client_sock, server_side=True + ) + except ssl.SSLError: + client_sock.close() + continue + + # Handle each connection in a new thread + threading.Thread( + target=self._handle_connection, + args=(client_sock,), + daemon=True + ).start() + except socket.timeout: + continue + except OSError: + break + + def _handle_connection(self, client_sock): + """Handle a single client connection.""" + upstream_sock = None + try: + # Connect to upstream + upstream_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + upstream_sock.connect(self.upstream) + + # Wrap upstream with TLS if enabled + if self.upstream_tls: + upstream_ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + upstream_ssl_ctx.minimum_version = ssl.TLSVersion.TLSv1_2 + upstream_ssl_ctx.set_alpn_protocols(['h2']) + upstream_ssl_ctx.check_hostname = False + upstream_ssl_ctx.verify_mode = ssl.CERT_NONE + upstream_sock = upstream_ssl_ctx.wrap_socket( + upstream_sock, server_hostname=self.upstream_host + ) + + # Create h2 connection state machines + client_h2 = h2.connection.H2Connection( + config=h2.config.H2Configuration(client_side=False) + ) + server_h2 = h2.connection.H2Connection( + config=h2.config.H2Configuration(client_side=True) + ) + + # Initialize connections + client_h2.initiate_connection() + client_sock.sendall(client_h2.data_to_send()) + server_h2.initiate_connection() + upstream_sock.sendall(server_h2.data_to_send()) + + client_sock.setblocking(False) + upstream_sock.setblocking(False) + + self._proxy_loop(client_sock, upstream_sock, client_h2, server_h2) + + except Exception: + pass # Connection errors are expected when client disconnects + finally: + if client_sock: + try: + client_sock.close() + except Exception: + pass + if upstream_sock: + try: + upstream_sock.close() + except Exception: + pass + + def _proxy_loop(self, client_sock, upstream_sock, client_h2, server_h2): + """Main proxy loop - forward data between client and upstream.""" + + def safe_end_stream(h2conn, stream_id): + """End stream, ignoring errors if already closed.""" + try: + h2conn.end_stream(stream_id) + except h2.exceptions.StreamClosedError: + pass + + for _ in range(1000): # Max iterations to prevent infinite loop + readable, _, _ = select.select([client_sock, upstream_sock], [], [], 0.05) + + for sock in readable: + try: + data = sock.recv(65535) + except BlockingIOError: + continue + except ssl.SSLWantReadError: + continue + except ssl.SSLWantWriteError: + continue + + if not data: + return + + if sock == client_sock: + # Client -> Proxy -> Upstream + events = client_h2.receive_data(data) + for ev in events: + if isinstance(ev, h2.events.RequestReceived): + hdrs = dict(ev.headers) + self.captured['request'].append(hdrs) + # Rewrite :authority and :scheme for upstream + fwd_headers = [] + for k, v in ev.headers: + if k in (b':authority', ':authority'): + # Use upstream host with port if not standard + if self.upstream_tls and self.upstream[1] == 443: + v = self.upstream_host.encode() if isinstance(k, bytes) else self.upstream_host + elif not self.upstream_tls and self.upstream[1] == 80: + v = self.upstream_host.encode() if isinstance(k, bytes) else self.upstream_host + else: + v = f"{self.upstream_host}:{self.upstream[1]}" + v = v.encode() if isinstance(k, bytes) else v + elif k in (b':scheme', ':scheme') and self.upstream_tls: + v = b'https' if isinstance(k, bytes) else 'https' + fwd_headers.append((k, v)) + server_h2.send_headers(ev.stream_id, fwd_headers) + elif isinstance(ev, h2.events.DataReceived): + server_h2.send_data(ev.stream_id, ev.data) + client_h2.acknowledge_received_data(len(ev.data), ev.stream_id) + elif isinstance(ev, h2.events.StreamEnded): + safe_end_stream(server_h2, ev.stream_id) + + to_send = server_h2.data_to_send() + if to_send: + upstream_sock.sendall(to_send) + to_send = client_h2.data_to_send() + if to_send: + client_sock.sendall(to_send) + + else: + # Upstream -> Proxy -> Client + events = server_h2.receive_data(data) + + # Detect trailers-only response (ResponseReceived + StreamEnded, no data) + # This happens when server sends HEADERS with END_STREAM + stream_events = {} + for ev in events: + sid = getattr(ev, 'stream_id', None) + if sid is not None: + if sid not in stream_events: + stream_events[sid] = [] + stream_events[sid].append(ev) + + for ev in events: + if isinstance(ev, h2.events.ResponseReceived): + hdrs = dict(ev.headers) + self.captured['response'].append(hdrs) + # Check if this is a trailers-only response + stream_evs = stream_events.get(ev.stream_id, []) + has_stream_ended = any(isinstance(e, h2.events.StreamEnded) for e in stream_evs) + has_data = any(isinstance(e, h2.events.DataReceived) for e in stream_evs) + if has_stream_ended and not has_data: + # Trailers-only: send headers with END_STREAM + client_h2.send_headers(ev.stream_id, ev.headers, end_stream=True) + else: + client_h2.send_headers(ev.stream_id, ev.headers) + elif isinstance(ev, h2.events.DataReceived): + client_h2.send_data(ev.stream_id, ev.data) + server_h2.acknowledge_received_data(len(ev.data), ev.stream_id) + elif isinstance(ev, h2.events.StreamEnded): + # Only end stream if we didn't already (trailers-only case) + stream_evs = stream_events.get(ev.stream_id, []) + has_response = any(isinstance(e, h2.events.ResponseReceived) for e in stream_evs) + has_data = any(isinstance(e, h2.events.DataReceived) for e in stream_evs) + if not (has_response and not has_data): + # Normal case: end stream separately + safe_end_stream(client_h2, ev.stream_id) + elif isinstance(ev, h2.events.TrailersReceived): + hdrs = dict(ev.headers) + self.captured['response'].append(hdrs) + try: + client_h2.send_headers(ev.stream_id, ev.headers, end_stream=True) + except h2.exceptions.StreamClosedError: + pass + + to_send = client_h2.data_to_send() + if to_send: + client_sock.sendall(to_send) + to_send = server_h2.data_to_send() + if to_send: + upstream_sock.sendall(to_send) + + def get_last_request_header(self, name): + """ + Get a header value from the last captured request. + + Args: + name: Header name (case-sensitive, typically lowercase) + + Returns: + Header value as string, or None if not found + """ + for hdrs in reversed(self.captured['request']): + # Try both bytes and string keys + for key in [name.encode() if isinstance(name, str) else name, name]: + if key in hdrs: + v = hdrs[key] + return v.decode() if isinstance(v, bytes) else v + return None + + def get_last_response_header(self, name): + """ + Get a header value from the last captured response. + + Args: + name: Header name (case-sensitive, typically lowercase) + + Returns: + Header value as string, or None if not found + """ + for hdrs in reversed(self.captured['response']): + for key in [name.encode() if isinstance(name, str) else name, name]: + if key in hdrs: + v = hdrs[key] + return v.decode() if isinstance(v, bytes) else v + return None + + @property + def url(self): + """Get the proxy URL for client configuration.""" + scheme = "https" if self.tls else "http" + return f"{scheme}://127.0.0.1:{self.port}"