Skip to content

Add support for IPv6 targets to network-latency and network-packet-loss faults #4645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 10 additions & 4 deletions ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ const (
tcAddQdiscRootCommandString = "tc qdisc add dev %s root handle 1: prio priomap 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2"
tcAddQdiscLatencyCommandString = "tc qdisc add dev %s parent 1:1 handle 10: netem delay %dms %dms"
tcAddQdiscLossCommandString = "tc qdisc add dev %s parent 1:1 handle 10: netem loss %d%%"
tcAllowlistIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 1 u32 match ip dst %s flowid 1:3"
tcAddFilterForIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 2 u32 match ip dst %s flowid 1:1"
tcAllowlistIPCommandString = "tc filter add dev %s protocol all parent 1:0 prio 1 u32 match %s dst %s flowid 1:3"
tcAddFilterForIPCommandString = "tc filter add dev %s protocol all parent 1:0 prio 2 u32 match %s dst %s flowid 1:1"
tcDeleteQdiscParentCommandString = "tc qdisc del dev %s parent 1:1 handle 10:"
tcDeleteQdiscRootCommandString = "tc qdisc del dev %s root handle 1: prio"
ip4 = "ip" // For matching IPv4 packets in a tc filter
ip6 = "ip6" // For matching IPv6 packets in a tc filter
allIPv4CIDR = "0.0.0.0/0"
allIPv6CIDR = "::/0"
dropTarget = "DROP"
Expand Down Expand Up @@ -1560,8 +1562,12 @@ func checkPacketLossFault(outputUnmarshalled []map[string]interface{}) (bool, er
func (h *FaultHandler) addIPAddressesToFilter(
ctx context.Context, ipAddressList []*string, taskMetadata *state.TaskResponse,
nsenterPrefix, commandString, interfaceName string) error {
for _, ip := range ipAddressList {
commandComposed := nsenterPrefix + fmt.Sprintf(commandString, interfaceName, aws.ToString(ip))
for _, ipPtr := range ipAddressList {
ip := aws.ToString(ipPtr)
commandComposed := nsenterPrefix + fmt.Sprintf(commandString, interfaceName, ip4, ip)
if utils.IsIPv6(ip) || utils.IsIPv6CIDR(ip) {
commandComposed = nsenterPrefix + fmt.Sprintf(commandString, interfaceName, ip6, ip)
}
cmdList := strings.Split(commandComposed, " ")
cmdOutput, err := h.runExecCommand(ctx, cmdList)
if err != nil {
Expand Down
20 changes: 10 additions & 10 deletions ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ var (
"SourcesToFilter": ipSourcesToFilter,
}

ipSources = []string{"52.95.154.1", "52.95.154.2"}
ipSources = []string{"52.95.154.1", "52.95.154.2", "1:2:3:4::"}

ipSourcesToFilter = []string{"8.8.8.8"}
ipSourcesToFilter = []string{"8.8.8.8", "5:5:5:5::"}

startNetworkBlackHolePortTestPrefix = fmt.Sprintf(startFaultRequestType, types.BlackHolePortFaultType)
stopNetworkBlackHolePortTestPrefix = fmt.Sprintf(stopFaultRequestType, types.BlackHolePortFaultType)
Expand Down Expand Up @@ -1825,8 +1825,8 @@ func generateStartNetworkLatencyTestCases() []networkFaultInjectionTestCase {
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD),
mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil),
)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(5).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(5).Return([]byte(tcCommandEmptyOutput), nil)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(7).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(7).Return([]byte(tcCommandEmptyOutput), nil)
},
expectedResponseJSON: happyFaultRunningResponse,
},
Expand Down Expand Up @@ -1886,8 +1886,8 @@ func generateStartNetworkLatencyTestCases() []networkFaultInjectionTestCase {
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD),
mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil),
)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(4).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(5).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(5).Return([]byte(tcCommandEmptyOutput), nil)
},
expectedResponseJSON: happyFaultRunningResponse,
},
Expand Down Expand Up @@ -2413,8 +2413,8 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase {
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD),
mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil),
)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(5).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(5).Return([]byte(tcCommandEmptyOutput), nil)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(7).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(7).Return([]byte(tcCommandEmptyOutput), nil)
},
expectedResponseJSON: happyFaultRunningResponse,
},
Expand Down Expand Up @@ -2473,8 +2473,8 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase {
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD),
mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil),
)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(4).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(5).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(5).Return([]byte(tcCommandEmptyOutput), nil)
},
expectedResponseJSON: happyFaultRunningResponse,
},
Expand Down
42 changes: 4 additions & 38 deletions ecs-agent/tmds/handlers/fault/v1/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@ package types
import (
"encoding/json"
"fmt"
"net"
"strconv"

"github.com/aws/amazon-ecs-agent/ecs-agent/logger"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/field"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils"
"github.com/aws/aws-sdk-go-v2/aws"
)
Expand Down Expand Up @@ -135,10 +132,10 @@ func (request NetworkLatencyRequest) ValidateRequest() error {
if len(request.Sources) == 0 {
return fmt.Errorf(MissingRequiredFieldError, "Sources")
}
if err := validateNetworkFaultRequestSources(request.Sources, "Sources"); err != nil {
if err := requireIPInRequestSources(request.Sources, "Sources"); err != nil {
return err
}
if err := validateNetworkFaultRequestSources(request.SourcesToFilter, "SourcesToFilter"); err != nil {
if err := requireIPInRequestSources(request.SourcesToFilter, "SourcesToFilter"); err != nil {
return err
}
return nil
Expand Down Expand Up @@ -174,10 +171,10 @@ func (request NetworkPacketLossRequest) ValidateRequest() error {
if len(request.Sources) == 0 {
return fmt.Errorf(MissingRequiredFieldError, "Sources")
}
if err := validateNetworkFaultRequestSources(request.Sources, "Sources"); err != nil {
if err := requireIPInRequestSources(request.Sources, "Sources"); err != nil {
return err
}
if err := validateNetworkFaultRequestSources(request.SourcesToFilter, "SourcesToFilter"); err != nil {
if err := requireIPInRequestSources(request.SourcesToFilter, "SourcesToFilter"); err != nil {
return err
}
return nil
Expand All @@ -203,37 +200,6 @@ func NewNetworkFaultInjectionErrorResponse(err string) NetworkFaultInjectionResp
}
}

// validateNetworkFaultRequestSources validates each source is IPv4 or IPv4 CIDR block.
func validateNetworkFaultRequestSources(sources []*string, sourcesType string) error {
for _, element := range sources {
if err := validateNetworkFaultRequestSource(aws.ToString(element), sourcesType); err != nil {
return err
}
}
return nil
}

// validateNetworkFaultRequestSource validates the source is IPv4 or IPv4 CIDR block.
func validateNetworkFaultRequestSource(source string, sourceType string) error {
ip := net.ParseIP(source)
if ip != nil && ip.To4() != nil {
return nil // IPv4 successful
}

_, ipnet, err := net.ParseCIDR(source)
if err == nil && ipnet.IP.To4() != nil {
return nil // IPv4 CIDR successful
}
if err != nil {
logger.Info("Failed to parse fault source as IPv4 CIDR block", logger.Fields{
"source": source,
field.Error: err,
})
}

return fmt.Errorf(InvalidValueError, source, sourceType)
}

// requireIPInRequestSources requires each source is IPv4/IPv6 or IPv4/IPv6 CIDR block.
func requireIPInRequestSources(sources []*string, sourcesType string) error {
for _, element := range sources {
Expand Down
29 changes: 0 additions & 29 deletions ecs-agent/tmds/handlers/fault/v1/types/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,35 +43,6 @@ func TestNetworkBlackholePortAddSourceToFilterIfNotAlready(t *testing.T) {
})
}

// Tests for validateNetworkFaultRequestSource function that parses IPv4 and IPv4 CIDR blocks.
func TestValidateNetworkFaultRequestSources(t *testing.T) {
tcs := []struct {
Name string
Input string
ShouldSucceed bool
}{
{"Valid IPv4", "1.2.3.4", true},
{"Valid IPv4 CIDR", "1.2.3.4/10", true},
{"Valid IPv6", "2001:db8::1", false},
{"Valid full IPv6", "2001:0db8:0000:0000:0000:0000:0000:0001", false},
{"IPv6 CIDR", "::1/128", false},
{"Invalid input", "invalid", false},
{"IPv4 with invalid CIDR", "192.168.1.0/", false},
{"IPv6 with invalid CIDR", "2001:db8::/129", false},
{"Empty input", "", false},
}
for _, tc := range tcs {
t.Run(tc.Name, func(t *testing.T) {
err := validateNetworkFaultRequestSource(tc.Input, "input")
if tc.ShouldSucceed {
require.NoError(t, err)
} else {
require.EqualError(t, err, fmt.Sprintf("invalid value %s for parameter input", tc.Input))
}
})
}
}

func TestRequireIPInRequestSources(t *testing.T) {
tcs := []struct {
Name string
Expand Down
Loading