diff --git a/bench_test.go b/bench_test.go index 9f27bb71b..08b6c3c7b 100644 --- a/bench_test.go +++ b/bench_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "github.com/pion/dtls/v2/internal/util" "github.com/pion/dtls/v2/pkg/crypto/selfsign" + dtlsnet "github.com/pion/dtls/v2/pkg/net" "github.com/pion/logging" "github.com/pion/transport/v2/dpipe" "github.com/pion/transport/v2/test" @@ -31,7 +31,7 @@ func TestSimpleReadWrite(t *testing.T) { gotHello := make(chan struct{}) go func() { - server, sErr := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{ + server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), }, false) @@ -49,7 +49,7 @@ func TestSimpleReadWrite(t *testing.T) { } }() - client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{ + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) @@ -79,7 +79,7 @@ func benchmarkConn(b *testing.B, n int64) { certificate, err := selfsign.GenerateSelfSigned() server := make(chan *Conn) go func() { - s, sErr := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{ + s, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, }, false) if err != nil { @@ -95,7 +95,7 @@ func benchmarkConn(b *testing.B, n int64) { b.ReportAllocs() b.SetBytes(int64(len(hw))) go func() { - client, cErr := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{InsecureSkipVerify: true}, false) + client, cErr := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{InsecureSkipVerify: true}, false) if cErr != nil { b.Error(err) } diff --git a/cipher_suite_test.go b/cipher_suite_test.go index 0d2d83d09..38bfd516e 100644 --- a/cipher_suite_test.go +++ b/cipher_suite_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/pion/dtls/v2/internal/ciphersuite" - "github.com/pion/dtls/v2/internal/util" + dtlsnet "github.com/pion/dtls/v2/pkg/net" "github.com/pion/transport/v2/dpipe" "github.com/pion/transport/v2/test" ) @@ -71,14 +71,14 @@ func TestCustomCipherSuite(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{ + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) c <- result{client, err} }() - server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{ + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) diff --git a/conn_go_test.go b/conn_go_test.go index 2978eb340..cfd0e32f4 100644 --- a/conn_go_test.go +++ b/conn_go_test.go @@ -15,8 +15,8 @@ import ( "testing" "time" - "github.com/pion/dtls/v2/internal/util" "github.com/pion/dtls/v2/pkg/crypto/selfsign" + dtlsnet "github.com/pion/dtls/v2/pkg/net" "github.com/pion/transport/v2/dpipe" "github.com/pion/transport/v2/test" ) @@ -86,7 +86,7 @@ func TestContextConfig(t *testing.T) { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() return func() (net.Conn, error) { - return Client(util.FromConn(ca), ca.RemoteAddr(), config) + return Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) }, func() { _ = ca.Close() } @@ -98,7 +98,7 @@ func TestContextConfig(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) ca, _ := dpipe.Pipe() return func() (net.Conn, error) { - return ClientWithContext(ctx, util.FromConn(ca), ca.RemoteAddr(), config) + return ClientWithContext(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) }, func() { cancel() _ = ca.Close() @@ -110,7 +110,7 @@ func TestContextConfig(t *testing.T) { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() return func() (net.Conn, error) { - return Server(util.FromConn(ca), ca.RemoteAddr(), config) + return Server(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) }, func() { _ = ca.Close() } @@ -122,7 +122,7 @@ func TestContextConfig(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) ca, _ := dpipe.Pipe() return func() (net.Conn, error) { - return ServerWithContext(ctx, util.FromConn(ca), ca.RemoteAddr(), config) + return ServerWithContext(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) }, func() { cancel() _ = ca.Close() diff --git a/conn_test.go b/conn_test.go index 4ec47571b..447526841 100644 --- a/conn_test.go +++ b/conn_test.go @@ -25,12 +25,12 @@ import ( "time" "github.com/pion/dtls/v2/internal/ciphersuite" - "github.com/pion/dtls/v2/internal/util" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/dtls/v2/pkg/crypto/hash" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/dtls/v2/pkg/crypto/signature" "github.com/pion/dtls/v2/pkg/crypto/signaturehash" + dtlsnet "github.com/pion/dtls/v2/pkg/net" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/extension" @@ -266,12 +266,12 @@ func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) { // Setup client go func() { - client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) c <- result{client, err} }() // Setup server - server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) if err != nil { return nil, nil, err } @@ -385,11 +385,11 @@ func TestHandshakeWithAlert(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - _, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), testCase.configClient, true) + _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), testCase.configClient, true) clientErr <- err }() - _, errServer := testServer(ctx, util.FromConn(cb), ca.RemoteAddr(), testCase.configServer, true) + _, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), ca.RemoteAddr(), testCase.configServer, true) if !errors.Is(errServer, testCase.errServer) { t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer) } @@ -552,7 +552,7 @@ func TestPSK(t *testing.T) { VerifyConnection: test.ClientVerifyConnection, } - c, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), conf, false) + c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) clientRes <- result{c, err} }() @@ -568,7 +568,7 @@ func TestPSK(t *testing.T) { VerifyConnection: test.ServerVerifyConnection, } - server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, false) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false) if test.WantFail { res := <-clientRes if err == nil || !strings.Contains(err.Error(), test.ExpectedServerErr) { @@ -627,7 +627,7 @@ func TestPSKHintFail(t *testing.T) { CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } - _, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), conf, false) + _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) clientErr <- err }() @@ -639,7 +639,7 @@ func TestPSKHintFail(t *testing.T) { CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } - if _, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, false); !errors.Is(err, serverAlertError) { + if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false); !errors.Is(err, serverAlertError) { t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err) } @@ -666,7 +666,7 @@ func TestClientTimeout(t *testing.T) { go func() { conf := &Config{} - c, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), conf, true) + c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, true) if err == nil { _ = c.Close() //nolint:contextcheck } @@ -754,11 +754,11 @@ func TestSRTPConfiguration(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ClientSRTP}, true) + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ClientSRTP}, true) c <- result{client, err} }() - server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ServerSRTP}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{SRTPProtectionProfiles: test.ServerSRTP}, true) if !errors.Is(err, test.WantServerError) { t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) } @@ -962,11 +962,11 @@ func TestClientCertificate(t *testing.T) { c := make(chan result) go func() { - client, err := Client(util.FromConn(ca), ca.RemoteAddr(), tt.clientCfg) + client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) c <- result{client, err} }() - server, err := Server(util.FromConn(cb), cb.RemoteAddr(), tt.serverCfg) + server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) res := <-c defer func() { if err == nil { @@ -1119,11 +1119,11 @@ func TestConnectionID(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), tt.clientCfg, true) + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg, true) c <- result{client, err} }() - server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), tt.serverCfg, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg, true) if err != nil { t.Fatalf("Unexpected server error: %v", err) } @@ -1272,11 +1272,11 @@ func TestExtendedMasterSecret(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), tt.clientCfg, true) + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg, true) c <- result{client, err} }() - server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), tt.serverCfg, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg, true) res := <-c defer func() { if err == nil { @@ -1382,11 +1382,11 @@ func TestServerCertificate(t *testing.T) { } srvCh := make(chan result) go func() { - s, err := Server(util.FromConn(cb), cb.RemoteAddr(), tt.serverCfg) + s, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) srvCh <- result{s, err} }() - cli, err := Client(util.FromConn(ca), ca.RemoteAddr(), tt.clientCfg) + cli, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) if err == nil { _ = cli.Close() } @@ -1486,11 +1486,11 @@ func TestCipherSuiteConfiguration(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{CipherSuites: test.ClientCipherSuites}, true) + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{CipherSuites: test.ClientCipherSuites}, true) c <- result{client, err} }() - server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{CipherSuites: test.ServerCipherSuites}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{CipherSuites: test.ServerCipherSuites}, true) if err == nil { defer func() { _ = server.Close() @@ -1555,7 +1555,7 @@ func TestCertificateAndPSKServer(t *testing.T) { config.CipherSuites = []CipherSuiteID{TLS_PSK_WITH_AES_128_GCM_SHA256} } - client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), config, false) + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) c <- result{client, err} }() @@ -1566,7 +1566,7 @@ func TestCertificateAndPSKServer(t *testing.T) { }, } - server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) if err == nil { defer func() { _ = server.Close() @@ -1658,11 +1658,11 @@ func TestPSKConfiguration(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate) + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate) c <- result{client, err} }() - _, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate) + _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate) if err != nil || test.WantServerError != nil { if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) { t.Fatalf("TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) @@ -1792,7 +1792,7 @@ func TestServerTimeout(t *testing.T) { FlightInterval: 100 * time.Millisecond, } - _, serverErr := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, true) + _, serverErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) var netErr net.Error if !errors.As(serverErr, &netErr) || !netErr.Timeout() { t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr) @@ -1907,7 +1907,7 @@ func TestProtocolVersionValidation(t *testing.T) { defer wg.Wait() go func() { defer wg.Done() - if _, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, true); !errors.Is(err, errUnsupportedProtocolVersion) { + if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true); !errors.Is(err, errUnsupportedProtocolVersion) { t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) } }() @@ -1997,7 +1997,7 @@ func TestProtocolVersionValidation(t *testing.T) { defer wg.Wait() go func() { defer wg.Done() - if _, err := testClient(ctx, util.FromConn(cb), cb.RemoteAddr(), config, true); !errors.Is(err, errUnsupportedProtocolVersion) { + if _, err := testClient(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true); !errors.Is(err, errUnsupportedProtocolVersion) { t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) } }() @@ -2095,7 +2095,7 @@ func TestMultipleHelloVerifyRequest(t *testing.T) { defer wg.Wait() go func() { defer wg.Done() - _, _ = testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{}, false) + _, _ = testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{}, false) }() for i, cookie := range cookies { @@ -2167,7 +2167,7 @@ func TestRenegotationInfo(t *testing.T) { defer cancel() go func() { - if _, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{}, true); !errors.Is(err, context.Canceled) { + if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true); !errors.Is(err, context.Canceled) { t.Error(err) } }() @@ -2279,7 +2279,7 @@ func TestServerNameIndicationExtension(t *testing.T) { ServerName: test.ServerName, } - _, _ = testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), conf, false) + _, _ = testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) }() // Receive ClientHello @@ -2397,7 +2397,7 @@ func TestALPNExtension(t *testing.T) { conf := &Config{ SupportedProtocols: test.ClientProtocolNameList, } - _, _ = testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), conf, false) + _, _ = testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) }() // Receive ClientHello @@ -2415,7 +2415,7 @@ func TestALPNExtension(t *testing.T) { conf := &Config{ SupportedProtocols: test.ServerProtocolNameList, } - if _, err2 := testServer(ctx2, util.FromConn(cb2), cb2.RemoteAddr(), conf, true); !errors.Is(err2, context.Canceled) { + if _, err2 := testServer(ctx2, dtlsnet.PacketConnFromConn(cb2), cb2.RemoteAddr(), conf, true); !errors.Is(err2, context.Canceled) { if test.ExpectAlertFromServer { //nolint // Assert the error type? } else { @@ -2562,7 +2562,7 @@ func TestSupportedGroupsExtension(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - if _, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{}, true); !errors.Is(err, context.Canceled) { + if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true); !errors.Is(err, context.Canceled) { t.Error(err) } }() @@ -2671,7 +2671,7 @@ func TestSessionResume(t *testing.T) { SessionStore: ss, MTU: 100, } - c, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), config, false) + c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) clientRes <- result{c, err} }() @@ -2681,7 +2681,7 @@ func TestSessionResume(t *testing.T) { SessionStore: ss, MTU: 100, } - server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) if err != nil { t.Fatalf("TestSessionResume: Server failed(%v)", err) } @@ -2725,14 +2725,14 @@ func TestSessionResume(t *testing.T) { ServerName: "example.com", SessionStore: s1, } - c, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), config, false) + c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) clientRes <- result{c, err} }() config := &Config{ SessionStore: s2, } - server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), config, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) if err != nil { t.Fatalf("TestSessionResumetion: Server failed(%v)", err) } @@ -2830,7 +2830,7 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - c, err := testClient(context.TODO(), util.FromConn(ca), ca.RemoteAddr(), &Config{CipherSuites: test.cipherList}, false) + c, err := testClient(context.TODO(), dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{CipherSuites: test.cipherList}, false) clientErr <- err client <- c }() @@ -2855,7 +2855,7 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) { t.Fatal(err) } - if s, err := testServer(context.TODO(), util.FromConn(cb), cb.RemoteAddr(), &Config{ + if s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: test.cipherList, Certificates: []tls.Certificate{serverCert}, }, false); err != nil { @@ -2920,7 +2920,7 @@ func TestMultipleServerCertificates(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - c, err := testClient(context.TODO(), util.FromConn(ca), ca.RemoteAddr(), &Config{ + c, err := testClient(context.TODO(), dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ RootCAs: caPool, ServerName: test.RequestServerName, VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { @@ -2940,7 +2940,7 @@ func TestMultipleServerCertificates(t *testing.T) { client <- c }() - if s, err := testServer(context.TODO(), util.FromConn(cb), cb.RemoteAddr(), &Config{Certificates: []tls.Certificate{fooCert, barCert}}, false); err != nil { + if s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{Certificates: []tls.Certificate{fooCert, barCert}}, false); err != nil { t.Fatal(err) } else if err = s.Close(); err != nil { t.Fatal(err) @@ -2992,11 +2992,11 @@ func TestEllipticCurveConfiguration(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) c <- result{client, err} }() - server, err := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) if err != nil { t.Fatalf("Server error: %v", err) } @@ -3048,7 +3048,7 @@ func TestSkipHelloVerify(t *testing.T) { gotHello := make(chan struct{}) go func() { - server, sErr := testServer(ctx, util.FromConn(cb), cb.RemoteAddr(), &Config{ + server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerifyHello: true, @@ -3067,7 +3067,7 @@ func TestSkipHelloVerify(t *testing.T) { } }() - client, err := testClient(ctx, util.FromConn(ca), ca.RemoteAddr(), &Config{ + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) diff --git a/e2e/e2e_lossy_test.go b/e2e/e2e_lossy_test.go index c694287da..59231e893 100644 --- a/e2e/e2e_lossy_test.go +++ b/e2e/e2e_lossy_test.go @@ -11,8 +11,8 @@ import ( "time" "github.com/pion/dtls/v2" - "github.com/pion/dtls/v2/internal/util" "github.com/pion/dtls/v2/pkg/crypto/selfsign" + dtlsnet "github.com/pion/dtls/v2/pkg/net" transportTest "github.com/pion/transport/v2/test" ) @@ -145,7 +145,7 @@ func TestPionE2ELossy(t *testing.T) { cfg.Certificates = []tls.Certificate{clientCert} } - client, startupErr := dtls.Client(util.FromConn(br.GetConn0()), br.GetConn0().RemoteAddr(), cfg) + client, startupErr := dtls.Client(dtlsnet.PacketConnFromConn(br.GetConn0()), br.GetConn0().RemoteAddr(), cfg) clientDone <- runResult{client, startupErr} }() @@ -160,7 +160,7 @@ func TestPionE2ELossy(t *testing.T) { cfg.ClientAuth = dtls.RequireAnyClientCert } - server, startupErr := dtls.Server(util.FromConn(br.GetConn1()), br.GetConn1().RemoteAddr(), cfg) + server, startupErr := dtls.Server(dtlsnet.PacketConnFromConn(br.GetConn1()), br.GetConn1().RemoteAddr(), cfg) serverDone <- runResult{server, startupErr} }() diff --git a/examples/dial/cid/main.go b/examples/dial/cid/main.go new file mode 100644 index 000000000..0ab8f4b78 --- /dev/null +++ b/examples/dial/cid/main.go @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package main implements an example DTLS client using a pre-shared key. +package main + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/pion/dtls/v2" + "github.com/pion/dtls/v2/examples/util" +) + +func main() { + // Prepare the IP to connect to + addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} + + // + // Everything below is the pion-DTLS API! Thanks for using it ❤️. + // + + // Prepare the configuration of the DTLS connection + config := &dtls.Config{ + PSK: func(hint []byte) ([]byte, error) { + fmt.Printf("Server's hint: %s \n", hint) + return []byte{0xAB, 0xC1, 0x23}, nil + }, + PSKIdentityHint: []byte("Pion DTLS Server"), + CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, + ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, + ConnectionIDGenerator: dtls.OnlySendCIDGenerator(), + } + + // Connect to a DTLS server + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) + util.Check(err) + defer func() { + util.Check(dtlsConn.Close()) + }() + + fmt.Println("Connected; type 'exit' to shutdown gracefully") + + // Simulate a chat session + util.Chat(dtlsConn) +} diff --git a/examples/listen/cid/main.go b/examples/listen/cid/main.go new file mode 100644 index 000000000..5c85fe81b --- /dev/null +++ b/examples/listen/cid/main.go @@ -0,0 +1,77 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package main implements a DTLS server using a pre-shared key. +package main + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/pion/dtls/v2" + "github.com/pion/dtls/v2/examples/util" +) + +func main() { + // Prepare the IP to connect to + addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} + + // Create parent context to cleanup handshaking connections on exit. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // + // Everything below is the pion-DTLS API! Thanks for using it ❤️. + // + + // Prepare the configuration of the DTLS connection + config := &dtls.Config{ + PSK: func(hint []byte) ([]byte, error) { + fmt.Printf("Client's hint: %s \n", hint) + return []byte{0xAB, 0xC1, 0x23}, nil + }, + PSKIdentityHint: []byte("Pion DTLS Client"), + CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, + ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, + // Create timeout context for accepted connection. + ConnectContextMaker: func() (context.Context, func()) { + return context.WithTimeout(ctx, 30*time.Second) + }, + ConnectionIDGenerator: dtls.RandomCIDGenerator(8), + } + + // Connect to a DTLS server + listener, err := dtls.Listen("udp", addr, config) + util.Check(err) + defer func() { + util.Check(listener.Close()) + }() + + fmt.Println("Listening") + + // Simulate a chat session + hub := util.NewHub() + + go func() { + for { + // Wait for a connection. + conn, err := listener.Accept() + util.Check(err) + // defer conn.Close() // TODO: graceful shutdown + + // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` + // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose + // functions like `ConnectionState` etc. + + // Register the connection with the chat hub + if err == nil { + hub.Register(conn) + } + } + }() + + // Start chatting + hub.Chat() +} diff --git a/go.mod b/go.mod index b35260c67..9025bfd20 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,9 @@ module github.com/pion/dtls/v2 require ( github.com/pion/logging v0.2.2 + github.com/pion/transport v0.14.1 github.com/pion/transport/v2 v2.2.2-0.20230802201558-f2dffd80896b + github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.12.0 golang.org/x/net v0.13.0 ) diff --git a/go.sum b/go.sum index b73379cfc..cf4923b39 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= +github.com/pion/transport v0.14.1 h1:XSM6olwW+o8J4SCmOBb/BpwZypkHeyM0PGFCxNQBr40= +github.com/pion/transport v0.14.1/go.mod h1:4tGmbk00NeYA3rUa9+n+dzCCoKkcy3YlYb99Jn2fNnI= github.com/pion/transport/v2 v2.2.2-0.20230802201558-f2dffd80896b h1:g/axuqY9eU5L6YeAQSq+yW4CU5fPqOb90EaWI+8xeiI= github.com/pion/transport/v2 v2.2.2-0.20230802201558-f2dffd80896b/go.mod h1:OJg3ojoBJopjEeECq2yJdXH9YVrUJ1uQ++NjXLOUorc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -25,6 +27,7 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.13.0 h1:Nvo8UFsZ8X3BhAC9699Z1j7XQ3rsZnUUm7jfBEk1ueY= @@ -37,12 +40,15 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= @@ -50,6 +56,7 @@ golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= diff --git a/internal/net/buffer.go b/internal/net/buffer.go new file mode 100644 index 000000000..cab9cc0c3 --- /dev/null +++ b/internal/net/buffer.go @@ -0,0 +1,222 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package net implements DTLS specific networking primitives. +// NOTE: this package is an adaption of pion/transport/packetio. If possible, +// the updates made in this repository will be reflected back upstream. If not, +// it is likely that this will be moved to a public package in this repository. +package net + +import ( + "bytes" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/pion/transport/deadline" +) + +// ErrTimeout indicates that deadline was reached before operation could be +// completed. +var ErrTimeout = errors.New("buffer: i/o timeout") + +// AddrPacket is a packet payload and the associated remote address from which +// it was received. +type AddrPacket struct { + addr net.Addr + data bytes.Buffer +} + +// PacketBuffer is a circular buffer for network packets. Each slot in the +// buffer supports a +type PacketBuffer struct { + mutex sync.Mutex + + packets []AddrPacket + write, read int + + // full indicates whether the buffer is full, which is needed to distinguish + // when the write pointer and read pointer are at the same index. + full bool + + notify chan struct{} + closed bool + + readDeadline *deadline.Deadline +} + +// NewPacketBuffer creates a new PacketBuffer. +func NewPacketBuffer() *PacketBuffer { + return &PacketBuffer{ + readDeadline: deadline.New(), + // In the narrow context in which this package is currently used, there + // will always be at least one packet written to the buffer. Therefore, + // we opt to allocate with size of 1 during construction, rather than + // waiting until that first packet is written. + packets: make([]AddrPacket, 1), + full: false, + } +} + +// WriteTo writes a single packet to the buffer. The supplied address will +// remain associated with the packet. +func (b *PacketBuffer) WriteTo(p []byte, addr net.Addr) (int, error) { + b.mutex.Lock() + + if b.closed { + b.mutex.Unlock() + return 0, io.ErrClosedPipe + } + + var notify chan struct{} + if b.notify != nil { + notify = b.notify + b.notify = nil + } + + // Check to see if we are full. + if b.full { + // If so, grow AddrPacket buffer. + var newPackets int + if len(b.packets) < 128 { + // Double the number of packets. + newPackets = len(b.packets) + } else { + // Increase the number of packets by 25%. + newPackets = len(b.packets) / 4 + } + b.packets = append(b.packets[:b.write], append(make([]AddrPacket, newPackets), b.packets[b.write:]...)...) + + // If write pointer is behind (wrapped around) or even with read + // pointer, must move read pointer forward. + b.full = false + if b.write <= b.read { + b.read += newPackets + } + } + + // Store the packet at the write pointer. + packet := &b.packets[b.write] + packet.data.Reset() + n, err := packet.data.Write(p) + if err != nil { + b.mutex.Unlock() + return n, err + } + packet.addr = addr + + // Increment write pointer. + b.write++ + + // If the write pointer is equal to the length of the buffer, wrap around. + if len(b.packets) == b.write { + b.write = 0 + } + + // If a write resulted in making write and read pointers equivalent, then we + // are full. + if b.write == b.read { + b.full = true + } + + b.mutex.Unlock() + + if notify != nil { + close(notify) + } + + return n, nil +} + +// ReadFrom reads a single packet from the buffer, or blocks until one is +// available. +func (b *PacketBuffer) ReadFrom(packet []byte) (n int, addr net.Addr, err error) { + select { + case <-b.readDeadline.Done(): + return 0, nil, ErrTimeout + default: + } + + for { + b.mutex.Lock() + + if b.read != b.write || b.full { + ap := b.packets[b.read] + if len(packet) < ap.data.Len() { + b.mutex.Unlock() + return 0, nil, io.ErrShortBuffer + } + + // Copy packet data from buffer. + n, err := ap.data.Read(packet) + if err != nil { + b.mutex.Unlock() + return n, nil, err + } + + // Advance read pointer. + b.read++ + if len(b.packets) == b.read { + b.read = 0 + } + + // If we were full before reading and have successfully read, we are + // no longer full. + if b.full { + b.full = false + } + + b.mutex.Unlock() + + return n, ap.addr, nil + } + + if b.closed { + b.mutex.Unlock() + return 0, nil, io.EOF + } + + if b.notify == nil { + b.notify = make(chan struct{}) + } + notify := b.notify + b.mutex.Unlock() + + select { + case <-b.readDeadline.Done(): + return 0, nil, ErrTimeout + case <-notify: + } + } +} + +// Close closes the buffer, allowing unread packets to be read, but erroring on +// any new writes. +func (b *PacketBuffer) Close() (err error) { + b.mutex.Lock() + + if b.closed { + b.mutex.Unlock() + return nil + } + + notify := b.notify + b.notify = nil + b.closed = true + + b.mutex.Unlock() + + if notify != nil { + close(notify) + } + + return nil +} + +// SetReadDeadline sets the read deadline for the buffer. +func (b *PacketBuffer) SetReadDeadline(t time.Time) error { + b.readDeadline.Set(t) + return nil +} diff --git a/internal/net/buffer_test.go b/internal/net/buffer_test.go new file mode 100644 index 000000000..c269102cd --- /dev/null +++ b/internal/net/buffer_test.go @@ -0,0 +1,401 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package net implements DTLS specific networking primitives. +package net + +import ( + "bytes" + "errors" + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBuffer(t *testing.T) { + assert := assert.New(t) + + buffer := NewPacketBuffer() + packet := make([]byte, 4) + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + assert.NoError(err) + + // Write once. + n, err := buffer.WriteTo([]byte{0, 1}, addr) + assert.NoError(err) + assert.Equal(2, n) + + // Read once. + var raddr net.Addr + n, raddr, err = buffer.ReadFrom(packet) + assert.NoError(err) + assert.Equal(2, n) + assert.Equal([]byte{0, 1}, packet[:n]) + assert.Equal(addr, raddr) + + // Read deadline. + err = buffer.SetReadDeadline(time.Unix(0, 1)) + assert.NoError(err) + n, raddr, err = buffer.ReadFrom(packet) + assert.EqualError(err, ErrTimeout.Error()) + assert.Equal(0, n) + assert.Equal(nil, raddr) + + // Reset deadline. + err = buffer.SetReadDeadline(time.Time{}) + assert.NoError(err) + + // Write twice. + n, err = buffer.WriteTo([]byte{2, 3, 4}, addr) + assert.NoError(err) + assert.Equal(3, n) + + n, err = buffer.WriteTo([]byte{5, 6, 7}, addr) + assert.NoError(err) + assert.Equal(3, n) + + // Read twice. + n, raddr, err = buffer.ReadFrom(packet) + assert.NoError(err) + assert.Equal(3, n) + assert.Equal([]byte{2, 3, 4}, packet[:n]) + assert.Equal(addr, raddr) + + n, raddr, err = buffer.ReadFrom(packet) + assert.NoError(err) + assert.Equal(3, n) + assert.Equal([]byte{5, 6, 7}, packet[:n]) + assert.Equal(addr, raddr) + + // Write once prior to close. + _, err = buffer.WriteTo([]byte{3}, addr) + assert.NoError(err) + + // Close. + assert.NoError(buffer.Close()) + + // Future writes will error. + _, err = buffer.WriteTo([]byte{4}, addr) + assert.Error(err) + + // But we can read the remaining data. + n, raddr, err = buffer.ReadFrom(packet) + assert.NoError(err) + assert.Equal(1, n) + assert.Equal([]byte{3}, packet[:n]) + assert.Equal(addr, raddr) + + // Until EOF. + _, _, err = buffer.ReadFrom(packet) + assert.Equal(io.EOF, err) +} + +func TestShortBuffer(t *testing.T) { + assert := assert.New(t) + + buffer := NewPacketBuffer() + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + assert.NoError(err) + + // Write once. + n, err := buffer.WriteTo([]byte{0, 1, 2, 3}, addr) + assert.NoError(err) + assert.Equal(4, n) + + // Try to read with a short buffer. + packet := make([]byte, 3) + var raddr net.Addr + n, raddr, err = buffer.ReadFrom(packet) + assert.Equal(io.ErrShortBuffer, err) + assert.Equal(nil, raddr) + assert.Equal(0, n) + + // Close. + assert.NoError(buffer.Close()) + + // Make sure you can Close twice. + assert.NoError(buffer.Close()) +} + +func TestWraparound(t *testing.T) { + assert := assert.New(t) + + buffer := NewPacketBuffer() + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + assert.NoError(err) + + // Write multiple. + n, err := buffer.WriteTo([]byte{0, 1, 2, 3}, addr) + assert.NoError(err) + assert.Equal(4, n) + + n, err = buffer.WriteTo([]byte{4, 5}, addr) + assert.NoError(err) + assert.Equal(2, n) + + n, err = buffer.WriteTo([]byte{6, 7, 8}, addr) + assert.NoError(err) + assert.Equal(3, n) + + // Verify underlying buffer length. + // Packet 1: buffer does not grow. + // Packet 2: buffer doubles from 1 to 2. + // Packet 3: buffer doubles from 2 to 4. + assert.Equal(4, len(buffer.packets)) + + // Read once. + packet := make([]byte, 4) + var raddr net.Addr + n, raddr, err = buffer.ReadFrom(packet) + assert.NoError(err) + assert.Equal(4, n) + assert.Equal([]byte{0, 1, 2, 3}, packet[:n]) + assert.Equal(addr, raddr) + + // Write again. + n, err = buffer.WriteTo([]byte{9, 10, 11}, addr) + assert.NoError(err) + assert.Equal(3, n) + + // Verify underlying buffer length. + // No change in buffer size. + assert.Equal(4, len(buffer.packets)) + + // Write again and verify buffer grew. + n, err = buffer.WriteTo([]byte{12, 13, 14, 15, 16, 17, 18, 19}, addr) + assert.NoError(err) + assert.Equal(8, n) + assert.Equal(4, len(buffer.packets)) + + // Close. + assert.NoError(buffer.Close()) +} + +func TestBufferAsync(t *testing.T) { + assert := assert.New(t) + + buffer := NewPacketBuffer() + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + assert.NoError(err) + + // Start up a goroutine to start a blocking read. + done := make(chan struct{}) + go func() { + packet := make([]byte, 4) + + n, raddr, rErr := buffer.ReadFrom(packet) + assert.NoError(rErr) + assert.Equal(2, n) + assert.Equal([]byte{0, 1}, packet[:n]) + assert.Equal(addr, raddr) + + _, _, err = buffer.ReadFrom(packet) + assert.Equal(io.EOF, err) + + close(done) + }() + + // Wait for the reader to start reading. + time.Sleep(time.Millisecond) + + // Write once + n, err := buffer.WriteTo([]byte{0, 1}, addr) + assert.NoError(err) + assert.Equal(2, n) + + // Wait for the reader to start reading again. + time.Sleep(time.Millisecond) + + // Close will unblock the reader. + assert.NoError(buffer.Close()) + + <-done +} + +func TestBufferAlloc(t *testing.T) { + packet := make([]byte, 1024) + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + if err != nil { + t.Fatalf("net.ResolveUDPAddr: %v", err) + } + + test := func(f func(count int) func(), count int, max float64) func(t *testing.T) { + return func(t *testing.T) { + allocs := testing.AllocsPerRun(3, f(count)) + if allocs > max { + t.Errorf("count=%v, max=%v, got %v", + count, max, allocs, + ) + } + } + } + + w := func(count int) func() { + return func() { + buffer := NewPacketBuffer() + for i := 0; i < count; i++ { + if _, err := buffer.WriteTo(packet, addr); err != nil { + t.Errorf("WriteTo: %v", err) + break + } + } + } + } + + // NOTE: these are noticeably higher than packetio.Buffer as each packet's + // bytes.Buffer will allocate at least once. + t.Run("100 writes", test(w, 100, 127)) + t.Run("200 writes", test(w, 200, 232)) + t.Run("400 writes", test(w, 400, 442)) + t.Run("1000 writes", test(w, 1000, 1051)) + + wr := func(count int) func() { + return func() { + buffer := NewPacketBuffer() + for i := 0; i < count; i++ { + if _, err := buffer.WriteTo(packet, addr); err != nil { + t.Fatalf("Write: %v", err) + } + if _, _, err := buffer.ReadFrom(packet); err != nil { + t.Fatalf("ReadFrom: %v", err) + } + } + } + } + + t.Run("100 writes and reads", test(wr, 100, 7)) + t.Run("1000 writes and reads", test(wr, 1000, 7)) + t.Run("10000 writes and reads", test(wr, 10000, 7)) +} + +func benchmarkBufferWR(b *testing.B, size int64, write bool, grow int) { // nolint:unparam + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + if err != nil { + b.Fatalf("net.ResolveUDPAddr: %v", err) + } + buffer := NewPacketBuffer() + packet := make([]byte, size) + + // Grow the buffer first + pad := make([]byte, 1022) + for len(buffer.packets) < grow { + if _, err := buffer.WriteTo(pad, addr); err != nil { + b.Fatalf("Write: %v", err) + } + } + for buffer.read != buffer.write { + if _, _, err := buffer.ReadFrom(pad); err != nil { + b.Fatalf("ReadFrom: %v", err) + } + } + + if write { + if _, err := buffer.WriteTo(packet, addr); err != nil { + b.Fatalf("Write: %v", err) + } + } + + b.SetBytes(size) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if _, err := buffer.WriteTo(packet, addr); err != nil { + b.Fatalf("Write: %v", err) + } + if _, _, err := buffer.ReadFrom(packet); err != nil { + b.Fatalf("Write: %v", err) + } + } +} + +// In this benchmark, the buffer is often empty, which is hopefully +// typical of real usage. +func BenchmarkBufferWR14(b *testing.B) { + benchmarkBufferWR(b, 14, false, 128) +} + +func BenchmarkBufferWR140(b *testing.B) { + benchmarkBufferWR(b, 140, false, 128) +} + +func BenchmarkBufferWR1400(b *testing.B) { + benchmarkBufferWR(b, 1400, false, 128) +} + +// Here, the buffer never becomes empty, which forces wraparound +func BenchmarkBufferWWR14(b *testing.B) { + benchmarkBufferWR(b, 14, true, 128) +} + +func BenchmarkBufferWWR140(b *testing.B) { + benchmarkBufferWR(b, 140, true, 128) +} + +func BenchmarkBufferWWR1400(b *testing.B) { + benchmarkBufferWR(b, 1400, true, 128) +} + +func benchmarkBuffer(b *testing.B, size int64) { + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + if err != nil { + b.Fatalf("net.ResolveUDPAddr: %v", err) + } + buffer := NewPacketBuffer() + b.SetBytes(size) + + done := make(chan struct{}) + go func() { + packet := make([]byte, size) + + for { + _, _, err := buffer.ReadFrom(packet) + if errors.Is(err, io.EOF) { + break + } else if err != nil { + b.Error(err) + break + } + } + + close(done) + }() + + packet := make([]byte, size) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var err error + for { + _, err = buffer.WriteTo(packet, addr) + if !errors.Is(err, bytes.ErrTooLarge) { + break + } + time.Sleep(time.Microsecond) + } + if err != nil { + b.Fatal(err) + } + } + + if err := buffer.Close(); err != nil { + b.Fatal(err) + } + + <-done +} + +func BenchmarkBuffer14(b *testing.B) { + benchmarkBuffer(b, 14) +} + +func BenchmarkBuffer140(b *testing.B) { + benchmarkBuffer(b, 140) +} + +func BenchmarkBuffer1400(b *testing.B) { + benchmarkBuffer(b, 1400) +} diff --git a/internal/net/udp/packet_conn.go b/internal/net/udp/packet_conn.go new file mode 100644 index 000000000..2bec3696f --- /dev/null +++ b/internal/net/udp/packet_conn.go @@ -0,0 +1,398 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package udp implements DTLS specific UDP networking primitives. +// NOTE: this package is an adaption of pion/transport/udp. If possible, the +// updates made in this repository will be reflected back upstream. If not, it +// is likely that this will be moved to a public package in this repository. +package udp + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + idtlsnet "github.com/pion/dtls/v2/internal/net" + dtlsnet "github.com/pion/dtls/v2/pkg/net" + "github.com/pion/transport/v2/deadline" +) + +const ( + receiveMTU = 8192 + defaultListenBacklog = 128 // same as Linux default +) + +// Typed errors +var ( + ErrClosedListener = errors.New("udp: listener closed") + ErrListenQueueExceeded = errors.New("udp: listen queue exceeded") +) + +// listener augments a connection-oriented Listener over a UDP PacketConn +type listener struct { + pConn *net.UDPConn + + accepting atomic.Value // bool + acceptCh chan *PacketConn + doneCh chan struct{} + doneOnce sync.Once + acceptFilter func([]byte) bool + connResolver func([]byte, net.Addr) string + connIdentifier func([]byte, net.Addr) (string, bool) + + connLock sync.Mutex + conns map[string]*PacketConn + connWG *sync.WaitGroup + + readWG sync.WaitGroup + errClose atomic.Value // error + + readDoneCh chan struct{} + errRead atomic.Value // error +} + +// Accept waits for and returns the next connection to the listener. +func (l *listener) Accept() (net.PacketConn, net.Addr, error) { + select { + case c := <-l.acceptCh: + l.connWG.Add(1) + return c, c.raddr, nil + + case <-l.readDoneCh: + err, _ := l.errRead.Load().(error) + return nil, nil, err + + case <-l.doneCh: + return nil, nil, ErrClosedListener + } +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (l *listener) Close() error { + var err error + l.doneOnce.Do(func() { + l.accepting.Store(false) + close(l.doneCh) + + l.connLock.Lock() + // Close unaccepted connections + lclose: + for { + select { + case c := <-l.acceptCh: + close(c.doneCh) + // If we have an alternate identifier, remove it from the connection + // map. + if id := c.id.Load(); id != nil { + delete(l.conns, id.(string)) //nolint:forcetypeassert + } + // If we haven't already removed the remote address, remove it + // from the connection map. + if !c.rmraddr.Load() { + delete(l.conns, c.raddr.String()) + c.rmraddr.Store(true) + } + default: + break lclose + } + } + nConns := len(l.conns) + l.connLock.Unlock() + + l.connWG.Done() + + if nConns == 0 { + // Wait if this is the final connection. + l.readWG.Wait() + if errClose, ok := l.errClose.Load().(error); ok { + err = errClose + } + } else { + err = nil + } + }) + + return err +} + +// Addr returns the listener's network address. +func (l *listener) Addr() net.Addr { + return l.pConn.LocalAddr() +} + +// ListenConfig stores options for listening to an address. +type ListenConfig struct { + // Backlog defines the maximum length of the queue of pending + // connections. It is equivalent of the backlog argument of + // POSIX listen function. + // If a connection request arrives when the queue is full, + // the request will be silently discarded, unlike TCP. + // Set zero to use default value 128 which is same as Linux default. + Backlog int + + // AcceptFilter determines whether the new conn should be made for + // the incoming packet. If not set, any packet creates new conn. + AcceptFilter func([]byte) bool + + // ConnectionResolver resolves an incoming packet to a connection by + // extracting an identifier from the packet contents. + ConnectionResolver func([]byte, net.Addr) string + + // ConnectionIdentifier extracts an identifier from an outgoing packet. If + // the identifier is not already associated with the connection, it will be + // added. + ConnectionIdentifier func([]byte, net.Addr) (string, bool) +} + +// Listen creates a new listener based on the ListenConfig. +func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (dtlsnet.PacketListener, error) { + if lc.Backlog == 0 { + lc.Backlog = defaultListenBacklog + } + + conn, err := net.ListenUDP(network, laddr) + if err != nil { + return nil, err + } + + l := &listener{ + pConn: conn, + acceptCh: make(chan *PacketConn, lc.Backlog), + conns: make(map[string]*PacketConn), + doneCh: make(chan struct{}), + acceptFilter: lc.AcceptFilter, + connResolver: lc.ConnectionResolver, + connIdentifier: lc.ConnectionIdentifier, + connWG: &sync.WaitGroup{}, + readDoneCh: make(chan struct{}), + } + + l.accepting.Store(true) + l.connWG.Add(1) + l.readWG.Add(2) // wait readLoop and Close execution routine + + go l.readLoop() + go func() { + l.connWG.Wait() + if err := l.pConn.Close(); err != nil { + l.errClose.Store(err) + } + l.readWG.Done() + }() + + return l, nil +} + +// Listen creates a new listener using default ListenConfig. +func Listen(network string, laddr *net.UDPAddr) (dtlsnet.PacketListener, error) { + return (&ListenConfig{}).Listen(network, laddr) +} + +// readLoop dispatches packets to the proper connection, creating a new one if +// necessary, until all connections are closed. +func (l *listener) readLoop() { + defer l.readWG.Done() + defer close(l.readDoneCh) + + buf := make([]byte, receiveMTU) + + for { + n, raddr, err := l.pConn.ReadFrom(buf) + if err != nil { + l.errRead.Store(err) + return + } + conn, ok, err := l.getConn(raddr, buf[:n]) + if err != nil { + continue + } + if ok { + _, _ = conn.buffer.WriteTo(buf[:n], raddr) + } + } +} + +// getConn gets an existing connection or creates a new one. +func (l *listener) getConn(raddr net.Addr, buf []byte) (*PacketConn, bool, error) { + l.connLock.Lock() + defer l.connLock.Unlock() + // If we have a custom resolver, use it. + if l.connResolver != nil { + conn, ok := l.conns[l.connResolver(buf, raddr)] + if ok { + return conn, true, nil + } + } + + // If we don't have a custom resolver, or we were unable to find an + // associated connection, fall back to remote address. + conn, ok := l.conns[raddr.String()] + if !ok { + if isAccepting, ok := l.accepting.Load().(bool); !isAccepting || !ok { + return nil, false, ErrClosedListener + } + if l.acceptFilter != nil { + if !l.acceptFilter(buf) { + return nil, false, nil + } + } + conn = l.newPacketConn(raddr) + select { + case l.acceptCh <- conn: + l.conns[raddr.String()] = conn + default: + return nil, false, ErrListenQueueExceeded + } + } + return conn, true, nil +} + +// PacketConn is a net.PacketConn implementation that is able to dictate its +// routing ID via an alternate identifier from its remote address. Internal +// buffering is performed for reads, and writes are passed through to the +// underlying net.PacketConn. +type PacketConn struct { + listener *listener + + raddr net.Addr + rmraddr atomic.Bool + id atomic.Value + + buffer *idtlsnet.PacketBuffer + + doneCh chan struct{} + doneOnce sync.Once + + writeDeadline *deadline.Deadline +} + +// newPacketConn constructs a new PacketConn. +func (l *listener) newPacketConn(raddr net.Addr) *PacketConn { + return &PacketConn{ + listener: l, + raddr: raddr, + buffer: idtlsnet.NewPacketBuffer(), + doneCh: make(chan struct{}), + writeDeadline: deadline.New(), + } +} + +// ReadFrom reads a single packet payload and its associated remote address from +// the underlying buffer. +func (c *PacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + return c.buffer.ReadFrom(p) +} + +// WriteTo writes len(p) bytes from p to the specified address. +func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + // If we have a connection identifier, check to see if the outgoing packet + // sets it. + if c.listener.connIdentifier != nil { + id := c.id.Load() + candidate, ok := c.listener.connIdentifier(p, addr) + // If this is a new identifier, add entry to connection map. + if ok && id != candidate { + c.listener.connLock.Lock() + c.listener.conns[candidate] = c + if id != nil { + delete(c.listener.conns, id.(string)) //nolint:forcetypeassert + } + c.listener.connLock.Unlock() + c.id.Store(candidate) + } + // If we are writing to a remote address that differs from the initial, + // we have an alternate identifier established, and we haven't already + // freed the remote address, free the remote address to be used by + // another connection. + // Note: this strategy results in holding onto a remote address after it + // is potentially no longer in use by the client. However, releasing + // earlier means that we could miss some packets that should have been + // routed to this connection. Ideally, we would drop the connection + // entry for the remote address as soon as the client starts sending + // using an alternate identifier, but in practice this proves + // challenging because any client could spoof a connection identifier, + // resulting in the remote address entry being dropped prior to the + // "real" client transitioning to sending using the alternate + // identifier. + if id != nil && !c.rmraddr.Load() && addr.String() != c.raddr.String() { + c.listener.connLock.Lock() + delete(c.listener.conns, c.raddr.String()) + c.rmraddr.Store(true) + c.listener.connLock.Unlock() + } + } + + select { + case <-c.writeDeadline.Done(): + return 0, context.DeadlineExceeded + default: + } + return c.listener.pConn.WriteTo(p, addr) +} + +// Close closes the conn and releases any Read calls +func (c *PacketConn) Close() error { + var err error + c.doneOnce.Do(func() { + c.listener.connWG.Done() + close(c.doneCh) + c.listener.connLock.Lock() + // If we have an alternate identifier, remove it from the connection + // map. + if id := c.id.Load(); id != nil { + delete(c.listener.conns, id.(string)) //nolint:forcetypeassert + } + // If we haven't already removed the remote address, remove it from the + // connection map. + if !c.rmraddr.Load() { + delete(c.listener.conns, c.raddr.String()) + c.rmraddr.Store(true) + } + nConns := len(c.listener.conns) + c.listener.connLock.Unlock() + + if isAccepting, ok := c.listener.accepting.Load().(bool); nConns == 0 && !isAccepting && ok { + // Wait if this is the final connection + c.listener.readWG.Wait() + if errClose, ok := c.listener.errClose.Load().(error); ok { + err = errClose + } + } else { + err = nil + } + + if errBuf := c.buffer.Close(); errBuf != nil && err == nil { + err = errBuf + } + }) + + return err +} + +// LocalAddr implements net.PacketConn.LocalAddr. +func (c *PacketConn) LocalAddr() net.Addr { + return c.listener.pConn.LocalAddr() +} + +// SetDeadline implements net.PacketConn.SetDeadline. +func (c *PacketConn) SetDeadline(t time.Time) error { + c.writeDeadline.Set(t) + return c.SetReadDeadline(t) +} + +// SetReadDeadline implements net.PacketConn.SetReadDeadline. +func (c *PacketConn) SetReadDeadline(t time.Time) error { + return c.buffer.SetReadDeadline(t) +} + +// SetWriteDeadline implements net.PacketConn.SetWriteDeadline. +func (c *PacketConn) SetWriteDeadline(t time.Time) error { + c.writeDeadline.Set(t) + // Write deadline of underlying connection should not be changed + // since the connection can be shared. + return nil +} diff --git a/internal/net/udp/packet_conn_test.go b/internal/net/udp/packet_conn_test.go new file mode 100644 index 000000000..10ea85851 --- /dev/null +++ b/internal/net/udp/packet_conn_test.go @@ -0,0 +1,678 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js +// +build !js + +// Package udp implements DTLS specific UDP networking primitives. +package udp + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "sync" + "testing" + "time" + + dtlsnet "github.com/pion/dtls/v2/pkg/net" + "github.com/pion/transport/test" +) + +var errHandshakeFailed = errors.New("handshake failed") + +func TestStressDuplex(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + // Run the test + stressDuplex(t) +} + +type rw struct { + p net.PacketConn + raddr net.Addr +} + +func fromPC(p net.PacketConn, raddr net.Addr) *rw { + return &rw{ + p: p, + raddr: raddr, + } +} + +func (r *rw) Read(p []byte) (int, error) { + n, _, err := r.p.ReadFrom(p) + return n, err +} + +func (r *rw) Write(p []byte) (int, error) { + return r.p.WriteTo(p, r.raddr) +} + +func stressDuplex(t *testing.T) { + listener, ca, cb, err := pipe() + if err != nil { + t.Fatal(err) + } + + defer func() { + if ca.Close() != nil { + t.Fatal(err) + } + if cb.Close() != nil { + t.Fatal(err) + } + if listener.Close() != nil { + t.Fatal(err) + } + }() + + opt := test.Options{ + MsgSize: 2048, + MsgCount: 1, // Can't rely on UDP message order in CI + } + + if err := test.StressDuplex(fromPC(ca, cb.LocalAddr()), cb, opt); err != nil { + t.Fatal(err) + } +} + +func TestListenerCloseTimeout(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 5) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + listener, ca, _, err := pipe() + if err != nil { + t.Fatal(err) + } + + err = listener.Close() + if err != nil { + t.Fatal(err) + } + + // Close client after server closes to cleanup + err = ca.Close() + if err != nil { + t.Fatal(err) + } +} + +func TestListenerCloseUnaccepted(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + const backlog = 2 + + network, addr := getConfig() + listener, err := (&ListenConfig{ + Backlog: backlog, + }).Listen(network, addr) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < backlog; i++ { + conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) + if dErr != nil { + t.Error(dErr) + continue + } + if _, wErr := conn.Write([]byte{byte(i)}); wErr != nil { + t.Error(wErr) + } + if cErr := conn.Close(); cErr != nil { + t.Error(cErr) + } + } + + time.Sleep(100 * time.Millisecond) // Wait all packets being processed by readLoop + + // Unaccepted connections must be closed by listener.Close() + if err = listener.Close(); err != nil { + t.Fatal(err) + } +} + +func TestListenerAcceptFilter(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + testCases := map[string]struct { + packet []byte + accept bool + }{ + "CreateConn": { + packet: []byte{0xAA}, + accept: true, + }, + "Discarded": { + packet: []byte{0x00}, + accept: false, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + network, addr := getConfig() + listener, err := (&ListenConfig{ + AcceptFilter: func(pkt []byte) bool { + return pkt[0] == 0xAA + }, + }).Listen(network, addr) + if err != nil { + t.Fatal(err) + } + + var wgAcceptLoop sync.WaitGroup + wgAcceptLoop.Add(1) + defer func() { + if lErr := listener.Close(); lErr != nil { + t.Fatal(lErr) + } + wgAcceptLoop.Wait() + }() + + conn, err := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) + if err != nil { + t.Fatal(err) + } + if _, err := conn.Write(testCase.packet); err != nil { + t.Fatal(err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + + chAccepted := make(chan struct{}) + go func() { + defer wgAcceptLoop.Done() + + conn, _, aArr := listener.Accept() + if aArr != nil { + if !errors.Is(aArr, ErrClosedListener) { + t.Error(aArr) + } + return + } + close(chAccepted) + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + + var accepted bool + select { + case <-chAccepted: + accepted = true + case <-time.After(10 * time.Millisecond): + } + + if accepted != testCase.accept { + if testCase.accept { + t.Error("Packet should create new conn") + } else { + t.Error("Packet should not create new conn") + } + } + }) + } +} + +func TestListenerConcurrent(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + const backlog = 2 + + network, addr := getConfig() + listener, err := (&ListenConfig{ + Backlog: backlog, + }).Listen(network, addr) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < backlog+1; i++ { + conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) + if dErr != nil { + t.Error(dErr) + continue + } + if _, wErr := conn.Write([]byte{byte(i)}); wErr != nil { + t.Error(wErr) + } + if cErr := conn.Close(); cErr != nil { + t.Error(cErr) + } + } + + time.Sleep(100 * time.Millisecond) // Wait all packets being processed by readLoop + + for i := 0; i < backlog; i++ { + conn, _, lErr := listener.Accept() + if lErr != nil { + t.Error(lErr) + continue + } + b := make([]byte, 1) + n, _, lErr := conn.ReadFrom(b) + if lErr != nil { + t.Error(lErr) + } else if !bytes.Equal([]byte{byte(i)}, b[:n]) { + t.Errorf("Packet from connection %d is wrong, expected: [%d], got: %v", i, i, b[:n]) + } + if lErr = conn.Close(); lErr != nil { + t.Error(lErr) + } + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if conn, _, lErr := listener.Accept(); !errors.Is(lErr, ErrClosedListener) { + t.Errorf("Connection exceeding backlog limit must be discarded: %v", lErr) + if lErr == nil { + _ = conn.Close() + } + } + }() + + time.Sleep(100 * time.Millisecond) // Last Accept should be discarded + err = listener.Close() + if err != nil { + t.Fatal(err) + } + + wg.Wait() +} + +func pipe() (dtlsnet.PacketListener, net.PacketConn, *net.UDPConn, error) { + // Start listening + network, addr := getConfig() + listener, err := Listen(network, addr) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to listen: %w", err) + } + + // Open a connection + var dConn *net.UDPConn + dConn, err = net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to dial: %w", err) + } + + // Write to the connection to initiate it + handshake := "hello" + _, err = dConn.Write([]byte(handshake)) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to write to dialed Conn: %w", err) + } + + // Accept the connection + var lConn net.PacketConn + lConn, _, err = listener.Accept() + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to accept Conn: %w", err) + } + + var n int + buf := make([]byte, len(handshake)) + if n, _, err = lConn.ReadFrom(buf); err != nil { + return nil, nil, nil, fmt.Errorf("failed to read handshake: %w", err) + } + + result := string(buf[:n]) + if handshake != result { + return nil, nil, nil, fmt.Errorf("%w: %s != %s", errHandshakeFailed, handshake, result) + } + + return listener, lConn, dConn, nil +} + +func getConfig() (string, *net.UDPAddr) { + return "udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} +} + +func TestConnClose(t *testing.T) { + lim := test.TimeOut(time.Second * 5) + defer lim.Stop() + + t.Run("Close", func(t *testing.T) { + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + l, ca, cb, errPipe := pipe() + if errPipe != nil { + t.Fatal(errPipe) + } + if err := ca.Close(); err != nil { + t.Errorf("Failed to close A side: %v", err) + } + if err := cb.Close(); err != nil { + t.Errorf("Failed to close B side: %v", err) + } + if err := l.Close(); err != nil { + t.Errorf("Failed to close listener: %v", err) + } + }) + t.Run("CloseError1", func(t *testing.T) { + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + l, ca, cb, errPipe := pipe() + if errPipe != nil { + t.Fatal(errPipe) + } + // Close l.pConn to inject error. + if err := l.(*listener).pConn.Close(); err != nil { //nolint:forcetypeassert + t.Error(err) + } + + if err := cb.Close(); err != nil { + t.Errorf("Failed to close A side: %v", err) + } + if err := ca.Close(); err != nil { + t.Errorf("Failed to close B side: %v", err) + } + if err := l.Close(); err == nil { + t.Errorf("Error is not propagated to Listener.Close") + } + }) + t.Run("CloseError2", func(t *testing.T) { + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + l, ca, cb, errPipe := pipe() + if errPipe != nil { + t.Fatal(errPipe) + } + // Close l.pConn to inject error. + if err := l.(*listener).pConn.Close(); err != nil { //nolint:forcetypeassert + t.Error(err) + } + + if err := cb.Close(); err != nil { + t.Errorf("Failed to close A side: %v", err) + } + if err := l.Close(); err != nil { + t.Errorf("Failed to close listener: %v", err) + } + if err := ca.Close(); err == nil { + t.Errorf("Error is not propagated to Conn.Close") + } + }) + t.Run("CancelRead", func(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 5) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + l, ca, cb, errPipe := pipe() + if errPipe != nil { + t.Fatal(errPipe) + } + + errC := make(chan error, 1) + go func() { + buf := make([]byte, 1024) + // This read will block because we don't write on the other side. + // Calling Close must unblock the call. + _, _, err := ca.ReadFrom(buf) + errC <- err + }() + + if err := ca.Close(); err != nil { // Trigger Read cancellation. + t.Errorf("Failed to close B side: %v", err) + } + + // Main test condition, Read should return + // after ca.Close() by closing the buffer. + if err := <-errC; !errors.Is(err, io.EOF) { + t.Errorf("expected err to be io.EOF but got %v", err) + } + + if err := cb.Close(); err != nil { + t.Errorf("Failed to close A side: %v", err) + } + if err := l.Close(); err != nil { + t.Errorf("Failed to close listener: %v", err) + } + }) +} + +func TestListenerCustomConnID(t *testing.T) { + const helloPayload, setPayload = "hello", "set" + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + type pkt struct { + ID int + Payload string + } + network, addr := getConfig() + listener, err := (&ListenConfig{ + ConnectionResolver: func(buf []byte, raddr net.Addr) string { + p := &pkt{} + if err := json.Unmarshal(buf, p); err != nil { + t.Fatal(err) + } + if p.Payload == helloPayload { + return raddr.String() + } + return fmt.Sprint(p.ID) + }, + ConnectionIdentifier: func(buf []byte, _ net.Addr) (string, bool) { + p := &pkt{} + if err := json.Unmarshal(buf, p); err != nil { + t.Fatal(err) + } + if p.Payload == setPayload { + return fmt.Sprint(p.ID), true + } + return "", false + }, + }).Listen(network, addr) + if err != nil { + t.Fatal(err) + } + + clientWg := sync.WaitGroup{} + var phaseOne [5]chan struct{} + for i := range phaseOne { + phaseOne[i] = make(chan struct{}) + } + serverWg := sync.WaitGroup{} + clientMap := map[string]struct{}{} + var clientMapMu sync.Mutex + for i := 0; i < 5; i++ { + serverWg.Add(1) + go func() { + defer serverWg.Done() + conn, _, err := listener.Accept() + if err != nil { + t.Error(err) + } + buf := make([]byte, 100) + n, raddr, rErr := conn.ReadFrom(buf) + if rErr != nil { + t.Error(err) + } + p := &pkt{} + if uErr := json.Unmarshal(buf[:n], p); uErr != nil { + t.Error(err) + } + // First message should be a hello and custom connection + // ID function will use remote address as identifier. + // Connection ID is extracted to signal that we are + // ready for the second message. + if p.Payload != helloPayload { + t.Error("Expected hello message") + } + connID := p.ID + + // Send set message to associate connection ID with this connection. + buf, err = json.Marshal(&pkt{ + ID: connID, + Payload: "set", + }) + if err != nil { + t.Error(err) + } + if _, wErr := conn.WriteTo(buf, raddr); wErr != nil { + t.Error(wErr) + } + close(phaseOne[connID]) + for j := 0; j < 4; j++ { + buf := make([]byte, 100) + n, _, err := conn.ReadFrom(buf) + if err != nil { + t.Error(err) + } + p := &pkt{} + if err := json.Unmarshal(buf[:n], p); err != nil { + t.Error(err) + } + if p.ID != connID { + t.Errorf("Expected connection ID %d, but got %d", connID, p.ID) + } + // Ensure we only ever receive one message from + // a given client. + clientMapMu.Lock() + if _, ok := clientMap[p.Payload]; ok { + t.Errorf("Multiple messages from single client %s", p.Payload) + } + clientMap[p.Payload] = struct{}{} + clientMapMu.Unlock() + } + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + } + + for i := 0; i < 5; i++ { + clientWg.Add(1) + go func(connID int) { + defer clientWg.Done() + conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) + if dErr != nil { + t.Error(dErr) + } + hbuf, err := json.Marshal(&pkt{ + ID: connID, + Payload: helloPayload, + }) + if err != nil { + t.Error(err) + } + if _, wErr := conn.Write(hbuf); wErr != nil { + t.Error(wErr) + } + + p := &pkt{} + buf := make([]byte, 100) + n, err := conn.Read(buf) + if err != nil { + t.Error(err) + } + if err := json.Unmarshal(buf[:n], p); err != nil { + t.Error(err) + } + // Second message should be a set and custom connection + // function will update the connection ID from remote + // address to the supplied ID. + if p.Payload != "set" { + t.Error("Expected set message") + } + if p.ID != connID { + t.Errorf("Expected connection ID %d, but got %d", connID, p.ID) + } + // Close connection. We will reconnect from a different remote + // address using the same connection ID. + if cErr := conn.Close(); cErr != nil { + t.Error(cErr) + } + }(i) + } + + // Spawn 20 clients sending on 5 connections. + for i := 1; i <= 20; i++ { + clientWg.Add(1) + go func(connID int) { + defer clientWg.Done() + // Ensure that we are using a connection ID for packet + // routing prior to sending any messages. + <-phaseOne[connID] + conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) + if dErr != nil { + t.Error(dErr) + } + buf, err := json.Marshal(&pkt{ + ID: connID, + Payload: conn.LocalAddr().String(), + }) + if err != nil { + t.Error(err) + } + if _, wErr := conn.Write(buf); wErr != nil { + t.Error(wErr) + } + if cErr := conn.Close(); cErr != nil { + t.Error(cErr) + } + }(i % 5) + } + + // Wait for clients to exit. + clientWg.Wait() + // Wait for servers to exit. + serverWg.Wait() + if err := listener.Close(); err != nil { + t.Fatal(err) + } +} diff --git a/internal/util/net.go b/internal/util/net.go deleted file mode 100644 index 5a94dcf2a..000000000 --- a/internal/util/net.go +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-FileCopyrightText: 2023 The Pion community -// SPDX-License-Identifier: MIT - -// Package util contains small helpers used across the repo -package util - -import ( - "net" - "time" -) - -// packetConn wraps a net.Conn with methods that satisfy net.PacketConn. -type packetConn struct { - conn net.Conn -} - -// FromConn converts a net.Conn into a net.PacketConn. -func FromConn(conn net.Conn) net.PacketConn { - return &packetConn{conn} -} - -// ReadFrom reads from the underlying net.Conn and returns its remote address. -func (cp *packetConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, err := cp.conn.Read(b) - return n, cp.conn.RemoteAddr(), err -} - -// WriteTo writes to the underlying net.Conn. -func (cp *packetConn) WriteTo(b []byte, _ net.Addr) (int, error) { - n, err := cp.conn.Write(b) - return n, err -} - -// Close closes the underlying net.Conn. -func (cp *packetConn) Close() error { - return cp.conn.Close() -} - -// LocalAddr returns the local address of the underlying net.Conn. -func (cp *packetConn) LocalAddr() net.Addr { - return cp.conn.LocalAddr() -} - -// SetDeadline sets the deadline on the underlying net.Conn. -func (cp *packetConn) SetDeadline(t time.Time) error { - return cp.conn.SetDeadline(t) -} - -// SetReadDeadline sets the read deadline on the underlying net.Conn. -func (cp *packetConn) SetReadDeadline(t time.Time) error { - return cp.conn.SetReadDeadline(t) -} - -// SetWriteDeadline sets the write deadline on the underlying net.Conn. -func (cp *packetConn) SetWriteDeadline(t time.Time) error { - return cp.conn.SetWriteDeadline(t) -} diff --git a/listener.go b/listener.go index 0d281fc4d..3adea73dd 100644 --- a/listener.go +++ b/listener.go @@ -6,12 +6,57 @@ package dtls import ( "net" - "github.com/pion/dtls/v2/internal/util" + "github.com/pion/dtls/v2/internal/net/udp" + dtlsnet "github.com/pion/dtls/v2/pkg/net" "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/extension" + "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" - "github.com/pion/transport/v2/udp" ) +// cidConnResolver extracts connection IDs from incoming packets and uses them +// to route to the proper connection. +func cidConnResolver(packet []byte, raddr net.Addr) string { + pkts, err := recordlayer.UnpackDatagram(packet) + if err != nil || len(pkts) < 1 { + return raddr.String() + } + h := &recordlayer.Header{} + if err := h.Unmarshal(pkts[0]); err != nil { + return raddr.String() + } + if h.ContentType != protocol.ContentTypeConnectionID { + return raddr.String() + } + return string(h.ConnectionID) +} + +// cidConnIdentifier extracts connection IDs from outgoing ServerHello packets +// and associates them with the associated connection. +func cidConnIdentifier(packet []byte, _ net.Addr) (string, bool) { + pkts, err := recordlayer.UnpackDatagram(packet) + if err != nil || len(pkts) < 1 { + return "", false + } + h := &recordlayer.Header{} + if err := h.Unmarshal(pkts[0]); err != nil { + return "", false + } + if h.ContentType != protocol.ContentTypeHandshake { + return "", false + } + sh := &handshake.MessageServerHello{} + if err := sh.Unmarshal(pkts[0]); err != nil { + return "", false + } + for _, ext := range sh.Extensions { + if e, ok := ext.(*extension.ConnectionID); ok { + return string(e.CID), true + } + } + return "", false +} + // Listen creates a DTLS listener func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, error) { if err := validateConfig(config); err != nil { @@ -31,6 +76,12 @@ func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, e return h.ContentType == protocol.ContentTypeHandshake }, } + // If connection ID support is enabled, then they must be supported in + // routing. + if config.ConnectionIDGenerator != nil { + lc.ConnectionResolver = cidConnResolver + lc.ConnectionIdentifier = cidConnIdentifier + } parent, err := lc.Listen(network, laddr) if err != nil { return nil, err @@ -42,7 +93,7 @@ func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, e } // NewListener creates a DTLS listener which accepts connections from an inner Listener. -func NewListener(inner net.Listener, config *Config) (net.Listener, error) { +func NewListener(inner dtlsnet.PacketListener, config *Config) (net.Listener, error) { if err := validateConfig(config); err != nil { return nil, err } @@ -56,7 +107,7 @@ func NewListener(inner net.Listener, config *Config) (net.Listener, error) { // listener represents a DTLS listener type listener struct { config *Config - parent net.Listener + parent dtlsnet.PacketListener } // Accept waits for and returns the next connection to the listener. @@ -64,11 +115,11 @@ type listener struct { // Connection handshake will timeout using ConnectContextMaker in the Config. // If you want to specify the timeout duration, set ConnectContextMaker. func (l *listener) Accept() (net.Conn, error) { - c, err := l.parent.Accept() + c, raddr, err := l.parent.Accept() if err != nil { return nil, err } - return Server(util.FromConn(c), c.RemoteAddr(), l.config) + return Server(c, raddr, l.config) } // Close closes the listener. diff --git a/pkg/net/net.go b/pkg/net/net.go new file mode 100644 index 000000000..4e1e428a5 --- /dev/null +++ b/pkg/net/net.go @@ -0,0 +1,107 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package net defines packet-oriented primitives that are compatible with net +// in the standard library. +package net + +import ( + "net" + "time" +) + +// A PacketListener is the same as net.Listener but returns a net.PacketConn on +// Accept() rather than a net.Conn. +// +// Multiple goroutines may invoke methods on a PacketListener simultaneously. +type PacketListener interface { + // Accept waits for and returns the next connection to the listener. + Accept() (net.PacketConn, net.Addr, error) + + // Close closes the listener. + // Any blocked Accept operations will be unblocked and return errors. + Close() error + + // Addr returns the listener's network address. + Addr() net.Addr +} + +// PacketListenerFromListener converts a net.Listener into a +// dtlsnet.PacketListener. +func PacketListenerFromListener(l net.Listener) PacketListener { + return &plistener{ + l: l, + } +} + +// plistener wraps a net.Listener and implements dtlsnet.PacketListener. +type plistener struct { + l net.Listener +} + +// Accept calls Accept on the underlying net.Listener and converts the returned +// net.Conn into a net.PacketConn. +func (p *plistener) Accept() (net.PacketConn, net.Addr, error) { + c, err := p.l.Accept() + if err != nil { + return PacketConnFromConn(c), nil, err + } + return PacketConnFromConn(c), c.RemoteAddr(), nil +} + +// Close closes the underlying net.Listener. +func (p *plistener) Close() error { + return p.l.Close() +} + +// Addr returns the address of the underlying net.Listener. +func (p *plistener) Addr() net.Addr { + return p.l.Addr() +} + +// PacketConnFromConn converts a net.Conn into a net.PacketConn. +func PacketConnFromConn(conn net.Conn) net.PacketConn { + return &pconn{conn} +} + +// pconn wraps a net.Conn and implements net.PacketConn. +type pconn struct { + conn net.Conn +} + +// ReadFrom reads from the underlying net.Conn and returns its remote address. +func (p *pconn) ReadFrom(b []byte) (int, net.Addr, error) { + n, err := p.conn.Read(b) + return n, p.conn.RemoteAddr(), err +} + +// WriteTo writes to the underlying net.Conn. +func (p *pconn) WriteTo(b []byte, _ net.Addr) (int, error) { + n, err := p.conn.Write(b) + return n, err +} + +// Close closes the underlying net.Conn. +func (p *pconn) Close() error { + return p.conn.Close() +} + +// LocalAddr returns the local address of the underlying net.Conn. +func (p *pconn) LocalAddr() net.Addr { + return p.conn.LocalAddr() +} + +// SetDeadline sets the deadline on the underlying net.Conn. +func (p *pconn) SetDeadline(t time.Time) error { + return p.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying net.Conn. +func (p *pconn) SetReadDeadline(t time.Time) error { + return p.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying net.Conn. +func (p *pconn) SetWriteDeadline(t time.Time) error { + return p.conn.SetWriteDeadline(t) +} diff --git a/resume_test.go b/resume_test.go index 740a44e6d..fe78e22e2 100644 --- a/resume_test.go +++ b/resume_test.go @@ -13,8 +13,8 @@ import ( "testing" "time" - "github.com/pion/dtls/v2/internal/util" "github.com/pion/dtls/v2/pkg/crypto/selfsign" + dtlsnet "github.com/pion/dtls/v2/pkg/net" "github.com/pion/transport/v2/test" ) @@ -68,7 +68,7 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.PacketConn, net.Add go func() { var remote *Conn var errR error - remote, errR = newRemote(util.FromConn(remoteConn), remoteConn.RemoteAddr(), config) + remote, errR = newRemote(dtlsnet.PacketConnFromConn(remoteConn), remoteConn.RemoteAddr(), config) if errR != nil { errChan <- errR } @@ -90,7 +90,7 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.PacketConn, net.Add }() var local *Conn - local, err = newLocal(util.FromConn(localConn1), localConn1.RemoteAddr(), config) + local, err = newLocal(dtlsnet.PacketConnFromConn(localConn1), localConn1.RemoteAddr(), config) if err != nil { fatal(t, errChan, err) } @@ -133,7 +133,7 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.PacketConn, net.Add // Resume dtls connection var resumed net.Conn - resumed, err = Resume(deserialized, util.FromConn(localConn2), localConn2.RemoteAddr(), config) + resumed, err = Resume(deserialized, dtlsnet.PacketConnFromConn(localConn2), localConn2.RemoteAddr(), config) if err != nil { fatal(t, errChan, err) }