Skip to content

Commit 1e1db68

Browse files
jkroepkebahkauv70
andauthored
Accept custom transport of custom http client (#1627)
* Respect transport of custom http client Signed-off-by: Jan-Otto Kröpke <mail@jkroepke.de> * Add CHANGELOG Signed-off-by: Jan-Otto Kröpke <mail@jkroepke.de> * Add test cases Signed-off-by: Jan-Otto Kröpke <mail@jkroepke.de> * avoid testify Signed-off-by: Jan-Otto Kröpke <mail@jkroepke.de> --------- Signed-off-by: Jan-Otto Kröpke <mail@jkroepke.de> Co-authored-by: Rüdiger Schmitz <ruediger.schmitz@inovex.de>
1 parent aef2fa0 commit 1e1db68

12 files changed

+586
-152
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
- **New:** API for application load balancer
44
- `cdn`: [v0.1.0](services/cdn/CHANGELOG.md#v010-2025-03-19)
55
- **New:** Introduce new API for content delivery
6+
- `core`: [v0.16.2](core/CHANGELOG.md#v0162-2025-03-21)
7+
- **New:** If a custom http.Client is provided, the http.Transport is respected. This allows customizing the http.Client with custom timeouts or instrumentation.
68
- `serverupdate`: [v1.0.0](services/serverupdate/CHANGELOG.md#v100-2025-03-19)
79
- **Breaking Change:** The region is no longer specified within the client configuration. Instead, the region must be passed as a parameter to any region-specific request.
810
- `serverbackup`: [v1.0.0](services/serverbackup/CHANGELOG.md#v100-2025-03-19)

core/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
## v0.16.2 (2025-03-21)
2+
- **New:** If a custom http.Client is provided, the http.Transport is respected. This allows customizing the http.Client with custom timeouts or instrumentation.
3+
14
## v0.16.1 (2025-02-25)
25

36
- **Bugfix:** STACKIT_PRIVATE_KEY and STACKIT_SERVICE_ACCOUNT_KEY can be set via environment variable or via credentials file.

core/auth/auth.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
4545
if cfg.CustomAuth != nil {
4646
return cfg.CustomAuth, nil
4747
} else if cfg.NoAuth {
48-
noAuthRoundTripper, err := NoAuth()
48+
noAuthRoundTripper, err := NoAuth(cfg)
4949
if err != nil {
5050
return nil, fmt.Errorf("configuring no auth client: %w", err)
5151
}
@@ -98,9 +98,22 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
9898

9999
// NoAuth configures a flow without authentication and returns an http.RoundTripper
100100
// that can be used to make unauthenticated requests
101-
func NoAuth() (rt http.RoundTripper, err error) {
101+
func NoAuth(cfgs ...*config.Configuration) (rt http.RoundTripper, err error) {
102102
noAuthConfig := clients.NoAuthFlowConfig{}
103103
noAuthRoundTripper := &clients.NoAuthFlow{}
104+
105+
var cfg *config.Configuration
106+
107+
if len(cfgs) > 0 {
108+
cfg = cfgs[0]
109+
} else {
110+
cfg = &config.Configuration{}
111+
}
112+
113+
if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil {
114+
noAuthConfig.HTTPTransport = cfg.HTTPClient.Transport
115+
}
116+
104117
if err := noAuthRoundTripper.Init(noAuthConfig); err != nil {
105118
return nil, fmt.Errorf("initializing client: %w", err)
106119
}
@@ -130,6 +143,10 @@ func TokenAuth(cfg *config.Configuration) (http.RoundTripper, error) {
130143
ServiceAccountToken: cfg.Token,
131144
}
132145

146+
if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil {
147+
tokenCfg.HTTPTransport = cfg.HTTPClient.Transport
148+
}
149+
133150
client := &clients.TokenFlow{}
134151
if err := client.Init(&tokenCfg); err != nil {
135152
return nil, fmt.Errorf("error initializing client: %w", err)
@@ -187,6 +204,10 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) {
187204
BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext,
188205
}
189206

207+
if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil {
208+
keyCfg.HTTPTransport = cfg.HTTPClient.Transport
209+
}
210+
190211
client := &clients.KeyFlow{}
191212
if err := client.Init(&keyCfg); err != nil {
192213
return nil, fmt.Errorf("error initializing client: %w", err)

core/auth/auth_test.go

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"crypto/x509"
77
"encoding/json"
88
"encoding/pem"
9+
"net/http"
910
"os"
1011
"reflect"
1112
"testing"
@@ -125,6 +126,7 @@ func TestSetupAuth(t *testing.T) {
125126
t.Fatalf("Creating temporary file: %s", err)
126127
}
127128
defer func() {
129+
_ = credentialsKeyFile.Close()
128130
err := os.Remove(credentialsKeyFile.Name())
129131
if err != nil {
130132
t.Fatalf("Removing temporary file: %s", err)
@@ -361,6 +363,7 @@ func TestDefaultAuth(t *testing.T) {
361363
t.Fatalf("Creating temporary file: %s", err)
362364
}
363365
defer func() {
366+
_ = saKeyFile.Close()
364367
err := os.Remove(saKeyFile.Name())
365368
if err != nil {
366369
t.Fatalf("Removing temporary file: %s", err)
@@ -377,19 +380,13 @@ func TestDefaultAuth(t *testing.T) {
377380
t.Fatalf("Writing private key to temporary file: %s", err)
378381
}
379382

380-
defer func() {
381-
err := saKeyFile.Close()
382-
if err != nil {
383-
t.Fatalf("Removing temporary file: %s", err)
384-
}
385-
}()
386-
387383
// create a credentials file with saKey and private key
388384
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
389385
if errs != nil {
390386
t.Fatalf("Creating temporary file: %s", err)
391387
}
392388
defer func() {
389+
_ = credentialsKeyFile.Close()
393390
err := os.Remove(credentialsKeyFile.Name())
394391
if err != nil {
395392
t.Fatalf("Removing temporary file: %s", err)
@@ -693,6 +690,28 @@ func TestNoAuth(t *testing.T) {
693690
}
694691
}
695692

693+
func TestNoAuthWithConfig(t *testing.T) {
694+
for _, test := range []struct {
695+
desc string
696+
}{
697+
{
698+
desc: "valid_case",
699+
},
700+
} {
701+
t.Run(test.desc, func(t *testing.T) {
702+
setTemporaryHome(t) // Get the default authentication client and ensure that it's not nil
703+
authClient, err := NoAuth(&config.Configuration{HTTPClient: http.DefaultClient})
704+
if err != nil {
705+
t.Fatalf("Test returned error on valid test case: %v", err)
706+
}
707+
708+
if authClient == nil {
709+
t.Fatalf("Client returned is nil for valid test case")
710+
}
711+
})
712+
}
713+
}
714+
696715
func TestGetServiceAccountEmail(t *testing.T) {
697716
for _, test := range []struct {
698717
description string

core/clients/key_flow.go

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ const (
3434

3535
// KeyFlow handles auth with SA key
3636
type KeyFlow struct {
37-
client *http.Client
37+
rt http.RoundTripper
38+
authClient *http.Client
3839
config *KeyFlowConfig
39-
doer func(req *http.Request) (resp *http.Response, err error)
4040
key *ServiceAccountKeyResponse
4141
privateKey *rsa.PrivateKey
4242
privateKeyPEM []byte
@@ -53,6 +53,8 @@ type KeyFlowConfig struct {
5353
ClientRetry *RetryConfig
5454
TokenUrl string
5555
BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil
56+
HTTPTransport http.RoundTripper
57+
AuthHTTPClient *http.Client
5658
}
5759

5860
// TokenResponseBody is the API response
@@ -124,7 +126,18 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error {
124126
if c.config.TokenUrl == "" {
125127
c.config.TokenUrl = tokenAPI
126128
}
127-
c.configureHTTPClient()
129+
130+
if c.rt = cfg.HTTPTransport; c.rt == nil {
131+
c.rt = http.DefaultTransport
132+
}
133+
134+
if c.authClient = cfg.AuthHTTPClient; cfg.AuthHTTPClient == nil {
135+
c.authClient = &http.Client{
136+
Transport: c.rt,
137+
Timeout: DefaultClientTimeout,
138+
}
139+
}
140+
128141
err := c.validate()
129142
if err != nil {
130143
return err
@@ -163,7 +176,7 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error {
163176

164177
// Roundtrip performs the request
165178
func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) {
166-
if c.client == nil {
179+
if c.rt == nil {
167180
return nil, fmt.Errorf("please run Init()")
168181
}
169182

@@ -172,17 +185,21 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) {
172185
return nil, err
173186
}
174187
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
175-
return c.doer(req)
188+
return c.rt.RoundTrip(req)
176189
}
177190

178191
// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field
179192
func (c *KeyFlow) GetAccessToken() (string, error) {
180-
if c.client == nil {
181-
return "", fmt.Errorf("nil http client, please run Init()")
193+
if c.rt == nil {
194+
return "", fmt.Errorf("nil http round tripper, please run Init()")
182195
}
183196

197+
var accessToken string
198+
184199
c.tokenMutex.RLock()
185-
accessToken := c.token.AccessToken
200+
if c.token != nil {
201+
accessToken = c.token.AccessToken
202+
}
186203
c.tokenMutex.RUnlock()
187204

188205
accessTokenExpired, err := tokenExpired(accessToken)
@@ -203,14 +220,6 @@ func (c *KeyFlow) GetAccessToken() (string, error) {
203220
return accessToken, nil
204221
}
205222

206-
// configureHTTPClient configures the HTTP client
207-
func (c *KeyFlow) configureHTTPClient() {
208-
client := &http.Client{}
209-
client.Timeout = DefaultClientTimeout
210-
c.client = client
211-
c.doer = c.client.Do
212-
}
213-
214223
// validate the client is configured well
215224
func (c *KeyFlow) validate() error {
216225
if c.config.ServiceAccountKey == nil {
@@ -242,8 +251,12 @@ func (c *KeyFlow) validate() error {
242251
// recreateAccessToken is used to create a new access token
243252
// when the existing one isn't valid anymore
244253
func (c *KeyFlow) recreateAccessToken() error {
254+
var refreshToken string
255+
245256
c.tokenMutex.RLock()
246-
refreshToken := c.token.RefreshToken
257+
if c.token != nil {
258+
refreshToken = c.token.RefreshToken
259+
}
247260
c.tokenMutex.RUnlock()
248261

249262
refreshTokenExpired, err := tokenExpired(refreshToken)
@@ -279,10 +292,6 @@ func (c *KeyFlow) createAccessToken() (err error) {
279292
// createAccessTokenWithRefreshToken creates an access token using
280293
// an existing pre-validated refresh token
281294
func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) {
282-
if c.client == nil {
283-
return fmt.Errorf("nil http client, please run Init()")
284-
}
285-
286295
c.tokenMutex.RLock()
287296
refreshToken := c.token.RefreshToken
288297
c.tokenMutex.RUnlock()
@@ -334,7 +343,8 @@ func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error)
334343
return nil, err
335344
}
336345
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
337-
return c.doer(req)
346+
347+
return c.authClient.Do(req)
338348
}
339349

340350
// parseTokenResponse parses the response from the server

core/clients/key_flow_continuous_refresh.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error {
4646
// Compute timestamp where we'll refresh token
4747
// Access token may be empty at this point, we have to check it
4848
var startRefreshTimestamp time.Time
49+
var accessToken string
4950

5051
refresher.keyFlow.tokenMutex.RLock()
51-
accessToken := refresher.keyFlow.token.AccessToken
52+
if refresher.keyFlow.token != nil {
53+
accessToken = refresher.keyFlow.token.AccessToken
54+
}
5255
refresher.keyFlow.tokenMutex.RUnlock()
5356
if accessToken == "" {
5457
startRefreshTimestamp = time.Now()

core/clients/key_flow_continuous_refresh_test.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,9 @@ func TestContinuousRefreshToken(t *testing.T) {
137137
config: &KeyFlowConfig{
138138
BackgroundTokenRefreshContext: ctx,
139139
},
140-
client: &http.Client{},
141-
doer: mockDo,
140+
authClient: &http.Client{
141+
Transport: mockTransportFn{mockDo},
142+
},
142143
token: &TokenResponseBody{
143144
AccessToken: accessToken,
144145
RefreshToken: refreshToken,
@@ -328,11 +329,13 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
328329
}
329330

330331
keyFlow := &KeyFlow{
331-
client: &http.Client{},
332332
config: &KeyFlowConfig{
333333
BackgroundTokenRefreshContext: ctx,
334334
},
335-
doer: mockDo,
335+
authClient: &http.Client{
336+
Transport: mockTransportFn{mockDo},
337+
},
338+
rt: mockTransportFn{mockDo},
336339
token: &TokenResponseBody{
337340
AccessToken: accessTokenFirst,
338341
RefreshToken: refreshToken,

0 commit comments

Comments
 (0)