Skip to content

Commit 4385ba1

Browse files
committed
Changing back TaskProtectionResponse to use pointers for nested struct fields
fix unit test to include asserting response JSON body
1 parent a71d156 commit 4385ba1

File tree

6 files changed

+78
-41
lines changed

6 files changed

+78
-41
lines changed

agent/handlers/task_server_setup_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3408,7 +3408,7 @@ func TestGetTaskProtection(t *testing.T) {
34083408
}, nil),
34093409
expectedStatusCode: http.StatusOK,
34103410
expectedResponseBody: tptypes.TaskProtectionResponse{
3411-
Failure: ecstypes.Failure{
3411+
Failure: &ecstypes.Failure{
34123412
Arn: aws.String(taskARN),
34133413
Reason: aws.String("ecs failure"),
34143414
},
@@ -3451,7 +3451,7 @@ func TestGetTaskProtection(t *testing.T) {
34513451
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(&ecsOutput, nil),
34523452
expectedStatusCode: http.StatusOK,
34533453
expectedResponseBody: tptypes.TaskProtectionResponse{
3454-
Protection: protectedTask,
3454+
Protection: &protectedTask,
34553455
},
34563456
})
34573457
})
@@ -3684,7 +3684,7 @@ func TestUpdateTaskProtection(t *testing.T) {
36843684
}, nil),
36853685
expectedStatusCode: http.StatusOK,
36863686
expectedResponseBody: tptypes.TaskProtectionResponse{
3687-
Failure: ecstypes.Failure{
3687+
Failure: &ecstypes.Failure{
36883688
Arn: aws.String(taskARN),
36893689
Reason: aws.String("ecs failure"),
36903690
},
@@ -3772,7 +3772,7 @@ func TestUpdateTaskProtection(t *testing.T) {
37723772
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(&ecsOutput, nil),
37733773
expectedStatusCode: http.StatusOK,
37743774
expectedResponseBody: tptypes.TaskProtectionResponse{
3775-
Protection: protectedTask,
3775+
Protection: &protectedTask,
37763776
},
37773777
}))
37783778
}

agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types/types.go

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func GetTaskProtectionHandler(
137137

138138
// ECS call was successful
139139
utils.WriteJSONResponse(w, http.StatusOK,
140-
types.NewTaskProtectionResponseProtection(responseBody.ProtectedTasks[0]), requestType)
140+
types.NewTaskProtectionResponseProtection(&responseBody.ProtectedTasks[0]), requestType)
141141
successMetric.WithCount(1).Done(nil)
142142
}
143143
}
@@ -251,7 +251,7 @@ func UpdateTaskProtectionHandler(
251251

252252
// ECS call was successful
253253
utils.WriteJSONResponse(w, http.StatusOK,
254-
types.NewTaskProtectionResponseProtection(response.ProtectedTasks[0]), requestType)
254+
types.NewTaskProtectionResponseProtection(&response.ProtectedTasks[0]), requestType)
255255
successMetric.WithCount(1).Done(nil)
256256
}
257257
}
@@ -351,7 +351,7 @@ func logAndValidateECSResponse(
351351
return http.StatusInternalServerError, &response
352352
}
353353

354-
response := types.NewTaskProtectionResponseFailure(failures[0])
354+
response := types.NewTaskProtectionResponseFailure(&failures[0])
355355
return http.StatusOK, &response
356356
}
357357

ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers_test.go

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,26 @@ import (
5050
)
5151

5252
const (
53-
cluster = "cluster"
54-
endpointId = "endpointId"
55-
ecsCallTimeout = 5 * time.Second
56-
taskARN = "taskARN"
57-
taskRoleCredsID = "taskRoleCredsID"
53+
cluster = "cluster"
54+
endpointId = "endpointId"
55+
ecsCallTimeout = 5 * time.Second
56+
taskARN = "taskARN"
57+
taskRoleCredsID = "taskRoleCredsID"
58+
updateTaskProtectionDecodeError = "UpdateTaskProtection: failed to decode request"
59+
)
60+
61+
var (
62+
taskMetadataErrorResponse = `{"error":{"Code":"ServerException","Message":"Failed to find a task for the request"}}`
63+
noCredentialsErrorResponse = fmt.Sprintf(`{"error":{"Arn":"%s","Code":"AccessDeniedException","Message":"Invalid Request: no task IAM role credentials available for task"}}`, taskARN)
64+
requestFailureErrorResponse = `{"requestID":"%s","error":{"Arn":"%s","Code":"AccessDeniedException","Message":"%s"}}`
65+
timeoutErrorResponse = fmt.Sprintf(`{"error":{"Arn":"%s","Code":"RequestCanceled","Message":"Timed out calling ECS Task Protection API"}}`, taskARN)
66+
nonRequestAWSErrorResponse = `{"error":{"Arn":"%s","Code":"InvalidParameterException","Message":"%s"}}`
67+
nonAWSErrorResponse = `{"error":{"Arn":"%s","Code":"ServerException","Message":"%s"}}`
68+
ecsErrorResponse = fmt.Sprintf(`{"failure":{"Arn":"%s","Detail":null,"Reason":"ecs failure 1"}}`, taskARN)
69+
multipleECSErrorResponse = fmt.Sprintf(`{"error":{"Arn":"%s","Code":"ServerException","Message":"Unexpected error occurred"}}`, taskARN)
70+
happyEnabledTaskProtectionResponse = fmt.Sprintf(`{"protection":{"ExpirationDate":null,"ProtectionEnabled":true,"TaskArn":"%s"}}`, taskARN)
71+
malformedRequestResponse = `{"error":{"Code":"InvalidParameterException","Message":"%s"}}`
72+
missingTaskProtectionFieldResponse = fmt.Sprintf(`{"error":{"Arn":"%s","Code":"InvalidParameterException","Message":"Invalid request: does not contain 'ProtectionEnabled' field"}}`, taskARN)
5873
)
5974

6075
// Tests the path for UpdateTaskProtection API
@@ -70,6 +85,7 @@ type TestCase struct {
7085
setMetricsExpectations func(ctrl *gomock.Controller, metricsFactory *mock_metrics.MockEntryFactory)
7186
expectedStatusCode int
7287
expectedResponseBody types.TaskProtectionResponse
88+
expectedResonseBodyJSON string
7389
}
7490

7591
func testTaskProtectionRequest(t *testing.T, tc TestCase) {
@@ -127,6 +143,7 @@ func testTaskProtectionRequest(t *testing.T, tc TestCase) {
127143
var actualResponseBody types.TaskProtectionResponse
128144
err = json.Unmarshal(recorder.Body.Bytes(), &actualResponseBody)
129145
require.NoError(t, err)
146+
assert.Equal(t, tc.expectedResonseBodyJSON, recorder.Body.String())
130147

131148
// Assert status code and body
132149
assert.Equal(t, tc.expectedStatusCode, recorder.Code)
@@ -162,12 +179,12 @@ func TestGetTaskProtection(t *testing.T) {
162179
testTaskProtectionRequest(t, taskMetadataFetchErrorCase(
163180
state.NewErrorMetadataFetchFailure(""), metricName, nil))
164181
})
165-
t.Run("task metadata uknown error", func(t *testing.T) {
182+
t.Run("task metadata unknown error", func(t *testing.T) {
166183
testTaskProtectionRequest(t, taskMetadataFetchErrorCase(
167184
errors.New("unknown"), metricName, nil))
168185
})
169186
t.Run("task role creds not found", func(t *testing.T) {
170-
testTaskProtectionRequest(t, taskRoleCredsNotFoundCase(metricName, nil))
187+
testTaskProtectionRequest(t, taskRoleCredsNotFoundCase(metricName, nil, noCredentialsErrorResponse))
171188
})
172189
t.Run("request failure", func(t *testing.T) {
173190
ecsRequestID := "reqID"
@@ -198,6 +215,7 @@ func TestGetTaskProtection(t *testing.T) {
198215
Message: ecsErrMessage,
199216
},
200217
},
218+
expectedResonseBodyJSON: fmt.Sprintf(requestFailureErrorResponse, ecsRequestID, taskARN, ecsErrMessage),
201219
})
202220
})
203221
t.Run("agent timeout", func(t *testing.T) {
@@ -216,6 +234,7 @@ func TestGetTaskProtection(t *testing.T) {
216234
Message: "Timed out calling ECS Task Protection API",
217235
},
218236
},
237+
expectedResonseBodyJSON: timeoutErrorResponse,
219238
})
220239
})
221240
t.Run("non-request-failure aws error", func(t *testing.T) {
@@ -234,6 +253,7 @@ func TestGetTaskProtection(t *testing.T) {
234253
Message: ecsErrMessage,
235254
},
236255
},
256+
expectedResonseBodyJSON: fmt.Sprintf(nonRequestAWSErrorResponse, taskARN, ecsErrMessage),
237257
})
238258
})
239259
t.Run("non-aws error", func(t *testing.T) {
@@ -249,6 +269,7 @@ func TestGetTaskProtection(t *testing.T) {
249269
Arn: taskARN, Code: apierrors.ErrCodeServerException, Message: err.Error(),
250270
},
251271
},
272+
expectedResonseBodyJSON: fmt.Sprintf(nonAWSErrorResponse, taskARN, err.Error()),
252273
})
253274
})
254275
t.Run("ecs failure", func(t *testing.T) {
@@ -262,8 +283,9 @@ func TestGetTaskProtection(t *testing.T) {
262283
setMetricsExpectations: metricsExpectations(metricName, 0),
263284
expectedStatusCode: http.StatusOK,
264285
expectedResponseBody: types.TaskProtectionResponse{
265-
Failure: ecsFailure,
286+
Failure: &ecsFailure,
266287
},
288+
expectedResonseBodyJSON: ecsErrorResponse,
267289
})
268290
})
269291
t.Run("more than one ecs failure", func(t *testing.T) {
@@ -282,6 +304,7 @@ func TestGetTaskProtection(t *testing.T) {
282304
Message: "Unexpected error occurred",
283305
},
284306
},
307+
expectedResonseBodyJSON: multipleECSErrorResponse,
285308
})
286309
})
287310
t.Run("happy case", func(t *testing.T) {
@@ -292,9 +315,10 @@ func TestGetTaskProtection(t *testing.T) {
292315
setFactoryExpectations: factoryExpectations(happyECSInput, &ecs.GetTaskProtectionOutput{
293316
ProtectedTasks: []ecstypes.ProtectedTask{protectedTask},
294317
}, nil),
295-
setMetricsExpectations: metricsExpectations(metricName, 1),
296-
expectedStatusCode: http.StatusOK,
297-
expectedResponseBody: types.TaskProtectionResponse{Protection: protectedTask},
318+
setMetricsExpectations: metricsExpectations(metricName, 1),
319+
expectedStatusCode: http.StatusOK,
320+
expectedResponseBody: types.TaskProtectionResponse{Protection: &protectedTask},
321+
expectedResonseBodyJSON: happyEnabledTaskProtectionResponse,
298322
})
299323
})
300324
}
@@ -351,9 +375,10 @@ func TestUpdateTaskProtection(t *testing.T) {
351375
expectedResponseBody: types.TaskProtectionResponse{
352376
Error: &types.ErrorResponse{
353377
Code: apierrors.ErrCodeInvalidParameterException,
354-
Message: "UpdateTaskProtection: failed to decode request",
378+
Message: updateTaskProtectionDecodeError,
355379
},
356380
},
381+
expectedResonseBodyJSON: fmt.Sprintf(malformedRequestResponse, updateTaskProtectionDecodeError),
357382
})
358383
})
359384
t.Run("invalid type in the request", func(t *testing.T) {
@@ -364,9 +389,10 @@ func TestUpdateTaskProtection(t *testing.T) {
364389
expectedResponseBody: types.TaskProtectionResponse{
365390
Error: &types.ErrorResponse{
366391
Code: apierrors.ErrCodeInvalidParameterException,
367-
Message: "UpdateTaskProtection: failed to decode request",
392+
Message: updateTaskProtectionDecodeError,
368393
},
369394
},
395+
expectedResonseBodyJSON: fmt.Sprintf(malformedRequestResponse, updateTaskProtectionDecodeError),
370396
})
371397
})
372398
t.Run("ProtectionEnabled field not found on the request", func(t *testing.T) {
@@ -386,10 +412,11 @@ func TestUpdateTaskProtection(t *testing.T) {
386412
Message: "Invalid request: does not contain 'ProtectionEnabled' field",
387413
},
388414
},
415+
expectedResonseBodyJSON: missingTaskProtectionFieldResponse,
389416
})
390417
})
391418
t.Run("task role creds not found", func(t *testing.T) {
392-
testTaskProtectionRequest(t, taskRoleCredsNotFoundCase(metricName, happyRequestBody))
419+
testTaskProtectionRequest(t, taskRoleCredsNotFoundCase(metricName, happyRequestBody, noCredentialsErrorResponse))
393420
})
394421
t.Run("request failure", func(t *testing.T) {
395422
ecsRequestID := "reqID"
@@ -421,6 +448,7 @@ func TestUpdateTaskProtection(t *testing.T) {
421448
Message: ecsErrMessage,
422449
},
423450
},
451+
expectedResonseBodyJSON: fmt.Sprintf(requestFailureErrorResponse, ecsRequestID, taskARN, ecsErrMessage),
424452
})
425453
})
426454
t.Run("agent timeout", func(t *testing.T) {
@@ -440,6 +468,7 @@ func TestUpdateTaskProtection(t *testing.T) {
440468
Message: "Timed out calling ECS Task Protection API",
441469
},
442470
},
471+
expectedResonseBodyJSON: timeoutErrorResponse,
443472
})
444473
})
445474
t.Run("non-request-failure aws error", func(t *testing.T) {
@@ -459,6 +488,7 @@ func TestUpdateTaskProtection(t *testing.T) {
459488
Message: ecsErrMessage,
460489
},
461490
},
491+
expectedResonseBodyJSON: fmt.Sprintf(nonRequestAWSErrorResponse, taskARN, ecsErrMessage),
462492
})
463493
})
464494
t.Run("non-aws error", func(t *testing.T) {
@@ -475,6 +505,7 @@ func TestUpdateTaskProtection(t *testing.T) {
475505
Arn: taskARN, Code: apierrors.ErrCodeServerException, Message: err.Error(),
476506
},
477507
},
508+
expectedResonseBodyJSON: fmt.Sprintf(nonAWSErrorResponse, taskARN, err.Error()),
478509
})
479510
})
480511
t.Run("ecs failure", func(t *testing.T) {
@@ -489,8 +520,9 @@ func TestUpdateTaskProtection(t *testing.T) {
489520
setMetricsExpectations: metricsExpectations(metricName, 0),
490521
expectedStatusCode: http.StatusOK,
491522
expectedResponseBody: types.TaskProtectionResponse{
492-
Failure: ecsFailure,
523+
Failure: &ecsFailure,
493524
},
525+
expectedResonseBodyJSON: ecsErrorResponse,
494526
})
495527
})
496528
t.Run("more than one ecs failure", func(t *testing.T) {
@@ -510,6 +542,7 @@ func TestUpdateTaskProtection(t *testing.T) {
510542
Message: "Unexpected error occurred",
511543
},
512544
},
545+
expectedResonseBodyJSON: multipleECSErrorResponse,
513546
})
514547
})
515548
t.Run("happy case", func(t *testing.T) {
@@ -521,9 +554,10 @@ func TestUpdateTaskProtection(t *testing.T) {
521554
setFactoryExpectations: factoryExpectations(happyECSInput, &ecs.UpdateTaskProtectionOutput{
522555
ProtectedTasks: []ecstypes.ProtectedTask{protectedTask},
523556
}, nil),
524-
setMetricsExpectations: metricsExpectations(metricName, 1),
525-
expectedStatusCode: http.StatusOK,
526-
expectedResponseBody: types.TaskProtectionResponse{Protection: protectedTask},
557+
setMetricsExpectations: metricsExpectations(metricName, 1),
558+
expectedStatusCode: http.StatusOK,
559+
expectedResponseBody: types.TaskProtectionResponse{Protection: &protectedTask},
560+
expectedResonseBodyJSON: happyEnabledTaskProtectionResponse,
527561
})
528562
})
529563
}
@@ -582,6 +616,7 @@ func taskMetadataFetchErrorCase(err error, metricName string, reqBody interface{
582616
Message: "Failed to find a task for the request",
583617
},
584618
},
619+
expectedResonseBodyJSON: taskMetadataErrorResponse,
585620
}
586621
}
587622

@@ -604,11 +639,12 @@ func taskMetadataLookupFailureCase(metricName string, reqBody interface{}) TestC
604639
Message: "Failed to find a task for the request",
605640
},
606641
},
642+
expectedResonseBodyJSON: `{"error":{"Code":"ResourceNotFoundException","Message":"Failed to find a task for the request"}}`,
607643
}
608644
}
609645

610646
// Creates a test case for Task Role credentials not found case.
611-
func taskRoleCredsNotFoundCase(metricName string, reqBody interface{}) TestCase {
647+
func taskRoleCredsNotFoundCase(metricName string, reqBody interface{}, expectedResponseJSON string) TestCase {
612648
return TestCase{
613649
setAgentStateExpectations: happyStateExpectations,
614650
setCredsManagerExpectations: func(credsManager *mock_credentials.MockManager) {
@@ -625,6 +661,7 @@ func taskRoleCredsNotFoundCase(metricName string, reqBody interface{}) TestCase
625661
Message: "Invalid Request: no task IAM role credentials available for task",
626662
},
627663
},
664+
expectedResonseBodyJSON: expectedResponseJSON,
628665
}
629666
}
630667

0 commit comments

Comments
 (0)