Skip to content

Commit 4a551b6

Browse files
sfc-gh-eqinsfc-gh-pmansour
authored andcommitted
Support WIF Impersonation on GCP workloads (#2496)
Co-authored-by: Peter Mansour <peter.mansour@snowflake.com>
1 parent ca43f9c commit 4a551b6

File tree

7 files changed

+237
-11
lines changed

7 files changed

+237
-11
lines changed

src/snowflake/connector/auth/workload_identity.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ def __init__(
5555
provider: AttestationProvider,
5656
token: str | None = None,
5757
entra_resource: str | None = None,
58+
impersonation_path: list[str] | None = None,
5859
**kwargs,
5960
) -> None:
6061
super().__init__(**kwargs)
6162
self.provider = provider
6263
self.token = token
6364
self.entra_resource = entra_resource
65+
self.impersonation_path = impersonation_path
6466

6567
self.attestation: WorkloadIdentityAttestation | None = None
6668

@@ -85,6 +87,7 @@ def prepare(
8587
self.provider,
8688
self.entra_resource,
8789
self.token,
90+
self.impersonation_path,
8891
session_manager=(
8992
conn._session_manager.clone(max_retries=0) if conn else None
9093
),

src/snowflake/connector/connection.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def _get_private_bytes_from_file(
214214
"authenticator": (DEFAULT_AUTHENTICATOR, (type(None), str)),
215215
"workload_identity_provider": (None, (type(None), AttestationProvider)),
216216
"workload_identity_entra_resource": (None, (type(None), str)),
217+
"workload_identity_impersonation_path": (None, (type(None), list[str])),
217218
"mfa_callback": (None, (type(None), Callable)),
218219
"password_callback": (None, (type(None), Callable)),
219220
"auth_class": (None, (type(None), AuthByPlugin)),
@@ -1355,10 +1356,24 @@ def __open_connection(self):
13551356
"errno": ER_INVALID_WIF_SETTINGS,
13561357
},
13571358
)
1359+
if (
1360+
self._workload_identity_impersonation_path
1361+
and self._workload_identity_provider != AttestationProvider.GCP
1362+
):
1363+
Error.errorhandler_wrapper(
1364+
self,
1365+
None,
1366+
ProgrammingError,
1367+
{
1368+
"msg": "workload_identity_impersonation_path is currently only supported for GCP.",
1369+
"errno": ER_INVALID_WIF_SETTINGS,
1370+
},
1371+
)
13581372
self.auth_class = AuthByWorkloadIdentity(
13591373
provider=self._workload_identity_provider,
13601374
token=self._token,
13611375
entra_resource=self._workload_identity_entra_resource,
1376+
impersonation_path=self._workload_identity_impersonation_path,
13621377
)
13631378
else:
13641379
# okta URL, e.g., https://<account>.okta.com/
@@ -1531,6 +1546,7 @@ def __config(self, **kwargs):
15311546
workload_identity_dependent_options = [
15321547
"workload_identity_provider",
15331548
"workload_identity_entra_resource",
1549+
"workload_identity_impersonation_path",
15341550
]
15351551
for dependent_option in workload_identity_dependent_options:
15361552
if (

src/snowflake/connector/wif_util.py

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
logger = logging.getLogger(__name__)
2121
SNOWFLAKE_AUDIENCE = "snowflakecomputing.com"
2222
DEFAULT_ENTRA_SNOWFLAKE_RESOURCE = "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad"
23+
GCP_METADATA_SERVICE_ACCOUNT_BASE_URL = (
24+
"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default"
25+
)
2326

2427

2528
@unique
@@ -184,29 +187,103 @@ def create_aws_attestation(
184187
)
185188

186189

187-
def create_gcp_attestation(
188-
session_manager: SessionManager | None = None,
189-
) -> WorkloadIdentityAttestation:
190-
"""Tries to create a workload identity attestation for GCP.
190+
def get_gcp_access_token(session_manager: SessionManager) -> str:
191+
"""Gets a GCP access token from the metadata server.
192+
193+
If the application isn't running on GCP or no credentials were found, raises an error.
194+
"""
195+
try:
196+
res = session_manager.request(
197+
method="GET",
198+
url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/token",
199+
headers={
200+
"Metadata-Flavor": "Google",
201+
},
202+
)
203+
res.raise_for_status()
204+
return res.json()["access_token"]
205+
except Exception as e:
206+
raise ProgrammingError(
207+
msg=f"Error fetching GCP access token: {e}. Ensure the application is running on GCP.",
208+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
209+
)
210+
211+
212+
def get_gcp_identity_token_via_impersonation(
213+
impersonation_path: list[str], session_manager: SessionManager
214+
) -> str:
215+
"""Gets a GCP identity token from the metadata server.
216+
217+
If the application isn't running on GCP or no credentials were found, raises an error.
218+
"""
219+
if not impersonation_path:
220+
raise ProgrammingError(
221+
msg="Error: impersonation_path cannot be empty.",
222+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
223+
)
224+
225+
current_sa_token = get_gcp_access_token(session_manager)
226+
impersonation_path = [
227+
f"projects/-/serviceAccounts/{client_id}" for client_id in impersonation_path
228+
]
229+
try:
230+
res = session_manager.post(
231+
url=f"https://iamcredentials.googleapis.com/v1/{impersonation_path[-1]}:generateIdToken",
232+
headers={
233+
"Authorization": f"Bearer {current_sa_token}",
234+
"Content-Type": "application/json",
235+
},
236+
json={
237+
"delegates": impersonation_path[:-1],
238+
"audience": SNOWFLAKE_AUDIENCE,
239+
},
240+
)
241+
res.raise_for_status()
242+
return res.json()["token"]
243+
except Exception as e:
244+
raise ProgrammingError(
245+
msg=f"Error fetching GCP identity token for impersonated GCP service account '{impersonation_path[-1]}': {e}. Ensure the application is running on GCP.",
246+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
247+
)
248+
249+
250+
def get_gcp_identity_token(session_manager: SessionManager) -> str:
251+
"""Gets a GCP identity token from the metadata server.
191252
192253
If the application isn't running on GCP or no credentials were found, raises an error.
193254
"""
194255
try:
195256
res = session_manager.request(
196257
method="GET",
197-
url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}",
258+
url=f"{GCP_METADATA_SERVICE_ACCOUNT_BASE_URL}/identity?audience={SNOWFLAKE_AUDIENCE}",
198259
headers={
199260
"Metadata-Flavor": "Google",
200261
},
201262
)
202263
res.raise_for_status()
264+
return res.content.decode("utf-8")
203265
except Exception as e:
204266
raise ProgrammingError(
205-
msg=f"Error fetching GCP metadata: {e}. Ensure the application is running on GCP.",
267+
msg=f"Error fetching GCP identity token: {e}. Ensure the application is running on GCP.",
206268
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
207269
)
208270

209-
jwt_str = res.content.decode("utf-8")
271+
272+
def create_gcp_attestation(
273+
session_manager: SessionManager,
274+
impersonation_path: list[str] | None = None,
275+
) -> WorkloadIdentityAttestation:
276+
"""Tries to create a workload identity attestation for GCP.
277+
278+
If the application isn't running on GCP or no credentials were found, raises an error.
279+
"""
280+
if impersonation_path:
281+
jwt_str = get_gcp_identity_token_via_impersonation(
282+
impersonation_path, session_manager
283+
)
284+
else:
285+
jwt_str = get_gcp_identity_token(session_manager)
286+
210287
_, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
211288
return WorkloadIdentityAttestation(
212289
AttestationProvider.GCP, jwt_str, {"sub": subject}
@@ -295,6 +372,7 @@ def create_attestation(
295372
provider: AttestationProvider,
296373
entra_resource: str | None = None,
297374
token: str | None = None,
375+
impersonation_path: list[str] | None = None,
298376
session_manager: SessionManager | None = None,
299377
) -> WorkloadIdentityAttestation:
300378
"""Entry point to create an attestation using the given provider.
@@ -313,7 +391,7 @@ def create_attestation(
313391
elif provider == AttestationProvider.AZURE:
314392
return create_azure_attestation(entra_resource, session_manager)
315393
elif provider == AttestationProvider.GCP:
316-
return create_gcp_attestation(session_manager)
394+
return create_gcp_attestation(session_manager, impersonation_path)
317395
elif provider == AttestationProvider.OIDC:
318396
return create_oidc_attestation(token)
319397
else:

test/csp_helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ def gen_dummy_id_token(
4040
)
4141

4242

43+
def gen_dummy_access_token(sub="test-subject") -> str:
44+
"""Generates a dummy access token using the given subject."""
45+
key = "secret"
46+
logger.debug(f"Generating dummy access token for subject {sub}")
47+
return (sub + key).encode("utf-8").hex()
48+
49+
4350
def build_response(content: bytes, status_code: int = 200, headers=None) -> Response:
4451
"""Builds a requests.Response object with the given status code and content."""
4552
response = Response()
@@ -285,6 +292,19 @@ def handle_request(self, method, parsed_url, headers, timeout):
285292
audience = query_string["audience"][0]
286293
self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience)
287294
return build_response(self.token.encode("utf-8"))
295+
elif (
296+
method == "GET"
297+
and parsed_url.path
298+
== "/computeMetadata/v1/instance/service-accounts/default/token"
299+
and headers.get("Metadata-Flavor") == "Google"
300+
):
301+
self.token = gen_dummy_access_token(sub=self.sub)
302+
ret = {
303+
"access_token": self.token,
304+
"expires_in": 3599,
305+
"token_type": "Bearer",
306+
}
307+
return build_response(json.dumps(ret).encode("utf-8"))
288308
else:
289309
# Reject malformed requests.
290310
raise HTTPError()

test/unit/test_auth_workload_identity.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
)
1818
from snowflake.connector.wif_util import AttestationProvider, get_aws_sts_hostname
1919

20-
from ..csp_helpers import FakeAwsEnvironment, FakeGceMetadataService, gen_dummy_id_token
20+
from ..csp_helpers import (
21+
FakeAwsEnvironment,
22+
FakeGceMetadataService,
23+
build_response,
24+
gen_dummy_access_token,
25+
gen_dummy_id_token,
26+
)
2127

2228
logger = logging.getLogger(__name__)
2329

@@ -289,7 +295,7 @@ def test_explicit_gcp_metadata_server_error_bubbles_up(exception):
289295
with pytest.raises(ProgrammingError) as excinfo:
290296
auth_class.prepare(conn=None)
291297

292-
assert "Error fetching GCP metadata:" in str(excinfo.value)
298+
assert "Error fetching GCP identity token:" in str(excinfo.value)
293299
assert "Ensure the application is running on GCP." in str(excinfo.value)
294300

295301

@@ -317,6 +323,44 @@ def test_explicit_gcp_generates_unique_assertion_content(
317323
assert auth_class.assertion_content == '{"_provider":"GCP","sub":"123456"}'
318324

319325

326+
@mock.patch("snowflake.connector.session_manager.SessionManager.post")
327+
def test_gcp_calls_correct_apis_and_populates_auth_data_for_final_sa(
328+
mock_post_request, fake_gce_metadata_service: FakeGceMetadataService
329+
):
330+
fake_gce_metadata_service.sub = "sa1"
331+
impersonation_path = ["sa2", "sa3"]
332+
sa1_access_token = gen_dummy_access_token("sa1")
333+
sa3_id_token = gen_dummy_id_token("sa3")
334+
335+
mock_post_request.return_value = build_response(
336+
json.dumps({"token": sa3_id_token}).encode("utf-8")
337+
)
338+
339+
auth_class = AuthByWorkloadIdentity(
340+
provider=AttestationProvider.GCP, impersonation_path=impersonation_path
341+
)
342+
auth_class.prepare(conn=None)
343+
344+
mock_post_request.assert_called_once_with(
345+
url="https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/sa3:generateIdToken",
346+
headers={
347+
"Authorization": f"Bearer {sa1_access_token}",
348+
"Content-Type": "application/json",
349+
},
350+
json={
351+
"delegates": ["projects/-/serviceAccounts/sa2"],
352+
"audience": "snowflakecomputing.com",
353+
},
354+
)
355+
356+
assert auth_class.assertion_content == '{"_provider":"GCP","sub":"sa3"}'
357+
assert extract_api_data(auth_class) == {
358+
"AUTHENTICATOR": "WORKLOAD_IDENTITY",
359+
"PROVIDER": "GCP",
360+
"TOKEN": sa3_id_token,
361+
}
362+
363+
320364
# -- Azure Tests --
321365

322366

test/unit/test_connection.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ def test_otel_error_message(caplog, mock_post_requests):
632632
"workload_identity_entra_resource",
633633
"api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b",
634634
),
635+
("workload_identity_impersonation_path", ["subject-b", "subject-c"]),
635636
],
636637
)
637638
def test_cannot_set_dependent_params_without_wlid_authenticator(
@@ -680,6 +681,71 @@ def test_workload_identity_provider_is_required_for_wif_authenticator(
680681
assert expected_error_msg in str(excinfo.value)
681682

682683

684+
@pytest.mark.parametrize(
685+
"provider_param",
686+
[
687+
# Strongly-typed values.
688+
AttestationProvider.AWS,
689+
AttestationProvider.AZURE,
690+
AttestationProvider.OIDC,
691+
# String values.
692+
"AWS",
693+
"AZURE",
694+
"OIDC",
695+
],
696+
)
697+
def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
698+
monkeypatch, provider_param
699+
):
700+
with monkeypatch.context() as m:
701+
m.setattr(
702+
"snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None
703+
)
704+
705+
with pytest.raises(ProgrammingError) as excinfo:
706+
snowflake.connector.connect(
707+
account="account",
708+
authenticator="WORKLOAD_IDENTITY",
709+
workload_identity_provider=provider_param,
710+
workload_identity_impersonation_path=[
711+
"sa2@project.iam.gserviceaccount.com"
712+
],
713+
)
714+
assert (
715+
"workload_identity_impersonation_path is currently only supported for GCP."
716+
in str(excinfo.value)
717+
)
718+
719+
720+
@pytest.mark.parametrize(
721+
"provider_param",
722+
[
723+
AttestationProvider.GCP,
724+
"GCP",
725+
],
726+
)
727+
def test_workload_identity_impersonation_path_supported_for_gcp_provider(
728+
monkeypatch, provider_param
729+
):
730+
with monkeypatch.context() as m:
731+
m.setattr(
732+
"snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None
733+
)
734+
735+
conn = snowflake.connector.connect(
736+
account="account",
737+
authenticator="WORKLOAD_IDENTITY",
738+
workload_identity_provider=provider_param,
739+
workload_identity_impersonation_path=[
740+
"sa2@project.iam.gserviceaccount.com"
741+
],
742+
)
743+
assert conn.auth_class.provider == AttestationProvider.GCP
744+
assert conn.auth_class.impersonation_path == [
745+
"sa2@project.iam.gserviceaccount.com"
746+
]
747+
748+
683749
@pytest.mark.parametrize(
684750
"provider_param, parsed_provider",
685751
[

test/wif/test_wif.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def test_should_authenticate_using_oidc():
5959

6060

6161
@pytest.mark.wif
62-
@pytest.mark.skip("Impersonation is still being developed")
6362
def test_should_authenticate_with_impersonation():
6463
if not isinstance(IMPERSONATION_PATH, str) or not IMPERSONATION_PATH:
6564
pytest.skip("Skipping test - IMPERSONATION_PATH is not set")

0 commit comments

Comments
 (0)