-
Notifications
You must be signed in to change notification settings - Fork 897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GODRIVER-3302 Handle malformatted message length properly. #1758
Conversation
API Change ReportNo changes found! |
03121d0
to
7a55a9d
Compare
7a55a9d
to
ddbd3e9
Compare
x/mongo/driver/topology/pool.go
Outdated
} | ||
size -= 4 | ||
} | ||
_, err = io.CopyN(ioutil.Discard, conn.nc, int64(size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ioutil
is deprecated, should use io.Discard
_, err = io.CopyN(ioutil.Discard, conn.nc, int64(size)) | |
_, err = io.CopyN(io.Discard, conn.nc, int64(size)) |
if err != nil { | ||
if l := size - 4 - int32(n); l > 0 && needToWait(err) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest renaming l
to remainingBytes
.
// read before returning the connection to the pool. | ||
awaitingResponse bool | ||
awaitingResponse *int32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest renaming awaitingResponse
to awaitRemainingBytes
.
var sizeBuf [4]byte | ||
_, err = io.ReadFull(conn.nc, sizeBuf[:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest making this logic a function that can be used here and in the connection read method:
func readWMSize(r io.Reader) (int32, error) {
const wireMessageSizePrefix = 4
var wmSizeBytes [wireMessageSizePrefix]byte
if _, err := io.ReadFull(r, wmSizeBytes[:]); err != nil {
return 0, fmt.Errorf("error reading the message size: %w", err)
}
size := (int32(wmSizeBytes[0])) |
(int32(wmSizeBytes[1]) << 8) |
(int32(wmSizeBytes[2]) << 16) |
(int32(wmSizeBytes[3]) << 24)
if size < 4 {
return 0, fmt.Errorf("malformed message length: %d", size)
}
return size, nil
}
@@ -115,12 +115,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection { | |||
return c | |||
} | |||
|
|||
// DriverConnectionID returns the driver connection ID. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved DriverConnectionID
down with other public methods.
func (c *connection) cancellationListenerCallback() { | ||
_ = c.close() | ||
} | ||
|
||
func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved transformNetworkError
closer to the caller.
@@ -537,10 +589,6 @@ func (c *connection) setCanStream(canStream bool) { | |||
c.canStream = canStream | |||
} | |||
|
|||
func (c initConnection) supportsStreaming() bool { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merged in (initConnection).SupportsStreaming()
.
@@ -833,39 +895,6 @@ func (c *Connection) DriverConnectionID() uint64 { | |||
return c.connection.DriverConnectionID() | |||
} | |||
|
|||
func configureTLS(ctx context.Context, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved closer to the caller.
@@ -919,11 +948,3 @@ func (c *cancellListener) StopListening() bool { | |||
c.done <- struct{}{} | |||
return c.aborted | |||
} | |||
|
|||
func (c *connection) OIDCTokenGenID() uint64 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved closer with other *connection
methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, great tests!
drivers-pr-bot please backport to master |
@@ -461,36 +506,43 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, | |||
} | |||
}() | |||
|
|||
needToWait := func(err error) bool { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional: Consider a more descriptive name, like isCSOTTimeout
.
size := (int32(wmSizeBytes[0])) | | ||
(int32(wmSizeBytes[1]) << 8) | | ||
(int32(wmSizeBytes[2]) << 16) | | ||
(int32(wmSizeBytes[3]) << 24) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional: Use the binary
package instead.
size := (int32(wmSizeBytes[0])) | | |
(int32(wmSizeBytes[1]) << 8) | | |
(int32(wmSizeBytes[2]) << 16) | | |
(int32(wmSizeBytes[3]) << 24) | |
size := int32(binary.LittleEndian.Uint32(wmSizeBytes[:])) |
x/mongo/driver/topology/pool_test.go
Outdated
errCh = make(chan error) | ||
|
||
var err error | ||
socket, err = net.Listen("unix", sockPath) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we should use TCP sockets to more accurately test the real-world use case of the connection pool. Is it possible to use TCP instead of Unix sockets?
Consider using the bootstrapConnections
helper used elsewhere in pool_test.go
, which also uses a connection teardown pattern that doesn't require using sleeps at the end of the connection handler.
x/mongo/driver/topology/pool_test.go
Outdated
var errCh chan error | ||
BGReadCallback = func(addr string, start, read time.Time, errs []error, connClosed bool) { | ||
defer close(errCh) | ||
|
||
for _, err := range errs { | ||
errCh <- err | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sharing this logic between all the subtests will make it easy to introduce unintentional interaction between subtests. We should set the callback for each subtest using test-specific channels. We should also set BGReadCallback
back to nil
at the end of each subtest to prevent interaction between tests.
E.g.
t.Run("subtest", func(t *testing.T) {
var errCh chan error
BGReadCallback = func(...) {
// ...
}
t.Cleanup(func() {
BGReadCallback = nil
})
// Make test assertions.
})
x/mongo/driver/topology/pool.go
Outdated
} | ||
_, err = io.CopyN(io.Discard, conn.nc, int64(size)) | ||
if err != nil { | ||
err = fmt.Errorf("error reading message of %d: %w", size, err) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional: This error message is a bit confusing. Consider a clearer error message.
err = fmt.Errorf("error reading message of %d: %w", size, err) | |
err = fmt.Errorf("error discarding %d byte message: %w", size, err) |
x/mongo/driver/topology/pool_test.go
Outdated
@@ -1122,6 +1126,226 @@ func TestPool(t *testing.T) { | |||
p.close(context.Background()) | |||
}) | |||
}) | |||
t.Run("bgRead", func(t *testing.T) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional: Consider moving the bgRead
subtests to a separate test function to allow de-indenting them one level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend refactoring the subtests in TestBackgroundRead
to simplify the assertions and speed up the tests. Here is the general recommended pattern:
cleanup := make(chan struct{})
defer close(cleanup)
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
defer func() {
<-cleanup
_ = nc.Close()
}()
// Write to the connection here.
})
// Test logic here.
var bgErr error
select {
case bgErr = <-errCh:
case <-time.After(100 * time.Millisecond):
t.Fatal("did not receive expected error after waiting for 100ms")
}
assert.EqualError(t, bgErr, "<expected error string>")
p.close(context.Background())
x/mongo/driver/topology/pool_test.go
Outdated
go func(t *testing.T) { | ||
}(t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This empty func can be removed.
go func(t *testing.T) { | |
}(t) |
x/mongo/driver/topology/pool_test.go
Outdated
wg := &sync.WaitGroup{} | ||
wg.Add(1) | ||
addr := bootstrapConnections(t, 1, func(nc net.Conn) { | ||
t.Helper() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this func is only part of this one test, we shouldn't use t.Helper()
here because it will obscure the actual error location. That applies for all other anonymous funcs passed to bootstrapConnections
in this test.
x/mongo/driver/topology/pool_test.go
Outdated
var err error | ||
_, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) | ||
noerr(t, err) | ||
time.Sleep(1500 * time.Millisecond) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're using all in-process connections, we should be able to reduce the sleeps and timeouts, which will speed up these tests significantly. I recommend a 50ms sleep here and 10ms timeout in csot.MakeTimeoutContext
, which is similar to the timeout values in other pool
tests.
The same recommendation applies to all subtests in TestBackgroundRead
.
x/mongo/driver/topology/pool_test.go
Outdated
time.Sleep(1500 * time.Millisecond) | ||
_, err = nc.Write([]byte{2, 3, 4}) | ||
noerr(t, err) | ||
time.Sleep(1500 * time.Millisecond) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can replace the sleep at the end of the connection handler with a "cleanup" signal, which will speed up the tests significantly.
E.g.
cleanup := make(chan struct{})
defer close(cleanup)
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
defer func() {
<-cleanup
_ = nc.Close()
}()
var err error
_, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1})
// ...
})
The same recommendation applies to all subtests in TestBackgroundRead
.
x/mongo/driver/topology/pool_test.go
Outdated
wg := &sync.WaitGroup{} | ||
wg.Add(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use a "cleanup" channel as previously recommended, this WaitGroup
will be unnecessary and can be removed.
x/mongo/driver/topology/pool_test.go
Outdated
wg.Wait() | ||
p.close(context.Background()) | ||
errs := []string{ | ||
"error discarding 3 byte message: EOF", | ||
} | ||
for i := 0; true; i++ { | ||
err, ok := <-errCh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we apply the previous recommended changes, we can simplify the assertion logic here.
var bgErr error
select {
case bgErr = <-errCh:
case <-time.After(100 * time.Millisecond):
t.Fatal("did not receive expected error after waiting for 100ms")
}
assert.EqualError(t, bgErr, "error discarding 3 byte message: EOF")
p.close(context.Background())
x/mongo/driver/topology/pool_test.go
Outdated
regex := regexp.MustCompile( | ||
`^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, | ||
) | ||
assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional: Consider printing the regex in the error message to make troubleshooting these failures easier.
This recommendation applies to all similar assertions in the TestBackgroundRead
subtests.
assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) | |
assert.True(t, regex.MatchString(err.Error()), "error %q does not match pattern %q", err, regex) |
x/mongo/driver/topology/pool_test.go
Outdated
conn, err := net.Dial("tcp", addr.String()) | ||
noerr(t, err) | ||
return newLimitConn(conn, 10), nil |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the point of using newLimitConn
here? Is the test trying to create timeout conditions or connection closed conditions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are correct. We no longer need it.
x/mongo/driver/topology/pool_test.go
Outdated
for i, err := range bgErrs { | ||
if i < len(wantErrs) { | ||
assert.EqualError(t, err, wantErrs[i], "mismatched err: %v", err) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion loop will pass if len(bgErrs) == 0
, even if wantErrs
contains expected error messages. We should fix and simplify the assertion logic.
E.g.
require.Len(t, bgErrs, "expected 1 error")
assert.EqualError(t, bgErrs[0], "error discarding 3 byte message: EOF")
x/mongo/driver/topology/pool_test.go
Outdated
for i, err := range bgErrs { | ||
if i < len(wantErrs) { | ||
assert.True(t, wantErrs[i].MatchString(err.Error()), "error %q does not match pattern %q", err, wantErrs[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion loop will pass if len(bgErrs) == 0
, even if wantErrs
contains expected error messages. We should fix and simplify the assertion logic.
E.g.
require.Len(t, bgErrs, "expected 1 error")
wantErrPattern := regexp.MustCompile(`^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`)
gotErr := bgErrs[0]
assert.True(t, wantErrPattern.MatchString(gotErr), "error %q does not match pattern %q", gotErr, wantErrPattern)
x/mongo/driver/topology/pool_test.go
Outdated
p := newPool( | ||
poolConfig{}, | ||
WithDialer(func(Dialer) Dialer { | ||
return DialerFunc(func(context.Context, string, string) (net.Conn, error) { | ||
return net.Dial("tcp", addr.String()) | ||
}) | ||
}), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional: Since there's no special dialing logic added here, we can pass addr
in the poolConfig
instead of overriding the dialer.
p := newPool( | |
poolConfig{}, | |
WithDialer(func(Dialer) Dialer { | |
return DialerFunc(func(context.Context, string, string) (net.Conn, error) { | |
return net.Dial("tcp", addr.String()) | |
}) | |
}), | |
) | |
p := newPool(poolConfig{ | |
Address: address.Address(addr.String()), | |
}) |
x/mongo/driver/topology/pool_test.go
Outdated
p := newPool( | ||
poolConfig{}, | ||
WithDialer(func(Dialer) Dialer { | ||
return DialerFunc(func(context.Context, string, string) (net.Conn, error) { | ||
return net.Dial("tcp", addr.String()) | ||
}) | ||
}), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional: Since there's no special dialing logic added here, we can pass addr
in the poolConfig
instead of overriding the dialer.
p := newPool( | |
poolConfig{}, | |
WithDialer(func(Dialer) Dialer { | |
return DialerFunc(func(context.Context, string, string) (net.Conn, error) { | |
return net.Dial("tcp", addr.String()) | |
}) | |
}), | |
) | |
p := newPool(poolConfig{ | |
Address: address.Address(addr.String()), | |
}) |
x/mongo/driver/topology/pool_test.go
Outdated
assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitRemainingBytes should be nil") | ||
close(errsCh) // this line causes a double close if BGReadCallback is ever called. | ||
}) | ||
t.Run("timeout on reading the message header", func(t *testing.T) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test seems to cover the scenario where the operation times out while waiting for the header, then bgRead
times out waiting for the full message (because not enough bytes are written). We should also test the scenario where the operation times out while waiting for the header, but then bgRead
succeeds.
E.g.
t.Run("timeout reading message header, successful background read", func(t *testing.T) {
// ...
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
defer func() {
<-cleanup
_ = nc.Close()
}()
// Wait until the operation times out, then write the full message
// length.
time.Sleep(timeout * 2)
_, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0, 0, 0})
noerr(t, err)
})
// ...
assert.Len(t, bgErrs, 0, "expected no errors from bgRead")
})
To differentiate it from the existing test, I recommend updating the test description and adding a comment:
t.Run("timeout reading message header, incomplete background read", func(t *testing.T) {
// ...
addr := bootstrapConnections(t, 1, func(nc net.Conn) {
defer func() {
<-cleanup
_ = nc.Close()
}()
// Wait until the operation times out, then write an incomplete
// message.
time.Sleep(timeout * 5)
_, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0})
noerr(t, err)
})
// ...
})
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will also add cases for "timeout reading message header, incomplete head during background read" and "timeout reading full message, successful background read".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! 👍
Sorry, unable to cherry-pick to master, please backport manually. Here are approximate instructions:
|
(cherry picked from commit be25b9a)
GODRIVER-3302
Summary
Background & Motivation
connection.read
has two places that can raise timeout errors, one is in reading the message header, and the other is in reading the remaining body of the message.The customer has observed an "incomplete read of full message" before the panic. In this scenario, we should not call
connection.read
directly in thebgRead
to start over the reading because the message header with the message size has already been read. Instead, we ought to head towards the message body with the given size.On the other hand, we need to start over in the
bgRead
if the header is not read at all, but cease processing if the header is only partially read to reduce the complexity.