Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement TCP-endpoint-based throttling #46

Merged
merged 1 commit into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion dpithrottle.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type DPIThrottleTrafficForTLSSNI struct {
// Logger is the MANDATORY logger to use.
Logger Logger

// PLR is the OPTIONAL extra packet loss rate to apply to the packet
// PLR is the OPTIONAL extra packet loss rate to apply to the packet.
PLR float64

// SNI is the OPTIONAL offending SNI
Expand Down Expand Up @@ -70,3 +70,55 @@ func (r *DPIThrottleTrafficForTLSSNI) Filter(
}
return policy, true
}

// DPIThrottleTrafficForTCPEndpoint is a [DPIRule] that throttles traffic
// for a given TCP endpoint. The zero value is not valid. Make sure
// you initialize all fields marked as MANDATORY.
type DPIThrottleTrafficForTCPEndpoint struct {
// Delay is the OPTIONAL extra delay to add to the flow.
Delay time.Duration

// Logger is the MANDATORY logger to use.
Logger Logger

// PLR is the OPTIONAL extra packet loss rate to apply to the packet.
PLR float64

// ServerIPAddress is the MANDATORY server endpoint IP address.
ServerIPAddress string

// ServerPort is the MANDATORY server endpoint port.
ServerPort uint16
}

var _ DPIRule = &DPIThrottleTrafficForTCPEndpoint{}

// Filter implements DPIRule
func (r *DPIThrottleTrafficForTCPEndpoint) Filter(
direction DPIDirection, packet *DissectedPacket) (*DPIPolicy, bool) {
// short circuit for the return path
if direction != DPIDirectionClientToServer {
return nil, false
}

// make sure the packet is TCP and for the proper endpoint
if !packet.MatchesDestination(layers.IPProtocolTCP, r.ServerIPAddress, r.ServerPort) {
return nil, false
}

r.Logger.Infof(
"netem: dpi: throttling flow %s:%d %s:%d/%s because the endpoint is filtered",
packet.SourceIPAddress(),
packet.SourcePort(),
packet.DestinationIPAddress(),
packet.DestinationPort(),
packet.TransportProtocol(),
)
policy := &DPIPolicy{
Delay: r.Delay,
Flags: 0,
PLR: r.PLR,
Spoofed: nil,
}
return policy, true
}
155 changes: 154 additions & 1 deletion integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ func TestDPITCPThrottleForSNI(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Log("checking for TLS flow throttling", tc.name)

// throttle the offending SNI to have high latency and hig losses
// throttle the offending SNI to have high latency and high losses
dpiEngine := netem.NewDPIEngine(log.Log)
dpiEngine.AddRule(&netem.DPIThrottleTrafficForTLSSNI{
Delay: 10 * time.Millisecond,
Expand Down Expand Up @@ -581,6 +581,159 @@ func TestDPITCPThrottleForSNI(t *testing.T) {
}
}

// TestDPITCPThrottleForTCPEndpoint verifies we can use the DPI to throttle
// connections using a specific TCP endpoint.
func TestDPITCPThrottleForTCPEndpoint(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}

// testcase describes a test case
type testcase struct {
// name is the name of the test case
name string

// endpointAddress is the address of the endpoint to block.
endpointAddress string

// endpointPort is the port of the endpoint to block.
endpointPort uint16

// checkAvgSpeed is a function the check whether
// the speed is consistent with expectations
checkAvgSpeed func(t *testing.T, speed float64)
}

var testcases = []testcase{{
name: "when the client is using a throttled endpoint",
endpointAddress: "10.0.0.1",
endpointPort: 443,
checkAvgSpeed: func(t *testing.T, speed float64) {
// See above comment regarding expected performance
// under the given RTT, MSS, and PLR constraints
const expectation = 5
if speed > expectation {
t.Fatal("goodput", speed, "above expectation", expectation)
}
},
}, {
name: "when the client is not using a throttled endpoint",
endpointAddress: "10.0.0.1",
endpointPort: 555, // different port
checkAvgSpeed: func(t *testing.T, speed float64) {
// See above comment regarding expected performance
// under the given RTT, MSS, and PLR constraints
const expectation = 5
if speed < expectation {
t.Fatal("goodput", speed, "below expectation", expectation)
}
},
}}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
t.Log("checking for TLS flow throttling", tc.name)

// throttle the offending endpoint to have high latency and high losses
dpiEngine := netem.NewDPIEngine(log.Log)
dpiEngine.AddRule(&netem.DPIThrottleTrafficForTCPEndpoint{
Delay: 10 * time.Millisecond,
Logger: log.Log,
PLR: 0.1,
ServerIPAddress: tc.endpointAddress,
ServerPort: tc.endpointPort,
})
lc := &netem.LinkConfig{
DPIEngine: dpiEngine,
LeftToRightDelay: 100 * time.Microsecond,
RightToLeftDelay: 100 * time.Microsecond,
}

// create a point-to-point topology, which consists of a single
// [Link] connecting two userspace network stacks.
topology := netem.MustNewPPPTopology(
"10.0.0.2",
"10.0.0.1",
log.Log,
lc,
)
defer topology.Close()

// make sure we have a deadline bound context
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// add DNS server to resolve the clientSNI domain
dnsConfig := netem.NewDNSConfig()
dnsConfig.AddRecord("ndt0.local", "", "10.0.0.1")
dnsServer, err := netem.NewDNSServer(log.Log, topology.Server, "10.0.0.1", dnsConfig)
if err != nil {
t.Fatal(err)
}
defer dnsServer.Close()

// start an NDT0 server in the background
ready, serverErrorCh := make(chan net.Listener, 1), make(chan error, 1)
go netem.RunNDT0Server(
ctx,
topology.Server,
net.ParseIP("10.0.0.1"),
443,
log.Log,
ready,
serverErrorCh,
true,
"ndt0.local",
"ndt0.xyz",
)

// await for the NDT0 server to be listening
listener := <-ready
defer listener.Close()

// run NDT0 client in the background and measure speed
clientErrorCh := make(chan error, 1)
perfch := make(chan *netem.NDT0PerformanceSample)
go netem.RunNDT0Client(
ctx,
topology.Client,
net.JoinHostPort("ndt0.local", "443"),
log.Log,
true,
clientErrorCh,
perfch,
)

// collect the average speed
var avgSpeed float64
for p := range perfch {
if p.Final {
avgSpeed = p.AvgSpeedMbps()
}
}

// make sure we have collected samples
if avgSpeed <= 0 {
t.Fatal("did not collect the average speed")
}

// make sure that neither the client nor the server
// reported a fundamental error
if err := <-clientErrorCh; err != nil {
t.Fatal(err)
}
if err := <-serverErrorCh; err != nil {
t.Fatal(err)
}

t.Log("measured goodput", avgSpeed)

// make sure that the speed is consistent with expectations
tc.checkAvgSpeed(t, avgSpeed)
})
}
}

// TestDPITCPResetForSNI verifies we can use the DPI to reset TCP
// connections using specific TLS SNI values.
func TestDPITCPResetForSNI(t *testing.T) {
Expand Down
Loading