Skip to content

Commit 1e095ed

Browse files
authored
fix get aws creds from environment (#3617)
Signed-off-by: Fabian Martinez <[email protected]>
1 parent f48b412 commit 1e095ed

File tree

4 files changed

+47
-25
lines changed

4 files changed

+47
-25
lines changed

common/authentication/aws/aws.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,6 @@ type Provider interface {
112112
Close() error
113113
}
114114

115-
func isX509Auth(m map[string]string) bool {
116-
tp, _ := m["trustProfileArn"]
117-
ta, _ := m["trustAnchorArn"]
118-
ar, _ := m["assumeRoleArn"]
119-
return tp != "" && ta != "" && ar != ""
120-
}
121-
122115
func NewProvider(ctx context.Context, opts Options, cfg *aws.Config) (Provider, error) {
123116
if isX509Auth(opts.Properties) {
124117
return newX509(ctx, opts, cfg)

common/authentication/aws/static.go

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ type StaticAuth struct {
3838
endpoint *string
3939
accessKey *string
4040
secretKey *string
41-
sessionToken *string
41+
sessionToken string
4242

4343
assumeRoleARN *string
44-
sessionName *string
44+
sessionName string
4545

4646
session *session.Session
4747
cfg *aws.Config
@@ -50,15 +50,7 @@ type StaticAuth struct {
5050

5151
func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) {
5252
auth := &StaticAuth{
53-
logger: opts.Logger,
54-
region: &opts.Region,
55-
endpoint: &opts.Endpoint,
56-
accessKey: &opts.AccessKey,
57-
secretKey: &opts.SecretKey,
58-
sessionToken: &opts.SessionToken,
59-
assumeRoleARN: &opts.AssumeRoleARN,
60-
sessionName: &opts.SessionName,
61-
53+
logger: opts.Logger,
6254
cfg: func() *aws.Config {
6355
// if nil is passed or it's just a default cfg,
6456
// then we use the options to build the aws cfg.
@@ -70,7 +62,29 @@ func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth
7062
clients: newClients(),
7163
}
7264

73-
initialSession, err := auth.getTokenClient()
65+
if opts.Region != "" {
66+
auth.region = &opts.Region
67+
}
68+
if opts.Endpoint != "" {
69+
auth.endpoint = &opts.Endpoint
70+
}
71+
if opts.AccessKey != "" {
72+
auth.accessKey = &opts.AccessKey
73+
}
74+
if opts.SecretKey != "" {
75+
auth.secretKey = &opts.SecretKey
76+
}
77+
if opts.SessionToken != "" {
78+
auth.sessionToken = opts.SessionToken
79+
}
80+
if opts.AssumeRoleARN != "" {
81+
auth.assumeRoleARN = &opts.AssumeRoleARN
82+
}
83+
if opts.SessionName != "" {
84+
auth.sessionName = opts.SessionName
85+
}
86+
87+
initialSession, err := auth.createSession()
7488
if err != nil {
7589
return nil, fmt.Errorf("failed to get token client: %v", err)
7690
}
@@ -231,8 +245,8 @@ func (a *StaticAuth) Kafka(opts KafkaOptions) (*KafkaClients, error) {
231245
if a.assumeRoleARN != nil {
232246
tokenProvider.awsIamRoleArn = *a.assumeRoleARN
233247
}
234-
if a.sessionName != nil {
235-
tokenProvider.awsStsSessionName = *a.sessionName
248+
if a.sessionName != "" {
249+
tokenProvider.awsStsSessionName = a.sessionName
236250
}
237251

238252
err := a.clients.kafka.New(a.session, &tokenProvider)
@@ -243,7 +257,7 @@ func (a *StaticAuth) Kafka(opts KafkaOptions) (*KafkaClients, error) {
243257
return a.clients.kafka, nil
244258
}
245259

246-
func (a *StaticAuth) getTokenClient() (*session.Session, error) {
260+
func (a *StaticAuth) createSession() (*session.Session, error) {
247261
var awsConfig *aws.Config
248262
if a.cfg == nil {
249263
awsConfig = aws.NewConfig()
@@ -257,13 +271,15 @@ func (a *StaticAuth) getTokenClient() (*session.Session, error) {
257271

258272
if a.accessKey != nil && a.secretKey != nil {
259273
// session token is an option field
260-
awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, *a.sessionToken))
274+
awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, a.sessionToken))
261275
}
262276

263277
if a.endpoint != nil {
264278
awsConfig = awsConfig.WithEndpoint(*a.endpoint)
265279
}
266280

281+
// TODO support assume role for all aws components
282+
267283
awsSession, err := session.NewSessionWithOptions(session.Options{
268284
Config: *awsConfig,
269285
SharedConfigState: session.SharedConfigEnable,

common/authentication/aws/static_test.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,22 @@ func TestGetTokenClient(t *testing.T) {
4848
awsInstance: &StaticAuth{
4949
accessKey: aws.String("testAccessKey"),
5050
secretKey: aws.String("testSecretKey"),
51-
sessionToken: aws.String("testSessionToken"),
51+
sessionToken: "testSessionToken",
5252
region: aws.String("us-west-2"),
5353
endpoint: aws.String("https://test.endpoint.com"),
5454
},
5555
},
56+
{
57+
name: "creds from environment",
58+
awsInstance: &StaticAuth{
59+
region: aws.String("us-west-2"),
60+
},
61+
},
5662
}
5763

5864
for _, tt := range tests {
5965
t.Run(tt.name, func(t *testing.T) {
60-
session, err := tt.awsInstance.getTokenClient()
66+
session, err := tt.awsInstance.createSession()
6167
require.NotNil(t, session)
6268
require.NoError(t, err)
6369
assert.Equal(t, tt.awsInstance.region, session.Config.Region)

common/authentication/aws/x509.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ import (
4141
"github.com/dapr/kit/ptr"
4242
)
4343

44+
func isX509Auth(m map[string]string) bool {
45+
tp := m["trustProfileArn"]
46+
ta := m["trustAnchorArn"]
47+
ar := m["assumeRoleArn"]
48+
return tp != "" && ta != "" && ar != ""
49+
}
50+
4451
type x509Options struct {
4552
TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"`
4653
TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"`

0 commit comments

Comments
 (0)