Skip to content

Commit

Permalink
GODRIVER-3302 Handle malformatted message length properly. (#1758) [m…
Browse files Browse the repository at this point in the history
…aster] (#1817)
  • Loading branch information
qingyang-hu authored Sep 17, 2024
1 parent 32ff39b commit e556841
Show file tree
Hide file tree
Showing 4 changed files with 462 additions and 115 deletions.
186 changes: 102 additions & 84 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package topology
import (
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -80,9 +81,9 @@ type connection struct {
// accessTokens in the OIDC authenticator cache.
oidcTokenGenID uint64

// awaitingResponse indicates that the server response was not completely
// awaitRemainingBytes indicates the size of server response that was not completely
// read before returning the connection to the pool.
awaitingResponse bool
awaitRemainingBytes *int32
}

// newConnection handles the creation of a connection. It does not connect the connection.
Expand Down Expand Up @@ -111,11 +112,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
return c
}

// DriverConnectionID returns the driver connection ID.
func (c *connection) DriverConnectionID() int64 {
return c.driverConnectionID
}

// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
// configuration.
func (c *connection) setGenerationNumber() {
Expand All @@ -137,6 +133,39 @@ func (c *connection) hasGenerationNumber() bool {
return driverutil.IsServerLoadBalanced(c.desc)
}

func configureTLS(ctx context.Context,
tlsConnSource tlsConnectionSource,
nc net.Conn,
addr address.Address,
config *tls.Config,
ocspOpts *ocsp.VerifyOptions,
) (net.Conn, error) {
// Ensure config.ServerName is always set for SNI.
if config.ServerName == "" {
hostname := addr.String()
colonPos := strings.LastIndex(hostname, ":")
if colonPos == -1 {
colonPos = len(hostname)
}

hostname = hostname[:colonPos]
config.ServerName = hostname
}

client := tlsConnSource.Client(nc, config)
if err := clientHandshake(ctx, client); err != nil {
return nil, err
}

// Only do OCSP verification if TLS verification is requested.
if !config.InsecureSkipVerify {
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
return nil, ocspErr
}
}
return client, nil
}

// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
// handshakes. All errors returned by connect are considered "before the handshake completes" and
// must be handled by calling the appropriate SDAM handshake error handler.
Expand Down Expand Up @@ -291,6 +320,10 @@ func (c *connection) closeConnectContext() {
}
}

func (c *connection) cancellationListenerCallback() {
_ = c.close()
}

func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
if originalError == nil {
return nil
Expand All @@ -313,10 +346,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
return originalError
}

func (c *connection) cancellationListenerCallback() {
_ = c.close()
}

func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
var err error
if atomic.LoadInt64(&c.state) != connConnected {
Expand Down Expand Up @@ -377,14 +406,9 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {

dst, errMsg, err := c.read(ctx)
if err != nil {
if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() {
// If the error was a timeout error, instead of closing the
// connection mark it as awaiting response so the pool can read the
// response before making it available to other operations.
c.awaitingResponse = true
} else {
// Otherwise, and close the connection because we don't know what
// the connection state is.
if c.awaitRemainingBytes == nil {
// If the connection was not marked as awaiting response, close the
// connection because we don't know what the connection state is.
c.close()
}
message := errMsg
Expand All @@ -401,6 +425,26 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
return dst, nil
}

func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) {
// read the length as an int32
size := int32(binary.LittleEndian.Uint32(wmSizeBytes[:]))

if size < 4 {
return 0, fmt.Errorf("malformed message length: %d", size)
}
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
// defaultMaxMessageSize instead.
maxMessageSize := c.desc.MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = defaultMaxMessageSize
}
if uint32(size) > maxMessageSize {
return 0, errResponseTooLarge
}

return size, nil
}

func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) {
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
defer func() {
Expand All @@ -414,36 +458,42 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
}
}()

isCSOTTimeout := func(err error) bool {
// If the error was a timeout error, instead of closing the
// connection mark it as awaiting response so the pool can read the
// response before making it available to other operations.
nerr := net.Error(nil)
return errors.As(err, &nerr) && nerr.Timeout()
}

// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
// reslice dst once instead of twice.
var sizeBuf [4]byte

// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
// because there might be more than one wire message waiting to be read, for example when
// reading messages from an exhaust cursor.
_, err = io.ReadFull(c.nc, sizeBuf[:])
n, err := io.ReadFull(c.nc, sizeBuf[:])
if err != nil {
if l := int32(n); l == 0 && isCSOTTimeout(err) {
c.awaitRemainingBytes = &l
}
return nil, "incomplete read of message header", err
}

// read the length as an int32
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)

// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
// defaultMaxMessageSize instead.
maxMessageSize := c.desc.MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = defaultMaxMessageSize
}
if uint32(size) > maxMessageSize {
return nil, errResponseTooLarge.Error(), errResponseTooLarge
size, err := c.parseWmSizeBytes(sizeBuf)
if err != nil {
return nil, err.Error(), err
}

dst := make([]byte, size)
copy(dst, sizeBuf[:])

_, err = io.ReadFull(c.nc, dst[4:])
n, err = io.ReadFull(c.nc, dst[4:])
if err != nil {
remainingBytes := size - 4 - int32(n)
if remainingBytes > 0 && isCSOTTimeout(err) {
c.awaitRemainingBytes = &remainingBytes
}
return dst, "incomplete read of full message", err
}

Expand Down Expand Up @@ -496,10 +546,6 @@ func (c *connection) setCanStream(canStream bool) {
c.canStream = canStream
}

func (c initConnection) supportsStreaming() bool {
return c.canStream
}

func (c *connection) setStreaming(streaming bool) {
c.currentlyStreaming = streaming
}
Expand All @@ -508,6 +554,14 @@ func (c *connection) getCurrentlyStreaming() bool {
return c.currentlyStreaming
}

func (c *connection) previousCanceled() bool {
if val := c.prevCanceled.Load(); val != nil {
return val.(bool)
}

return false
}

func (c *connection) ID() string {
return c.id
}
Expand All @@ -516,12 +570,17 @@ func (c *connection) ServerConnectionID() *int64 {
return c.serverConnectionID
}

func (c *connection) previousCanceled() bool {
if val := c.prevCanceled.Load(); val != nil {
return val.(bool)
}
// DriverConnectionID returns the driver connection ID.
func (c *connection) DriverConnectionID() int64 {
return c.driverConnectionID
}

return false
func (c *connection) OIDCTokenGenID() uint64 {
return c.oidcTokenGenID
}

func (c *connection) SetOIDCTokenGenID(genID uint64) {
c.oidcTokenGenID = genID
}

// initConnection is an adapter used during connection initialization. It has the minimum
Expand Down Expand Up @@ -562,7 +621,7 @@ func (c initConnection) CurrentlyStreaming() bool {
return c.getCurrentlyStreaming()
}
func (c initConnection) SupportsStreaming() bool {
return c.supportsStreaming()
return c.canStream
}

// Connection implements the driver.Connection interface to allow reading and writing wire
Expand Down Expand Up @@ -797,39 +856,6 @@ func (c *Connection) DriverConnectionID() int64 {
return c.connection.DriverConnectionID()
}

func configureTLS(ctx context.Context,
tlsConnSource tlsConnectionSource,
nc net.Conn,
addr address.Address,
config *tls.Config,
ocspOpts *ocsp.VerifyOptions,
) (net.Conn, error) {
// Ensure config.ServerName is always set for SNI.
if config.ServerName == "" {
hostname := addr.String()
colonPos := strings.LastIndex(hostname, ":")
if colonPos == -1 {
colonPos = len(hostname)
}

hostname = hostname[:colonPos]
config.ServerName = hostname
}

client := tlsConnSource.Client(nc, config)
if err := clientHandshake(ctx, client); err != nil {
return nil, err
}

// Only do OCSP verification if TLS verification is requested.
if !config.InsecureSkipVerify {
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
return nil, ocspErr
}
}
return client, nil
}

// OIDCTokenGenID returns the OIDC token generation ID.
func (c *Connection) OIDCTokenGenID() uint64 {
return c.oidcTokenGenID
Expand All @@ -839,11 +865,3 @@ func (c *Connection) OIDCTokenGenID() uint64 {
func (c *Connection) SetOIDCTokenGenID(genID uint64) {
c.oidcTokenGenID = genID
}

func (c *connection) OIDCTokenGenID() uint64 {
return c.oidcTokenGenID
}

func (c *connection) SetOIDCTokenGenID(genID uint64) {
c.oidcTokenGenID = genID
}
17 changes: 17 additions & 0 deletions x/mongo/driver/topology/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ func TestConnection(t *testing.T) {
}
listener.assertCalledOnce(t)
})
t.Run("size too small errors", func(t *testing.T) {
err := errors.New("malformed message length: 3")
tnc := &testNetConn{readerr: err, buf: []byte{0x03, 0x00, 0x00, 0x00}}
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
listener := newTestCancellationListener(false)
conn.cancellationListener = listener

want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: err.Error()}
_, got := conn.readWireMessage(context.Background())
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
if !tnc.closed {
t.Errorf("failed to closeConnection net.Conn after error writing bytes.")
}
listener.assertCalledOnce(t)
})
t.Run("full message read errors", func(t *testing.T) {
err := errors.New("Read error")
tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}}
Expand Down
Loading

0 comments on commit e556841

Please sign in to comment.