Skip to content

Commit

Permalink
GODRIVER-2348 Extend test coverage for csot
Browse files Browse the repository at this point in the history
  • Loading branch information
prestonvasquez committed Feb 13, 2024
1 parent 9f4a8e5 commit e69bc0a
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 6 deletions.
7 changes: 4 additions & 3 deletions internal/csot/csot.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ import (

type withoutMaxTime struct{}

// WithoutMaxTime returns a new context with a "skipMaxTime" value that
// WithoutMaxTime returns a new context with a "withoutMaxTime" value that
// is used to inform operation construction to not add a maxTimeMS to a wire
// message, regardless of a context deadline. This is specifically used for
// monitoring where non-awaitable hello commands are put on the wire.
// monitoring where non-awaitable hello commands are put on the wire, or to
// indicate that the user has set a "0" (i.e. infinite) CSOT.
func WithoutMaxTime(ctx context.Context) context.Context {
return context.WithValue(ctx, withoutMaxTime{}, true)
}

// IsWithoutMaxTime checks if the provided context has been assigned the
// "skipMaxTime" value.
// "withoutMaxTime" value.
func IsWithoutMaxTime(ctx context.Context) bool {
return ctx.Value(withoutMaxTime{}) != nil
}
Expand Down
105 changes: 103 additions & 2 deletions internal/csot/csot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ import (
"go.mongodb.org/mongo-driver/internal/assert"
)

func newTestContext(t *testing.T, timeout time.Duration) context.Context {
func newTestContext(t *testing.T, timeout time.Duration, values ...interface{}) context.Context {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
t.Cleanup(cancel)

for _, value := range values {
ctx = context.WithValue(ctx, value, true)
}

return ctx
}

Expand Down Expand Up @@ -298,7 +302,7 @@ func TestValidChangeStreamTimeouts(t *testing.T) {
},
{
name: "no context deadline and maxAwaitTime with zero timeout",
parent: context.Background(),
parent: newTestContext(t, -1, withoutMaxTime{}),
maxAwaitTimeout: newDurPtr(1),
timeout: newDurPtr(-1),
wantTimeout: 0,
Expand All @@ -317,3 +321,100 @@ func TestValidChangeStreamTimeouts(t *testing.T) {
})
}
}

func TestWithTimeout(t *testing.T) {
t.Parallel()

tests := []struct {
name string
parent context.Context
timeout *time.Duration
wantTimeout time.Duration
wantDeadline bool
wantValues []interface{}
}{
{
name: "deadline set with non-zero timeout",
parent: newTestContext(t, 1),
timeout: newDurPtr(2),
wantTimeout: 1,
wantDeadline: true,
wantValues: []interface{}{},
},
{
name: "deadline set with zero timeout",
parent: newTestContext(t, 1),
timeout: newDurPtr(0),
wantTimeout: 1,
wantDeadline: true,
wantValues: []interface{}{},
},
{
name: "deadline set with nil timeout",
parent: newTestContext(t, 1),
timeout: nil,
wantTimeout: 1,
wantDeadline: true,
wantValues: []interface{}{},
},
{
name: "deadline unset with non-zero timeout",
parent: context.Background(),
timeout: newDurPtr(1),
wantTimeout: 1,
wantDeadline: true,
wantValues: []interface{}{},
},
{
name: "deadline unset with zero timeout",
parent: context.Background(),
timeout: newDurPtr(0),
wantTimeout: 0,
wantDeadline: false,
wantValues: []interface{}{withoutMaxTime{}},
},
{
name: "deadline unset with nil timeout",
parent: context.Background(),
timeout: nil,
wantTimeout: 0,
wantDeadline: false,
wantValues: []interface{}{},
},
{
name: "deadline unset with non-zero timeout with withoutMaxTime",
parent: WithoutMaxTime(context.Background()),
timeout: newDurPtr(1),
wantTimeout: 1,
wantDeadline: false,
wantValues: []interface{}{withoutMaxTime{}},
},
}

for _, test := range tests {
test := test // Capture the range variable

t.Run(test.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := WithTimeout(test.parent, test.timeout)
t.Cleanup(cancel)

deadline, gotDeadline := ctx.Deadline()
assert.Equal(t, test.wantDeadline, gotDeadline)

if gotDeadline {
delta := time.Until(deadline) - test.wantTimeout
tolerance := 5 * time.Millisecond

assert.True(t, delta > -1*tolerance, "expected delta=%d > %d", delta, -1*tolerance)
assert.True(t, delta <= tolerance, "expected delta=%d <= %d", delta, tolerance)
}

for _, wantValue := range test.wantValues {
assert.NotNil(t, ctx.Value(wantValue), "expected context to have value %v", wantValue)
}
})
}

}
2 changes: 1 addition & 1 deletion internal/integration/unified/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (op *operation) run(ctx context.Context, loopDone <-chan struct{}) (*operat
// Special handling for the "timeoutMS" field because it applies to (almost) all operations.
if tms, ok := op.Arguments.Lookup("timeoutMS").Int32OK(); ok {
timeout := time.Duration(tms) * time.Millisecond
newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, timeout)
newCtx, cancelFunc := csot.WithTimeout(ctx, &timeout)
// Redefine ctx to be the new timeout-derived context.
ctx = newCtx
// Cancel the timeout-derived context at the end of run to avoid a context leak.
Expand Down

0 comments on commit e69bc0a

Please sign in to comment.