Skip to content

Commit 848cf95

Browse files
authored
feat: Support workload identity federation flow (#4074)
1 parent e284bdf commit 848cf95

19 files changed

+1132
-253
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
- `certificates`: [v1.2.0](services/certificates/CHANGELOG.md#v120)
4040
- **Feature:** Switch from `v2beta` API version to `v2` version.
4141
- **Breaking change:** Rename `CreateCertificateResponse` to `GetCertificateResponse`
42+
- `core`:
43+
- [v0.21.0](core/CHANGELOG.md#v0210)
44+
- **Deprecation:** KeyFlow `SetToken` and `GetToken` will be removed after 2026-07-01. Use GetAccessToken instead and rely on client refresh.
45+
- **Feature:** Support Workload Identity Federation flow
4246
- `sfs`:
4347
- [v0.2.0](services/sfs/CHANGELOG.md)
4448
- **Breaking change:** Remove region configuration in `APIClient`

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,4 +234,4 @@ See the [release documentation](./RELEASE.md) for further information.
234234

235235
## License
236236

237-
Apache 2.0
237+
Apache 2.0

core/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## v0.21.0
2+
- **Deprecation:** KeyFlow `SetToken` and `GetToken` will be removed after 2026-07-01. Use GetAccessToken instead and rely on client refresh.
3+
- **Feature:** Support Workload Identity Federation flow
4+
15
## v0.20.1
26
- **Improvement:** Improve error message when passing a PEM encoded file to as service account key
37

core/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
v0.20.1
1+
v0.21.0

core/auth/auth.go

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
5151
return nil, fmt.Errorf("configuring no auth client: %w", err)
5252
}
5353
return noAuthRoundTripper, nil
54+
} else if cfg.WorkloadIdentityFederation {
55+
wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg)
56+
if err != nil {
57+
return nil, fmt.Errorf("configuring no auth client: %w", err)
58+
}
59+
return wifRoundTripper, nil
5460
} else if cfg.ServiceAccountKey != "" || cfg.ServiceAccountKeyPath != "" {
5561
keyRoundTripper, err := KeyAuth(cfg)
5662
if err != nil {
@@ -84,14 +90,18 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
8490
cfg = &config.Configuration{}
8591
}
8692

87-
// Key flow
88-
rt, err = KeyAuth(cfg)
93+
// WIF flow
94+
rt, err = WorkloadIdentityFederationAuth(cfg)
8995
if err != nil {
90-
keyFlowErr := err
91-
// Token flow
92-
rt, err = TokenAuth(cfg)
96+
// Key flow
97+
rt, err = KeyAuth(cfg)
9398
if err != nil {
94-
return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err)
99+
keyFlowErr := err
100+
// Token flow
101+
rt, err = TokenAuth(cfg)
102+
if err != nil {
103+
return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err)
104+
}
95105
}
96106
}
97107
return rt, nil
@@ -221,6 +231,29 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) {
221231
return client, nil
222232
}
223233

234+
// WorkloadIdentityFederationAuth configures the wif flow and returns an http.RoundTripper
235+
// that can be used to make authenticated requests using an access token
236+
func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTripper, error) {
237+
wifConfig := clients.WorkloadIdentityFederationFlowConfig{
238+
TokenUrl: cfg.TokenCustomUrl,
239+
BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext,
240+
ClientID: cfg.ServiceAccountEmail,
241+
TokenExpiration: cfg.ServiceAccountFederatedTokenExpiration,
242+
FederatedTokenFunction: cfg.ServiceAccountFederatedTokenFunc,
243+
}
244+
245+
if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil {
246+
wifConfig.HTTPTransport = cfg.HTTPClient.Transport
247+
}
248+
249+
client := &clients.WorkloadIdentityFederationFlow{}
250+
if err := client.Init(&wifConfig); err != nil {
251+
return nil, fmt.Errorf("error initializing client: %w", err)
252+
}
253+
254+
return client, nil
255+
}
256+
224257
// readCredentialsFile reads the credentials file from the specified path and returns Credentials
225258
func readCredentialsFile(path string) (*Credentials, error) {
226259
if path == "" {

core/auth/auth_test.go

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"testing"
1414
"time"
1515

16+
"github.com/golang-jwt/jwt/v5"
1617
"github.com/google/uuid"
1718
"github.com/stackitcloud/stackit-sdk-go/core/clients"
1819
"github.com/stackitcloud/stackit-sdk-go/core/config"
@@ -121,6 +122,32 @@ func TestSetupAuth(t *testing.T) {
121122
}
122123
}()
123124

125+
// create a wif assertion file
126+
wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt")
127+
if errs != nil {
128+
t.Fatalf("Creating temporary file: %s", err)
129+
}
130+
defer func() {
131+
_ = wifAssertionFile.Close()
132+
err := os.Remove(wifAssertionFile.Name())
133+
if err != nil {
134+
t.Fatalf("Removing temporary file: %s", err)
135+
}
136+
}()
137+
138+
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
139+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
140+
Subject: "sub",
141+
}).SignedString([]byte("test"))
142+
if err != nil {
143+
t.Fatalf("Removing temporary file: %s", err)
144+
}
145+
146+
_, errs = wifAssertionFile.WriteString(string(token))
147+
if errs != nil {
148+
t.Fatalf("Writing wif assertion to temporary file: %s", err)
149+
}
150+
124151
// create a credentials file with saKey and private key
125152
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
126153
if errs != nil {
@@ -147,12 +174,19 @@ func TestSetupAuth(t *testing.T) {
147174
desc string
148175
config *config.Configuration
149176
setToken bool
177+
setWorkloadIdentity bool
150178
setKeys bool
151179
setKeyPaths bool
152180
setCredentialsFilePathToken bool
153181
setCredentialsFilePathKey bool
154182
isValid bool
155183
}{
184+
{
185+
desc: "wif_config",
186+
config: nil,
187+
setWorkloadIdentity: true,
188+
isValid: true,
189+
},
156190
{
157191
desc: "token_config",
158192
config: nil,
@@ -241,6 +275,12 @@ func TestSetupAuth(t *testing.T) {
241275
t.Setenv("STACKIT_CREDENTIALS_PATH", "")
242276
}
243277

278+
if test.setWorkloadIdentity {
279+
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name())
280+
} else {
281+
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "")
282+
}
283+
244284
t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email")
245285

246286
authRoundTripper, err := SetupAuth(test.config)
@@ -253,7 +293,7 @@ func TestSetupAuth(t *testing.T) {
253293
t.Fatalf("Test didn't return error on invalid test case")
254294
}
255295

256-
if test.isValid && authRoundTripper == nil {
296+
if authRoundTripper == nil && test.isValid {
257297
t.Fatalf("Roundtripper returned is nil for valid test case")
258298
}
259299
})
@@ -381,6 +421,32 @@ func TestDefaultAuth(t *testing.T) {
381421
t.Fatalf("Writing private key to temporary file: %s", err)
382422
}
383423

424+
// create a wif assertion file
425+
wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt")
426+
if errs != nil {
427+
t.Fatalf("Creating temporary file: %s", err)
428+
}
429+
defer func() {
430+
_ = wifAssertionFile.Close()
431+
err := os.Remove(wifAssertionFile.Name())
432+
if err != nil {
433+
t.Fatalf("Removing temporary file: %s", err)
434+
}
435+
}()
436+
437+
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
438+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
439+
Subject: "sub",
440+
}).SignedString([]byte("test"))
441+
if err != nil {
442+
t.Fatalf("Removing temporary file: %s", err)
443+
}
444+
445+
_, errs = wifAssertionFile.WriteString(string(token))
446+
if errs != nil {
447+
t.Fatalf("Writing wif assertion to temporary file: %s", err)
448+
}
449+
384450
// create a credentials file with saKey and private key
385451
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
386452
if errs != nil {
@@ -409,6 +475,7 @@ func TestDefaultAuth(t *testing.T) {
409475
setKeyPaths bool
410476
setKeys bool
411477
setCredentialsFilePathKey bool
478+
setWorkloadIdentity bool
412479
isValid bool
413480
expectedFlow string
414481
}{
@@ -418,6 +485,14 @@ func TestDefaultAuth(t *testing.T) {
418485
isValid: true,
419486
expectedFlow: "token",
420487
},
488+
{
489+
desc: "wif_precedes_key_precedes_token",
490+
setToken: true,
491+
setKeyPaths: true,
492+
setWorkloadIdentity: true,
493+
isValid: true,
494+
expectedFlow: "wif",
495+
},
421496
{
422497
desc: "key_precedes_token",
423498
setToken: true,
@@ -475,6 +550,13 @@ func TestDefaultAuth(t *testing.T) {
475550
} else {
476551
t.Setenv("STACKIT_SERVICE_ACCOUNT_TOKEN", "")
477552
}
553+
554+
if test.setWorkloadIdentity {
555+
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name())
556+
} else {
557+
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "")
558+
}
559+
478560
t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email")
479561

480562
// Get the default authentication client and ensure that it's not nil
@@ -501,6 +583,10 @@ func TestDefaultAuth(t *testing.T) {
501583
if _, ok := authClient.(*clients.KeyFlow); !ok {
502584
t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient))
503585
}
586+
case "wif":
587+
if _, ok := authClient.(*clients.WorkloadIdentityFederationFlow); !ok {
588+
t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient))
589+
}
504590
}
505591
}
506592
})

core/clients/auth_flow.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package clients
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"time"
10+
11+
"github.com/golang-jwt/jwt/v5"
12+
"github.com/stackitcloud/stackit-sdk-go/core/oapierror"
13+
)
14+
15+
const (
16+
defaultTokenExpirationLeeway = time.Second * 5
17+
)
18+
19+
type AuthFlow interface {
20+
RoundTrip(req *http.Request) (*http.Response, error)
21+
GetAccessToken() (string, error)
22+
getBackgroundTokenRefreshContext() context.Context
23+
refreshAccessToken() error
24+
}
25+
26+
// TokenResponseBody is the API response
27+
// when requesting a new token
28+
type TokenResponseBody struct {
29+
AccessToken string `json:"access_token"`
30+
ExpiresIn int `json:"expires_in"`
31+
// Deprecated: RefreshToken is no longer used and the SDK will not attempt to refresh tokens using it but will instead use the AuthFlow implementation to get new tokens.
32+
// This will be removed after 2026-07-01.
33+
RefreshToken string `json:"refresh_token"`
34+
Scope string `json:"scope"`
35+
TokenType string `json:"token_type"`
36+
}
37+
38+
func parseTokenResponse(res *http.Response) (*TokenResponseBody, error) {
39+
if res == nil {
40+
return nil, fmt.Errorf("received bad response from API")
41+
}
42+
if res.StatusCode != http.StatusOK {
43+
body, err := io.ReadAll(res.Body)
44+
if err != nil {
45+
// Fail silently, omit body from error
46+
// We're trying to show error details, so it's unnecessary to fail because of this err
47+
body = []byte{}
48+
}
49+
return nil, &oapierror.GenericOpenAPIError{
50+
StatusCode: res.StatusCode,
51+
Body: body,
52+
}
53+
}
54+
body, err := io.ReadAll(res.Body)
55+
if err != nil {
56+
return nil, err
57+
}
58+
59+
token := &TokenResponseBody{}
60+
err = json.Unmarshal(body, token)
61+
if err != nil {
62+
return nil, fmt.Errorf("unmarshal token response: %w", err)
63+
}
64+
return token, nil
65+
}
66+
67+
func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) {
68+
if token == "" {
69+
return true, nil
70+
}
71+
72+
// We can safely use ParseUnverified because we are not authenticating the user at this point.
73+
// We're just checking the expiration time
74+
tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{})
75+
if err != nil {
76+
return false, fmt.Errorf("parse token: %w", err)
77+
}
78+
79+
expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime()
80+
if err != nil {
81+
return false, fmt.Errorf("get expiration timestamp: %w", err)
82+
}
83+
84+
// Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring
85+
// between retrieving the token and upstream systems validating it.
86+
now := time.Now().Add(tokenExpirationLeeway)
87+
return now.After(expirationTimestampNumeric.Time), nil
88+
}

0 commit comments

Comments
 (0)