1717package awskms
1818
1919import (
20+ "context"
2021 "github.com/aws/aws-sdk-go-v2/aws"
22+ "github.com/aws/aws-sdk-go-v2/config"
2123 "github.com/aws/aws-sdk-go-v2/credentials"
24+ "github.com/aws/aws-sdk-go-v2/credentials/stscreds"
25+ "github.com/aws/aws-sdk-go-v2/service/sts"
2226 "github.com/confluentinc/confluent-kafka-go/v2/schemaregistry/rules/encryption"
2327 "github.com/tink-crypto/tink-go/v2/core/registry"
28+ "os"
29+ "strings"
2430)
2531
2632const (
2733 prefix = "aws-kms://"
2834 accessKeyID = "access.key.id"
2935 secretAccessKey = "secret.access.key"
36+ profile = "profile"
37+ roleArn = "role.arn"
38+ roleSessionName = "role.session.name"
39+ roleExternalID = "role.external.id"
3040)
3141
3242func init () {
@@ -51,13 +61,55 @@ func (l *awsDriver) NewKMSClient(conf map[string]string, keyURL *string) (regist
5161 if keyURL != nil {
5262 uriPrefix = * keyURL
5363 }
64+ arn := conf [roleArn ]
65+ if arn == "" {
66+ arn = os .Getenv ("AWS_ROLE_ARN" )
67+ }
68+ sessionName := conf [roleSessionName ]
69+ if sessionName == "" {
70+ sessionName = os .Getenv ("AWS_ROLE_SESSION_NAME" )
71+ }
72+ externalID := conf [roleExternalID ]
73+ if externalID == "" {
74+ externalID = os .Getenv ("AWS_ROLE_EXTERNAL_ID" )
75+ }
5476 var creds aws.CredentialsProvider
55- key , ok := conf [accessKeyID ]
56- if ok {
57- secret , ok := conf [secretAccessKey ]
58- if ok {
59- creds = credentials .NewStaticCredentialsProvider (key , secret , "" )
77+ key := conf [accessKeyID ]
78+ secret := conf [secretAccessKey ]
79+ sourceProfile := conf [profile ]
80+ if key != "" && secret != "" {
81+ creds = credentials .NewStaticCredentialsProvider (key , secret , "" )
82+ } else if sourceProfile != "" {
83+ cfg , err := config .LoadDefaultConfig (context .Background (),
84+ config .WithSharedConfigProfile (sourceProfile ),
85+ )
86+ if err != nil {
87+ return nil , err
88+ }
89+ creds = cfg .Credentials
90+ }
91+ if arn != "" {
92+ region , err := getRegion (strings .TrimPrefix (uriPrefix , prefix ))
93+ if err != nil {
94+ return nil , err
95+ }
96+ stsSvc := sts .New (sts.Options {
97+ Credentials : creds ,
98+ Region : region ,
99+ })
100+ if sessionName == "" {
101+ sessionName = "confluent-encrypt"
60102 }
103+ var extID * string
104+ if externalID != "" {
105+ extID = & externalID
106+ }
107+ creds = stscreds .NewAssumeRoleProvider (stsSvc , arn , func (o * stscreds.AssumeRoleOptions ) {
108+ o .RoleSessionName = sessionName
109+ o .ExternalID = extID
110+ })
111+ creds = aws .NewCredentialsCache (creds )
61112 }
113+
62114 return NewClient (uriPrefix , creds )
63115}
0 commit comments