Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-3302 Handle malformatted message length properly. #1758

Merged
merged 8 commits into from
Sep 17, 2024

Conversation

qingyang-hu
Copy link
Collaborator

@qingyang-hu qingyang-hu commented Aug 14, 2024

GODRIVER-3302

Summary

  • Add a sanity check for the message size
  • Update the background read logic

Background & Motivation

connection.read has two places that can raise timeout errors, one is in reading the message header, and the other is in reading the remaining body of the message.
The customer has observed an "incomplete read of full message" before the panic. In this scenario, we should not call connection.read directly in the bgRead to start over the reading because the message header with the message size has already been read. Instead, we ought to head towards the message body with the given size.
On the other hand, we need to start over in the bgRead if the header is not read at all, but cease processing if the header is only partially read to reduce the complexity.

@mongodb-drivers-pr-bot mongodb-drivers-pr-bot bot added the priority-3-low Low Priority PR for Review label Aug 14, 2024
Copy link
Contributor

mongodb-drivers-pr-bot bot commented Aug 14, 2024

API Change Report

No changes found!

@qingyang-hu qingyang-hu marked this pull request as ready for review August 15, 2024 13:33
@qingyang-hu qingyang-hu marked this pull request as draft August 15, 2024 17:41
@qingyang-hu qingyang-hu force-pushed the godriver3302 branch 5 times, most recently from 03121d0 to 7a55a9d Compare August 20, 2024 17:09
@qingyang-hu qingyang-hu marked this pull request as ready for review August 20, 2024 19:54
}
size -= 4
}
_, err = io.CopyN(ioutil.Discard, conn.nc, int64(size))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ioutil is deprecated, should use io.Discard

Suggested change
_, err = io.CopyN(ioutil.Discard, conn.nc, int64(size))
_, err = io.CopyN(io.Discard, conn.nc, int64(size))

if err != nil {
if l := size - 4 - int32(n); l > 0 && needToWait(err) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest renaming l to remainingBytes.

// read before returning the connection to the pool.
awaitingResponse bool
awaitingResponse *int32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest renaming awaitingResponse to awaitRemainingBytes.

Comment on lines +830 to +831
var sizeBuf [4]byte
_, err = io.ReadFull(conn.nc, sizeBuf[:])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest making this logic a function that can be used here and in the connection read method:

func readWMSize(r io.Reader) (int32, error) {
	const wireMessageSizePrefix = 4

	var wmSizeBytes [wireMessageSizePrefix]byte
	if _, err := io.ReadFull(r, wmSizeBytes[:]); err != nil {
		return 0, fmt.Errorf("error reading the message size: %w", err)
	}

	size := (int32(wmSizeBytes[0])) |
		(int32(wmSizeBytes[1]) << 8) |
		(int32(wmSizeBytes[2]) << 16) |
		(int32(wmSizeBytes[3]) << 24)

	if size < 4 {
		return 0, fmt.Errorf("malformed message length: %d", size)
	}

	return size, nil
}

@@ -115,12 +115,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
return c
}

// DriverConnectionID returns the driver connection ID.
Copy link
Collaborator Author

@qingyang-hu qingyang-hu Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved DriverConnectionID down with other public methods.

func (c *connection) cancellationListenerCallback() {
_ = c.close()
}

func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved transformNetworkError closer to the caller.

@@ -537,10 +589,6 @@ func (c *connection) setCanStream(canStream bool) {
c.canStream = canStream
}

func (c initConnection) supportsStreaming() bool {
Copy link
Collaborator Author

@qingyang-hu qingyang-hu Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged in (initConnection).SupportsStreaming().

@@ -833,39 +895,6 @@ func (c *Connection) DriverConnectionID() uint64 {
return c.connection.DriverConnectionID()
}

func configureTLS(ctx context.Context,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved closer to the caller.

@@ -919,11 +948,3 @@ func (c *cancellListener) StopListening() bool {
c.done <- struct{}{}
return c.aborted
}

func (c *connection) OIDCTokenGenID() uint64 {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved closer with other *connection methods.

prestonvasquez
prestonvasquez previously approved these changes Aug 27, 2024
Copy link
Collaborator

@prestonvasquez prestonvasquez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great tests!

@blink1073 blink1073 added the priority-2-medium Medium Priority PR for Review label Aug 29, 2024
@blink1073
Copy link
Member

drivers-pr-bot please backport to master

@qingyang-hu qingyang-hu removed the priority-3-low Low Priority PR for Review label Aug 30, 2024
@blink1073 blink1073 added priority-1-high High Priority PR for Review and removed priority-2-medium Medium Priority PR for Review labels Sep 12, 2024
@@ -461,36 +506,43 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
}
}()

needToWait := func(err error) bool {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Consider a more descriptive name, like isCSOTTimeout.

Comment on lines 475 to 478
size := (int32(wmSizeBytes[0])) |
(int32(wmSizeBytes[1]) << 8) |
(int32(wmSizeBytes[2]) << 16) |
(int32(wmSizeBytes[3]) << 24)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Use the binary package instead.

Suggested change
size := (int32(wmSizeBytes[0])) |
(int32(wmSizeBytes[1]) << 8) |
(int32(wmSizeBytes[2]) << 16) |
(int32(wmSizeBytes[3]) << 24)
size := int32(binary.LittleEndian.Uint32(wmSizeBytes[:]))

errCh = make(chan error)

var err error
socket, err = net.Listen("unix", sockPath)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we should use TCP sockets to more accurately test the real-world use case of the connection pool. Is it possible to use TCP instead of Unix sockets?

Consider using the bootstrapConnections helper used elsewhere in pool_test.go, which also uses a connection teardown pattern that doesn't require using sleeps at the end of the connection handler.

Comment on lines 1132 to 1139
var errCh chan error
BGReadCallback = func(addr string, start, read time.Time, errs []error, connClosed bool) {
defer close(errCh)

for _, err := range errs {
errCh <- err
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sharing this logic between all the subtests will make it easy to introduce unintentional interaction between subtests. We should set the callback for each subtest using test-specific channels. We should also set BGReadCallback back to nil at the end of each subtest to prevent interaction between tests.

E.g.

t.Run("subtest", func(t *testing.T) {
	var errCh chan error
	BGReadCallback = func(...) {
		// ...
	}
	t.Cleanup(func() {
		BGReadCallback = nil
	})
	// Make test assertions.
})

}
_, err = io.CopyN(io.Discard, conn.nc, int64(size))
if err != nil {
err = fmt.Errorf("error reading message of %d: %w", size, err)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: This error message is a bit confusing. Consider a clearer error message.

Suggested change
err = fmt.Errorf("error reading message of %d: %w", size, err)
err = fmt.Errorf("error discarding %d byte message: %w", size, err)

@@ -1122,6 +1126,226 @@ func TestPool(t *testing.T) {
p.close(context.Background())
})
})
t.Run("bgRead", func(t *testing.T) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Consider moving the bgRead subtests to a separate test function to allow de-indenting them one level.

Copy link
Collaborator

@matthewdale matthewdale left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend refactoring the subtests in TestBackgroundRead to simplify the assertions and speed up the tests. Here is the general recommended pattern:

cleanup := make(chan struct{})
defer close(cleanup)
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
	defer func() {
		<-cleanup
		_ = nc.Close()
	}()

	// Write to the connection here.
})

// Test logic here.

var bgErr error
select {
case bgErr = <-errCh:
case <-time.After(100 * time.Millisecond):
	t.Fatal("did not receive expected error after waiting for 100ms")
}
assert.EqualError(t, bgErr, "<expected error string>")

p.close(context.Background())

Comment on lines 1214 to 1215
go func(t *testing.T) {
}(t)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This empty func can be removed.

Suggested change
go func(t *testing.T) {
}(t)

wg := &sync.WaitGroup{}
wg.Add(1)
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
t.Helper()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this func is only part of this one test, we shouldn't use t.Helper() here because it will obscure the actual error location. That applies for all other anonymous funcs passed to bootstrapConnections in this test.

var err error
_, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1})
noerr(t, err)
time.Sleep(1500 * time.Millisecond)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're using all in-process connections, we should be able to reduce the sleeps and timeouts, which will speed up these tests significantly. I recommend a 50ms sleep here and 10ms timeout in csot.MakeTimeoutContext, which is similar to the timeout values in other pool tests.

The same recommendation applies to all subtests in TestBackgroundRead.

time.Sleep(1500 * time.Millisecond)
_, err = nc.Write([]byte{2, 3, 4})
noerr(t, err)
time.Sleep(1500 * time.Millisecond)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can replace the sleep at the end of the connection handler with a "cleanup" signal, which will speed up the tests significantly.

E.g.

cleanup := make(chan struct{})
defer close(cleanup)
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
	defer func() {
		<-cleanup
		_ = nc.Close()
	}()
	var err error
	_, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1})
	// ...
})

The same recommendation applies to all subtests in TestBackgroundRead.

Comment on lines 1268 to 1269
wg := &sync.WaitGroup{}
wg.Add(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use a "cleanup" channel as previously recommended, this WaitGroup will be unnecessary and can be removed.

Comment on lines 1311 to 1317
wg.Wait()
p.close(context.Background())
errs := []string{
"error discarding 3 byte message: EOF",
}
for i := 0; true; i++ {
err, ok := <-errCh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we apply the previous recommended changes, we can simplify the assertion logic here.

var bgErr error
select {
case bgErr = <-errCh:
case <-time.After(100 * time.Millisecond):
	t.Fatal("did not receive expected error after waiting for 100ms")
}
assert.EqualError(t, bgErr, "error discarding 3 byte message: EOF")

p.close(context.Background())

regex := regexp.MustCompile(
`^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Consider printing the regex in the error message to make troubleshooting these failures easier.

This recommendation applies to all similar assertions in the TestBackgroundRead subtests.

Suggested change
assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err)
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex)

Comment on lines 1280 to 1282
conn, err := net.Dial("tcp", addr.String())
noerr(t, err)
return newLimitConn(conn, 10), nil
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of using newLimitConn here? Is the test trying to create timeout conditions or connection closed conditions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct. We no longer need it.

Comment on lines 1310 to 1312
for i, err := range bgErrs {
if i < len(wantErrs) {
assert.EqualError(t, err, wantErrs[i], "mismatched err: %v", err)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion loop will pass if len(bgErrs) == 0, even if wantErrs contains expected error messages. We should fix and simplify the assertion logic.

E.g.

require.Len(t, bgErrs, "expected 1 error")
assert.EqualError(t, bgErrs[0], "error discarding 3 byte message: EOF")

Comment on lines 1242 to 1244
for i, err := range bgErrs {
if i < len(wantErrs) {
assert.True(t, wantErrs[i].MatchString(err.Error()), "error %q does not match pattern %q", err, wantErrs[i])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion loop will pass if len(bgErrs) == 0, even if wantErrs contains expected error messages. We should fix and simplify the assertion logic.

E.g.

require.Len(t, bgErrs, "expected 1 error")
wantErrPattern := regexp.MustCompile(`^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`)
gotErr := bgErrs[0]
assert.True(t, wantErrPattern.MatchString(gotErr), "error %q does not match pattern %q", gotErr, wantErrPattern)

Comment on lines 1208 to 1215
p := newPool(
poolConfig{},
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
return net.Dial("tcp", addr.String())
})
}),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Since there's no special dialing logic added here, we can pass addr in the poolConfig instead of overriding the dialer.

Suggested change
p := newPool(
poolConfig{},
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
return net.Dial("tcp", addr.String())
})
}),
)
p := newPool(poolConfig{
Address: address.Address(addr.String()),
})

Comment on lines 1161 to 1168
p := newPool(
poolConfig{},
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
return net.Dial("tcp", addr.String())
})
}),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Since there's no special dialing logic added here, we can pass addr in the poolConfig instead of overriding the dialer.

Suggested change
p := newPool(
poolConfig{},
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
return net.Dial("tcp", addr.String())
})
}),
)
p := newPool(poolConfig{
Address: address.Address(addr.String()),
})

assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitRemainingBytes should be nil")
close(errsCh) // this line causes a double close if BGReadCallback is ever called.
})
t.Run("timeout on reading the message header", func(t *testing.T) {
Copy link
Collaborator

@matthewdale matthewdale Sep 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test seems to cover the scenario where the operation times out while waiting for the header, then bgRead times out waiting for the full message (because not enough bytes are written). We should also test the scenario where the operation times out while waiting for the header, but then bgRead succeeds.

E.g.

t.Run("timeout reading message header, successful background read", func(t *testing.T) {
	// ...
	addr := bootstrapConnections(t, 1, func(nc net.Conn) {
		defer func() {
			<-cleanup
			_ = nc.Close()
		}()

		// Wait until the operation times out, then write the full message
		// length.
		time.Sleep(timeout * 2)
		_, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0, 0, 0})
		noerr(t, err)
	})
	// ...
	assert.Len(t, bgErrs, 0, "expected no errors from bgRead")
})

To differentiate it from the existing test, I recommend updating the test description and adding a comment:

t.Run("timeout reading message header, incomplete background read", func(t *testing.T) {
	// ...
	addr := bootstrapConnections(t, 1, func(nc net.Conn) {
		defer func() {
			<-cleanup
			_ = nc.Close()
		}()

		// Wait until the operation times out, then write an incomplete
		// message.
		time.Sleep(timeout * 5)
		_, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0})
		noerr(t, err)
	})
	// ...
})

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will also add cases for "timeout reading message header, incomplete head during background read" and "timeout reading full message, successful background read".

Copy link
Collaborator

@matthewdale matthewdale left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! 👍

@qingyang-hu qingyang-hu merged commit be25b9a into mongodb:v1 Sep 17, 2024
30 of 33 checks passed
Copy link
Contributor

Sorry, unable to cherry-pick to master, please backport manually. Here are approximate instructions:

  1. Checkout backport branch and update it.
git checkout -b cherry-pick-master-be25b9a26aff54fe27210f86d00db5d573452643 master

git fetch origin be25b9a26aff54fe27210f86d00db5d573452643
  1. Cherry pick the first parent branch of the this PR on top of the older branch:
git cherry-pick -x -m1 be25b9a26aff54fe27210f86d00db5d573452643
  1. You will likely have some merge/cherry-pick conflicts here, fix them and commit:
git commit -am {message}
  1. Push to a named branch:
git push origin cherry-pick-master-be25b9a26aff54fe27210f86d00db5d573452643
  1. Create a PR against branch master. I would have named this PR:

"GODRIVER-3302 Handle malformatted message length properly. (#1758) [master]"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
priority-1-high High Priority PR for Review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants