From 12a97f78f889b2cfd9094c558eafe338c2d6bd59 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 12 Jan 2024 17:53:09 -0700 Subject: [PATCH 1/3] GODRIVER-3049 Replace bsoncore.DocumentSequence with a bson type-agnostic analogue (#1492) --- internal/integration/mtest/sent_message.go | 30 +- mongo/batch_cursor.go | 2 +- mongo/cursor.go | 47 ++- mongo/cursor_test.go | 44 ++- x/bsonx/bsoncore/document_sequence.go | 189 --------- x/bsonx/bsoncore/document_sequence_test.go | 422 --------------------- x/bsonx/bsoncore/iterator.go | 115 ++++++ x/bsonx/bsoncore/iterator_test.go | 306 +++++++++++++++ x/mongo/driver/batch_cursor.go | 84 ++-- 9 files changed, 541 insertions(+), 698 deletions(-) delete mode 100644 x/bsonx/bsoncore/document_sequence.go delete mode 100644 x/bsonx/bsoncore/document_sequence_test.go create mode 100644 x/bsonx/bsoncore/iterator.go create mode 100644 x/bsonx/bsoncore/iterator_test.go diff --git a/internal/integration/mtest/sent_message.go b/internal/integration/mtest/sent_message.go index 94eed12257..db1ef33ffa 100644 --- a/internal/integration/mtest/sent_message.go +++ b/internal/integration/mtest/sent_message.go @@ -30,7 +30,7 @@ type SentMessage struct { // The documents sent for an insert, update, or delete command. This is separated into its own field because it's // sent as part of the command document in OP_QUERY and as a document sequence outside the command document in // OP_MSG. - DocumentSequence *bsoncore.DocumentSequence + Batch *bsoncore.Iterator } type sentMsgParseFn func([]byte) (*SentMessage, error) @@ -87,26 +87,25 @@ func parseOpQuery(wm []byte) (*SentMessage, error) { // For OP_QUERY, inserts, updates, and deletes are sent as a BSON array of documents inside the main command // document. Pull these sequences out into an ArrayStyle DocumentSequence. - var docSequence *bsoncore.DocumentSequence + var batch *bsoncore.Iterator cmdElems, _ := commandDoc.Elements() for _, elem := range cmdElems { switch elem.Key() { case "documents", "updates", "deletes": - docSequence = &bsoncore.DocumentSequence{ - Style: bsoncore.ArrayStyle, - Data: elem.Value().Array(), + batch = &bsoncore.Iterator{ + List: elem.Value().Array(), } } - if docSequence != nil { + if batch != nil { // There can only be one of these arrays in a well-formed command, so we exit the loop once one is found. break } } sm := &SentMessage{ - Command: commandDoc, - ReadPreference: rpDoc, - DocumentSequence: docSequence, + Command: commandDoc, + ReadPreference: rpDoc, + Batch: batch, } return sm, nil } @@ -156,7 +155,7 @@ func parseSentOpMsg(wm []byte) (*SentMessage, error) { rpDoc = rpVal.Document() } - var docSequence *bsoncore.DocumentSequence + var batch *bsoncore.Iterator if len(wm) != 0 { // If there are bytes remaining in the wire message, they must correspond to a DocumentSequence section. if wm, err = assertMsgSectionType(wm, wiremessage.DocumentSequence); err != nil { @@ -169,16 +168,15 @@ func parseSentOpMsg(wm []byte) (*SentMessage, error) { return nil, errors.New("failed to read document sequence") } - docSequence = &bsoncore.DocumentSequence{ - Style: bsoncore.SequenceStyle, - Data: data, + batch = &bsoncore.Iterator{ + List: data, } } sm := &SentMessage{ - Command: commandDoc, - ReadPreference: rpDoc, - DocumentSequence: docSequence, + Command: commandDoc, + ReadPreference: rpDoc, + Batch: batch, } return sm, nil } diff --git a/mongo/batch_cursor.go b/mongo/batch_cursor.go index 51d59d0ffa..9e87b00ae4 100644 --- a/mongo/batch_cursor.go +++ b/mongo/batch_cursor.go @@ -25,7 +25,7 @@ type batchCursor interface { // Batch will return a DocumentSequence for the current batch of documents. The returned // DocumentSequence is only valid until the next call to Next or Close. - Batch() *bsoncore.DocumentSequence + Batch() *bsoncore.Iterator // Server returns a pointer to the cursor's server. Server() driver.Server diff --git a/mongo/cursor.go b/mongo/cursor.go index 67db1c2953..552c49d550 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -33,7 +33,7 @@ type Cursor struct { Current bson.Raw bc batchCursor - batch *bsoncore.DocumentSequence + batch *bsoncore.Iterator batchLength int bsonOpts *options.BSONOptions registry *bsoncodec.Registry @@ -72,9 +72,10 @@ func newCursorWithSession( c.closeImplicitSession() } - // Initialize just the batchLength here so RemainingBatchLength will return an accurate result. The actual batch - // will be pulled up by the first Next/TryNext call. - c.batchLength = c.bc.Batch().DocumentCount() + // Initialize just the batchLength here so RemainingBatchLength will return an + // accurate result. The actual batch will be pulled up by the first + // Next/TryNext call. + c.batchLength = c.bc.Batch().Count() return c, nil } @@ -91,10 +92,11 @@ func NewCursorFromDocuments(documents []interface{}, err error, registry *bsonco registry = bson.DefaultRegistry } - // Convert documents slice to a sequence-style byte array. buf := new(bytes.Buffer) enc := new(bson.Encoder) - for _, doc := range documents { + + values := make([]bsoncore.Value, len(documents)) + for i, doc := range documents { switch t := doc.(type) { case nil: return nil, ErrNilDocument @@ -102,20 +104,32 @@ func NewCursorFromDocuments(documents []interface{}, err error, registry *bsonco // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. doc = bson.Raw(t) } + vw, err := bsonrw.NewBSONValueWriter(buf) if err != nil { return nil, err } + enc.Reset(vw) enc.SetRegistry(registry) - err = enc.Encode(doc) - if err != nil { + + if err = enc.Encode(doc); err != nil { return nil, err } + + dup := make([]byte, len(buf.Bytes())) + copy(dup, buf.Bytes()) + + values[i] = bsoncore.Value{ + Type: bson.TypeEmbeddedDocument, + Data: dup, + } + + buf.Reset() } c := &Cursor{ - bc: driver.NewBatchCursorFromDocuments(buf.Bytes()), + bc: driver.NewBatchCursorFromList(bsoncore.BuildArray(nil, values...)), registry: registry, err: err, } @@ -123,7 +137,8 @@ func NewCursorFromDocuments(documents []interface{}, err error, registry *bsonco // Initialize batch and batchLength here. The underlying batch cursor will be preloaded with the // provided contents, and thus already has a batch before calls to Next/TryNext. c.batch = c.bc.Batch() - c.batchLength = c.bc.Batch().DocumentCount() + c.batchLength = c.bc.Batch().Count() + return c, nil } @@ -166,12 +181,12 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool { if ctx == nil { ctx = context.Background() } - doc, err := c.batch.Next() + val, err := c.batch.Next() switch { case err == nil: // Consume the next document in the current batch. c.batchLength-- - c.Current = bson.Raw(doc) + c.Current = bson.Raw(val.Data) return true case errors.Is(err, io.EOF): // Need to do a getMore default: @@ -209,12 +224,12 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool { // Use the new batch to update the batch and batchLength fields. Consume the first document in the batch. c.batch = c.bc.Batch() - c.batchLength = c.batch.DocumentCount() - doc, err = c.batch.Next() + c.batchLength = c.batch.Count() + val, err = c.batch.Next() switch { case err == nil: c.batchLength-- - c.Current = bson.Raw(doc) + c.Current = bson.Raw(val.Data) return true case errors.Is(err, io.EOF): // Empty batch so we continue default: @@ -348,7 +363,7 @@ func (c *Cursor) RemainingBatchLength() int { // addFromBatch adds all documents from batch to sliceVal starting at the given index. It returns the new slice value, // the next empty index in the slice, and an error if one occurs. -func (c *Cursor) addFromBatch(sliceVal reflect.Value, elemType reflect.Type, batch *bsoncore.DocumentSequence, +func (c *Cursor) addFromBatch(sliceVal reflect.Value, elemType reflect.Type, batch *bsoncore.Iterator, index int) (reflect.Value, int, error) { docs, err := batch.Documents() diff --git a/mongo/cursor_test.go b/mongo/cursor_test.go index 3781109019..1aaf41555b 100644 --- a/mongo/cursor_test.go +++ b/mongo/cursor_test.go @@ -13,6 +13,7 @@ import ( "time" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/options" @@ -21,17 +22,17 @@ import ( ) type testBatchCursor struct { - batches []*bsoncore.DocumentSequence - batch *bsoncore.DocumentSequence + batches []*bsoncore.Iterator + batch *bsoncore.Iterator closed bool } func newTestBatchCursor(numBatches, batchSize int) *testBatchCursor { - batches := make([]*bsoncore.DocumentSequence, 0, numBatches) + batches := make([]*bsoncore.Iterator, 0, numBatches) counter := 0 for batch := 0; batch < numBatches; batch++ { - var docSequence []byte + var values []bsoncore.Value for doc := 0; doc < batchSize; doc++ { var elem []byte @@ -40,12 +41,18 @@ func newTestBatchCursor(numBatches, batchSize int) *testBatchCursor { var doc []byte doc = bsoncore.BuildDocumentFromElements(doc, elem) - docSequence = append(docSequence, doc...) + val := bsoncore.Value{ + Type: bsontype.EmbeddedDocument, + Data: doc, + } + + values = append(values, val) } - batches = append(batches, &bsoncore.DocumentSequence{ - Style: bsoncore.SequenceStyle, - Data: docSequence, + arr := bsoncore.BuildArray(nil, values...) + + batches = append(batches, &bsoncore.Iterator{ + List: arr, }) } @@ -72,7 +79,7 @@ func (tbc *testBatchCursor) Next(context.Context) bool { return true } -func (tbc *testBatchCursor) Batch() *bsoncore.DocumentSequence { +func (tbc *testBatchCursor) Batch() *bsoncore.Iterator { return tbc.batch } @@ -262,3 +269,22 @@ func TestNewCursorFromDocuments(t *testing.T) { mockErr, cur.Err()) }) } + +func BenchmarkNewCursorFromDocuments(b *testing.B) { + // Prepare sample data + documents := []interface{}{ + bson.D{{"_id", 0}, {"foo", "bar"}}, + bson.D{{"_id", 1}, {"baz", "qux"}}, + bson.D{{"_id", 2}, {"quux", "quuz"}}, + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := NewCursorFromDocuments(documents, nil, nil) + if err != nil { + b.Fatalf("Error creating cursor: %v", err) + } + } +} diff --git a/x/bsonx/bsoncore/document_sequence.go b/x/bsonx/bsoncore/document_sequence.go deleted file mode 100644 index e35bd0cd9a..0000000000 --- a/x/bsonx/bsoncore/document_sequence.go +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2022-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsoncore - -import ( - "errors" - "io" - - "go.mongodb.org/mongo-driver/bson/bsontype" -) - -// DocumentSequenceStyle is used to represent how a document sequence is laid out in a slice of -// bytes. -type DocumentSequenceStyle uint32 - -// These constants are the valid styles for a DocumentSequence. -const ( - _ DocumentSequenceStyle = iota - SequenceStyle - ArrayStyle -) - -// DocumentSequence represents a sequence of documents. The Style field indicates how the documents -// are laid out inside of the Data field. -type DocumentSequence struct { - Style DocumentSequenceStyle - Data []byte - Pos int -} - -// ErrCorruptedDocument is returned when a full document couldn't be read from the sequence. -var ErrCorruptedDocument = errors.New("invalid DocumentSequence: corrupted document") - -// ErrNonDocument is returned when a DocumentSequence contains a non-document BSON value. -var ErrNonDocument = errors.New("invalid DocumentSequence: a non-document value was found in sequence") - -// ErrInvalidDocumentSequenceStyle is returned when an unknown DocumentSequenceStyle is set on a -// DocumentSequence. -var ErrInvalidDocumentSequenceStyle = errors.New("invalid DocumentSequenceStyle") - -// DocumentCount returns the number of documents in the sequence. -func (ds *DocumentSequence) DocumentCount() int { - if ds == nil { - return 0 - } - switch ds.Style { - case SequenceStyle: - var count int - var ok bool - rem := ds.Data - for len(rem) > 0 { - _, rem, ok = ReadDocument(rem) - if !ok { - return 0 - } - count++ - } - return count - case ArrayStyle: - _, rem, ok := ReadLength(ds.Data) - if !ok { - return 0 - } - - var count int - for len(rem) > 1 { - _, rem, ok = ReadElement(rem) - if !ok { - return 0 - } - count++ - } - return count - default: - return 0 - } -} - -// Empty returns true if the sequence is empty. It always returns true for unknown sequence styles. -func (ds *DocumentSequence) Empty() bool { - if ds == nil { - return true - } - - switch ds.Style { - case SequenceStyle: - return len(ds.Data) == 0 - case ArrayStyle: - return len(ds.Data) <= 5 - default: - return true - } -} - -// ResetIterator resets the iteration point for the Next method to the beginning of the document -// sequence. -func (ds *DocumentSequence) ResetIterator() { - if ds == nil { - return - } - ds.Pos = 0 -} - -// Documents returns a slice of the documents. If nil either the Data field is also nil or could not -// be properly read. -func (ds *DocumentSequence) Documents() ([]Document, error) { - if ds == nil { - return nil, nil - } - switch ds.Style { - case SequenceStyle: - rem := ds.Data - var docs []Document - var doc Document - var ok bool - for { - doc, rem, ok = ReadDocument(rem) - if !ok { - if len(rem) == 0 { - break - } - return nil, ErrCorruptedDocument - } - docs = append(docs, doc) - } - return docs, nil - case ArrayStyle: - if len(ds.Data) == 0 { - return nil, nil - } - vals, err := Document(ds.Data).Values() - if err != nil { - return nil, ErrCorruptedDocument - } - docs := make([]Document, 0, len(vals)) - for _, v := range vals { - if v.Type != bsontype.EmbeddedDocument { - return nil, ErrNonDocument - } - docs = append(docs, v.Data) - } - return docs, nil - default: - return nil, ErrInvalidDocumentSequenceStyle - } -} - -// Next retrieves the next document from this sequence and returns it. This method will return -// io.EOF when it has reached the end of the sequence. -func (ds *DocumentSequence) Next() (Document, error) { - if ds == nil || ds.Pos >= len(ds.Data) { - return nil, io.EOF - } - switch ds.Style { - case SequenceStyle: - doc, _, ok := ReadDocument(ds.Data[ds.Pos:]) - if !ok { - return nil, ErrCorruptedDocument - } - ds.Pos += len(doc) - return doc, nil - case ArrayStyle: - if ds.Pos < 4 { - if len(ds.Data) < 4 { - return nil, ErrCorruptedDocument - } - ds.Pos = 4 // Skip the length of the document - } - if len(ds.Data[ds.Pos:]) == 1 && ds.Data[ds.Pos] == 0x00 { - return nil, io.EOF // At the end of the document - } - elem, _, ok := ReadElement(ds.Data[ds.Pos:]) - if !ok { - return nil, ErrCorruptedDocument - } - ds.Pos += len(elem) - val := elem.Value() - if val.Type != bsontype.EmbeddedDocument { - return nil, ErrNonDocument - } - return val.Data, nil - default: - return nil, ErrInvalidDocumentSequenceStyle - } -} diff --git a/x/bsonx/bsoncore/document_sequence_test.go b/x/bsonx/bsoncore/document_sequence_test.go deleted file mode 100644 index bf40fa878d..0000000000 --- a/x/bsonx/bsoncore/document_sequence_test.go +++ /dev/null @@ -1,422 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2022-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsoncore - -import ( - "bytes" - "errors" - "io" - "strconv" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestDocumentSequence(t *testing.T) { - - genArrayStyle := func(num int) []byte { - idx, seq := AppendDocumentStart(nil) - for i := 0; i < num; i++ { - seq = AppendDocumentElement( - seq, strconv.Itoa(i), - BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)), - ) - } - seq, _ = AppendDocumentEnd(seq, idx) - return seq - } - genSequenceStyle := func(num int) []byte { - var seq []byte - for i := 0; i < num; i++ { - seq = append(seq, BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159))...) - } - return seq - } - - idx, arrayStyle := AppendDocumentStart(nil) - idx2, arrayStyle := AppendDocumentElementStart(arrayStyle, "0") - arrayStyle = AppendDoubleElement(arrayStyle, "pi", 3.14159) - arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx2) - idx2, arrayStyle = AppendDocumentElementStart(arrayStyle, "1") - arrayStyle = AppendStringElement(arrayStyle, "hello", "world") - arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx2) - arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx) - - t.Run("Documents", func(t *testing.T) { - testCases := []struct { - name string - style DocumentSequenceStyle - data []byte - documents []Document - err error - }{ - { - "SequenceStle/corrupted document", - SequenceStyle, - []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, - nil, - ErrCorruptedDocument, - }, - { - "SequenceStyle/success", - SequenceStyle, - BuildDocument( - BuildDocument( - nil, - AppendStringElement(AppendDoubleElement(nil, "pi", 3.14159), "hello", "world"), - ), - AppendDoubleElement(AppendStringElement(nil, "hello", "world"), "pi", 3.14159), - ), - []Document{ - BuildDocument(nil, AppendStringElement(AppendDoubleElement(nil, "pi", 3.14159), "hello", "world")), - BuildDocument(nil, AppendDoubleElement(AppendStringElement(nil, "hello", "world"), "pi", 3.14159)), - }, - nil, - }, - { - "ArrayStyle/insufficient bytes", - ArrayStyle, - []byte{0x01, 0x02, 0x03, 0x04, 0x05}, - nil, - ErrCorruptedDocument, - }, - { - "ArrayStyle/non-document", - ArrayStyle, - BuildDocument(nil, AppendDoubleElement(nil, "0", 12345.67890)), - nil, - ErrNonDocument, - }, - { - "ArrayStyle/success", - ArrayStyle, - arrayStyle, - []Document{ - BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)), - BuildDocument(nil, AppendStringElement(nil, "hello", "world")), - }, - nil, - }, - {"Invalid DocumentSequenceStyle", 0, nil, nil, ErrInvalidDocumentSequenceStyle}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ds := &DocumentSequence{ - Style: tc.style, - Data: tc.data, - } - documents, err := ds.Documents() - if !cmp.Equal(documents, tc.documents) { - t.Errorf("Documents do not match. got %v; want %v", documents, tc.documents) - } - if !errors.Is(err, tc.err) { - t.Errorf("Errors do not match. got %v; want %v", err, tc.err) - } - }) - } - }) - t.Run("Next", func(t *testing.T) { - seqDoc := BuildDocument( - BuildDocument( - nil, - AppendDoubleElement(nil, "pi", 3.14159), - ), - AppendStringElement(nil, "hello", "world"), - ) - - idx, arrayStyle := AppendDocumentStart(nil) - idx2, arrayStyle := AppendDocumentElementStart(arrayStyle, "0") - arrayStyle = AppendDoubleElement(arrayStyle, "pi", 3.14159) - arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx2) - idx2, arrayStyle = AppendDocumentElementStart(arrayStyle, "1") - arrayStyle = AppendStringElement(arrayStyle, "hello", "world") - arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx2) - arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx) - - testCases := []struct { - name string - style DocumentSequenceStyle - data []byte - pos int - document Document - err error - }{ - {"io.EOF", 0, make([]byte, 10), 10, nil, io.EOF}, - { - "SequenceStyle/corrupted document", - SequenceStyle, - []byte{0x01, 0x02, 0x03, 0x04}, - 0, - nil, - ErrCorruptedDocument, - }, - { - "SequenceStyle/success/first", - SequenceStyle, - seqDoc, - 0, - BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)), - nil, - }, - { - "SequenceStyle/success/second", - SequenceStyle, - seqDoc, - 17, - BuildDocument(nil, AppendStringElement(nil, "hello", "world")), - nil, - }, - { - "ArrayStyle/corrupted document/too short", - ArrayStyle, - []byte{0x01, 0x02, 0x03}, - 0, - nil, - ErrCorruptedDocument, - }, - { - "ArrayStyle/corrupted document/invalid element", - ArrayStyle, - []byte{0x00, 0x00, 0x00, 0x00, 0x01, '0', 0x00, 0x01, 0x02}, - 0, - nil, - ErrCorruptedDocument, - }, - { - "ArrayStyle/non-document", - ArrayStyle, - BuildDocument(nil, AppendDoubleElement(nil, "0", 12345.67890)), - 0, - nil, - ErrNonDocument, - }, - { - "ArrayStyle/success/first", - ArrayStyle, - arrayStyle, - 0, - BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)), - nil, - }, - { - "ArrayStyle/success/second", - ArrayStyle, - arrayStyle, - 24, - BuildDocument(nil, AppendStringElement(nil, "hello", "world")), - nil, - }, - {"Invalid DocumentSequenceStyle", 0, make([]byte, 4), 0, nil, ErrInvalidDocumentSequenceStyle}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ds := &DocumentSequence{ - Style: tc.style, - Data: tc.data, - Pos: tc.pos, - } - document, err := ds.Next() - if !bytes.Equal(document, tc.document) { - t.Errorf("Documents do not match. got %v; want %v", document, tc.document) - } - if !errors.Is(err, tc.err) { - t.Errorf("Errors do not match. got %v; want %v", err, tc.err) - } - }) - } - }) - - t.Run("Full Iteration", func(t *testing.T) { - testCases := []struct { - name string - style DocumentSequenceStyle - data []byte - count int - }{ - {"SequenceStyle/success/nil", SequenceStyle, nil, 0}, - {"SequenceStyle/success/0", SequenceStyle, []byte{}, 0}, - {"SequenceStyle/success/1", SequenceStyle, genSequenceStyle(1), 1}, - {"SequenceStyle/success/2", SequenceStyle, genSequenceStyle(2), 2}, - {"SequenceStyle/success/10", SequenceStyle, genSequenceStyle(10), 10}, - {"SequenceStyle/success/100", SequenceStyle, genSequenceStyle(100), 100}, - {"ArrayStyle/success/nil", ArrayStyle, nil, 0}, - {"ArrayStyle/success/0", ArrayStyle, []byte{0x05, 0x00, 0x00, 0x00, 0x00}, 0}, - {"ArrayStyle/success/1", ArrayStyle, genArrayStyle(1), 1}, - {"ArrayStyle/success/2", ArrayStyle, genArrayStyle(2), 2}, - {"ArrayStyle/success/10", ArrayStyle, genArrayStyle(10), 10}, - {"ArrayStyle/success/100", ArrayStyle, genArrayStyle(100), 100}, - } - - for _, tc := range testCases { - t.Run("Documents/"+tc.name, func(t *testing.T) { - ds := &DocumentSequence{ - Style: tc.style, - Data: tc.data, - } - docs, err := ds.Documents() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - count := len(docs) - if count != tc.count { - t.Errorf("Coun't fully iterate documents, wrong count. got %v; want %v", count, tc.count) - } - }) - t.Run("Next/"+tc.name, func(t *testing.T) { - ds := &DocumentSequence{ - Style: tc.style, - Data: tc.data, - } - var docs []Document - for { - doc, err := ds.Next() - if errors.Is(err, io.EOF) { - break - } - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - docs = append(docs, doc) - } - count := len(docs) - if count != tc.count { - t.Errorf("Coun't fully iterate documents, wrong count. got %v; want %v", count, tc.count) - } - }) - } - }) - t.Run("DocumentCount", func(t *testing.T) { - testCases := []struct { - name string - style DocumentSequenceStyle - data []byte - count int - }{ - { - "SequenceStyle/corrupt document/first", - SequenceStyle, - []byte{0x01, 0x02, 0x03}, - 0, - }, - { - "SequenceStyle/corrupt document/second", - SequenceStyle, - []byte{0x05, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03}, - 0, - }, - {"SequenceStyle/success/nil", SequenceStyle, nil, 0}, - {"SequenceStyle/success/0", SequenceStyle, []byte{}, 0}, - {"SequenceStyle/success/1", SequenceStyle, genSequenceStyle(1), 1}, - {"SequenceStyle/success/2", SequenceStyle, genSequenceStyle(2), 2}, - {"SequenceStyle/success/10", SequenceStyle, genSequenceStyle(10), 10}, - {"SequenceStyle/success/100", SequenceStyle, genSequenceStyle(100), 100}, - { - "ArrayStyle/corrupt document/length", - ArrayStyle, - []byte{0x01, 0x02, 0x03}, - 0, - }, - { - "ArrayStyle/corrupt element/first", - ArrayStyle, - BuildDocument(nil, []byte{0x01, 0x00, 0x03, 0x04, 0x05}), - 0, - }, - { - "ArrayStyle/corrupt element/second", - ArrayStyle, - BuildDocument(nil, []byte{0x0A, 0x00, 0x01, 0x00, 0x03, 0x04, 0x05}), - 0, - }, - {"ArrayStyle/success/nil", ArrayStyle, nil, 0}, - {"ArrayStyle/success/0", ArrayStyle, []byte{0x05, 0x00, 0x00, 0x00, 0x00}, 0}, - {"ArrayStyle/success/1", ArrayStyle, genArrayStyle(1), 1}, - {"ArrayStyle/success/2", ArrayStyle, genArrayStyle(2), 2}, - {"ArrayStyle/success/10", ArrayStyle, genArrayStyle(10), 10}, - {"ArrayStyle/success/100", ArrayStyle, genArrayStyle(100), 100}, - {"Invalid DocumentSequenceStyle", 0, nil, 0}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ds := &DocumentSequence{ - Style: tc.style, - Data: tc.data, - } - count := ds.DocumentCount() - if count != tc.count { - t.Errorf("Document counts don't match. got %v; want %v", count, tc.count) - } - }) - } - }) - t.Run("Empty", func(t *testing.T) { - testCases := []struct { - name string - ds *DocumentSequence - isEmpty bool - }{ - {"ArrayStyle/is empty/nil", nil, true}, - {"ArrayStyle/is empty/0", &DocumentSequence{Style: ArrayStyle, Data: []byte{0x05, 0x00, 0x00, 0x00, 0x00}}, true}, - {"ArrayStyle/is not empty/non-0", &DocumentSequence{Style: ArrayStyle, Data: genArrayStyle(10)}, false}, - {"SequenceStyle/is empty/nil", nil, true}, - {"SequenceStyle/is empty/0", &DocumentSequence{Style: SequenceStyle, Data: []byte{}}, true}, - {"SequenceStyle/is not empty/non-0", &DocumentSequence{Style: SequenceStyle, Data: genSequenceStyle(10)}, false}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - isEmpty := tc.ds.Empty() - if isEmpty != tc.isEmpty { - t.Errorf("Unexpected Empty result. got %v; want %v", isEmpty, tc.isEmpty) - } - }) - } - }) - t.Run("ResetIterator", func(t *testing.T) { - ds := &DocumentSequence{Pos: 1234567890} - want := 0 - ds.ResetIterator() - if ds.Pos != want { - t.Errorf("Unexpected position after ResetIterator. got %d; want %d", ds.Pos, want) - } - }) - t.Run("no panic on nil", func(t *testing.T) { - capturePanic := func() { - if err := recover(); err != nil { - t.Errorf("Unexpected panic. got %v; want ", err) - } - } - t.Run("DocumentCount", func(t *testing.T) { - defer capturePanic() - var ds *DocumentSequence - _ = ds.DocumentCount() - }) - t.Run("Empty", func(t *testing.T) { - defer capturePanic() - var ds *DocumentSequence - _ = ds.Empty() - }) - t.Run("ResetIterator", func(t *testing.T) { - defer capturePanic() - var ds *DocumentSequence - ds.ResetIterator() - }) - t.Run("Documents", func(t *testing.T) { - defer capturePanic() - var ds *DocumentSequence - _, _ = ds.Documents() - }) - t.Run("Next", func(t *testing.T) { - defer capturePanic() - var ds *DocumentSequence - _, _ = ds.Next() - }) - }) -} diff --git a/x/bsonx/bsoncore/iterator.go b/x/bsonx/bsoncore/iterator.go new file mode 100644 index 0000000000..f4f6236d77 --- /dev/null +++ b/x/bsonx/bsoncore/iterator.go @@ -0,0 +1,115 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncore + +import ( + "errors" + "fmt" + "io" + + "go.mongodb.org/mongo-driver/bson/bsontype" +) + +// errCorruptedDocument is returned when a full document couldn't be read from +// the sequence. +var errCorruptedDocument = errors.New("invalid DocumentSequence: corrupted document") + +// Iterator maintains a list of BSON values and keeps track of the current +// position in relation to its Next() method. +type Iterator struct { + List Array // List of BSON values + pos int // The position of the iterator in the list in reference to Next() +} + +// Count returned the number of elements in the iterator's list. +func (iter *Iterator) Count() int { + if iter == nil { + return 0 + } + + _, rem, ok := ReadLength(iter.List) + if !ok { + return 0 + } + + var count int + for len(rem) > 1 { + _, rem, ok = ReadElement(rem) + if !ok { + return 0 + } + count++ + } + return count +} + +// Empty returns true if the iterator's list is empty. +func (iter *Iterator) Empty() bool { + return len(iter.List) <= 5 +} + +// Reset will reset the iteration point for the Next method to the beginning of +// the list. +func (iter *Iterator) Reset() { + iter.pos = 0 +} + +// Documents traverses the list as documents and returns them. This method +// assumes that the underlying list is composed of documents and will return +// an error otherwise. +func (iter *Iterator) Documents() ([]Document, error) { + if iter == nil || len(iter.List) == 0 { + return nil, nil + } + + vals, err := iter.List.Values() + if err != nil { + return nil, errCorruptedDocument + } + + docs := make([]Document, 0, len(vals)) + for _, v := range vals { + if v.Type != bsontype.EmbeddedDocument { + return nil, fmt.Errorf("invalid DocumentSequence: a non-document value was found in sequence") + } + + docs = append(docs, v.Data) + } + + return docs, nil +} + +// Next retrieves the next value from the list and returns it. This method will +// return io.EOF when it has reached the end of the list. +func (iter *Iterator) Next() (*Value, error) { + if iter == nil || iter.pos >= len(iter.List) { + return nil, io.EOF + } + + if iter.pos < 4 { + if len(iter.List) < 4 { + return nil, errCorruptedDocument + } + + iter.pos = 4 // Skip the length of the document + } + + rem := iter.List[iter.pos:] + if len(rem) == 1 && rem[0] == 0x00 { + return nil, io.EOF // At the end of the document + } + + elem, _, ok := ReadElement(rem) + if !ok { + return nil, errCorruptedDocument + } + + iter.pos += len(elem) + val := elem.Value() + + return &val, nil +} diff --git a/x/bsonx/bsoncore/iterator_test.go b/x/bsonx/bsoncore/iterator_test.go new file mode 100644 index 0000000000..1011f6033f --- /dev/null +++ b/x/bsonx/bsoncore/iterator_test.go @@ -0,0 +1,306 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncore + +import ( + "io" + "testing" + + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" +) + +func TestIterator_Reset(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + values []Value + }{ + { + name: "documents", + values: []Value{ + { + Type: bsontype.EmbeddedDocument, + Data: BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)), + }, + { + Type: bsontype.EmbeddedDocument, + Data: BuildDocument(nil, AppendDoubleElement(nil, "grav", 9.8)), + }, + }, + }, + { + name: "strings", + values: []Value{ + { + Type: bsontype.String, + Data: AppendString(nil, "foo"), + }, + { + Type: bsontype.String, + Data: AppendString(nil, "bar"), + }, + }, + }, + { + name: "type mixing", + values: []Value{ + { + Type: bsontype.String, + Data: AppendString(nil, "foo"), + }, + { + Type: bsontype.Boolean, + Data: AppendBoolean(nil, true), + }, + { + Type: bsontype.EmbeddedDocument, + Data: BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)), + }, + }, + }, + } + + for _, tcase := range tests { + tcase := tcase + + t.Run(tcase.name, func(t *testing.T) { + t.Parallel() + + // 1. Create the iterator + array := BuildArray(nil, tcase.values...) + iter := &Iterator{List: array} + + // 2. Read one of the documents using Next() + _, err := iter.Next() + assert.NoError(t, err) + + // 3. Reset the position + iter.Reset() + + // 4. Assert that we get the first value when re-running Next. + got, err := iter.Next() + + assert.NoError(t, err) + assert.NotNil(t, got) + assert.Equal(t, tcase.values[0], *got) + }) + } +} + +func TestIterator_Count(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + values []Value + want int + }{ + { + name: "empty", + values: []Value{}, + want: 0, + }, + { + name: "nil", + values: nil, + want: 0, + }, + { + name: "singleton", + values: []Value{ + { + Type: bsontype.String, + Data: AppendString(nil, "foo"), + }, + }, + want: 1, + }, + { + name: "non singleton", + values: []Value{ + { + Type: bsontype.String, + Data: AppendString(nil, "foo"), + }, + { + Type: bsontype.String, + Data: AppendString(nil, "bar"), + }, + }, + want: 2, + }, + { + name: "document bearing", + values: []Value{ + { + Type: bsontype.EmbeddedDocument, + Data: BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)), + }, + }, + want: 1, + }, + { + name: "type mixing", + values: []Value{ + { + Type: bsontype.String, + Data: AppendString(nil, "foo"), + }, + { + Type: bsontype.Boolean, + Data: AppendBoolean(nil, true), + }, + { + Type: bsontype.EmbeddedDocument, + Data: BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)), + }, + }, + want: 3, + }, + } + + for _, tcase := range tests { + tcase := tcase + + t.Run(tcase.name, func(t *testing.T) { + t.Parallel() + + var array Array + if tcase.values != nil { + array = BuildArray(nil, tcase.values...) + } + + got := (&Iterator{List: array}).Count() + assert.Equal(t, tcase.want, got) + }) + } +} + +func TestIterator_Next(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + values []Value + err error + }{ + { + name: "empty", + values: []Value{}, + err: io.EOF, + }, + { + name: "nil", + values: nil, + err: io.EOF, + }, + { + name: "singleton", + values: []Value{ + { + Type: bsontype.String, + Data: AppendString(nil, "foo"), + }, + }, + }, + { + name: "document bearing", + values: []Value{ + { + Type: bsontype.EmbeddedDocument, + Data: BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)), + }, + }, + }, + { + name: "type mixing", + values: []Value{ + { + Type: bsontype.String, + Data: AppendString(nil, "foo"), + }, + { + Type: bsontype.Boolean, + Data: AppendBoolean(nil, true), + }, + { + Type: bsontype.EmbeddedDocument, + Data: BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)), + }, + }, + }, + } + + for _, tcase := range tests { + tcase := tcase + + t.Run(tcase.name, func(t *testing.T) { + t.Parallel() + + var array Array + if tcase.values != nil { + array = BuildArray(nil, tcase.values...) + } + + iter := &Iterator{List: array} + + for _, want := range tcase.values { + got, err := iter.Next() + require.NoErrorf(t, err, "failed to parse the next value") + + assert.Equal(t, want.Type, got.Type) + assert.Equal(t, want.Data, got.Data) + } + + // Make sure the last call to next results in an EOF. + _, err := iter.Next() + assert.ErrorIs(t, err, io.EOF) + }) + } + +} + +// BenchmarkNext measures the performance of the Next function. +func BenchmarkIterator_Next(b *testing.B) { + values := []Value{ + { + Type: bsontype.Double, + Data: AppendDouble(nil, 3.14159), + }, + { + Type: bsontype.String, + Data: AppendString(nil, "foo"), + }, + { + Type: bsontype.EmbeddedDocument, + Data: BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)), + }, + { + Type: bsontype.Boolean, + Data: AppendBoolean(nil, true), + }, + } + + iter := &Iterator{} + iter.List = BuildArray(nil, values...) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := iter.Next() + if err == io.EOF { + // If we reach the end of the list, reset the iterator for the next iteration. + iter.pos = 0 + } else if err != nil { + b.Fatalf("Unexpected error: %v", err) + } + } +} diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index 827e536137..f98d5e739b 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -45,7 +45,7 @@ type BatchCursor struct { connection PinnedConnection batchSize int32 maxTimeMS int64 - currentBatch *bsoncore.DocumentSequence + currentBatch *bsoncore.Iterator firstBatch bool cmdMonitor *event.CommandMonitor postBatchResumeToken bsoncore.Document @@ -64,7 +64,7 @@ type CursorResponse struct { ErrorProcessor ErrorProcessor // This will only be set when pinning to a connection. Connection PinnedConnection Desc description.Server - FirstBatch *bsoncore.DocumentSequence + FirstBatch *bsoncore.Iterator Database string Collection string ID int64 @@ -102,7 +102,8 @@ func NewCursorResponse(info ResponseInfo) (CursorResponse, error) { if !ok { return CursorResponse{}, fmt.Errorf("firstBatch should be an array but is a BSON %s", elem.Value().Type) } - curresp.FirstBatch = &bsoncore.DocumentSequence{Style: bsoncore.ArrayStyle, Data: arr} + + curresp.FirstBatch = &bsoncore.Iterator{List: arr} case "ns": ns, ok := elem.Value().StringValueOK() if !ok { @@ -163,8 +164,14 @@ type CursorOptions struct { } // NewBatchCursor creates a new BatchCursor from the provided parameters. -func NewBatchCursor(cr CursorResponse, clientSession *session.Client, clock *session.ClusterClock, opts CursorOptions) (*BatchCursor, error) { - ds := cr.FirstBatch +func NewBatchCursor( + cr CursorResponse, + clientSession *session.Client, + clock *session.ClusterClock, + opts CursorOptions, +) (*BatchCursor, error) { + firstBatch := cr.FirstBatch + bc := &BatchCursor{ clientSession: clientSession, clock: clock, @@ -186,46 +193,27 @@ func NewBatchCursor(cr CursorResponse, clientSession *session.Client, clock *ses encoderFn: opts.MarshalValueEncoderFn, } - if ds != nil { - bc.numReturned = int32(ds.DocumentCount()) - } - if cr.Desc.WireVersion == nil { - bc.limit = opts.Limit - - // Take as many documents from the batch as needed. - if bc.limit != 0 && bc.limit < bc.numReturned { - for i := int32(0); i < bc.limit; i++ { - _, err := ds.Next() - if err != nil { - return nil, err - } - } - ds.Data = ds.Data[:ds.Pos] - ds.ResetIterator() - } + if firstBatch != nil { + bc.numReturned = int32(firstBatch.Count()) } - bc.currentBatch = ds + bc.currentBatch = firstBatch + return bc, nil } // NewEmptyBatchCursor returns a batch cursor that is empty. func NewEmptyBatchCursor() *BatchCursor { - return &BatchCursor{currentBatch: new(bsoncore.DocumentSequence)} + return &BatchCursor{currentBatch: new(bsoncore.Iterator)} } -// NewBatchCursorFromDocuments returns a batch cursor with current batch set to a sequence-style -// DocumentSequence containing the provided documents. -func NewBatchCursorFromDocuments(documents []byte) *BatchCursor { +// NewBatchCursorFromList returns a batch cursor with current batch set to an +// itertor that can traverse the BSON data contained within the array. +func NewBatchCursorFromList(array []byte) *BatchCursor { return &BatchCursor{ - currentBatch: &bsoncore.DocumentSequence{ - Data: documents, - Style: bsoncore.SequenceStyle, - }, - // BatchCursors created with this function have no associated ID nor server, so no getMore - // calls will be made. - id: 0, - server: nil, + currentBatch: &bsoncore.Iterator{List: array}, + id: 0, + server: nil, } } @@ -260,10 +248,14 @@ func (bc *BatchCursor) Next(ctx context.Context) bool { // Batch will return a DocumentSequence for the current batch of documents. The returned // DocumentSequence is only valid until the next call to Next or Close. -func (bc *BatchCursor) Batch() *bsoncore.DocumentSequence { return bc.currentBatch } +func (bc *BatchCursor) Batch() *bsoncore.Iterator { + return bc.currentBatch +} // Err returns the latest error encountered. -func (bc *BatchCursor) Err() error { return bc.err } +func (bc *BatchCursor) Err() error { + return bc.err +} // Close closes this batch cursor. func (bc *BatchCursor) Close(ctx context.Context) error { @@ -273,9 +265,9 @@ func (bc *BatchCursor) Close(ctx context.Context) error { err := bc.KillCursor(ctx) bc.id = 0 - bc.currentBatch.Data = nil - bc.currentBatch.Style = 0 - bc.currentBatch.ResetIterator() + + bc.currentBatch.List = nil + bc.currentBatch.Reset() connErr := bc.unpinConnection() if err == nil { @@ -304,7 +296,7 @@ func (bc *BatchCursor) Server() Server { } func (bc *BatchCursor) clearBatch() { - bc.currentBatch.Data = bc.currentBatch.Data[:0] + bc.currentBatch.List = bc.currentBatch.List[:0] } // KillCursor kills cursor on server without closing batch cursor @@ -405,10 +397,12 @@ func (bc *BatchCursor) getMore(ctx context.Context) { if !ok { return fmt.Errorf("cursor.nextBatch should be an array but is a BSON %s", response.Lookup("cursor", "nextBatch").Type) } - bc.currentBatch.Style = bsoncore.ArrayStyle - bc.currentBatch.Data = batch - bc.currentBatch.ResetIterator() - bc.numReturned += int32(bc.currentBatch.DocumentCount()) // Required for legacy operations which don't support limit. + + bc.currentBatch.List = batch + bc.currentBatch.Reset() + + // Required for legacy operations which don't support limit. + bc.numReturned += int32(bc.currentBatch.Count()) pbrt, err := response.LookupErr("cursor", "postBatchResumeToken") if err != nil { From 91b9075a5334e221c571408eaaf7fb801a9adefc Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 16 Jan 2024 11:25:20 -0600 Subject: [PATCH 2/3] GODRIVER-3071 [master] Correct uint Encoding BSON Documentation (#1516) Co-authored-by: Preston Vasquez --- bson/doc.go | 5 ++--- bson/encoder_example_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/bson/doc.go b/bson/doc.go index 048b5eb998..05e33a4412 100644 --- a/bson/doc.go +++ b/bson/doc.go @@ -68,10 +68,9 @@ // 2. int8, int16, and int32 marshal to a BSON int32. // 3. int marshals to a BSON int32 if the value is between math.MinInt32 and math.MaxInt32, inclusive, and a BSON int64 // otherwise. -// 4. int64 marshals to BSON int64. +// 4. int64 marshals to BSON int64 (unless [Encoder.IntMinSize] is set). // 5. uint8 and uint16 marshal to a BSON int32. -// 6. uint, uint32, and uint64 marshal to a BSON int32 if the value is between math.MinInt32 and math.MaxInt32, -// inclusive, and BSON int64 otherwise. +// 6. uint, uint32, and uint64 marshal to a BSON int64 (unless [Encoder.IntMinSize] is set). // 7. BSON null and undefined values will unmarshal into the zero value of a field (e.g. unmarshalling a BSON null or // undefined value into a string will yield the empty string.). // diff --git a/bson/encoder_example_test.go b/bson/encoder_example_test.go index 576e6bd791..69487a3091 100644 --- a/bson/encoder_example_test.go +++ b/bson/encoder_example_test.go @@ -234,3 +234,33 @@ func ExampleEncoder_multipleExtendedJSONDocuments() { // {"x":{"$numberInt":"3"},"y":{"$numberInt":"4"}} // {"x":{"$numberInt":"4"},"y":{"$numberInt":"5"}} } + +func ExampleEncoder_IntMinSize() { + // Create an encoder that will marshal integers as the minimum BSON int size + // (either 32 or 64 bits) that can represent the integer value. + type foo struct { + Bar uint32 + } + + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + if err != nil { + panic(err) + } + + enc, err := bson.NewEncoder(vw) + if err != nil { + panic(err) + } + + enc.IntMinSize() + + err = enc.Encode(foo{2}) + if err != nil { + panic(err) + } + + fmt.Println(bson.Raw(buf.Bytes()).String()) + // Output: + // {"bar": {"$numberInt":"2"}} +} From df800a9fc535505363c62d9583fc9f3c8c3a8d9e Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Tue, 16 Jan 2024 14:26:57 -0700 Subject: [PATCH 3/3] GODRIVER-2762 Use minimum RTT for CSOT (#1507) Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com> --- internal/cmd/compilecheck/go.mod | 1 - internal/cmd/compilecheck/go.sum | 2 - internal/integration/client_test.go | 103 ----- .../unified/testrunner_operation.go | 18 + .../command-execution.json | 394 ++++++++++++++++++ .../command-execution.yml | 250 +++++++++++ x/mongo/driver/driver.go | 3 - x/mongo/driver/operation.go | 20 +- x/mongo/driver/operation_test.go | 91 ++-- x/mongo/driver/topology/rtt_monitor.go | 150 +++---- x/mongo/driver/topology/rtt_monitor_test.go | 265 +++++------- x/mongo/driver/topology/server.go | 3 +- 12 files changed, 875 insertions(+), 425 deletions(-) create mode 100644 testdata/client-side-operations-timeout/command-execution.json create mode 100644 testdata/client-side-operations-timeout/command-execution.yml diff --git a/internal/cmd/compilecheck/go.mod b/internal/cmd/compilecheck/go.mod index d3be1dedec..151d32646a 100644 --- a/internal/cmd/compilecheck/go.mod +++ b/internal/cmd/compilecheck/go.mod @@ -12,7 +12,6 @@ require go.mongodb.org/mongo-driver v1.11.7 require ( github.com/golang/snappy v0.0.1 // indirect github.com/klauspost/compress v1.13.6 // indirect - github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect diff --git a/internal/cmd/compilecheck/go.sum b/internal/cmd/compilecheck/go.sum index 3180c9060e..83cc061005 100644 --- a/internal/cmd/compilecheck/go.sum +++ b/internal/cmd/compilecheck/go.sum @@ -4,8 +4,6 @@ github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= -github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 494ad0b94f..8350db58e0 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -609,109 +609,6 @@ func TestClient(t *testing.T) { assert.Equal(t, 0, closed, "expected no connections to be closed") }) - mt.Run("RTT90 is monitored", func(mt *mtest.T) { - mt.Parallel() - - // Reset the client with a dialer that delays all network round trips by 300ms and set the - // heartbeat interval to 100ms to reduce the time it takes to collect RTT samples. - mt.ResetClient(options.Client(). - SetDialer(newSlowConnDialer(slowConnDialerDelay)). - SetHeartbeatInterval(reducedHeartbeatInterval)) - - // Assert that RTT90s are eventually >300ms. - topo := getTopologyFromClient(mt.Client) - assert.Soon(mt, func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(100 * time.Millisecond) - - // Wait for all of the server's RTT90s to be >300ms. - done := true - for _, desc := range topo.Description().Servers { - server, err := topo.FindServer(desc) - assert.Nil(mt, err, "FindServer error: %v", err) - if server.RTTMonitor().P90() <= 300*time.Millisecond { - done = false - } - } - if done { - return - } - } - }, 10*time.Second) - }) - - // Test that if Timeout is set and the RTT90 is greater than the remaining timeout for an operation, the - // operation is not sent to the server, fails with a timeout error, and no connections are closed. - mt.Run("RTT90 used to prevent sending requests", func(mt *mtest.T) { - mt.Parallel() - - // Assert that we can call Ping with a 250ms timeout. - ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) - defer cancel() - err := mt.Client.Ping(ctx, nil) - assert.Nil(mt, err, "Ping error: %v", err) - - // Reset the client with a dialer that delays all network round trips by 300ms, set the - // heartbeat interval to 100ms to reduce the time it takes to collect RTT samples, and - // set a Timeout of 0 (infinite) on the Client to ensure that RTT90 is used as a sending - // threshold. - tpm := eventtest.NewTestPoolMonitor() - mt.ResetClient(options.Client(). - SetPoolMonitor(tpm.PoolMonitor). - SetDialer(newSlowConnDialer(slowConnDialerDelay)). - SetHeartbeatInterval(reducedHeartbeatInterval). - SetTimeout(0)) - - // Assert that RTT90s are eventually >275ms. - topo := getTopologyFromClient(mt.Client) - assert.Soon(mt, func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(100 * time.Millisecond) - - // Wait for all of the server's RTT90s to be >275ms. - done := true - for _, desc := range topo.Description().Servers { - server, err := topo.FindServer(desc) - assert.Nil(mt, err, "FindServer error: %v", err) - if server.RTTMonitor().P90() <= 275*time.Millisecond { - done = false - } - } - if done { - return - } - } - }, 10*time.Second) - - // Once we've waited for the RTT90 for the servers to be >275ms, run 10 Ping operations - // with a timeout of 275ms and expect that they return timeout errors. - for i := 0; i < 10; i++ { - ctx, cancel = context.WithTimeout(context.Background(), 275*time.Millisecond) - err := mt.Client.Ping(ctx, nil) - cancel() - assert.NotNil(mt, err, "expected Ping to return an error") - assert.True(mt, mongo.IsTimeout(err), "expected a timeout error, got: %v", err) - } - - // Assert that the Ping timeouts result in no connections being closed. - closed := len(tpm.Events(func(e *event.PoolEvent) bool { return e.Type == event.ConnectionClosed })) - assert.Equal(t, 0, closed, "expected no connections to be closed") - }) - // Test that OP_MSG is used for authentication-related commands on 3.6+ (WV 6+). Do not test when API version is // set, as handshakes will always use OP_MSG. opMsgOpts := mtest.NewOptions().ClientType(mtest.Proxy).MinServerVersion("3.6").Auth(true).RequireAPIVersion(false) diff --git a/internal/integration/unified/testrunner_operation.go b/internal/integration/unified/testrunner_operation.go index 49ff143d68..3027fbe393 100644 --- a/internal/integration/unified/testrunner_operation.go +++ b/internal/integration/unified/testrunner_operation.go @@ -186,6 +186,15 @@ func executeTestRunnerOperation(ctx context.Context, op *operation, loopDone <-c } } } + return nil + case "wait": + waitMS, err := convertValueToMilliseconds(args.Lookup("ms")) + if err != nil { + return err + } + + time.Sleep(waitMS) + return nil case "runOnThread": operationRaw, err := args.LookupErr("operation") @@ -484,3 +493,12 @@ func verifyIndexExists(ctx context.Context, dbName, collName, indexName string, } return nil } + +func convertValueToMilliseconds(val bson.RawValue) (time.Duration, error) { + int32Val, ok := val.Int32OK() + if !ok { + return 0, fmt.Errorf("failed to convert value of type %s to int32", val.Type) + } + + return time.Duration(int32Val) * time.Millisecond, nil +} diff --git a/testdata/client-side-operations-timeout/command-execution.json b/testdata/client-side-operations-timeout/command-execution.json new file mode 100644 index 0000000000..10f87d43ac --- /dev/null +++ b/testdata/client-side-operations-timeout/command-execution.json @@ -0,0 +1,394 @@ +{ + "description": "timeoutMS behaves correctly during command execution", + "schemaVersion": "1.9", + "runOnRequirements": [ + { + "minServerVersion": "4.9", + "topologies": [ + "single", + "replicaset", + "sharded-replicaset", + "sharded" + ], + "serverless": "forbid" + } + ], + "createEntities": [ + { + "client": { + "id": "failPointClient", + "useMultipleMongoses": false + } + } + ], + "initialData": [ + { + "collectionName": "coll", + "databaseName": "test", + "documents": [] + }, + { + "collectionName": "timeoutColl", + "databaseName": "test", + "documents": [] + } + ], + "tests": [ + { + "description": "maxTimeMS value in the command is less than timeoutMS", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "hello", + "isMaster" + ], + "appName": "reduceMaxTimeMSTest", + "blockConnection": true, + "blockTimeMS": 50 + } + } + } + }, + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "useMultipleMongoses": false, + "uriOptions": { + "appName": "reduceMaxTimeMSTest", + "w": 1, + "timeoutMS": 500, + "heartbeatFrequencyMS": 500 + }, + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database", + "client": "client", + "databaseName": "test" + } + }, + { + "collection": { + "id": "timeoutCollection", + "database": "database", + "collectionName": "timeoutColl" + } + } + ] + } + }, + { + "name": "insertOne", + "object": "timeoutCollection", + "arguments": { + "document": { + "_id": 1 + }, + "timeoutMS": 100000 + } + }, + { + "name": "wait", + "object": "testRunner", + "arguments": { + "ms": 1000 + } + }, + { + "name": "insertOne", + "object": "timeoutCollection", + "arguments": { + "document": { + "_id": 2 + } + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert", + "databaseName": "test", + "command": { + "insert": "timeoutColl" + } + } + }, + { + "commandStartedEvent": { + "commandName": "insert", + "databaseName": "test", + "command": { + "insert": "timeoutColl", + "maxTimeMS": { + "$$lte": 450 + } + } + } + } + ] + } + ] + }, + { + "description": "command is not sent if RTT is greater than timeoutMS", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "hello", + "isMaster" + ], + "appName": "rttTooHighTest", + "blockConnection": true, + "blockTimeMS": 50 + } + } + } + }, + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "useMultipleMongoses": false, + "uriOptions": { + "appName": "rttTooHighTest", + "w": 1, + "timeoutMS": 10, + "heartbeatFrequencyMS": 500 + }, + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database", + "client": "client", + "databaseName": "test" + } + }, + { + "collection": { + "id": "timeoutCollection", + "database": "database", + "collectionName": "timeoutColl" + } + } + ] + } + }, + { + "name": "insertOne", + "object": "timeoutCollection", + "arguments": { + "document": { + "_id": 1 + }, + "timeoutMS": 100000 + } + }, + { + "name": "wait", + "object": "testRunner", + "arguments": { + "ms": 1000 + } + }, + { + "name": "insertOne", + "object": "timeoutCollection", + "arguments": { + "document": { + "_id": 2 + } + }, + "expectError": { + "isTimeoutError": true + } + }, + { + "name": "insertOne", + "object": "timeoutCollection", + "arguments": { + "document": { + "_id": 3 + } + }, + "expectError": { + "isTimeoutError": true + } + }, + { + "name": "insertOne", + "object": "timeoutCollection", + "arguments": { + "document": { + "_id": 4 + } + }, + "expectError": { + "isTimeoutError": true + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert", + "databaseName": "test", + "command": { + "insert": "timeoutColl" + } + } + } + ] + } + ] + }, + { + "description": "short-circuit is not enabled with only 1 RTT measurement", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "hello", + "isMaster" + ], + "appName": "reduceMaxTimeMSTest", + "blockConnection": true, + "blockTimeMS": 100 + } + } + } + }, + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "useMultipleMongoses": false, + "uriOptions": { + "appName": "reduceMaxTimeMSTest", + "w": 1, + "timeoutMS": 90, + "heartbeatFrequencyMS": 100000 + }, + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database", + "client": "client", + "databaseName": "test" + } + }, + { + "collection": { + "id": "timeoutCollection", + "database": "database", + "collectionName": "timeoutColl" + } + } + ] + } + }, + { + "name": "insertOne", + "object": "timeoutCollection", + "arguments": { + "document": { + "_id": 1 + }, + "timeoutMS": 100000 + } + }, + { + "name": "insertOne", + "object": "timeoutCollection", + "arguments": { + "document": { + "_id": 2 + } + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert", + "databaseName": "test", + "command": { + "insert": "timeoutColl" + } + } + }, + { + "commandStartedEvent": { + "commandName": "insert", + "databaseName": "test", + "command": { + "insert": "timeoutColl", + "maxTimeMS": { + "$$lte": 450 + } + } + } + } + ] + } + ] + } + ] +} diff --git a/testdata/client-side-operations-timeout/command-execution.yml b/testdata/client-side-operations-timeout/command-execution.yml new file mode 100644 index 0000000000..400a90867a --- /dev/null +++ b/testdata/client-side-operations-timeout/command-execution.yml @@ -0,0 +1,250 @@ +description: "timeoutMS behaves correctly during command execution" + +schemaVersion: "1.9" + +runOnRequirements: + # The appName filter cannot be used to set a fail point on connection handshakes until server version 4.9 due to + # SERVER-49220/SERVER-49336. + - minServerVersion: "4.9" + # Skip load-balanced and serverless which do not support RTT measurements. + topologies: [ single, replicaset, sharded ] + serverless: forbid + +createEntities: + - client: + id: &failPointClient failPointClient + useMultipleMongoses: false + +initialData: + # The corresponding entities for the collections defined here are created in test-level createEntities operations. + # This is done so that tests can set fail points that will affect all of the handshakes and heartbeats done by a + # client. The collection and database names are listed here so that the collections will be dropped and re-created at + # the beginning of each test. + - collectionName: ®ularCollectionName coll + databaseName: &databaseName test + documents: [] + - collectionName: &timeoutCollectionName timeoutColl + databaseName: &databaseName test + documents: [] + +tests: + - description: "maxTimeMS value in the command is less than timeoutMS" + operations: + # Artificially increase the server RTT to ~50ms. + - name: failPoint + object: testRunner + arguments: + client: *failPointClient + failPoint: + configureFailPoint: failCommand + mode: "alwaysOn" + data: + failCommands: ["hello", "isMaster"] + appName: &appName reduceMaxTimeMSTest + blockConnection: true + blockTimeMS: 50 + # Create a client with the app name specified in the fail point and timeoutMS higher than blockTimeMS. + # Also create database and collection entities derived from the new client. + - name: createEntities + object: testRunner + arguments: + entities: + - client: + id: &client client + useMultipleMongoses: false + uriOptions: + appName: *appName + w: 1 # Override server's w:majority default to speed up the test. + timeoutMS: 500 + heartbeatFrequencyMS: 500 + observeEvents: + - commandStartedEvent + - database: + id: &database database + client: *client + databaseName: *databaseName + - collection: + id: &timeoutCollection timeoutCollection + database: *database + collectionName: *timeoutCollectionName + # Do an operation with a large timeout to ensure the servers are discovered. + - name: insertOne + object: *timeoutCollection + arguments: + document: { _id: 1 } + timeoutMS: 100000 + # Wait until short-circuiting has been enabled (at least 2 RTT measurements). + - name: wait + object: testRunner + arguments: + ms: 1000 + # Do an operation with timeoutCollection so the event will include a maxTimeMS field. + - name: insertOne + object: *timeoutCollection + arguments: + document: { _id: 2 } + expectEvents: + - client: *client + events: + - commandStartedEvent: + commandName: insert + databaseName: *databaseName + command: + insert: *timeoutCollectionName + - commandStartedEvent: + commandName: insert + databaseName: *databaseName + command: + insert: *timeoutCollectionName + maxTimeMS: { $$lte: 450 } + + - description: "command is not sent if RTT is greater than timeoutMS" + operations: + # Artificially increase the server RTT to ~50ms. + - name: failPoint + object: testRunner + arguments: + client: *failPointClient + failPoint: + configureFailPoint: failCommand + mode: "alwaysOn" + data: + failCommands: ["hello", "isMaster"] + appName: &appName rttTooHighTest + blockConnection: true + blockTimeMS: 50 + # Create a client with the app name specified in the fail point. Also create database and collection entities + # derived from the new client. There is one collection entity with no timeoutMS and another with a timeoutMS + # that's lower than the fail point's blockTimeMS value. + - name: createEntities + object: testRunner + arguments: + entities: + - client: + id: &client client + useMultipleMongoses: false + uriOptions: + appName: *appName + w: 1 # Override server's w:majority default to speed up the test. + timeoutMS: 10 + heartbeatFrequencyMS: 500 + observeEvents: + - commandStartedEvent + - database: + id: &database database + client: *client + databaseName: *databaseName + - collection: + id: &timeoutCollection timeoutCollection + database: *database + collectionName: *timeoutCollectionName + # Do an operation with a large timeout to ensure the servers are discovered. + - name: insertOne + object: *timeoutCollection + arguments: + document: { _id: 1 } + timeoutMS: 100000 + # Wait until short-circuiting has been enabled (at least 2 RTT measurements). + - name: wait + object: testRunner + arguments: + ms: 1000 + # Do an operation with timeoutCollection which will error. + - name: insertOne + object: *timeoutCollection + arguments: + document: { _id: 2 } + expectError: + isTimeoutError: true + # Do an operation with timeoutCollection which will error. + - name: insertOne + object: *timeoutCollection + arguments: + document: { _id: 3 } + expectError: + isTimeoutError: true + # Do an operation with timeoutCollection which will error. + - name: insertOne + object: *timeoutCollection + arguments: + document: { _id: 4 } + expectError: + isTimeoutError: true + expectEvents: + # There should only be one event, which corresponds to the first + # insertOne call. For the subsequent insertOne calls, drivers should + # fail client-side. + - client: *client + events: + - commandStartedEvent: + commandName: insert + databaseName: *databaseName + command: + insert: *timeoutCollectionName + + - description: "short-circuit is not enabled with only 1 RTT measurement" + operations: + # Artificially increase the server RTT to ~300ms. + - name: failPoint + object: testRunner + arguments: + client: *failPointClient + failPoint: + configureFailPoint: failCommand + mode: "alwaysOn" + data: + failCommands: ["hello", "isMaster"] + appName: &appName reduceMaxTimeMSTest + blockConnection: true + blockTimeMS: 100 + # Create a client with the app name specified in the fail point and timeoutMS lower than blockTimeMS. + # Also create database and collection entities derived from the new client. + - name: createEntities + object: testRunner + arguments: + entities: + - client: + id: &client client + useMultipleMongoses: false + uriOptions: + appName: *appName + w: 1 # Override server's w:majority default to speed up the test. + timeoutMS: 90 + heartbeatFrequencyMS: 100000 # Override heartbeatFrequencyMS to ensure only 1 RTT is recorded. + observeEvents: + - commandStartedEvent + - database: + id: &database database + client: *client + databaseName: *databaseName + - collection: + id: &timeoutCollection timeoutCollection + database: *database + collectionName: *timeoutCollectionName + # Do an operation with a large timeout to ensure the servers are discovered. + - name: insertOne + object: *timeoutCollection + arguments: + document: { _id: 1 } + timeoutMS: 100000 + # Do an operation with timeoutCollection which will succeed. If this + # fails it indicates the driver mistakenly used the min RTT even though + # there has only been one sample. + - name: insertOne + object: *timeoutCollection + arguments: + document: { _id: 2 } + expectEvents: + - client: *client + events: + - commandStartedEvent: + commandName: insert + databaseName: *databaseName + command: + insert: *timeoutCollectionName + - commandStartedEvent: + commandName: insert + databaseName: *databaseName + command: + insert: *timeoutCollectionName + maxTimeMS: { $$lte: 450 } diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 584bb6012a..d0c0ee5c22 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -82,9 +82,6 @@ type RTTMonitor interface { // Min returns the minimum observed round-trip time over the window period. Min() time.Duration - // P90 returns the 90th percentile observed round-trip time over the window period. - P90() time.Duration - // Stats returns stringified stats of the current state of the monitor. Stats() string } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index b1f0bce873..cd47ea13e3 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -668,7 +668,7 @@ func (op Operation) Execute(ctx context.Context) error { } // Calculate maxTimeMS value to potentially be appended to the wire message. - maxTimeMS, err := op.calculateMaxTimeMS(ctx, srvr.RTTMonitor().P90(), srvr.RTTMonitor().Stats()) + maxTimeMS, err := op.calculateMaxTimeMS(ctx, srvr.RTTMonitor().Min(), srvr.RTTMonitor().Stats()) if err != nil { return err } @@ -759,14 +759,8 @@ func (op Operation) Execute(ctx context.Context) error { if ctx.Err() != nil { err = ctx.Err() } else if deadline, ok := ctx.Deadline(); ok { - if csot.IsTimeoutContext(ctx) && time.Now().Add(srvr.RTTMonitor().P90()).After(deadline) { - err = fmt.Errorf( - "remaining time %v until context deadline is less than 90th percentile RTT: %w\n%v", - time.Until(deadline), - ErrDeadlineWouldBeExceeded, - srvr.RTTMonitor().Stats()) - } else if time.Now().Add(srvr.RTTMonitor().Min()).After(deadline) { - err = context.DeadlineExceeded + if time.Now().Add(srvr.RTTMonitor().Min()).After(deadline) { + err = fmt.Errorf("%w: %v", ErrDeadlineWouldBeExceeded, srvr.RTTMonitor().Stats()) } } @@ -1546,22 +1540,22 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) // if the ctx is a Timeout context. If the context is not a Timeout context, it uses the // operation's MaxTimeMS if set. If no MaxTimeMS is set on the operation, and context is // not a Timeout context, calculateMaxTimeMS returns 0. -func (op Operation) calculateMaxTimeMS(ctx context.Context, rtt90 time.Duration, rttStats string) (uint64, error) { +func (op Operation) calculateMaxTimeMS(ctx context.Context, rttMin time.Duration, rttStats string) (uint64, error) { if csot.IsTimeoutContext(ctx) { if deadline, ok := ctx.Deadline(); ok { remainingTimeout := time.Until(deadline) - maxTime := remainingTimeout - rtt90 // Always round up to the next millisecond value so we never truncate the calculated // maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms). - maxTimeMS := int64((maxTime + (time.Millisecond - 1)) / time.Millisecond) + maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond) if maxTimeMS <= 0 { return 0, fmt.Errorf( - "remaining time %v until context deadline is less than or equal to 90th percentile RTT: %w\n%v", + "remaining time %v until context deadline is less than or equal to rtt minimum: %w\n%v", remainingTimeout, ErrDeadlineWouldBeExceeded, rttStats) } + return uint64(maxTimeMS), nil } } else if op.MaxTime != nil { diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index be16be3f50..9fbfaae133 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -269,53 +269,72 @@ func TestOperation(t *testing.T) { }) }) t.Run("calculateMaxTimeMS", func(t *testing.T) { - timeout := 5 * time.Second - maxTime := 2 * time.Second - negMaxTime := -2 * time.Second - shortRTT := 50 * time.Millisecond - longRTT := 10 * time.Second + var ( + timeout = 5 * time.Second + maxTime = 2 * time.Second + negMaxTime = -2 * time.Second + shortRTT = 50 * time.Millisecond + longRTT = 10 * time.Second + verShortRTT = 400 * time.Microsecond + ) + timeoutCtx, cancel := csot.MakeTimeoutContext(context.Background(), timeout) defer cancel() testCases := []struct { - name string - op Operation - ctx context.Context - rtt90 time.Duration - want uint64 - err error + name string + op Operation + ctx context.Context + rtt RTTMonitor + rttMin time.Duration + rttStats string + want uint64 + err error }{ { - name: "uses context deadline and rtt90 with timeout", - op: Operation{MaxTime: &maxTime}, - ctx: timeoutCtx, - rtt90: shortRTT, - want: 5000, - err: nil, + name: "uses context deadline and rtt90 with timeout", + op: Operation{MaxTime: &maxTime}, + ctx: timeoutCtx, + rttMin: shortRTT, + rttStats: "", + want: 5000, + err: nil, + }, + { + name: "uses MaxTime without timeout", + op: Operation{MaxTime: &maxTime}, + ctx: context.Background(), + rttMin: longRTT, + rttStats: "", + want: 2000, + err: nil, }, { - name: "uses MaxTime without timeout", - op: Operation{MaxTime: &maxTime}, - ctx: context.Background(), - rtt90: longRTT, - want: 2000, - err: nil, + name: "errors when remaining timeout is less than rtt90", + op: Operation{MaxTime: &maxTime}, + ctx: timeoutCtx, + rttMin: timeout, + rttStats: "", + want: 0, + err: ErrDeadlineWouldBeExceeded, }, { - name: "errors when remaining timeout is less than rtt90", - op: Operation{MaxTime: &maxTime}, - ctx: timeoutCtx, - rtt90: timeout, - want: 0, - err: ErrDeadlineWouldBeExceeded, + name: "errors when MaxTime is negative", + op: Operation{MaxTime: &negMaxTime}, + ctx: context.Background(), + rttMin: longRTT, + rttStats: "", + want: 0, + err: ErrNegativeMaxTime, }, { - name: "errors when MaxTime is negative", - op: Operation{MaxTime: &negMaxTime}, - ctx: context.Background(), - rtt90: longRTT, - want: 0, - err: ErrNegativeMaxTime, + name: "sub millisecond rtt should round up", + op: Operation{MaxTime: &verShortRTT}, + ctx: context.Background(), + rttMin: longRTT, + rttStats: "", + want: 1, + err: nil, }, } for _, tc := range testCases { @@ -324,7 +343,7 @@ func TestOperation(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - got, err := tc.op.calculateMaxTimeMS(tc.ctx, tc.rtt90, "") + got, err := tc.op.calculateMaxTimeMS(tc.ctx, tc.rttMin, tc.rttStats) // Assert that the calculated maxTimeMS is less than or equal to the expected value. A few // milliseconds will have elapsed toward the context deadline, and (remainingTimeout diff --git a/x/mongo/driver/topology/rtt_monitor.go b/x/mongo/driver/topology/rtt_monitor.go index 0934beed89..8b0a4b4950 100644 --- a/x/mongo/driver/topology/rtt_monitor.go +++ b/x/mongo/driver/topology/rtt_monitor.go @@ -7,21 +7,20 @@ package topology import ( + "container/list" "context" "fmt" - "math" "sync" "time" - "github.com/montanaflynn/stats" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" ) const ( - rttAlphaValue = 0.2 - minSamples = 10 - maxSamples = 500 + rttAlphaValue = 0.2 + minRTTSamplesForMovingMin = 2 + maxRTTSamplesForMovingMin = 10 ) type rttConfig struct { @@ -45,12 +44,10 @@ type rttMonitor struct { // disconnecting will await the cancellation of a started connection. The // use case for rttMonitor.connect needs to be goroutine safe. connMu sync.Mutex - samples []time.Duration - offset int - minRTT time.Duration - rtt90 time.Duration averageRTT time.Duration averageRTTSet bool + movingMin *list.List + minRTT time.Duration closeWg sync.WaitGroup cfg *rttConfig @@ -67,15 +64,12 @@ func newRTTMonitor(cfg *rttConfig) *rttMonitor { } ctx, cancel := context.WithCancel(context.Background()) - // Determine the number of samples we need to keep to store the minWindow of RTT durations. The - // number of samples must be between [10, 500]. - numSamples := int(math.Max(minSamples, math.Min(maxSamples, float64((cfg.minRTTWindow)/cfg.interval)))) return &rttMonitor{ - samples: make([]time.Duration, numSamples), - cfg: cfg, - ctx: ctx, - cancelFn: cancel, + cfg: cfg, + ctx: ctx, + cancelFn: cancel, + movingMin: list.New(), } } @@ -132,8 +126,8 @@ func (r *rttMonitor) start() { // successfully established the new connection. Otherwise, close the connection and try to // create another new connection. if err == nil { - r.addSample(conn.helloRTT) r.runHellos(conn) + r.addSample(conn.helloRTT) } // Close any connection here because we're either about to try to create another new @@ -197,29 +191,51 @@ func (r *rttMonitor) reset() { r.mu.Lock() defer r.mu.Unlock() - for i := range r.samples { - r.samples[i] = 0 - } - r.offset = 0 - r.minRTT = 0 - r.rtt90 = 0 + r.movingMin = list.New() r.averageRTT = 0 r.averageRTTSet = false } +// appendMovingMin will append the RTT to the movingMin list which tracks a +// minimum RTT within the last "minRTTSamplesForMovingMin" RTT samples. +func (r *rttMonitor) appendMovingMin(rtt time.Duration) { + if r.movingMin == nil || rtt < 0 { + return + } + + if r.movingMin.Len() == maxRTTSamplesForMovingMin { + r.movingMin.Remove(r.movingMin.Front()) + } + + r.movingMin.PushBack(rtt) +} + +// min will return the minimum value in the movingMin list. +func (r *rttMonitor) min() time.Duration { + if r.movingMin == nil || r.movingMin.Len() < minRTTSamplesForMovingMin { + return 0 + } + + var min time.Duration + for e := r.movingMin.Front(); e != nil; e = e.Next() { + val := e.Value.(time.Duration) + + if min == 0 || val < min { + min = val + } + } + + return min +} + func (r *rttMonitor) addSample(rtt time.Duration) { // Lock for the duration of this method. We're doing compuationally inexpensive work very infrequently, so lock // contention isn't expected. r.mu.Lock() defer r.mu.Unlock() - r.samples[r.offset] = rtt - r.offset = (r.offset + 1) % len(r.samples) - // Set the minRTT and 90th percentile RTT of all collected samples. Require at least 10 samples before - // setting these to prevent noisy samples on startup from artificially increasing RTT and to allow the - // calculation of a 90th percentile. - r.minRTT = min(r.samples, minSamples) - r.rtt90 = percentile(90.0, r.samples, minSamples) + r.appendMovingMin(rtt) + r.minRTT = r.min() if !r.averageRTTSet { r.averageRTT = rtt @@ -230,48 +246,6 @@ func (r *rttMonitor) addSample(rtt time.Duration) { r.averageRTT = time.Duration(rttAlphaValue*float64(rtt) + (1-rttAlphaValue)*float64(r.averageRTT)) } -// min returns the minimum value of the slice of duration samples. Zero values are not considered -// samples and are ignored. If no samples or fewer than minSamples are found in the slice, min -// returns 0. -func min(samples []time.Duration, minSamples int) time.Duration { - count := 0 - min := time.Duration(math.MaxInt64) - for _, d := range samples { - if d > 0 { - count++ - } - if d > 0 && d < min { - min = d - } - } - if count == 0 || count < minSamples { - return 0 - } - return min -} - -// percentile returns the specified percentile value of the slice of duration samples. Zero values -// are not considered samples and are ignored. If no samples or fewer than minSamples are found -// in the slice, percentile returns 0. -func percentile(perc float64, samples []time.Duration, minSamples int) time.Duration { - // Convert Durations to float64s. - floatSamples := make([]float64, 0, len(samples)) - for _, sample := range samples { - if sample > 0 { - floatSamples = append(floatSamples, float64(sample)) - } - } - if len(floatSamples) == 0 || len(floatSamples) < minSamples { - return 0 - } - - p, err := stats.Percentile(floatSamples, perc) - if err != nil { - panic(fmt.Errorf("x/mongo/driver/topology: error calculating %f percentile RTT: %v for samples:\n%v", perc, err, floatSamples)) - } - return time.Duration(p) -} - // EWMA returns the exponentially weighted moving average observed round-trip time. func (r *rttMonitor) EWMA() time.Duration { r.mu.RLock() @@ -288,41 +262,11 @@ func (r *rttMonitor) Min() time.Duration { return r.minRTT } -// P90 returns the 90th percentile observed round-trip time over the window period. -func (r *rttMonitor) P90() time.Duration { - r.mu.RLock() - defer r.mu.RUnlock() - - return r.rtt90 -} - // Stats returns stringified stats of the current state of the monitor. func (r *rttMonitor) Stats() string { r.mu.RLock() defer r.mu.RUnlock() - // Calculate standard deviation and average (non-EWMA) of samples. - var sum float64 - floatSamples := make([]float64, 0, len(r.samples)) - for _, sample := range r.samples { - if sample > 0 { - floatSamples = append(floatSamples, float64(sample)) - sum += float64(sample) - } - } - - var avg, stdDev float64 - if len(floatSamples) > 0 { - avg = sum / float64(len(floatSamples)) - - var err error - stdDev, err = stats.StandardDeviation(floatSamples) - if err != nil { - panic(fmt.Errorf("x/mongo/driver/topology: error calculating standard deviation RTT: %v for samples:\n%v", err, floatSamples)) - } - } - return fmt.Sprintf(`Round-trip-time monitor statistics:`+"\n"+ - `average RTT: %v, minimum RTT: %v, 90th percentile RTT: %v, standard dev: %v`+"\n", - time.Duration(avg), r.minRTT, r.rtt90, time.Duration(stdDev)) + `moving average RTT: %v, minimum RTT: %v`+"\n", r.averageRTT, r.minRTT) } diff --git a/x/mongo/driver/topology/rtt_monitor_test.go b/x/mongo/driver/topology/rtt_monitor_test.go index b2f73f8862..5fa1cb9bf1 100644 --- a/x/mongo/driver/topology/rtt_monitor_test.go +++ b/x/mongo/driver/topology/rtt_monitor_test.go @@ -8,9 +8,9 @@ package topology import ( "bytes" + "container/list" "context" "io" - "math" "net" "sync" "sync/atomic" @@ -83,7 +83,7 @@ func (*mockSlowConn) SetReadDeadline(_ time.Time) error { return nil } func (*mockSlowConn) SetWriteDeadline(_ time.Time) error { return nil } func TestRTTMonitor(t *testing.T) { - t.Run("measures the average, minimum and 90th percentile RTT", func(t *testing.T) { + t.Run("measures the average and minimum RTT", func(t *testing.T) { t.Parallel() dialer := DialerFunc(func(_ context.Context, _, _ string) (net.Conn, error) { @@ -103,10 +103,10 @@ func TestRTTMonitor(t *testing.T) { assert.Eventuallyf( t, - func() bool { return rtt.EWMA() > 0 && rtt.Min() > 0 && rtt.P90() > 0 }, + func() bool { return rtt.EWMA() > 0 && rtt.Min() > 0 }, 1*time.Second, 10*time.Millisecond, - "expected EWMA(), Min() and P90() to return positive durations within 1 second") + "expected EWMA() and Min() to return positive durations within 1 second") assert.True( t, rtt.EWMA() > 0, @@ -117,46 +117,6 @@ func TestRTTMonitor(t *testing.T) { rtt.Min() > 0, "expected Min() to return a positive duration, got %v", rtt.Min()) - assert.True( - t, - rtt.P90() > 0, - "expected P90() to return a positive duration, got %v", - rtt.P90()) - }) - - t.Run("creates the correct size samples slice", func(t *testing.T) { - t.Parallel() - - cases := []struct { - desc string - interval time.Duration - wantSamplesLen int - }{ - { - desc: "default", - interval: 10 * time.Second, - wantSamplesLen: 30, - }, - { - desc: "min", - interval: 10 * time.Minute, - wantSamplesLen: 10, - }, - { - desc: "max", - interval: 1 * time.Millisecond, - wantSamplesLen: 500, - }, - } - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - rtt := newRTTMonitor(&rttConfig{ - interval: tc.interval, - minRTTWindow: 5 * time.Minute, - }) - assert.Equal(t, tc.wantSamplesLen, len(rtt.samples), "expected samples length to match") - }) - } }) t.Run("can connect and disconnect repeatedly", func(t *testing.T) { @@ -213,12 +173,7 @@ func TestRTTMonitor(t *testing.T) { 1*time.Second, 10*time.Millisecond, "expected Min() to return a positive duration within 1 second") - assert.Eventuallyf( - t, - func() bool { return rtt.P90() > 0 }, - 1*time.Second, - 10*time.Millisecond, - "expected P90() to return a positive duration within 1 second") + rtt.reset() } }) @@ -315,12 +270,6 @@ func TestRTTMonitor(t *testing.T) { 1*time.Second, 10*time.Millisecond, "expected Min() to return a positive duration within 1 second") - assert.Eventuallyf( - t, - func() bool { return rtt.P90() > 0 }, - 1*time.Second, - 10*time.Millisecond, - "expected P90() to return a positive duration within 1 second") rtt.disconnect() l.Close() @@ -328,149 +277,139 @@ func TestRTTMonitor(t *testing.T) { }) } -func TestMin(t *testing.T) { - cases := []struct { - desc string - samples []time.Duration - minSamples int - want time.Duration +// makeArithmeticSamples will make an arithmetic sequence of time.Duration +// samples starting at the lower value as ms and ending at the upper value as +// ms. For example, if lower=1 and upder=4, then the return value will be +// [1ms, 2ms, 3ms, 4ms]. +func makeArithmeticSamples(lower, upper int) []time.Duration { + samples := []time.Duration{} + for i := 0; i < upper-lower+1; i++ { + samples = append(samples, time.Duration(lower+i)*time.Millisecond) + } + + return samples +} + +func TestRTTMonitor_appendMovingMin(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + samples []time.Duration + want []time.Duration }{ { - desc: "Should return the min for minSamples = 0", - samples: []time.Duration{1, 0, 0, 0}, - minSamples: 0, - want: 1, - }, - { - desc: "Should return 0 for fewer than minSamples samples", - samples: []time.Duration{1, 0, 0, 0}, - minSamples: 2, - want: 0, - }, - { - desc: "Should return 0 for empty samples slice", - samples: []time.Duration{}, - minSamples: 0, - want: 0, - }, - { - desc: "Should return 0 for no valid samples", - samples: []time.Duration{0, 0, 0}, - minSamples: 0, - want: 0, + name: "singleton", + samples: makeArithmeticSamples(1, 1), + want: makeArithmeticSamples(1, 1), }, { - desc: "Should return max int64 if all samples are max int64", - samples: []time.Duration{math.MaxInt64, math.MaxInt64, math.MaxInt64}, - minSamples: 0, - want: math.MaxInt64, + name: "multiplicity", + samples: makeArithmeticSamples(1, 2), + want: makeArithmeticSamples(1, 2), }, { - desc: "Should return the minimum if there are enough samples", - samples: []time.Duration{1 * time.Second, 100 * time.Millisecond, 150 * time.Millisecond, 0, 0, 0}, - minSamples: 3, - want: 100 * time.Millisecond, + name: "exceed maxRTTSamples", + samples: makeArithmeticSamples(1, 11), + want: makeArithmeticSamples(2, 11), }, { - desc: "Should return 0 if there are are not enough samples", - samples: []time.Duration{1 * time.Second, 100 * time.Millisecond, 0, 0, 0, 0}, - minSamples: 3, - want: 0, + name: "exceed maxRTTSamples but only with negative values", + samples: makeArithmeticSamples(-1, 9), + want: makeArithmeticSamples(0, 9), }, } - for _, tc := range cases { - tc := tc - t.Run(tc.desc, func(t *testing.T) { + for _, test := range tests { + test := test // capture the range variable + + t.Run(test.name, func(t *testing.T) { t.Parallel() - got := min(tc.samples, tc.minSamples) - assert.Equal(t, tc.want, got, "unexpected result from min()") + rtt := &rttMonitor{ + movingMin: list.New(), + } + + for _, sample := range test.samples { + rtt.appendMovingMin(sample) + } + + pos := 0 + for e := rtt.movingMin.Front(); e != nil; e = e.Next() { + assert.Equal(t, test.want[pos], e.Value) + + pos++ + } }) } } -func TestPercentile(t *testing.T) { - cases := []struct { - desc string - samples []time.Duration - minSamples int - percentile float64 - want time.Duration +func TestRTTMonitor_min(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + samples []time.Duration + want time.Duration }{ { - desc: "Should return 0 for fewer than minSamples samples", - samples: []time.Duration{1, 0, 0, 0}, - minSamples: 2, - percentile: 90.0, - want: 0, + name: "empty", + samples: []time.Duration{}, + want: 0, }, { - desc: "Should return 0 for empty samples slice", - samples: []time.Duration{}, - minSamples: 0, - percentile: 90.0, - want: 0, + name: "one", + samples: makeArithmeticSamples(1, 1), + want: 0, }, { - desc: "Should return 0 for no valid samples", - samples: []time.Duration{0, 0, 0}, - minSamples: 0, - percentile: 90.0, - want: 0, + name: "two", + samples: makeArithmeticSamples(1, 2), + want: 1 * time.Millisecond, }, { - desc: "First tertile when minSamples = 0", - samples: []time.Duration{1, 2, 3, 0, 0, 0}, - minSamples: 0, - percentile: 33.34, - want: 1, + name: "non-unit lower bound", + samples: makeArithmeticSamples(2, 9), + want: 2 * time.Millisecond, }, { - desc: "90th percentile when there are enough samples", + name: "negative lower bound with 2 values", + samples: []time.Duration{-1, 1}, + want: 0, + }, + { + name: "negative lower bound with 3 values", samples: []time.Duration{ - 100 * time.Millisecond, - 200 * time.Millisecond, - 300 * time.Millisecond, - 400 * time.Millisecond, - 500 * time.Millisecond, - 600 * time.Millisecond, - 700 * time.Millisecond, - 800 * time.Millisecond, - 900 * time.Millisecond, - 1 * time.Second, - 0, 0, 0}, - minSamples: 10, - percentile: 90.0, - want: 900 * time.Millisecond, + -1 * time.Millisecond, + 1 * time.Millisecond, + 2 * time.Millisecond}, + want: 1 * time.Millisecond, }, { - desc: "10th percentile when there are enough samples", + name: "non-sequential", samples: []time.Duration{ - 100 * time.Millisecond, - 200 * time.Millisecond, - 300 * time.Millisecond, - 400 * time.Millisecond, - 500 * time.Millisecond, - 600 * time.Millisecond, - 700 * time.Millisecond, - 800 * time.Millisecond, - 900 * time.Millisecond, - 1 * time.Second, - 0, 0, 0}, - minSamples: 10, - percentile: 10.0, - want: 100 * time.Millisecond, + 2 * time.Millisecond, + 1 * time.Millisecond, + }, + want: 1 * time.Millisecond, }, } - for _, tc := range cases { - tc := tc - t.Run(tc.desc, func(t *testing.T) { + for _, test := range tests { + test := test // capture the range variable + + t.Run(test.name, func(t *testing.T) { t.Parallel() - got := percentile(tc.percentile, tc.samples, tc.minSamples) - assert.Equal(t, tc.want, got, "unexpected result from percentile()") + rtt := &rttMonitor{ + movingMin: list.New(), + } + + for _, sample := range test.samples { + rtt.appendMovingMin(sample) + } + + assert.Equal(t, test.want, rtt.min()) }) } } diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 82bfdccc60..7005cd01fd 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -659,6 +659,7 @@ func (s *Server) update() { // If the server supports streaming or we're already streaming, we want to move to streaming the next response // without waiting. If the server has transitioned to Unknown from a network error, we want to do another // check without waiting in case it was a transient error and the server isn't actually down. + serverSupportsStreaming := desc.Kind != description.Unknown && desc.TopologyVersion != nil connectionIsStreaming := s.conn != nil && s.conn.getCurrentlyStreaming() transitionedFromNetworkError := desc.LastError != nil && unwrapConnectionError(desc.LastError) != nil && previousDescription.Kind != description.Unknown @@ -667,7 +668,7 @@ func (s *Server) update() { s.rttMonitor.connect() } - if isStreamable(s) || connectionIsStreaming || transitionedFromNetworkError { + if isStreamable(s) && (serverSupportsStreaming || connectionIsStreaming) || transitionedFromNetworkError { continue }