diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index a28759553..0b023df3f 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -323,6 +323,20 @@ func parseImpersonationChain(chain string) (string, []string) { const iamLoginScope = "https://www.googleapis.com/auth/sqlservice.login" +// iamAuthNEnabled returns true if IAM authentication is enabled globally +// or for any instance in the configuration. +func (c *Config) iamAuthNEnabled() bool { + if c.IAMAuthN { + return true + } + for _, inst := range c.Instances { + if inst.IAMAuthN != nil && *inst.IAMAuthN { + return true + } + } + return false +} + func credentialsOpt(c Config, l cloudsql.Logger) (cloudsqlconn.Option, error) { // If service account impersonation is configured, set up an impersonated // credentials token source. @@ -363,7 +377,8 @@ func credentialsOpt(c Config, l cloudsql.Logger) (cloudsqlconn.Option, error) { if err != nil { return nil, err } - if c.IAMAuthN { + + if c.iamAuthNEnabled() { iamLoginTS, err := impersonate.CredentialsTokenSource( context.Background(), impersonate.CredentialsConfig{ @@ -439,7 +454,7 @@ func (c *Config) DialerOptions(l cloudsql.Logger) ([]cloudsqlconn.Option, error) opts = append(opts, cloudsqlconn.WithUniverseDomain(c.UniverseDomain)) } - if c.IAMAuthN { + if c.iamAuthNEnabled() { opts = append(opts, cloudsqlconn.WithIAMAuthN()) } diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 58afa4b75..fe6400d02 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -280,6 +280,14 @@ func TestPostgresIAMDBAuthn(t *testing.T) { dsn: fmt.Sprintf("host=localhost user=%s database=%s sslmode=disable", impersonatedIAMUser, *postgresDB), }, + { + desc: "using impersonation with query param", + args: []string{ + "--impersonate-service-account", *impersonatedUser, + fmt.Sprintf("%s?auto-iam-authn=true", *postgresConnName)}, + dsn: fmt.Sprintf("host=localhost user=%s password=password database=%s sslmode=disable", + impersonatedIAMUser, *postgresDB), + }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) {