diff --git a/conn.go b/conn.go index b0d8cde4f..b1f552537 100644 --- a/conn.go +++ b/conn.go @@ -1026,6 +1026,10 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh } else { switch { case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed): + case errors.Is(err, recordlayer.ErrInvalidPacketLength): + // Decode error must be silently discarded + // [RFC6347 Section-4.1.2.7] + continue default: if c.isHandshakeCompletedSuccessfully() { // Keep read loop and pass the read error to Read() diff --git a/conn_test.go b/conn_test.go index d3226d876..f0264df15 100644 --- a/conn_test.go +++ b/conn_test.go @@ -389,7 +389,7 @@ func TestHandshakeWithAlert(t *testing.T) { clientErr <- err }() - _, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), ca.RemoteAddr(), testCase.configServer, true) + _, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), testCase.configServer, true) if !errors.Is(errServer, testCase.errServer) { t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer) } @@ -402,6 +402,71 @@ func TestHandshakeWithAlert(t *testing.T) { } } +func TestHandshakeWithInvalidRecord(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + type result struct { + c *Conn + err error + } + clientErr := make(chan result, 1) + ca, cb := dpipe.Pipe() + caWithInvalidRecord := &connWithCallback{Conn: ca} + + var msgSeq atomic.Int32 + // Send invalid record after first message + caWithInvalidRecord.onWrite = func(b []byte) { + if msgSeq.Add(1) == 2 { + if _, err := ca.Write([]byte{0x01, 0x02}); err != nil { + t.Fatal(err) + } + } + } + go func() { + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(caWithInvalidRecord), caWithInvalidRecord.RemoteAddr(), &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + }, true) + clientErr <- result{client, err} + }() + + server, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + }, true) + + errClient := <-clientErr + + defer func() { + if server != nil { + if err := server.Close(); err != nil { + t.Fatal(err) + } + } + + if errClient.c != nil { + if err := errClient.c.Close(); err != nil { + t.Fatal(err) + } + } + }() + + if errServer != nil { + t.Fatalf("Server failed(%v)", errServer) + } + + if errClient.err != nil { + t.Fatalf("Client failed(%v)", errClient.err) + } +} + func TestExportKeyingMaterial(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) @@ -3096,3 +3161,15 @@ func TestSkipHelloVerify(t *testing.T) { t.Error(err) } } + +type connWithCallback struct { + net.Conn + onWrite func([]byte) +} + +func (c *connWithCallback) Write(b []byte) (int, error) { + if c.onWrite != nil { + c.onWrite(b) + } + return c.Conn.Write(b) +} diff --git a/pkg/protocol/recordlayer/errors.go b/pkg/protocol/recordlayer/errors.go index cd4cb60a5..1c1898844 100644 --- a/pkg/protocol/recordlayer/errors.go +++ b/pkg/protocol/recordlayer/errors.go @@ -11,9 +11,11 @@ import ( ) var ( - errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 - errInvalidPacketLength = &protocol.TemporaryError{Err: errors.New("packet length and declared length do not match")} //nolint:goerr113 - errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113 - errUnsupportedProtocolVersion = &protocol.FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113 - errInvalidContentType = &protocol.TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113 + // ErrInvalidPacketLength is returned when the packet length too small or declared length do not match + ErrInvalidPacketLength = &protocol.TemporaryError{Err: errors.New("packet length and declared length do not match")} //nolint:goerr113 + + errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 + errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113 + errUnsupportedProtocolVersion = &protocol.FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113 + errInvalidContentType = &protocol.TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113 ) diff --git a/pkg/protocol/recordlayer/recordlayer.go b/pkg/protocol/recordlayer/recordlayer.go index 213a7976a..4acacf074 100644 --- a/pkg/protocol/recordlayer/recordlayer.go +++ b/pkg/protocol/recordlayer/recordlayer.go @@ -100,12 +100,12 @@ func UnpackDatagram(buf []byte) ([][]byte, error) { for offset := 0; len(buf) != offset; { if len(buf)-offset <= FixedHeaderSize { - return nil, errInvalidPacketLength + return nil, ErrInvalidPacketLength } pktLen := (FixedHeaderSize + int(binary.BigEndian.Uint16(buf[offset+11:]))) if offset+pktLen > len(buf) { - return nil, errInvalidPacketLength + return nil, ErrInvalidPacketLength } out = append(out, buf[offset:offset+pktLen]) @@ -129,12 +129,12 @@ func ContentAwareUnpackDatagram(buf []byte, cidLength int) ([][]byte, error) { lenIdx += cidLength } if len(buf)-offset <= headerSize { - return nil, errInvalidPacketLength + return nil, ErrInvalidPacketLength } pktLen := (headerSize + int(binary.BigEndian.Uint16(buf[offset+lenIdx:]))) if offset+pktLen > len(buf) { - return nil, errInvalidPacketLength + return nil, ErrInvalidPacketLength } out = append(out, buf[offset:offset+pktLen]) diff --git a/pkg/protocol/recordlayer/recordlayer_test.go b/pkg/protocol/recordlayer/recordlayer_test.go index 2e16c0104..760a73040 100644 --- a/pkg/protocol/recordlayer/recordlayer_test.go +++ b/pkg/protocol/recordlayer/recordlayer_test.go @@ -39,12 +39,12 @@ func TestUDPDecode(t *testing.T) { { Name: "Invalid packet length", Data: []byte{0x14, 0xfe}, - WantError: errInvalidPacketLength, + WantError: ErrInvalidPacketLength, }, { Name: "Packet declared invalid length", Data: []byte{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0xFF, 0x01}, - WantError: errInvalidPacketLength, + WantError: ErrInvalidPacketLength, }, } { dtlsPkts, err := UnpackDatagram(test.Data)