diff --git a/dpithrottle.go b/dpithrottle.go index afe5546..0b0f98a 100644 --- a/dpithrottle.go +++ b/dpithrottle.go @@ -20,7 +20,7 @@ type DPIThrottleTrafficForTLSSNI struct { // Logger is the MANDATORY logger to use. Logger Logger - // PLR is the OPTIONAL extra packet loss rate to apply to the packet + // PLR is the OPTIONAL extra packet loss rate to apply to the packet. PLR float64 // SNI is the OPTIONAL offending SNI @@ -70,3 +70,55 @@ func (r *DPIThrottleTrafficForTLSSNI) Filter( } return policy, true } + +// DPIThrottleTrafficForTCPEndpoint is a [DPIRule] that throttles traffic +// for a given TCP endpoint. The zero value is not valid. Make sure +// you initialize all fields marked as MANDATORY. +type DPIThrottleTrafficForTCPEndpoint struct { + // Delay is the OPTIONAL extra delay to add to the flow. + Delay time.Duration + + // Logger is the MANDATORY logger to use. + Logger Logger + + // PLR is the OPTIONAL extra packet loss rate to apply to the packet. + PLR float64 + + // ServerIPAddress is the MANDATORY server endpoint IP address. + ServerIPAddress string + + // ServerPort is the MANDATORY server endpoint port. + ServerPort uint16 +} + +var _ DPIRule = &DPIThrottleTrafficForTCPEndpoint{} + +// Filter implements DPIRule +func (r *DPIThrottleTrafficForTCPEndpoint) Filter( + direction DPIDirection, packet *DissectedPacket) (*DPIPolicy, bool) { + // short circuit for the return path + if direction != DPIDirectionClientToServer { + return nil, false + } + + // make sure the packet is TCP and for the proper endpoint + if !packet.MatchesDestination(layers.IPProtocolTCP, r.ServerIPAddress, r.ServerPort) { + return nil, false + } + + r.Logger.Infof( + "netem: dpi: throttling flow %s:%d %s:%d/%s because the endpoint is filtered", + packet.SourceIPAddress(), + packet.SourcePort(), + packet.DestinationIPAddress(), + packet.DestinationPort(), + packet.TransportProtocol(), + ) + policy := &DPIPolicy{ + Delay: r.Delay, + Flags: 0, + PLR: r.PLR, + Spoofed: nil, + } + return policy, true +} diff --git a/integration_test.go b/integration_test.go index f9d7db8..a56421c 100644 --- a/integration_test.go +++ b/integration_test.go @@ -482,7 +482,7 @@ func TestDPITCPThrottleForSNI(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Log("checking for TLS flow throttling", tc.name) - // throttle the offending SNI to have high latency and hig losses + // throttle the offending SNI to have high latency and high losses dpiEngine := netem.NewDPIEngine(log.Log) dpiEngine.AddRule(&netem.DPIThrottleTrafficForTLSSNI{ Delay: 10 * time.Millisecond, @@ -581,6 +581,159 @@ func TestDPITCPThrottleForSNI(t *testing.T) { } } +// TestDPITCPThrottleForTCPEndpoint verifies we can use the DPI to throttle +// connections using a specific TCP endpoint. +func TestDPITCPThrottleForTCPEndpoint(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + // testcase describes a test case + type testcase struct { + // name is the name of the test case + name string + + // endpointAddress is the address of the endpoint to block. + endpointAddress string + + // endpointPort is the port of the endpoint to block. + endpointPort uint16 + + // checkAvgSpeed is a function the check whether + // the speed is consistent with expectations + checkAvgSpeed func(t *testing.T, speed float64) + } + + var testcases = []testcase{{ + name: "when the client is using a throttled endpoint", + endpointAddress: "10.0.0.1", + endpointPort: 443, + checkAvgSpeed: func(t *testing.T, speed float64) { + // See above comment regarding expected performance + // under the given RTT, MSS, and PLR constraints + const expectation = 5 + if speed > expectation { + t.Fatal("goodput", speed, "above expectation", expectation) + } + }, + }, { + name: "when the client is not using a throttled endpoint", + endpointAddress: "10.0.0.1", + endpointPort: 555, // different port + checkAvgSpeed: func(t *testing.T, speed float64) { + // See above comment regarding expected performance + // under the given RTT, MSS, and PLR constraints + const expectation = 5 + if speed < expectation { + t.Fatal("goodput", speed, "below expectation", expectation) + } + }, + }} + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + t.Log("checking for TLS flow throttling", tc.name) + + // throttle the offending endpoint to have high latency and high losses + dpiEngine := netem.NewDPIEngine(log.Log) + dpiEngine.AddRule(&netem.DPIThrottleTrafficForTCPEndpoint{ + Delay: 10 * time.Millisecond, + Logger: log.Log, + PLR: 0.1, + ServerIPAddress: tc.endpointAddress, + ServerPort: tc.endpointPort, + }) + lc := &netem.LinkConfig{ + DPIEngine: dpiEngine, + LeftToRightDelay: 100 * time.Microsecond, + RightToLeftDelay: 100 * time.Microsecond, + } + + // create a point-to-point topology, which consists of a single + // [Link] connecting two userspace network stacks. + topology := netem.MustNewPPPTopology( + "10.0.0.2", + "10.0.0.1", + log.Log, + lc, + ) + defer topology.Close() + + // make sure we have a deadline bound context + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // add DNS server to resolve the clientSNI domain + dnsConfig := netem.NewDNSConfig() + dnsConfig.AddRecord("ndt0.local", "", "10.0.0.1") + dnsServer, err := netem.NewDNSServer(log.Log, topology.Server, "10.0.0.1", dnsConfig) + if err != nil { + t.Fatal(err) + } + defer dnsServer.Close() + + // start an NDT0 server in the background + ready, serverErrorCh := make(chan net.Listener, 1), make(chan error, 1) + go netem.RunNDT0Server( + ctx, + topology.Server, + net.ParseIP("10.0.0.1"), + 443, + log.Log, + ready, + serverErrorCh, + true, + "ndt0.local", + "ndt0.xyz", + ) + + // await for the NDT0 server to be listening + listener := <-ready + defer listener.Close() + + // run NDT0 client in the background and measure speed + clientErrorCh := make(chan error, 1) + perfch := make(chan *netem.NDT0PerformanceSample) + go netem.RunNDT0Client( + ctx, + topology.Client, + net.JoinHostPort("ndt0.local", "443"), + log.Log, + true, + clientErrorCh, + perfch, + ) + + // collect the average speed + var avgSpeed float64 + for p := range perfch { + if p.Final { + avgSpeed = p.AvgSpeedMbps() + } + } + + // make sure we have collected samples + if avgSpeed <= 0 { + t.Fatal("did not collect the average speed") + } + + // make sure that neither the client nor the server + // reported a fundamental error + if err := <-clientErrorCh; err != nil { + t.Fatal(err) + } + if err := <-serverErrorCh; err != nil { + t.Fatal(err) + } + + t.Log("measured goodput", avgSpeed) + + // make sure that the speed is consistent with expectations + tc.checkAvgSpeed(t, avgSpeed) + }) + } +} + // TestDPITCPResetForSNI verifies we can use the DPI to reset TCP // connections using specific TLS SNI values. func TestDPITCPResetForSNI(t *testing.T) {