Skip to content

Changing back TaskProtectionResponse to use pointers for nested struct fields #4559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3408,7 +3408,7 @@ func TestGetTaskProtection(t *testing.T) {
}, nil),
expectedStatusCode: http.StatusOK,
expectedResponseBody: tptypes.TaskProtectionResponse{
Failure: ecstypes.Failure{
Failure: &ecstypes.Failure{
Arn: aws.String(taskARN),
Reason: aws.String("ecs failure"),
},
Expand Down Expand Up @@ -3451,7 +3451,7 @@ func TestGetTaskProtection(t *testing.T) {
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(&ecsOutput, nil),
expectedStatusCode: http.StatusOK,
expectedResponseBody: tptypes.TaskProtectionResponse{
Protection: protectedTask,
Protection: &protectedTask,
},
})
})
Expand Down Expand Up @@ -3684,7 +3684,7 @@ func TestUpdateTaskProtection(t *testing.T) {
}, nil),
expectedStatusCode: http.StatusOK,
expectedResponseBody: tptypes.TaskProtectionResponse{
Failure: ecstypes.Failure{
Failure: &ecstypes.Failure{
Arn: aws.String(taskARN),
Reason: aws.String("ecs failure"),
},
Expand Down Expand Up @@ -3772,7 +3772,7 @@ func TestUpdateTaskProtection(t *testing.T) {
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(&ecsOutput, nil),
expectedStatusCode: http.StatusOK,
expectedResponseBody: tptypes.TaskProtectionResponse{
Protection: protectedTask,
Protection: &protectedTask,
},
}))
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func GetTaskProtectionHandler(

// ECS call was successful
utils.WriteJSONResponse(w, http.StatusOK,
types.NewTaskProtectionResponseProtection(responseBody.ProtectedTasks[0]), requestType)
types.NewTaskProtectionResponseProtection(&responseBody.ProtectedTasks[0]), requestType)
successMetric.WithCount(1).Done(nil)
}
}
Expand Down Expand Up @@ -251,7 +251,7 @@ func UpdateTaskProtectionHandler(

// ECS call was successful
utils.WriteJSONResponse(w, http.StatusOK,
types.NewTaskProtectionResponseProtection(response.ProtectedTasks[0]), requestType)
types.NewTaskProtectionResponseProtection(&response.ProtectedTasks[0]), requestType)
successMetric.WithCount(1).Done(nil)
}
}
Expand Down Expand Up @@ -351,7 +351,7 @@ func logAndValidateECSResponse(
return http.StatusInternalServerError, &response
}

response := types.NewTaskProtectionResponseFailure(failures[0])
response := types.NewTaskProtectionResponseFailure(&failures[0])
return http.StatusOK, &response
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,26 @@ import (
)

const (
cluster = "cluster"
endpointId = "endpointId"
ecsCallTimeout = 5 * time.Second
taskARN = "taskARN"
taskRoleCredsID = "taskRoleCredsID"
cluster = "cluster"
endpointId = "endpointId"
ecsCallTimeout = 5 * time.Second
taskARN = "taskARN"
taskRoleCredsID = "taskRoleCredsID"
updateTaskProtectionDecodeError = "UpdateTaskProtection: failed to decode request"
)

var (
taskMetadataErrorResponse = `{"error":{"Code":"ServerException","Message":"Failed to find a task for the request"}}`
noCredentialsErrorResponse = fmt.Sprintf(`{"error":{"Arn":"%s","Code":"AccessDeniedException","Message":"Invalid Request: no task IAM role credentials available for task"}}`, taskARN)
requestFailureErrorResponse = `{"requestID":"%s","error":{"Arn":"%s","Code":"AccessDeniedException","Message":"%s"}}`
timeoutErrorResponse = fmt.Sprintf(`{"error":{"Arn":"%s","Code":"RequestCanceled","Message":"Timed out calling ECS Task Protection API"}}`, taskARN)
nonRequestAWSErrorResponse = `{"error":{"Arn":"%s","Code":"InvalidParameterException","Message":"%s"}}`
nonAWSErrorResponse = `{"error":{"Arn":"%s","Code":"ServerException","Message":"%s"}}`
ecsErrorResponse = fmt.Sprintf(`{"failure":{"Arn":"%s","Detail":null,"Reason":"ecs failure 1"}}`, taskARN)
multipleECSErrorResponse = fmt.Sprintf(`{"error":{"Arn":"%s","Code":"ServerException","Message":"Unexpected error occurred"}}`, taskARN)
happyEnabledTaskProtectionResponse = fmt.Sprintf(`{"protection":{"ExpirationDate":null,"ProtectionEnabled":true,"TaskArn":"%s"}}`, taskARN)
malformedRequestResponse = `{"error":{"Code":"InvalidParameterException","Message":"%s"}}`
missingTaskProtectionFieldResponse = fmt.Sprintf(`{"error":{"Arn":"%s","Code":"InvalidParameterException","Message":"Invalid request: does not contain 'ProtectionEnabled' field"}}`, taskARN)
)

// Tests the path for UpdateTaskProtection API
Expand All @@ -70,6 +85,7 @@ type TestCase struct {
setMetricsExpectations func(ctrl *gomock.Controller, metricsFactory *mock_metrics.MockEntryFactory)
expectedStatusCode int
expectedResponseBody types.TaskProtectionResponse
expectedResonseBodyJSON string
}

func testTaskProtectionRequest(t *testing.T, tc TestCase) {
Expand Down Expand Up @@ -127,6 +143,7 @@ func testTaskProtectionRequest(t *testing.T, tc TestCase) {
var actualResponseBody types.TaskProtectionResponse
err = json.Unmarshal(recorder.Body.Bytes(), &actualResponseBody)
require.NoError(t, err)
assert.Equal(t, tc.expectedResonseBodyJSON, recorder.Body.String())

// Assert status code and body
assert.Equal(t, tc.expectedStatusCode, recorder.Code)
Expand Down Expand Up @@ -162,12 +179,12 @@ func TestGetTaskProtection(t *testing.T) {
testTaskProtectionRequest(t, taskMetadataFetchErrorCase(
state.NewErrorMetadataFetchFailure(""), metricName, nil))
})
t.Run("task metadata uknown error", func(t *testing.T) {
t.Run("task metadata unknown error", func(t *testing.T) {
testTaskProtectionRequest(t, taskMetadataFetchErrorCase(
errors.New("unknown"), metricName, nil))
})
t.Run("task role creds not found", func(t *testing.T) {
testTaskProtectionRequest(t, taskRoleCredsNotFoundCase(metricName, nil))
testTaskProtectionRequest(t, taskRoleCredsNotFoundCase(metricName, nil, noCredentialsErrorResponse))
})
t.Run("request failure", func(t *testing.T) {
ecsRequestID := "reqID"
Expand Down Expand Up @@ -198,6 +215,7 @@ func TestGetTaskProtection(t *testing.T) {
Message: ecsErrMessage,
},
},
expectedResonseBodyJSON: fmt.Sprintf(requestFailureErrorResponse, ecsRequestID, taskARN, ecsErrMessage),
})
})
t.Run("agent timeout", func(t *testing.T) {
Expand All @@ -216,6 +234,7 @@ func TestGetTaskProtection(t *testing.T) {
Message: "Timed out calling ECS Task Protection API",
},
},
expectedResonseBodyJSON: timeoutErrorResponse,
})
})
t.Run("non-request-failure aws error", func(t *testing.T) {
Expand All @@ -234,6 +253,7 @@ func TestGetTaskProtection(t *testing.T) {
Message: ecsErrMessage,
},
},
expectedResonseBodyJSON: fmt.Sprintf(nonRequestAWSErrorResponse, taskARN, ecsErrMessage),
})
})
t.Run("non-aws error", func(t *testing.T) {
Expand All @@ -249,6 +269,7 @@ func TestGetTaskProtection(t *testing.T) {
Arn: taskARN, Code: apierrors.ErrCodeServerException, Message: err.Error(),
},
},
expectedResonseBodyJSON: fmt.Sprintf(nonAWSErrorResponse, taskARN, err.Error()),
})
})
t.Run("ecs failure", func(t *testing.T) {
Expand All @@ -262,8 +283,9 @@ func TestGetTaskProtection(t *testing.T) {
setMetricsExpectations: metricsExpectations(metricName, 0),
expectedStatusCode: http.StatusOK,
expectedResponseBody: types.TaskProtectionResponse{
Failure: ecsFailure,
Failure: &ecsFailure,
},
expectedResonseBodyJSON: ecsErrorResponse,
})
})
t.Run("more than one ecs failure", func(t *testing.T) {
Expand All @@ -282,6 +304,7 @@ func TestGetTaskProtection(t *testing.T) {
Message: "Unexpected error occurred",
},
},
expectedResonseBodyJSON: multipleECSErrorResponse,
})
})
t.Run("happy case", func(t *testing.T) {
Expand All @@ -292,9 +315,10 @@ func TestGetTaskProtection(t *testing.T) {
setFactoryExpectations: factoryExpectations(happyECSInput, &ecs.GetTaskProtectionOutput{
ProtectedTasks: []ecstypes.ProtectedTask{protectedTask},
}, nil),
setMetricsExpectations: metricsExpectations(metricName, 1),
expectedStatusCode: http.StatusOK,
expectedResponseBody: types.TaskProtectionResponse{Protection: protectedTask},
setMetricsExpectations: metricsExpectations(metricName, 1),
expectedStatusCode: http.StatusOK,
expectedResponseBody: types.TaskProtectionResponse{Protection: &protectedTask},
expectedResonseBodyJSON: happyEnabledTaskProtectionResponse,
})
})
}
Expand Down Expand Up @@ -351,9 +375,10 @@ func TestUpdateTaskProtection(t *testing.T) {
expectedResponseBody: types.TaskProtectionResponse{
Error: &types.ErrorResponse{
Code: apierrors.ErrCodeInvalidParameterException,
Message: "UpdateTaskProtection: failed to decode request",
Message: updateTaskProtectionDecodeError,
},
},
expectedResonseBodyJSON: fmt.Sprintf(malformedRequestResponse, updateTaskProtectionDecodeError),
})
})
t.Run("invalid type in the request", func(t *testing.T) {
Expand All @@ -364,9 +389,10 @@ func TestUpdateTaskProtection(t *testing.T) {
expectedResponseBody: types.TaskProtectionResponse{
Error: &types.ErrorResponse{
Code: apierrors.ErrCodeInvalidParameterException,
Message: "UpdateTaskProtection: failed to decode request",
Message: updateTaskProtectionDecodeError,
},
},
expectedResonseBodyJSON: fmt.Sprintf(malformedRequestResponse, updateTaskProtectionDecodeError),
})
})
t.Run("ProtectionEnabled field not found on the request", func(t *testing.T) {
Expand All @@ -386,10 +412,11 @@ func TestUpdateTaskProtection(t *testing.T) {
Message: "Invalid request: does not contain 'ProtectionEnabled' field",
},
},
expectedResonseBodyJSON: missingTaskProtectionFieldResponse,
})
})
t.Run("task role creds not found", func(t *testing.T) {
testTaskProtectionRequest(t, taskRoleCredsNotFoundCase(metricName, happyRequestBody))
testTaskProtectionRequest(t, taskRoleCredsNotFoundCase(metricName, happyRequestBody, noCredentialsErrorResponse))
})
t.Run("request failure", func(t *testing.T) {
ecsRequestID := "reqID"
Expand Down Expand Up @@ -421,6 +448,7 @@ func TestUpdateTaskProtection(t *testing.T) {
Message: ecsErrMessage,
},
},
expectedResonseBodyJSON: fmt.Sprintf(requestFailureErrorResponse, ecsRequestID, taskARN, ecsErrMessage),
})
})
t.Run("agent timeout", func(t *testing.T) {
Expand All @@ -440,6 +468,7 @@ func TestUpdateTaskProtection(t *testing.T) {
Message: "Timed out calling ECS Task Protection API",
},
},
expectedResonseBodyJSON: timeoutErrorResponse,
})
})
t.Run("non-request-failure aws error", func(t *testing.T) {
Expand All @@ -459,6 +488,7 @@ func TestUpdateTaskProtection(t *testing.T) {
Message: ecsErrMessage,
},
},
expectedResonseBodyJSON: fmt.Sprintf(nonRequestAWSErrorResponse, taskARN, ecsErrMessage),
})
})
t.Run("non-aws error", func(t *testing.T) {
Expand All @@ -475,6 +505,7 @@ func TestUpdateTaskProtection(t *testing.T) {
Arn: taskARN, Code: apierrors.ErrCodeServerException, Message: err.Error(),
},
},
expectedResonseBodyJSON: fmt.Sprintf(nonAWSErrorResponse, taskARN, err.Error()),
})
})
t.Run("ecs failure", func(t *testing.T) {
Expand All @@ -489,8 +520,9 @@ func TestUpdateTaskProtection(t *testing.T) {
setMetricsExpectations: metricsExpectations(metricName, 0),
expectedStatusCode: http.StatusOK,
expectedResponseBody: types.TaskProtectionResponse{
Failure: ecsFailure,
Failure: &ecsFailure,
},
expectedResonseBodyJSON: ecsErrorResponse,
})
})
t.Run("more than one ecs failure", func(t *testing.T) {
Expand All @@ -510,6 +542,7 @@ func TestUpdateTaskProtection(t *testing.T) {
Message: "Unexpected error occurred",
},
},
expectedResonseBodyJSON: multipleECSErrorResponse,
})
})
t.Run("happy case", func(t *testing.T) {
Expand All @@ -521,9 +554,10 @@ func TestUpdateTaskProtection(t *testing.T) {
setFactoryExpectations: factoryExpectations(happyECSInput, &ecs.UpdateTaskProtectionOutput{
ProtectedTasks: []ecstypes.ProtectedTask{protectedTask},
}, nil),
setMetricsExpectations: metricsExpectations(metricName, 1),
expectedStatusCode: http.StatusOK,
expectedResponseBody: types.TaskProtectionResponse{Protection: protectedTask},
setMetricsExpectations: metricsExpectations(metricName, 1),
expectedStatusCode: http.StatusOK,
expectedResponseBody: types.TaskProtectionResponse{Protection: &protectedTask},
expectedResonseBodyJSON: happyEnabledTaskProtectionResponse,
})
})
}
Expand Down Expand Up @@ -582,6 +616,7 @@ func taskMetadataFetchErrorCase(err error, metricName string, reqBody interface{
Message: "Failed to find a task for the request",
},
},
expectedResonseBodyJSON: taskMetadataErrorResponse,
}
}

Expand All @@ -604,11 +639,12 @@ func taskMetadataLookupFailureCase(metricName string, reqBody interface{}) TestC
Message: "Failed to find a task for the request",
},
},
expectedResonseBodyJSON: `{"error":{"Code":"ResourceNotFoundException","Message":"Failed to find a task for the request"}}`,
}
}

// Creates a test case for Task Role credentials not found case.
func taskRoleCredsNotFoundCase(metricName string, reqBody interface{}) TestCase {
func taskRoleCredsNotFoundCase(metricName string, reqBody interface{}, expectedResponseJSON string) TestCase {
return TestCase{
setAgentStateExpectations: happyStateExpectations,
setCredsManagerExpectations: func(credsManager *mock_credentials.MockManager) {
Expand All @@ -625,6 +661,7 @@ func taskRoleCredsNotFoundCase(metricName string, reqBody interface{}) TestCase
Message: "Invalid Request: no task IAM role credentials available for task",
},
},
expectedResonseBodyJSON: expectedResponseJSON,
}
}

Expand Down
Loading
Loading