Skip to content

Commit 83a8e2e

Browse files
authored
Add AWS AssumeRole support to AWS KMS (#1359)
* Add AWS AssumeRole support to AWS KMS * Minor cleanup * Fix golint
1 parent 0827f16 commit 83a8e2e

File tree

1 file changed

+57
-5
lines changed

1 file changed

+57
-5
lines changed

schemaregistry/rules/encryption/awskms/aws_driver.go

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,26 @@
1717
package awskms
1818

1919
import (
20+
"context"
2021
"github.com/aws/aws-sdk-go-v2/aws"
22+
"github.com/aws/aws-sdk-go-v2/config"
2123
"github.com/aws/aws-sdk-go-v2/credentials"
24+
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
25+
"github.com/aws/aws-sdk-go-v2/service/sts"
2226
"github.com/confluentinc/confluent-kafka-go/v2/schemaregistry/rules/encryption"
2327
"github.com/tink-crypto/tink-go/v2/core/registry"
28+
"os"
29+
"strings"
2430
)
2531

2632
const (
2733
prefix = "aws-kms://"
2834
accessKeyID = "access.key.id"
2935
secretAccessKey = "secret.access.key"
36+
profile = "profile"
37+
roleArn = "role.arn"
38+
roleSessionName = "role.session.name"
39+
roleExternalID = "role.external.id"
3040
)
3141

3242
func init() {
@@ -51,13 +61,55 @@ func (l *awsDriver) NewKMSClient(conf map[string]string, keyURL *string) (regist
5161
if keyURL != nil {
5262
uriPrefix = *keyURL
5363
}
64+
arn := conf[roleArn]
65+
if arn == "" {
66+
arn = os.Getenv("AWS_ROLE_ARN")
67+
}
68+
sessionName := conf[roleSessionName]
69+
if sessionName == "" {
70+
sessionName = os.Getenv("AWS_ROLE_SESSION_NAME")
71+
}
72+
externalID := conf[roleExternalID]
73+
if externalID == "" {
74+
externalID = os.Getenv("AWS_ROLE_EXTERNAL_ID")
75+
}
5476
var creds aws.CredentialsProvider
55-
key, ok := conf[accessKeyID]
56-
if ok {
57-
secret, ok := conf[secretAccessKey]
58-
if ok {
59-
creds = credentials.NewStaticCredentialsProvider(key, secret, "")
77+
key := conf[accessKeyID]
78+
secret := conf[secretAccessKey]
79+
sourceProfile := conf[profile]
80+
if key != "" && secret != "" {
81+
creds = credentials.NewStaticCredentialsProvider(key, secret, "")
82+
} else if sourceProfile != "" {
83+
cfg, err := config.LoadDefaultConfig(context.Background(),
84+
config.WithSharedConfigProfile(sourceProfile),
85+
)
86+
if err != nil {
87+
return nil, err
88+
}
89+
creds = cfg.Credentials
90+
}
91+
if arn != "" {
92+
region, err := getRegion(strings.TrimPrefix(uriPrefix, prefix))
93+
if err != nil {
94+
return nil, err
95+
}
96+
stsSvc := sts.New(sts.Options{
97+
Credentials: creds,
98+
Region: region,
99+
})
100+
if sessionName == "" {
101+
sessionName = "confluent-encrypt"
60102
}
103+
var extID *string
104+
if externalID != "" {
105+
extID = &externalID
106+
}
107+
creds = stscreds.NewAssumeRoleProvider(stsSvc, arn, func(o *stscreds.AssumeRoleOptions) {
108+
o.RoleSessionName = sessionName
109+
o.ExternalID = extID
110+
})
111+
creds = aws.NewCredentialsCache(creds)
61112
}
113+
62114
return NewClient(uriPrefix, creds)
63115
}

0 commit comments

Comments
 (0)