Skip to content

Commit

Permalink
GODRIVER-2348 Update connectTimeoutMS to cover all blocking operation…
Browse files Browse the repository at this point in the history
…s during connection establishment
  • Loading branch information
prestonvasquez committed Jan 19, 2024
1 parent 15c6d42 commit 9af952f
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 54 deletions.
8 changes: 7 additions & 1 deletion mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,13 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) {
if err != nil {
return nil, err
}
client.serverAPI = topology.ServerAPIFromServerOptions(cfg.ServerOpts)

var connectTimeout time.Duration
if clientOpt.ConnectTimeout != nil {
connectTimeout = *clientOpt.ConnectTimeout
}

client.serverAPI = topology.ServerAPIFromServerOptions(connectTimeout, cfg.ServerOpts)

if client.deployment == nil {
client.deployment, err = topology.New(cfg)
Expand Down
14 changes: 11 additions & 3 deletions x/mongo/driver/topology/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -1091,10 +1091,18 @@ func (p *pool) createConnections(ctx context.Context, wg *sync.WaitGroup) {
// Pass the createConnections context to connect to allow pool close to
// cancel connection establishment so shutdown doesn't block indefinitely if
// connectTimeout=0.
ctx, cancel := context.WithTimeout(ctx, p.connectTimeout)
defer cancel()
//
// Per the specifications, an explicit value of connectTimeout=0 means the
// timeout is "infinite".
connctx := context.Background()
if p.connectTimeout != 0 {
var cancel context.CancelFunc
connctx, cancel = context.WithTimeout(ctx, p.connectTimeout)

defer cancel()
}

err := conn.connect(ctx)
err := conn.connect(connctx)
if err != nil {
w.tryDeliver(nil, err)

Expand Down
38 changes: 12 additions & 26 deletions x/mongo/driver/topology/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func NewServer(
connectTimeout time.Duration,
opts ...ServerOption,
) *Server {
cfg := newServerConfig(opts...)
cfg := newServerConfig(connectTimeout, opts...)
globalCtx, globalCtxCancel := context.WithCancel(context.Background())
s := &Server{
state: serverDisconnected,
Expand Down Expand Up @@ -767,8 +767,9 @@ func (s *Server) updateDescription(desc description.Server) {
func (s *Server) createConnection() *connection {
opts := copyConnectionOpts(s.cfg.connectionOpts)
opts = append(opts,
WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
// TODO(GODRIVER-2348): Deprecate this logic
WithReadTimeout(func(time.Duration) time.Duration { return 10 * time.Second }),
WithWriteTimeout(func(time.Duration) time.Duration { return 10 * time.Second }),
// We override whatever handshaker is currently attached to the options with a basic
// one because need to make sure we don't do auth.
WithHandshaker(func(h Handshaker) Handshaker {
Expand Down Expand Up @@ -797,27 +798,14 @@ func (s *Server) setupHeartbeatConnection(ctx context.Context) error {

s.conn = conn

s.heartbeatLock.Unlock()

// Apply a deadline of connectTimeoutMS to connect. Release the resources of
// this context when complete.
ctx, cancel := context.WithTimeout(ctx, s.cfg.heartbeatTimeout)
defer cancel()
if s.cfg.connectTimeout != 0 {
var cancelFn context.CancelFunc
ctx, cancelFn = context.WithTimeout(ctx, s.cfg.connectTimeout)

// There is a possibility that the topolgy is disconnected while the heartbeat
// connection is mid-setup. In this case, we need to cancel establishing the
// connection immediately.
done := make(chan struct{})
defer close(done) // close the Go routine if seutp conmpletes w/o signal.
defer cancelFn()
}

go func() {
select {
case <-s.cancelHeartbeatCheck:
cancel()
case <-done:
// Do nothing, cancel will be handled when the function returns.
}
}()
s.heartbeatLock.Unlock()

return s.conn.connect(ctx)
}
Expand All @@ -836,7 +824,6 @@ func (s *Server) cancelCheck() {
// indefinitely.
if !s.heartbeatCanceled {
s.cancelHeartbeatCheck <- struct{}{}
s.cancelHeartbeatSetup <- struct{}{}

s.heartbeatCanceled = true
}
Expand Down Expand Up @@ -942,7 +929,7 @@ func (s *Server) check(ctx context.Context) (description.Server, error) {
// If connectTimeoutMS=0, the socket timeout should be infinite. Otherwise, it is connectTimeoutMS +
// heartbeatFrequencyMS to account for the fact that the query will block for heartbeatFrequencyMS
// server-side.
socketTimeout := s.cfg.heartbeatTimeout
socketTimeout := s.cfg.connectTimeout
if socketTimeout != 0 {
socketTimeout += s.cfg.heartbeatInterval
}
Expand All @@ -954,8 +941,7 @@ func (s *Server) check(ctx context.Context) (description.Server, error) {
default:
// The server doesn't support the awaitable protocol. Set the socket timeout to connectTimeoutMS and
// execute a regular heartbeat without any additional parameters.

s.conn.setSocketTimeout(s.cfg.heartbeatTimeout)
s.conn.setSocketTimeout(s.cfg.connectTimeout)
err = baseOperation.Execute(ctx) // HERE
}

Expand Down
18 changes: 5 additions & 13 deletions x/mongo/driver/topology/server_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type serverConfig struct {
connectionOpts []ConnectionOption
appname string
heartbeatInterval time.Duration
heartbeatTimeout time.Duration
connectTimeout time.Duration
serverMonitoringMode string
serverMonitor *event.ServerMonitor
registry *bsoncodec.Registry
Expand All @@ -44,10 +44,10 @@ type serverConfig struct {
poolMaintainInterval time.Duration
}

func newServerConfig(opts ...ServerOption) *serverConfig {
func newServerConfig(connectTimeout time.Duration, opts ...ServerOption) *serverConfig {
cfg := &serverConfig{
heartbeatInterval: 10 * time.Second,
heartbeatTimeout: 10 * time.Second,
connectTimeout: connectTimeout,
registry: defaultRegistry,
}

Expand All @@ -66,8 +66,8 @@ type ServerOption func(*serverConfig)

// ServerAPIFromServerOptions will return the server API options if they have been functionally set on the ServerOption
// slice.
func ServerAPIFromServerOptions(opts []ServerOption) *driver.ServerAPIOptions {
return newServerConfig(opts...).serverAPI
func ServerAPIFromServerOptions(connectTimeout time.Duration, opts []ServerOption) *driver.ServerAPIOptions {
return newServerConfig(connectTimeout, opts...).serverAPI
}

func withMonitoringDisabled(fn func(bool) bool) ServerOption {
Expand Down Expand Up @@ -104,14 +104,6 @@ func WithHeartbeatInterval(fn func(time.Duration) time.Duration) ServerOption {
}
}

// WithHeartbeatTimeout configures how long to wait for a heartbeat socket to
// connection.
func WithHeartbeatTimeout(fn func(time.Duration) time.Duration) ServerOption {
return func(cfg *serverConfig) {
cfg.heartbeatTimeout = fn(cfg.heartbeatTimeout)
}
}

// WithMaxConnections configures the maximum number of connections to allow for
// a given server. If max is 0, then maximum connection pool size is not limited.
func WithMaxConnections(fn func(uint64) uint64) ServerOption {
Expand Down
16 changes: 12 additions & 4 deletions x/mongo/driver/topology/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -808,11 +808,12 @@ func TestServer(t *testing.T) {
})
t.Run("createConnection overwrites WithSocketTimeout", func(t *testing.T) {
socketTimeout := 40 * time.Second
connectTimeout := 10 * time.Second

s := NewServer(
address.Address("localhost"),
primitive.NewObjectID(),
defaultConnectionTimeout,
connectTimeout,
WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption {
return append(
connOpts,
Expand All @@ -823,9 +824,16 @@ func TestServer(t *testing.T) {
)

conn := s.createConnection()
assert.Equal(t, s.cfg.heartbeatTimeout, 10*time.Second, "expected heartbeatTimeout to be: %v, got: %v", 10*time.Second, s.cfg.heartbeatTimeout)
assert.Equal(t, s.cfg.heartbeatTimeout, conn.readTimeout, "expected readTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.readTimeout)
assert.Equal(t, s.cfg.heartbeatTimeout, conn.writeTimeout, "expected writeTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.writeTimeout)
assert.Equal(t, s.cfg.connectTimeout, 10*time.Second,
"expected heartbeatTimeout to be: %v, got: %v", 10*time.Second, s.cfg.connectTimeout)

// TODO(GODRIVER-2348): The following two tests might be removed when
// feature-gating CSOT
assert.Equal(t, s.cfg.connectTimeout, conn.readTimeout,
"expected readTimeout to be: %v, got: %v", s.cfg.connectTimeout, conn.readTimeout)

assert.Equal(t, s.cfg.connectTimeout, conn.writeTimeout,
"expected writeTimeout to be: %v, got: %v", s.cfg.connectTimeout, conn.writeTimeout)
})
}

Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/topology.go
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ func (t *Topology) selectServerFromDescription(desc description.Topology,
func (t *Topology) pollSRVRecords(hosts string) {
defer t.pollingwg.Done()

serverConfig := newServerConfig(t.cfg.ServerOpts...)
serverConfig := newServerConfig(t.cfg.ConnectTimeout, t.cfg.ServerOpts...)
heartbeatInterval := serverConfig.heartbeatInterval

pollTicker := time.NewTicker(t.rescanSRVInterval)
Expand Down
7 changes: 1 addition & 6 deletions x/mongo/driver/topology/topology_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,7 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config,
}
}
connOpts = append(connOpts, WithHandshaker(handshaker))
// ConnectTimeout
if co.ConnectTimeout != nil {
serverOpts = append(serverOpts, WithHeartbeatTimeout(
func(time.Duration) time.Duration { return *co.ConnectTimeout },
))
}

// Dialer
if co.Dialer != nil {
connOpts = append(connOpts, WithDialer(
Expand Down

0 comments on commit 9af952f

Please sign in to comment.