Skip to content
15 changes: 11 additions & 4 deletions google/auth/compute_engine/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ def _get_metadata_ip_root(use_mtls: bool):

# Timeout in seconds to wait for the GCE metadata server when detecting the
# GCE environment.
try:
_METADATA_DEFAULT_TIMEOUT = int(os.getenv(environment_vars.GCE_METADATA_TIMEOUT, 3))
except ValueError: # pragma: NO COVER
_METADATA_DEFAULT_TIMEOUT = 3
_METADATA_PING_DEFAULT_TIMEOUT = 3

# The number of tries to perform when waiting for the GCE metadata server
# when detecting the GCE environment.
Expand Down Expand Up @@ -209,6 +206,16 @@ def ping(
# could lead to false negatives in the event that we are on GCE, but
# the metadata resolution was particularly slow. The latter case is
# "unlikely".

if timeout is None:
try:
timeout = float(os.getenv(
environment_vars.GCE_METADATA_TIMEOUT,
str(_METADATA_PING_DEFAULT_TIMEOUT)))
except ValueError:
timeout = _METADATA_PING_DEFAULT_TIMEOUT

retries = 0
headers = _METADATA_HEADERS.copy()
headers[metrics.API_CLIENT_HEADER] = metrics.mds_ping()

Expand Down
3 changes: 3 additions & 0 deletions google/auth/environment_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@
Used to distinguish between GAE gen1 and GAE gen2+.
"""

GCE_METADATA_TIMEOUT = "GCE_METADATA_TIMEOUT"
"""Environment variable for setting timeouts in seconds for metadata queries."""

# AWS environment variables used with AWS workload identity pools to retrieve
# AWS security credentials and the AWS region needed to create a serialized
# signed requests to the AWS STS GetCalledIdentity API that can be exchanged
Expand Down
45 changes: 41 additions & 4 deletions tests/compute_engine/test__metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,46 @@ def test_ping_success(mock_metrics_header_value):

request.assert_called_once_with(
method="GET",
url="http://169.254.169.254",
headers=MDS_PING_REQUEST_HEADER,
timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
url=_metadata._METADATA_IP_ROOT,
headers=_metadata._METADATA_HEADER,
timeout=_metadata._METADATA_PING_DEFAULT_TIMEOUT,
)

@mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE)
def test_ping_success_with_gce_metadata_timeout(mock_metrics_header_value):
request = make_request("", headers=_metadata._METADATA_HEADERS)
gce_metadata_timeout = .5
os.environ[
environment_vars.GCE_METADATA_TIMEOUT] = str(gce_metadata_timeout)

try:
assert _metadata.ping(request)
finally:
del os.environ[environment_vars.GCE_METADATA_TIMEOUT]

request.assert_called_once_with(
method="GET",
url=_metadata._METADATA_IP_ROOT,
headers=_metadata._METADATA_HEADER,
timeout=gce_metadata_timeout,
)

@mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE)
def test_ping_success_with_invalid_gce_metadata_timeout(mock_metrics_header_value):
request = make_request("", headers=_metadata._METADATA_HEADERS)
os.environ[
environment_vars.GCE_METADATA_TIMEOUT] = "Not a valid float value!"

try:
assert _metadata.ping(request)
finally:
del os.environ[environment_vars.GCE_METADATA_TIMEOUT]

request.assert_called_once_with(
method="GET",
url=_metadata._METADATA_IP_ROOT,
headers=_metadata._METADATA_HEADERS,
timeout=_metadata._METADATA_PING_DEFAULT_TIMEOUT, # Fallback value.
)


Expand Down Expand Up @@ -197,7 +234,7 @@ def test_ping_success_custom_root(mock_metrics_header_value):
method="GET",
url="http://" + fake_ip,
headers=MDS_PING_REQUEST_HEADER,
timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
timeout=_metadata._METADATA_PING_DEFAULT_TIMEOUT,
)


Expand Down
Loading