diff --git a/core/connection/connection.go b/core/connection/connection.go index 8334e5417b..3eb9fc2efe 100644 --- a/core/connection/connection.go +++ b/core/connection/connection.go @@ -50,6 +50,14 @@ type Dialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } +// DialerFunc is a type implemented by functions that can be used as a Dialer. +type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// DialContext implements the Dialer interface. +func (df DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return df(ctx, network, address) +} + // DefaultDialer is the Dialer implementation that is used by this package. Changing this // will also change the Dialer used for this package. This should only be changed why all // of the connections being made need to use a different Dialer. Most of the time, using a diff --git a/core/connection/connection_test.go b/core/connection/connection_test.go index 910fe84c92..365e77fd60 100644 --- a/core/connection/connection_test.go +++ b/core/connection/connection_test.go @@ -7,26 +7,83 @@ package connection import ( + "context" "net" + "sync" "testing" ) // bootstrapConnection creates a listener that will listen for a single connection // on the return address. The user provided run function will be called with the accepted // connection. The user is responsible for closing the connection. -func bootstrapConnection(t *testing.T, run func(net.Conn)) net.Addr { - l, err := net.Listen("tcp", ":0") +func bootstrapConnections(t *testing.T, num int, run func(net.Conn)) net.Addr { + l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Errorf("Could not set up a listener: %v", err) t.FailNow() } go func() { - c, err := l.Accept() - if err != nil { - t.Errorf("Could not accept a connection: %v", err) + for i := 0; i < num; i++ { + c, err := l.Accept() + if err != nil { + t.Errorf("Could not accept a connection: %v", err) + } + go run(c) } _ = l.Close() - run(c) }() return l.Addr() } + +type netconn struct { + net.Conn + closed chan struct{} + d *dialer +} + +func (nc *netconn) Close() error { + nc.closed <- struct{}{} + nc.d.connclosed(nc) + return nc.Conn.Close() +} + +type dialer struct { + Dialer + opened map[*netconn]struct{} + closed map[*netconn]struct{} + sync.Mutex +} + +func newdialer(d Dialer) *dialer { + return &dialer{Dialer: d, opened: make(map[*netconn]struct{}), closed: make(map[*netconn]struct{})} +} + +func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + d.Lock() + defer d.Unlock() + c, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + nc := &netconn{Conn: c, closed: make(chan struct{}, 1), d: d} + d.opened[nc] = struct{}{} + return nc, nil +} + +func (d *dialer) connclosed(nc *netconn) { + d.Lock() + defer d.Unlock() + d.closed[nc] = struct{}{} +} + +func (d *dialer) lenopened() int { + d.Lock() + defer d.Unlock() + return len(d.opened) +} + +func (d *dialer) lenclosed() int { + d.Lock() + defer d.Unlock() + return len(d.closed) +} diff --git a/core/connection/error.go b/core/connection/error.go index bda3a756f9..0222604e65 100644 --- a/core/connection/error.go +++ b/core/connection/error.go @@ -34,3 +34,8 @@ type NetworkError struct { func (ne NetworkError) Error() string { return fmt.Sprintf("connection(%s): %s", ne.ConnectionID, ne.Wrapped.Error()) } + +// PoolError is an error returned from a Pool method. +type PoolError string + +func (pe PoolError) Error() string { return string(pe) } diff --git a/core/connection/pool.go b/core/connection/pool.go index 950c125cac..ed3fdbe139 100644 --- a/core/connection/pool.go +++ b/core/connection/pool.go @@ -8,28 +8,56 @@ package connection import ( "context" - "errors" "sync" "sync/atomic" "github.com/mongodb/mongo-go-driver/core/addr" "github.com/mongodb/mongo-go-driver/core/description" + "github.com/mongodb/mongo-go-driver/core/wiremessage" "golang.org/x/sync/semaphore" ) // ErrPoolClosed is returned from an attempt to use a closed pool. -var ErrPoolClosed = errors.New("pool is closed") +var ErrPoolClosed = PoolError("pool is closed") // ErrSizeLargerThanCapacity is returned from an attempt to create a pool with a size // larger than the capacity. -var ErrSizeLargerThanCapacity = errors.New("size is larger than capacity") +var ErrSizeLargerThanCapacity = PoolError("size is larger than capacity") + +// ErrPoolConnected is returned from an attempt to connect an already connected pool +var ErrPoolConnected = PoolError("pool is connected") + +// ErrPoolDisconnected is returned from an attempt to disconnect an already disconnected +// or disconnecting pool. +var ErrPoolDisconnected = PoolError("pool is disconnected or disconnecting") + +// ErrConnectionClosed is returned from an attempt to use an already closed connection. +var ErrConnectionClosed = Error{ConnectionID: "", message: "connection is closed"} + +// These constants represent the connection states of a pool. +const ( + disconnected int32 = iota + disconnecting + connected +) // Pool is used to pool Connections to a server. type Pool interface { // Get must return a nil *description.Server if the returned connection is // not a newly dialed connection. Get(context.Context) (Connection, *description.Server, error) - Close() error + // Connect handles the initialization of a Pool and allow Connections to be + // retrieved and pooled. Implementations must return an error if Connect is + // called more than once before calling Disconnect. + Connect(context.Context) error + // Disconnect closest connections managed by this Pool. Implementations must + // either wait until all of the connections in use have been returned and + // closed or the context expires before returning. If the context expires + // via cancellation, deadline, timeout, or some other manner, implementations + // must close the in use connections. If this method returns with no errors, + // all connections managed by this pool must be closed. Calling Disconnect + // multiple times after a single Connect call must result in an error. + Disconnect(context.Context) error Drain() error } @@ -39,6 +67,10 @@ type pool struct { conns chan *pooledConnection generation uint64 sem *semaphore.Weighted + connected int32 + nextid uint64 + capacity uint64 + inflight map[uint64]*pooledConnection sync.Mutex } @@ -55,6 +87,9 @@ func NewPool(address addr.Addr, size, capacity uint64, opts ...Option) (Pool, er conns: make(chan *pooledConnection, size), generation: 0, sem: semaphore.NewWeighted(int64(capacity)), + connected: disconnected, + capacity: capacity, + inflight: make(map[uint64]*pooledConnection), opts: opts, } return p, nil @@ -65,52 +100,70 @@ func (p *pool) Drain() error { return nil } -func (p *pool) Close() error { - p.Lock() - conns := p.conns - p.conns = nil - p.Unlock() - - if conns == nil { - return nil +func (p *pool) Connect(ctx context.Context) error { + if !atomic.CompareAndSwapInt32(&p.connected, disconnected, connected) { + return ErrPoolConnected } + atomic.AddUint64(&p.generation, 1) + return nil +} - close(conns) - - var err error - - for pc := range conns { - err = pc.Close() +func (p *pool) Disconnect(ctx context.Context) error { + if !atomic.CompareAndSwapInt32(&p.connected, connected, disconnecting) { + return ErrPoolDisconnected } - return err + // We first clear out the idle connections, then we attempt to acquire the entire capacity + // semaphore. If the context is either cancelled, the deadline expires, or there is a timeout + // the semaphore acquire method will return an error. If that happens, we will aggressively + // close the remaining open connections. If we were able to successfully acquire the semaphore, + // then all of the in flight connections have been closed and we release the semaphore. +loop: + for { + select { + case pc := <-p.conns: + // This error would be overwritten by the semaphore + _ = pc.Close() + default: + break loop + } + } + err := p.sem.Acquire(ctx, int64(p.capacity)) + if err != nil { + p.Lock() + // We copy the remaining connections to close into a slice, then + // iterate the slice to do the closing. This allows us to use a single + // function to actually clean up and close connections at the expense of + // a double iteration in the worst case. + toClose := make([]*pooledConnection, 0, len(p.inflight)) + for _, pc := range p.inflight { + toClose = append(toClose, pc) + } + p.Unlock() + for _, pc := range toClose { + _ = pc.Close() + } + } else { + p.sem.Release(int64(p.capacity)) + } + atomic.StoreInt32(&p.connected, disconnected) + return nil } func (p *pool) Get(ctx context.Context) (Connection, *description.Server, error) { - p.Lock() - conns := p.conns - p.Unlock() - - if conns == nil { + if atomic.LoadInt32(&p.connected) != connected { return nil, nil, ErrPoolClosed } - return p.get(ctx, conns) -} - -func (p *pool) get(ctx context.Context, conns chan *pooledConnection) (Connection, *description.Server, error) { g := atomic.LoadUint64(&p.generation) select { - case c := <-conns: - if c == nil { - return nil, nil, ErrPoolClosed - } + case c := <-p.conns: if c.Expired() { - go c.Connection.Close() - return p.get(ctx, conns) + go p.closeConnection(c) + return p.Get(ctx) } - return c, nil, nil + return &acquired{Connection: c}, nil, nil case <-ctx.Done(): return nil, nil, ctx.Err() default: @@ -123,33 +176,45 @@ func (p *pool) get(ctx context.Context, conns chan *pooledConnection) (Connectio return nil, nil, err } - return &pooledConnection{ + pc := &pooledConnection{ Connection: c, p: p, generation: g, - }, desc, nil + id: atomic.AddUint64(&p.nextid, 1), + } + p.Lock() + if atomic.LoadInt32(&p.connected) != connected { + p.Unlock() + p.closeConnection(pc) + return nil, nil, ErrPoolClosed + } + defer p.Unlock() + p.inflight[pc.id] = pc + return &acquired{Connection: pc}, desc, nil } } -func (p *pool) returnConnection(pc *pooledConnection) error { - if pc.Expired() { - return pc.Connection.Close() +func (p *pool) closeConnection(pc *pooledConnection) error { + if !atomic.CompareAndSwapInt32(&pc.closed, 0, 1) { + return nil } - + pc.p.sem.Release(1) p.Lock() - defer p.Unlock() + delete(p.inflight, pc.id) + p.Unlock() + return pc.Connection.Close() +} - if p.conns == nil { - pc.p.sem.Release(1) - return pc.Connection.Close() +func (p *pool) returnConnection(pc *pooledConnection) error { + if atomic.LoadInt32(&p.connected) != connected || pc.Expired() { + return p.closeConnection(pc) } select { case p.conns <- pc: return nil default: - pc.p.sem.Release(1) - return pc.Connection.Close() + return p.closeConnection(pc) } } @@ -161,6 +226,8 @@ type pooledConnection struct { Connection p *pool generation uint64 + id uint64 + closed int32 } func (pc *pooledConnection) Close() error { @@ -170,3 +237,65 @@ func (pc *pooledConnection) Close() error { func (pc *pooledConnection) Expired() bool { return pc.Connection.Expired() || pc.p.isExpired(pc.generation) } + +type acquired struct { + Connection + + sync.Mutex +} + +func (a *acquired) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error { + a.Lock() + defer a.Unlock() + if a.Connection == nil { + return ErrConnectionClosed + } + return a.Connection.WriteWireMessage(ctx, wm) +} + +func (a *acquired) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) { + a.Lock() + defer a.Unlock() + if a.Connection == nil { + return nil, ErrConnectionClosed + } + return a.Connection.ReadWireMessage(ctx) +} + +func (a *acquired) Close() error { + a.Lock() + defer a.Unlock() + if a.Connection == nil { + return nil + } + err := a.Connection.Close() + a.Connection = nil + return err +} + +func (a *acquired) Expired() bool { + a.Lock() + defer a.Unlock() + if a.Connection == nil { + return true + } + return a.Connection.Expired() +} + +func (a *acquired) Alive() bool { + a.Lock() + defer a.Unlock() + if a.Connection == nil { + return false + } + return a.Connection.Alive() +} + +func (a *acquired) ID() string { + a.Lock() + defer a.Unlock() + if a.Connection == nil { + return "" + } + return a.Connection.ID() +} diff --git a/core/connection/pool_test.go b/core/connection/pool_test.go new file mode 100644 index 0000000000..388235fa74 --- /dev/null +++ b/core/connection/pool_test.go @@ -0,0 +1,542 @@ +package connection + +import ( + "context" + "errors" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/mongodb/mongo-go-driver/core/addr" +) + +func TestPool(t *testing.T) { + noerr := func(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Errorf("Unepexted error: %v", err) + t.FailNow() + } + } + t.Run("NewPool", func(t *testing.T) { + t.Run("should be connected", func(t *testing.T) { + P, err := NewPool(addr.Addr(""), 1, 2) + p := P.(*pool) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + if p.connected != connected { + t.Errorf("Expected new pool to be connected. got %v; want %v", p.connected, connected) + } + }) + t.Run("size cannot be larger than capcity", func(t *testing.T) { + _, err := NewPool(addr.Addr(""), 5, 1) + if err != ErrSizeLargerThanCapacity { + t.Errorf("Should recieve error when size is larger than capacity. got %v; want %v", err, ErrSizeLargerThanCapacity) + } + }) + }) + t.Run("Disconnect", func(t *testing.T) { + t.Run("cannot disconnect twice", func(t *testing.T) { + p, err := NewPool(addr.Addr(""), 1, 2) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + err = p.Disconnect(context.Background()) + noerr(t, err) + err = p.Disconnect(context.Background()) + if err != ErrPoolDisconnected { + t.Errorf("Should not be able to call disconnect twice. got %v; want %v", err, ErrPoolDisconnected) + } + }) + t.Run("closes idle connections", func(t *testing.T) { + cleanup := make(chan struct{}) + address := bootstrapConnections(t, 3, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool(addr.Addr(address.String()), 3, 3, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + conns := [3]Connection{} + for idx := range [3]struct{}{} { + conns[idx], _, err = p.Get(context.Background()) + noerr(t, err) + } + for idx := range [3]struct{}{} { + err = conns[idx].Close() + noerr(t, err) + } + if d.lenopened() != 3 { + t.Errorf("Should have opened 3 connections, but didn't. got %d; want %d", d.lenopened(), 3) + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err = p.Disconnect(ctx) + noerr(t, err) + if d.lenclosed() != 3 { + t.Errorf("Should have closed 3 connections, but didn't. got %d; want %d", d.lenclosed(), 3) + } + close(cleanup) + ok := p.(*pool).sem.TryAcquire(int64(p.(*pool).capacity)) + if !ok { + t.Errorf("clean shutdown should acquire and release semaphore, but semaphore still held") + } else { + p.(*pool).sem.Release(int64(p.(*pool).capacity)) + } + }) + t.Run("closes inflight connections when context expires", func(t *testing.T) { + cleanup := make(chan struct{}) + address := bootstrapConnections(t, 3, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool(addr.Addr(address.String()), 3, 3, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + conns := [3]Connection{} + for idx := range [3]struct{}{} { + conns[idx], _, err = p.Get(context.Background()) + noerr(t, err) + } + for idx := range [2]struct{}{} { + err = conns[idx].Close() + noerr(t, err) + } + if d.lenopened() != 3 { + t.Errorf("Should have opened 3 connections, but didn't. got %d; want %d", d.lenopened(), 3) + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + cancel() + err = p.Disconnect(ctx) + noerr(t, err) + if d.lenclosed() != 3 { + t.Errorf("Should have closed 3 connections, but didn't. got %d; want %d", d.lenclosed(), 3) + } + close(cleanup) + ok := p.(*pool).sem.TryAcquire(int64(p.(*pool).capacity)) + if !ok { + t.Errorf("clean shutdown should acquire and release semaphore, but semaphore still held") + } else { + p.(*pool).sem.Release(int64(p.(*pool).capacity)) + } + }) + t.Run("properly sets the connection state on return", func(t *testing.T) { + cleanup := make(chan struct{}) + address := bootstrapConnections(t, 3, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool(addr.Addr(address.String()), 3, 3, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + c, _, err := p.Get(context.Background()) + noerr(t, err) + err = c.Close() + noerr(t, err) + if d.lenopened() != 1 { + t.Errorf("Should have opened 3 connections, but didn't. got %d; want %d", d.lenopened(), 1) + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err = p.Disconnect(ctx) + noerr(t, err) + if d.lenclosed() != 1 { + t.Errorf("Should have closed 3 connections, but didn't. got %d; want %d", d.lenclosed(), 1) + } + close(cleanup) + state := atomic.LoadInt32(&(p.(*pool)).connected) + if state != disconnected { + t.Errorf("Should have set the connection state on return. got %d; want %d", state, disconnected) + } + }) + }) + t.Run("Connect", func(t *testing.T) { + t.Run("can reconnect a disconnected pool", func(t *testing.T) { + cleanup := make(chan struct{}) + address := bootstrapConnections(t, 3, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool(addr.Addr(address.String()), 3, 3, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + c, _, err := p.Get(context.Background()) + noerr(t, err) + gen := c.(*acquired).Connection.(*pooledConnection).generation + if gen != 1 { + t.Errorf("Connection should have a newer generation. got %d; want %d", gen, 1) + } + err = c.Close() + noerr(t, err) + if d.lenopened() != 1 { + t.Errorf("Should have opened 3 connections, but didn't. got %d; want %d", d.lenopened(), 1) + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err = p.Disconnect(ctx) + noerr(t, err) + if d.lenclosed() != 1 { + t.Errorf("Should have closed 3 connections, but didn't. got %d; want %d", d.lenclosed(), 1) + } + close(cleanup) + state := atomic.LoadInt32(&(p.(*pool)).connected) + if state != disconnected { + t.Errorf("Should have set the connection state on return. got %d; want %d", state, disconnected) + } + err = p.Connect(context.Background()) + noerr(t, err) + + c, _, err = p.Get(context.Background()) + noerr(t, err) + gen = atomic.LoadUint64(&(c.(*acquired).Connection.(*pooledConnection)).generation) + if gen != 2 { + t.Errorf("Connection should have a newer generation. got %d; want %d", gen, 2) + } + err = c.Close() + noerr(t, err) + if d.lenopened() != 2 { + t.Errorf("Should have opened 3 connections, but didn't. got %d; want %d", d.lenopened(), 2) + } + }) + t.Run("cannot connect multiple times without disconnect", func(t *testing.T) { + p, err := NewPool(addr.Addr(""), 3, 3) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + err = p.Connect(context.Background()) + if err != ErrPoolConnected { + t.Errorf("Shouldn't be able to connect to already connected pool. got %v; want %v", err, ErrPoolConnected) + } + err = p.Connect(context.Background()) + if err != ErrPoolConnected { + t.Errorf("Shouldn't be able to connect to already connected pool. got %v; want %v", err, ErrPoolConnected) + } + err = p.Disconnect(context.Background()) + noerr(t, err) + err = p.Connect(context.Background()) + if err != nil { + t.Errorf("Should be able to connect to pool after disconnect. got %v; want ", err) + } + }) + t.Run("can disconnect and reconnect multiple times", func(t *testing.T) { + p, err := NewPool(addr.Addr(""), 3, 3) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + err = p.Disconnect(context.Background()) + noerr(t, err) + err = p.Connect(context.Background()) + if err != nil { + t.Errorf("Should be able to connect to disconnected pool. got %v; want ", err) + } + err = p.Disconnect(context.Background()) + noerr(t, err) + err = p.Connect(context.Background()) + if err != nil { + t.Errorf("Should be able to connect to disconnected pool. got %v; want ", err) + } + err = p.Disconnect(context.Background()) + noerr(t, err) + err = p.Connect(context.Background()) + if err != nil { + t.Errorf("Should be able to connect to pool after disconnect. got %v; want ", err) + } + }) + }) + t.Run("Get", func(t *testing.T) { + t.Run("return context error when already cancelled", func(t *testing.T) { + cleanup := make(chan struct{}) + address := bootstrapConnections(t, 3, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool(addr.Addr(address.String()), 3, 3, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + cancel() + _, _, err = p.Get(ctx) + if err != context.Canceled { + t.Errorf("Should return context error when already cancelled. got %v; want %v", err, context.Canceled) + } + close(cleanup) + }) + t.Run("return context error when attempting to acquire semaphore", func(t *testing.T) { + cleanup := make(chan struct{}) + address := bootstrapConnections(t, 3, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool(addr.Addr(address.String()), 3, 3, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + ok := p.(*pool).sem.TryAcquire(3) + if !ok { + t.Errorf("Could not acquire the entire semaphore.") + } + _, _, err = p.Get(ctx) + if err != context.DeadlineExceeded { + t.Errorf("Should return context error when already canclled. got %v; want %v", err, context.DeadlineExceeded) + } + close(cleanup) + }) + t.Run("return error when attempting to create new connection", func(t *testing.T) { + want := errors.New("create new connection error") + var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) { return nil, want } + p, err := NewPool(addr.Addr(""), 1, 2, WithDialer(func(Dialer) Dialer { return dialer })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + _, _, got := p.Get(context.Background()) + if got != want { + t.Errorf("Should return error from calling New. got %v; want %v", got, want) + } + }) + t.Run("adds connection to inflight pool", func(t *testing.T) { + cleanup := make(chan struct{}) + address := bootstrapConnections(t, 1, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool(addr.Addr(address.String()), 3, 3, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + c, _, err := p.Get(ctx) + noerr(t, err) + inflight := len(p.(*pool).inflight) + if inflight != 1 { + t.Errorf("Incorrect number of inlight connections. got %d; want %d", inflight, 1) + } + err = c.Close() + noerr(t, err) + close(cleanup) + }) + t.Run("closes expired connections", func(t *testing.T) { + cleanup := make(chan struct{}) + address := bootstrapConnections(t, 2, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool( + addr.Addr(address.String()), 3, 3, + WithDialer(func(Dialer) Dialer { return d }), + WithIdleTimeout(func(time.Duration) time.Duration { return 10 * time.Millisecond }), + ) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + c, _, err := p.Get(ctx) + noerr(t, err) + if d.lenopened() != 1 { + t.Errorf("Should have opened 1 connection, but didn't. got %d; want %d", d.lenopened(), 1) + } + err = c.Close() + noerr(t, err) + time.Sleep(15 * time.Millisecond) + if d.lenclosed() != 0 { + t.Errorf("Should have closed 0 connections, but didn't. got %d; want %d", d.lenopened(), 0) + } + c, _, err = p.Get(ctx) + noerr(t, err) + if d.lenopened() != 2 { + t.Errorf("Should have opened 2 connections, but didn't. got %d; want %d", d.lenopened(), 2) + } + time.Sleep(10 * time.Millisecond) + if d.lenclosed() != 1 { + t.Errorf("Should have closed 1 connection, but didn't. got %d; want %d", d.lenopened(), 1) + } + close(cleanup) + }) + t.Run("recycles connections", func(t *testing.T) { + cleanup := make(chan struct{}) + address := bootstrapConnections(t, 3, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool(addr.Addr(address.String()), 3, 3, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + for range [3]struct{}{} { + c, _, err := p.Get(context.Background()) + noerr(t, err) + err = c.Close() + noerr(t, err) + if d.lenopened() != 1 { + t.Errorf("Should have opened 1 connection, but didn't. got %d; want %d", d.lenopened(), 1) + } + } + close(cleanup) + }) + t.Run("cannot get from disconnected pool", func(t *testing.T) { + cleanup := make(chan struct{}) + address := bootstrapConnections(t, 3, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool(addr.Addr(address.String()), 3, 3, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Microsecond) + defer cancel() + err = p.Disconnect(ctx) + noerr(t, err) + _, _, err = p.Get(context.Background()) + if err != ErrPoolClosed { + t.Errorf("Should get error from disconnected pool. got %v; want %v", err, ErrPoolClosed) + } + close(cleanup) + }) + t.Run("pool closes excess connections when returned", func(t *testing.T) { + cleanup := make(chan struct{}) + address := bootstrapConnections(t, 3, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool(addr.Addr(address.String()), 1, 3, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + conns := [3]Connection{} + for idx := range [3]struct{}{} { + conns[idx], _, err = p.Get(context.Background()) + noerr(t, err) + } + for idx := range [3]struct{}{} { + err = conns[idx].Close() + noerr(t, err) + } + if d.lenopened() != 3 { + t.Errorf("Should have opened 3 connections, but didn't. got %d; want %d", d.lenopened(), 3) + } + if d.lenclosed() != 2 { + t.Errorf("Should have closed 2 connections, but didn't. got %d; want %d", d.lenopened(), 2) + } + close(cleanup) + }) + t.Run("cannot get more than capacity connections", func(t *testing.T) { + cleanup := make(chan struct{}) + address := bootstrapConnections(t, 3, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool(addr.Addr(address.String()), 1, 2, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + conns := [2]Connection{} + for idx := range [2]struct{}{} { + conns[idx], _, err = p.Get(context.Background()) + noerr(t, err) + } + if d.lenopened() != 2 { + t.Errorf("Should have opened 2 connections, but didn't. got %d; want %d", d.lenopened(), 2) + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + _, _, err = p.Get(ctx) + if err != context.DeadlineExceeded { + t.Errorf("Should not be able to get more than capacity connections. got %v; want %v", err, context.DeadlineExceeded) + } + err = conns[0].Close() + noerr(t, err) + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + c, _, err := p.Get(ctx) + noerr(t, err) + err = c.Close() + noerr(t, err) + + err = p.Drain() + noerr(t, err) + + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + c, _, err = p.Get(ctx) + noerr(t, err) + if d.lenopened() != 3 { + t.Errorf("Should have opened 3 connections, but didn't. got %d; want %d", d.lenopened(), 3) + } + close(cleanup) + }) + }) + t.Run("Connection", func(t *testing.T) { + t.Run("Connection Close Does Not Error After Pool Is Disconnected", func(t *testing.T) { + cleanup := make(chan struct{}) + defer close(cleanup) + address := bootstrapConnections(t, 3, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + p, err := NewPool(addr.Addr(address.String()), 2, 4, WithDialer(func(Dialer) Dialer { return d })) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + c1, _, err := p.Get(context.Background()) + noerr(t, err) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = p.Disconnect(ctx) + noerr(t, err) + err = c1.Close() + if err != nil { + t.Errorf("Conneciton Close should not error after Pool is Disconnected, but got error: %v", err) + } + }) + t.Run("Does not return to pool twice", func(t *testing.T) { + cleanup := make(chan struct{}) + defer close(cleanup) + address := bootstrapConnections(t, 1, func(nc net.Conn) { + <-cleanup + nc.Close() + }) + d := newdialer(&net.Dialer{}) + P, err := NewPool(addr.Addr(address.String()), 2, 4, WithDialer(func(Dialer) Dialer { return d })) + p := P.(*pool) + noerr(t, err) + err = p.Connect(context.Background()) + noerr(t, err) + c1, _, err := p.Get(context.Background()) + noerr(t, err) + if len(p.conns) != 0 { + t.Errorf("Should be no connections in pool. got %d; want %d", len(p.conns), 0) + } + err = c1.Close() + noerr(t, err) + err = c1.Close() + noerr(t, err) + if len(p.conns) != 1 { + t.Errorf("Should not return connection to pool twice. got %d; want %d", len(p.conns), 1) + } + }) + }) +} diff --git a/core/examples/cluster_monitoring/main.go b/core/examples/cluster_monitoring/main.go index 6d129c4428..2c7df91eaa 100644 --- a/core/examples/cluster_monitoring/main.go +++ b/core/examples/cluster_monitoring/main.go @@ -7,6 +7,7 @@ package main import ( + "context" "log" "github.com/kr/pretty" @@ -18,7 +19,7 @@ func main() { if err != nil { log.Fatalf("could not create topology: %v", err) } - topo.Init() + topo.Connect(context.Background()) sub, err := topo.Subscribe() if err != nil { diff --git a/core/examples/count/main.go b/core/examples/count/main.go index a0ad54b248..4d17c62a19 100644 --- a/core/examples/count/main.go +++ b/core/examples/count/main.go @@ -42,7 +42,7 @@ func main() { if err != nil { log.Fatal(err) } - t.Init() + t.Connect(context.Background()) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() diff --git a/core/examples/server_monitoring/main.go b/core/examples/server_monitoring/main.go index 96580b43fc..e21fc73be1 100644 --- a/core/examples/server_monitoring/main.go +++ b/core/examples/server_monitoring/main.go @@ -7,6 +7,7 @@ package main import ( + "context" "log" "time" @@ -17,7 +18,8 @@ import ( ) func main() { - s, err := topology.NewServer( + s, err := topology.ConnectServer( + context.Background(), addr.Addr("localhost:27017"), topology.WithHeartbeatInterval(func(time.Duration) time.Duration { return 2 * time.Second }), topology.WithConnectionOptions( diff --git a/core/examples/workload/main.go b/core/examples/workload/main.go index bfec67741e..695f206e75 100644 --- a/core/examples/workload/main.go +++ b/core/examples/workload/main.go @@ -67,7 +67,7 @@ func main() { <-done log.Println("interupt received: shutting down") - _ = c.Close() + _ = c.Disconnect(ctx) log.Println("finished") } diff --git a/core/integration/aggregate_test.go b/core/integration/aggregate_test.go index 8ce28f2c0e..6b19b77636 100644 --- a/core/integration/aggregate_test.go +++ b/core/integration/aggregate_test.go @@ -119,7 +119,7 @@ func TestCommandAggregate(t *testing.T) { t.Run("MaxTimeMS", func(t *testing.T) { t.Skip("max time is flaky on the server") - server, err := topology.NewServer(addr.Addr(*host)) + server, err := topology.ConnectServer(context.Background(), addr.Addr(*host)) noerr(t, err) conn, err := server.Connection(context.Background()) noerr(t, err) diff --git a/core/integration/command_test.go b/core/integration/command_test.go index ae617998fe..e5e79ac1bb 100644 --- a/core/integration/command_test.go +++ b/core/integration/command_test.go @@ -27,7 +27,7 @@ func TestCommand(t *testing.T) { } t.Parallel() - server, err := topology.NewServer(addr.Addr(*host), serveropts(t)...) + server, err := topology.ConnectServer(context.Background(), addr.Addr(*host), serveropts(t)...) noerr(t, err) ctx := context.Background() diff --git a/core/integration/main_test.go b/core/integration/main_test.go index f092fdbaa5..d73c5aa02e 100644 --- a/core/integration/main_test.go +++ b/core/integration/main_test.go @@ -7,12 +7,16 @@ package integration import ( + "context" "flag" "fmt" + "net" "os" "strings" + "sync" "testing" + "github.com/mongodb/mongo-go-driver/core/connection" "github.com/mongodb/mongo-go-driver/core/connstring" ) @@ -44,6 +48,14 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } +func noerr(t *testing.T, err error) { + if err != nil { + t.Helper() + t.Errorf("Unepexted error: %v", err) + t.FailNow() + } +} + // addTLSConfigToURI checks for the environmental variable indicating that the tests are being run // on an SSL-enabled server, and if so, returns a new URI with the necessary configuration. func addTLSConfigToURI(uri string) string { @@ -64,3 +76,56 @@ func addTLSConfigToURI(uri string) string { return uri + "ssl=true&sslCertificateAuthorityFile=" + caFile } + +type netconn struct { + net.Conn + closed chan struct{} + d *dialer +} + +func (nc *netconn) Close() error { + nc.closed <- struct{}{} + nc.d.connclosed(nc) + return nc.Conn.Close() +} + +type dialer struct { + connection.Dialer + opened map[*netconn]struct{} + closed map[*netconn]struct{} + sync.Mutex +} + +func newdialer(d connection.Dialer) *dialer { + return &dialer{Dialer: d, opened: make(map[*netconn]struct{}), closed: make(map[*netconn]struct{})} +} + +func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + d.Lock() + defer d.Unlock() + c, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + nc := &netconn{Conn: c, closed: make(chan struct{}, 1), d: d} + d.opened[nc] = struct{}{} + return nc, nil +} + +func (d *dialer) connclosed(nc *netconn) { + d.Lock() + defer d.Unlock() + d.closed[nc] = struct{}{} +} + +func (d *dialer) lenopened() int { + d.Lock() + defer d.Unlock() + return len(d.opened) +} + +func (d *dialer) lenclosed() int { + d.Lock() + defer d.Unlock() + return len(d.closed) +} diff --git a/core/integration/pool_test.go b/core/integration/pool_test.go index 204675dd1b..9187b837e5 100644 --- a/core/integration/pool_test.go +++ b/core/integration/pool_test.go @@ -21,6 +21,7 @@ import ( func TestPool(t *testing.T) { noerr := func(t *testing.T, err error) { if err != nil { + t.Helper() t.Errorf("Unepexted error: %v", err) t.FailNow() } @@ -55,6 +56,8 @@ func TestPool(t *testing.T) { if err != nil { t.Errorf("Unexpected error while creating pool: %v", err) } + err = p.Connect(context.TODO()) + noerr(t, err) c1, _, err := p.Get(context.Background()) noerr(t, err) @@ -75,6 +78,8 @@ func TestPool(t *testing.T) { if err != nil { t.Errorf("Unexpected error while creating pool: %v", err) } + err = p.Connect(context.TODO()) + noerr(t, err) c1, _, err := p.Get(context.Background()) noerr(t, err) first := c1.ID() @@ -93,6 +98,8 @@ func TestPool(t *testing.T) { if err != nil { t.Errorf("Unexpected error while creating pool: %v", err) } + err = p.Connect(context.TODO()) + noerr(t, err) ctx, cancel := context.WithCancel(context.Background()) cancel() _, _, err = p.Get(ctx) @@ -105,6 +112,8 @@ func TestPool(t *testing.T) { if err != nil { t.Errorf("Unexpected error while creating pool: %v", err) } + err = p.Connect(context.TODO()) + noerr(t, err) _, _, err = p.Get(context.Background()) if !strings.Contains(err.Error(), "dial tcp") { t.Errorf("Expected context called error, but got: %v", err) @@ -115,7 +124,11 @@ func TestPool(t *testing.T) { if err != nil { t.Errorf("Unexpected error while creating pool: %v", err) } - err = p.Close() + err = p.Connect(context.TODO()) + noerr(t, err) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + err = p.Disconnect(ctx) noerr(t, err) _, _, err = p.Get(context.Background()) if err != connection.ErrPoolClosed { @@ -127,9 +140,13 @@ func TestPool(t *testing.T) { if err != nil { t.Errorf("Unexpected error while creating pool: %v", err) } + err = p.Connect(context.TODO()) + noerr(t, err) c1, _, err := p.Get(context.Background()) noerr(t, err) - err = p.Close() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + err = p.Disconnect(ctx) noerr(t, err) err = c1.Close() if err != nil { @@ -153,6 +170,8 @@ func TestPool(t *testing.T) { if err != nil { t.Errorf("Unexpected error while creating pool: %v", err) } + err = p.Connect(context.TODO()) + noerr(t, err) c1, _, err := p.Get(context.Background()) noerr(t, err) if c1.Expired() != false { @@ -168,18 +187,6 @@ func TestPool(t *testing.T) { // Implement this once there is a more testable Dialer. t.Skip() }) - t.Run("Close Is Idempotent", func(t *testing.T) { - p, err := connection.NewPool(addr.Addr(*host), 2, 4, opts...) - if err != nil { - t.Errorf("Unexpected error while creating pool: %v", err) - } - err = p.Close() - noerr(t, err) - err = p.Close() - if err != nil { - t.Errorf("Should be able to call Close twice, but got error: %v", err) - } - }) t.Run("Pool Close Closes All Connections In A Pool", func(t *testing.T) { // Implement this once there is a more testable Dialer. t.Skip() diff --git a/core/integration/server_test.go b/core/integration/server_test.go index 5bffb5c944..cfc3dbe6c9 100644 --- a/core/integration/server_test.go +++ b/core/integration/server_test.go @@ -8,6 +8,7 @@ package integration import ( "context" + "net" "strings" "testing" "time" @@ -22,25 +23,26 @@ import ( func TestTopologyServer(t *testing.T) { noerr := func(t *testing.T, err error) { if err != nil { + t.Helper() t.Errorf("Unepexted error: %v", err) t.FailNow() } } t.Run("After close, should not return new connection", func(t *testing.T) { - s, err := topology.NewServer(addr.Addr(*host), serveropts(t)...) + s, err := topology.ConnectServer(context.Background(), addr.Addr(*host), serveropts(t)...) noerr(t, err) - err = s.Close() + err = s.Disconnect(context.TODO()) noerr(t, err) _, err = s.Connection(context.Background()) - if err != connection.ErrPoolClosed { + if err != topology.ErrServerClosed { t.Errorf("Expected error from getting a connection from closed server, but got %v", err) } }) t.Run("Shouldn't be able to get more than max connections", func(t *testing.T) { t.Parallel() - s, err := topology.NewServer(addr.Addr(*host), + s, err := topology.ConnectServer(context.Background(), addr.Addr(*host), serveropts( t, topology.WithMaxConnections(func(uint16) uint16 { return 2 }), @@ -80,7 +82,7 @@ func TestTopologyServer(t *testing.T) { t.Run("Write network timeout", func(t *testing.T) {}) }) t.Run("Close should close all subscription channels", func(t *testing.T) { - s, err := topology.NewServer(addr.Addr(*host), serveropts(t)...) + s, err := topology.ConnectServer(context.Background(), addr.Addr(*host), serveropts(t)...) noerr(t, err) var done1, done2 = make(chan struct{}), make(chan struct{}) @@ -105,7 +107,7 @@ func TestTopologyServer(t *testing.T) { close(done2) }() - err = s.Close() + err = s.Disconnect(context.TODO()) noerr(t, err) select { @@ -121,12 +123,12 @@ func TestTopologyServer(t *testing.T) { } }) t.Run("Subscribe after Close should return an error", func(t *testing.T) { - s, err := topology.NewServer(addr.Addr(*host), serveropts(t)...) + s, err := topology.ConnectServer(context.Background(), addr.Addr(*host), serveropts(t)...) noerr(t, err) sub, err := s.Subscribe() noerr(t, err) - err = s.Close() + err = s.Disconnect(context.TODO()) noerr(t, err) for range sub.C { @@ -137,6 +139,116 @@ func TestTopologyServer(t *testing.T) { t.Errorf("Did not receive expected error. got %v; want %v", err, topology.ErrSubscribeAfterClosed) } }) + t.Run("Disconnect", func(t *testing.T) { + t.Run("cannot disconnect before connecting", func(t *testing.T) { + s, err := topology.NewServer(addr.Addr(*host), serveropts(t)...) + noerr(t, err) + + got := s.Disconnect(context.TODO()) + if got != topology.ErrServerClosed { + t.Errorf("Expected a server disconnected error. got %v; want %v", got, topology.ErrServerClosed) + } + }) + t.Run("cannot disconnect twice", func(t *testing.T) { + s, err := topology.NewServer(addr.Addr(*host), serveropts(t)...) + noerr(t, err) + err = s.Connect(context.TODO()) + noerr(t, err) + + got := s.Disconnect(context.TODO()) + if got != nil { + t.Errorf("Expected no server disconnected error. got %v; want ", got) + } + got = s.Disconnect(context.TODO()) + if got != topology.ErrServerClosed { + t.Errorf("Expected a server disconnected error. got %v; want %v", got, topology.ErrServerClosed) + } + }) + t.Run("all open sockets should be closed after disconnect", func(t *testing.T) { + d := newdialer(&net.Dialer{}) + s, err := topology.NewServer( + addr.Addr(*host), + serveropts( + t, + topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option { + return append(opts, connection.WithDialer(func(connection.Dialer) connection.Dialer { return d })) + }), + )..., + ) + noerr(t, err) + err = s.Connect(context.TODO()) + noerr(t, err) + + conns := [3]connection.Connection{} + for idx := range [3]struct{}{} { + conns[idx], err = s.Connection(context.TODO()) + noerr(t, err) + } + for idx := range [2]struct{}{} { + err = conns[idx].Close() + noerr(t, err) + } + if d.lenopened() < 3 { + t.Errorf("Should have opened at least 3 connections, but didn't. got %d; want >%d", d.lenopened(), 3) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = s.Disconnect(ctx) + noerr(t, err) + if d.lenclosed() < 3 { + t.Errorf("Should have closed at least 3 connections, but didn't. got %d; want >%d", d.lenclosed(), 3) + } + }) + }) + t.Run("Connect", func(t *testing.T) { + t.Run("can reconnect a disconnected server", func(t *testing.T) { + s, err := topology.NewServer(addr.Addr(*host), serveropts(t)...) + noerr(t, err) + err = s.Connect(context.TODO()) + noerr(t, err) + + err = s.Disconnect(context.TODO()) + noerr(t, err) + err = s.Connect(context.TODO()) + noerr(t, err) + }) + t.Run("cannot connect multiple times without disconnect", func(t *testing.T) { + s, err := topology.NewServer(addr.Addr(*host), serveropts(t)...) + noerr(t, err) + err = s.Connect(context.TODO()) + noerr(t, err) + + err = s.Disconnect(context.TODO()) + noerr(t, err) + err = s.Connect(context.TODO()) + noerr(t, err) + err = s.Connect(context.TODO()) + if err != topology.ErrServerConnected { + t.Errorf("Did not receive expected error. got %v; want %v", err, topology.ErrServerConnected) + } + }) + t.Run("can disconnect and reconnect multiple times", func(t *testing.T) { + s, err := topology.NewServer(addr.Addr(*host), serveropts(t)...) + noerr(t, err) + err = s.Connect(context.TODO()) + noerr(t, err) + + err = s.Disconnect(context.TODO()) + noerr(t, err) + err = s.Connect(context.TODO()) + noerr(t, err) + + err = s.Disconnect(context.TODO()) + noerr(t, err) + err = s.Connect(context.TODO()) + noerr(t, err) + + err = s.Disconnect(context.TODO()) + noerr(t, err) + err = s.Connect(context.TODO()) + noerr(t, err) + }) + }) } func serveropts(t *testing.T, opts ...topology.ServerOption) []topology.ServerOption { diff --git a/core/integration/topology_test.go b/core/integration/topology_test.go new file mode 100644 index 0000000000..27563215f4 --- /dev/null +++ b/core/integration/topology_test.go @@ -0,0 +1,142 @@ +package integration + +import ( + "context" + "net" + "testing" + + "github.com/mongodb/mongo-go-driver/core/connection" + "github.com/mongodb/mongo-go-driver/core/connstring" + "github.com/mongodb/mongo-go-driver/core/description" + "github.com/mongodb/mongo-go-driver/core/topology" +) + +func TestTopologyTopology(t *testing.T) { + t.Run("Disconnect", func(t *testing.T) { + t.Run("cannot disconnect before connecting", func(t *testing.T) { + topo, err := topology.New(topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return connectionString })) + noerr(t, err) + err = topo.Disconnect(context.TODO()) + if err != topology.ErrTopologyClosed { + t.Errorf("Expected a topology disconnected error. got %v; want %v", err, topology.ErrTopologyClosed) + } + }) + t.Run("cannot disconnect twice", func(t *testing.T) { + topo, err := topology.New(topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return connectionString })) + noerr(t, err) + err = topo.Connect(context.TODO()) + noerr(t, err) + err = topo.Disconnect(context.TODO()) + noerr(t, err) + err = topo.Disconnect(context.TODO()) + if err != topology.ErrTopologyClosed { + t.Errorf("Expected a topology disconnected error. got %v; want %v", err, topology.ErrTopologyClosed) + } + }) + t.Run("all open sockets should be closed after disconnect", func(t *testing.T) { + d := newdialer(&net.Dialer{}) + topo, err := topology.New( + topology.WithConnString( + func(connstring.ConnString) connstring.ConnString { return connectionString }, + ), + topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption { + return append( + opts, + topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option { + return append( + opts, + connection.WithDialer(func(connection.Dialer) connection.Dialer { return d }), + ) + }), + ) + }), + ) + noerr(t, err) + err = topo.Connect(context.TODO()) + noerr(t, err) + ss, err := topo.SelectServer(context.TODO(), description.WriteSelector()) + noerr(t, err) + + conns := [3]connection.Connection{} + for idx := range [3]struct{}{} { + conns[idx], err = ss.Connection(context.TODO()) + noerr(t, err) + } + for idx := range [2]struct{}{} { + err = conns[idx].Close() + noerr(t, err) + } + if d.lenopened() < 3 { + t.Errorf("Should have opened at least 3 connections, but didn't. got %d; want >%d", d.lenopened(), 3) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = topo.Disconnect(ctx) + noerr(t, err) + if d.lenclosed() != d.lenopened() { + t.Errorf( + "Should have closed the same number of connections as opened. closed %d; opened %d", + d.lenclosed(), d.lenopened()) + } + }) + }) + t.Run("Connect", func(t *testing.T) { + t.Run("can reconnect a disconnected topology", func(t *testing.T) { + topo, err := topology.New( + topology.WithConnString( + func(connstring.ConnString) connstring.ConnString { return connectionString }, + ), + ) + noerr(t, err) + err = topo.Connect(context.TODO()) + noerr(t, err) + err = topo.Disconnect(context.TODO()) + noerr(t, err) + err = topo.Connect(context.TODO()) + noerr(t, err) + }) + t.Run("cannot connect multiple times without disconnect", func(t *testing.T) { + topo, err := topology.New( + topology.WithConnString( + func(connstring.ConnString) connstring.ConnString { return connectionString }, + ), + ) + noerr(t, err) + err = topo.Connect(context.TODO()) + noerr(t, err) + err = topo.Disconnect(context.TODO()) + noerr(t, err) + err = topo.Connect(context.TODO()) + noerr(t, err) + err = topo.Connect(context.TODO()) + if err != topology.ErrTopologyConnected { + t.Errorf("Expected a topology connected error. got %v; want %v", err, topology.ErrTopologyConnected) + } + }) + t.Run("can disconnect and reconnect multiple times", func(t *testing.T) { + topo, err := topology.New( + topology.WithConnString( + func(connstring.ConnString) connstring.ConnString { return connectionString }, + ), + ) + noerr(t, err) + err = topo.Connect(context.TODO()) + noerr(t, err) + + err = topo.Disconnect(context.TODO()) + noerr(t, err) + err = topo.Connect(context.TODO()) + noerr(t, err) + + err = topo.Disconnect(context.TODO()) + noerr(t, err) + err = topo.Connect(context.TODO()) + noerr(t, err) + + err = topo.Disconnect(context.TODO()) + noerr(t, err) + err = topo.Connect(context.TODO()) + noerr(t, err) + }) + }) +} diff --git a/core/topology/connection.go b/core/topology/connection.go index d608301f92..5b33a272f5 100644 --- a/core/topology/connection.go +++ b/core/topology/connection.go @@ -8,7 +8,6 @@ package topology import ( "context" - "errors" "net" "github.com/mongodb/mongo-go-driver/core/connection" @@ -20,36 +19,22 @@ import ( // error is returned, the pool on the server can be cleared. type sconn struct { connection.Connection - s *Server + s *Server + id uint64 } func (sc *sconn) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) { - if sc.Connection == nil { - return nil, errors.New("already closed") - } wm, err := sc.Connection.ReadWireMessage(ctx) sc.processErr(err) return wm, err } func (sc *sconn) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error { - if sc.Connection == nil { - return errors.New("already closed") - } err := sc.Connection.WriteWireMessage(ctx, wm) sc.processErr(err) return err } -func (sc *sconn) Close() error { - if sc.Connection == nil { - return nil - } - err := sc.Connection.Close() - sc.Connection = nil - return err -} - func (sc *sconn) processErr(err error) { ne, ok := err.(connection.NetworkError) if !ok { diff --git a/core/topology/initial_dns_seedlist_discovery_test.go b/core/topology/initial_dns_seedlist_discovery_test.go index 7c541dd388..682f242365 100644 --- a/core/topology/initial_dns_seedlist_discovery_test.go +++ b/core/topology/initial_dns_seedlist_discovery_test.go @@ -92,7 +92,7 @@ func runSeedlistTest(t *testing.T, filename string, test *seedlistTestCase) { // make a topology from the options c, err := New(WithConnString(func(connstring.ConnString) connstring.ConnString { return cs })) require.NoError(t, err) - c.Init() + c.Connect(context.Background()) for _, host := range test.Hosts { _, err := getServerByAddress(host, c) diff --git a/core/topology/server.go b/core/topology/server.go index 4da25919f2..614649ec8e 100644 --- a/core/topology/server.go +++ b/core/topology/server.go @@ -11,6 +11,7 @@ import ( "errors" "math" "sync" + "sync/atomic" "time" "github.com/mongodb/mongo-go-driver/bson" @@ -23,11 +24,16 @@ import ( ) const minHeartbeatInterval = 500 * time.Millisecond +const connectionSemaphoreSize = math.MaxInt64 // ErrServerClosed occurs when an attempt to get a connection is made after // the server has been closed. var ErrServerClosed = errors.New("server is closed") +// ErrServerConnected occurs when at attempt to connect is made after a server +// has already been connected. +var ErrServerConnected = errors.New("server is connected") + // SelectedServer represents a specific server that was selected during server selection. // It contains the kind of the typology it was selected from. type SelectedServer struct { @@ -45,20 +51,26 @@ func (ss *SelectedServer) Description() description.SelectedServer { } } +// These constants represent the connection states of a server. +const ( + disconnected int32 = iota + disconnecting + connected + connecting +) + // Server is a single server within a topology. type Server struct { cfg *serverConfig address addr.Addr - l sync.Mutex - closed bool - done chan struct{} - checkNow chan struct{} - closewg sync.WaitGroup - pool connection.Pool + connectionstate int32 + done chan struct{} + checkNow chan struct{} + closewg sync.WaitGroup + pool connection.Pool - desc description.Server - dmtx sync.RWMutex + desc atomic.Value // holds a description.Server averageRTTSet bool averageRTT time.Duration @@ -69,6 +81,20 @@ type Server struct { subscriptionsClosed bool } +// ConnectServer creates a new Server and then initializes it using the +// Connect method. +func ConnectServer(ctx context.Context, address addr.Addr, opts ...ServerOption) (*Server, error) { + srvr, err := NewServer(address, opts...) + if err != nil { + return nil, err + } + err = srvr.Connect(ctx) + if err != nil { + return nil, err + } + return srvr, nil +} + // NewServer creates a new server. The mongodb server at the address will be monitored // on an internal monitoring goroutine. func NewServer(address addr.Addr, opts ...ServerOption) (*Server, error) { @@ -84,10 +110,9 @@ func NewServer(address addr.Addr, opts ...ServerOption) (*Server, error) { done: make(chan struct{}), checkNow: make(chan struct{}, 1), - desc: description.Server{Addr: address}, - subscribers: make(map[uint64]chan description.Server), } + s.desc.Store(description.Server{Addr: address}) var maxConns uint64 if cfg.maxConns == 0 { @@ -96,41 +121,59 @@ func NewServer(address addr.Addr, opts ...ServerOption) (*Server, error) { maxConns = uint64(cfg.maxConns) } - // TODO(skriptble): Add a configurer here that will take any newly dialed connections for this pool - // and put their server descriptions through the fsm. s.pool, err = connection.NewPool(address, uint64(cfg.maxIdleConns), maxConns, cfg.connectionOpts...) if err != nil { return nil, err } - go s.update() - s.closewg.Add(1) - return s, nil } -// Close closes the server. -func (s *Server) Close() error { - s.l.Lock() - defer s.l.Unlock() +// Connect initialzies the Server by starting background monitoring goroutines. +// This method must be called before a Server can be used. +func (s *Server) Connect(ctx context.Context) error { + if !atomic.CompareAndSwapInt32(&s.connectionstate, disconnected, connected) { + return ErrServerConnected + } + s.desc.Store(description.Server{Addr: s.address}) + go s.update() + s.closewg.Add(1) + return s.pool.Connect(ctx) +} - if s.closed { - return nil +// Disconnect closes sockets to the server referenced by this Server. +// Subscriptions to this Server will be closed. Disconnect will shutdown +// any monitoring goroutines, close the idle connection pool, and will +// wait until all the in use connections have been returned to the connection +// pool and are closed before returning. If the context expires via +// cancellation, deadline, or timeout before the in use connections have been +// returned, the in use connections will be closed, resulting in the failure of +// any in flight read or write operations. If this method returns with no +// errors, all connections associated with this Server have been closed. +func (s *Server) Disconnect(ctx context.Context) error { + if !atomic.CompareAndSwapInt32(&s.connectionstate, connected, disconnecting) { + return ErrServerClosed } - s.closed = true - close(s.done) - err := s.pool.Close() + // For every call to Connect there must be at least 1 goroutine that is + // waiting on the done channel. + s.done <- struct{}{} + err := s.pool.Disconnect(ctx) if err != nil { return err } + s.closewg.Wait() + atomic.StoreInt32(&s.connectionstate, disconnected) return nil } // Connection gets a connection to the server. func (s *Server) Connection(ctx context.Context) (connection.Connection, error) { + if atomic.LoadInt32(&s.connectionstate) != connected { + return nil, ErrServerClosed + } conn, desc, err := s.pool.Get(ctx) if err != nil { return nil, err @@ -138,14 +181,13 @@ func (s *Server) Connection(ctx context.Context) (connection.Connection, error) if desc != nil { go s.updateDescription(*desc, false) } - return &sconn{Connection: conn, s: s}, nil + sc := &sconn{Connection: conn, s: s} + return sc, nil } // Description returns a description of the server as of the last heartbeat. func (s *Server) Description() description.Server { - s.dmtx.RLock() - defer s.dmtx.RUnlock() - return s.desc + return s.desc.Load().(description.Server) } // SelectedDescription returns a description.SelectedServer with a Kind of @@ -163,10 +205,11 @@ func (s *Server) SelectedDescription() description.SelectedServer { // updated server descriptions will be sent. The channel will have a buffer // size of one, and will be pre-populated with the current description. func (s *Server) Subscribe() (*ServerSubscription, error) { + if atomic.LoadInt32(&s.connectionstate) != connected { + return nil, ErrSubscribeAfterClosed + } ch := make(chan description.Server, 1) - s.dmtx.Lock() - defer s.dmtx.Unlock() - ch <- s.desc + ch <- s.desc.Load().(description.Server) s.subLock.Lock() defer s.subLock.Unlock() @@ -198,15 +241,19 @@ func (s *Server) RequestImmediateCheck() { // update handles performing heartbeats and updating any subscribers of the // newest description.Server retrieved. func (s *Server) update() { - defer func() { - // TODO(skriptble): What should we do here? - _ = recover() - }() + defer s.closewg.Done() heartbeatTicker := time.NewTicker(s.cfg.heartbeatInterval) rateLimiter := time.NewTicker(minHeartbeatInterval) checkNow := s.checkNow done := s.done + defer func() { + if r := recover(); r != nil { + // We keep this goroutine alive attempting to read from the done channel. + <-done + } + }() + var conn connection.Connection var desc description.Server @@ -226,7 +273,6 @@ func (s *Server) update() { s.subscriptionsClosed = true s.subLock.Unlock() conn.Close() - s.closewg.Done() return } @@ -246,9 +292,7 @@ func (s *Server) updateDescription(desc description.Server, initial bool) { // ¯\_(ツ)_/¯ _ = recover() }() - s.dmtx.Lock() - s.desc = desc - s.dmtx.Unlock() + s.desc.Store(desc) s.subLock.Lock() for _, c := range s.subscribers { diff --git a/core/topology/topology.go b/core/topology/topology.go index dab9ee1bf1..871175fb42 100644 --- a/core/topology/topology.go +++ b/core/topology/topology.go @@ -15,6 +15,7 @@ import ( "errors" "math/rand" "sync" + "sync/atomic" "time" "github.com/mongodb/mongo-go-driver/core/addr" @@ -29,6 +30,10 @@ var ErrSubscribeAfterClosed = errors.New("cannot subscribe after close") // closed Topology. var ErrTopologyClosed = errors.New("topology is closed") +// ErrTopologyConnected is returned whena user attempts to connect to an +// already connected Topology. +var ErrTopologyConnected = errors.New("topology is connected or connecting") + // ErrServerSelectionTimeout is returned from server selection when the server // selection process took longer than allowed by the timeout. var ErrServerSelectionTimeout = errors.New("server selection timeout") @@ -44,26 +49,30 @@ const ( // Topology respresents a MongoDB deployment. type Topology struct { - // There are too many closed booleans, but we can fix that later - // with a refactor. For now, just making a definitive one to guard - // the Close method. - closed bool - initialized bool - l sync.Mutex + connectionstate int32 cfg *config - desc description.Topology - dmtx sync.Mutex + desc atomic.Value // holds a description.Topology + + done chan struct{} - fsm *fsm - changes chan description.Server + fsm *fsm + changes chan description.Server + changeswg sync.WaitGroup + // This should really be encapsulated into it's own type. This will likely + // require a redesign so we can share a minimum of data between the + // subscribers and the topology. subscribers map[uint64]chan description.Topology currentSubscriberID uint64 subscriptionsClosed bool subLock sync.Mutex + // We should redesign how we connect and handle individal servers. This is + // too difficult to maintain and it's rather easy to accidentally access + // the servers without acquiring the lock or checking if the servers are + // closed. This lock should also be an RWMutex. serversLock sync.Mutex serversClosed bool servers map[addr.Addr]*Server @@ -80,11 +89,13 @@ func New(opts ...Option) (*Topology, error) { t := &Topology{ cfg: cfg, + done: make(chan struct{}), fsm: newFSM(), changes: make(chan description.Server), subscribers: make(map[uint64]chan description.Topology), servers: make(map[addr.Addr]*Server), } + t.desc.Store(description.Topology{}) if cfg.replicaSetName != "" { t.fsm.SetName = cfg.replicaSetName @@ -98,69 +109,76 @@ func New(opts ...Option) (*Topology, error) { return t, nil } -// Init initializes a Topology and starts the monitoring process. This function +// Connect initializes a Topology and starts the monitoring process. This function // must be called to properly monitor the topology. -func (t *Topology) Init() { - t.l.Lock() - defer t.l.Unlock() - if t.initialized { - return +func (t *Topology) Connect(ctx context.Context) error { + if !atomic.CompareAndSwapInt32(&t.connectionstate, disconnected, connecting) { + return ErrTopologyConnected } - t.initialized = true + t.desc.Store(description.Topology{}) + var err error t.serversLock.Lock() for _, a := range t.cfg.seedList { address := addr.Addr(a).Canonicalize() t.fsm.Servers = append(t.fsm.Servers, description.Server{Addr: address}) - t.addServer(address) + err = t.addServer(ctx, address) } t.serversLock.Unlock() go t.update() + t.changeswg.Add(1) + + atomic.StoreInt32(&t.connectionstate, connected) + return err } -// Close closes the topology. It stops the monitoring thread and +// Disconnect closes the topology. It stops the monitoring thread and // closes all open subscriptions. -func (t *Topology) Close() error { - t.l.Lock() - defer t.l.Unlock() - if t.closed { - return nil +func (t *Topology) Disconnect(ctx context.Context) error { + if !atomic.CompareAndSwapInt32(&t.connectionstate, connected, disconnecting) { + return ErrTopologyClosed } - t.closed = true t.serversLock.Lock() t.serversClosed = true for address, server := range t.servers { - t.removeServer(address, server) + t.removeServer(ctx, address, server) } t.serversLock.Unlock() t.wg.Wait() - close(t.changes) + t.done <- struct{}{} + t.changeswg.Wait() - t.dmtx.Lock() - t.desc = description.Topology{} - t.dmtx.Unlock() + t.desc.Store(description.Topology{}) + atomic.StoreInt32(&t.connectionstate, disconnected) return nil } // Description returns a description of the topology. func (t *Topology) Description() description.Topology { - t.dmtx.Lock() - defer t.dmtx.Unlock() - return t.desc + td, ok := t.desc.Load().(description.Topology) + if !ok { + td = description.Topology{} + } + return td } // Subscribe returns a Subscription on which all updated description.Topologys // will be sent. The channel of the subscription will have a buffer size of one, // and will be pre-populated with the current description.Topology. func (t *Topology) Subscribe() (*Subscription, error) { + if atomic.LoadInt32(&t.connectionstate) != connected { + return nil, errors.New("cannot subscribe to Topology that is not connected") + } ch := make(chan description.Topology, 1) - t.dmtx.Lock() - ch <- t.desc - t.dmtx.Unlock() + td, ok := t.desc.Load().(description.Topology) + if !ok { + td = description.Topology{} + } + ch <- td t.subLock.Lock() defer t.subLock.Unlock() @@ -181,6 +199,9 @@ func (t *Topology) Subscribe() (*Subscription, error) { // RequestImmediateCheck will send heartbeats to all the servers in the // topology right away, instead of waiting for the heartbeat timeout. func (t *Topology) RequestImmediateCheck() { + if atomic.LoadInt32(&t.connectionstate) != connected { + return + } t.serversLock.Lock() for _, server := range t.servers { server.RequestImmediateCheck() @@ -192,6 +213,9 @@ func (t *Topology) RequestImmediateCheck() { // server selection spec, and will time out after severSelectionTimeout or when the // parent context is done. func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelector) (*SelectedServer, error) { + if atomic.LoadInt32(&t.connectionstate) != connected { + return nil, ErrTopologyClosed + } var ssTimeoutCh <-chan time.Time if t.cfg.serverSelectionTimeout > 0 { @@ -230,11 +254,11 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect // findServer will attempt to find a server that fits the given server description. // This method will return nil, nil if a matching server could not be found. func (t *Topology) findServer(selected description.Server) (*SelectedServer, error) { - t.l.Lock() - defer t.l.Unlock() - if t.closed { + if atomic.LoadInt32(&t.connectionstate) != connected { return nil, ErrTopologyClosed } + t.serversLock.Lock() + defer t.serversLock.Unlock() server, ok := t.servers[selected.Addr] if !ok { return nil, nil @@ -279,42 +303,48 @@ func (t *Topology) selectServer(ctx context.Context, subscriptionCh <-chan descr } func (t *Topology) update() { + defer t.changeswg.Done() defer func() { // ¯\_(ツ)_/¯ - _ = recover() + if r := recover(); r != nil { + <-t.done + } }() - for change := range t.changes { - current, err := t.apply(change) - if err != nil { - continue - } + for { + select { + case change := <-t.changes: + current, err := t.apply(context.TODO(), change) + if err != nil { + continue + } - t.dmtx.Lock() - t.desc = current - t.dmtx.Unlock() + t.desc.Store(current) - t.subLock.Lock() - for _, ch := range t.subscribers { - // We drain the description if there's one in the channel - select { - case <-ch: - default: + t.subLock.Lock() + for _, ch := range t.subscribers { + // We drain the description if there's one in the channel + select { + case <-ch: + default: + } + ch <- current } - ch <- current + t.subLock.Unlock() + case <-t.done: + t.subLock.Lock() + for id, ch := range t.subscribers { + close(ch) + delete(t.subscribers, id) + } + t.subscriptionsClosed = true + t.subLock.Unlock() + return } - t.subLock.Unlock() - } - t.subLock.Lock() - for id, ch := range t.subscribers { - close(ch) - delete(t.subscribers, id) } - t.subscriptionsClosed = true - t.subLock.Unlock() } -func (t *Topology) apply(desc description.Server) (description.Topology, error) { +func (t *Topology) apply(ctx context.Context, desc description.Server) (description.Topology, error) { var err error prev := t.fsm.Topology @@ -332,34 +362,32 @@ func (t *Topology) apply(desc description.Server) (description.Topology, error) for _, removed := range diff.Removed { if s, ok := t.servers[removed.Addr]; ok { - t.removeServer(removed.Addr, s) + t.removeServer(ctx, removed.Addr, s) } } for _, added := range diff.Added { - t.addServer(added.Addr) + t.addServer(ctx, added.Addr) } t.serversLock.Unlock() return current, nil } -func (t *Topology) addServer(address addr.Addr) { +func (t *Topology) addServer(ctx context.Context, address addr.Addr) error { if _, ok := t.servers[address]; ok { - return + return nil } - svr, err := NewServer(address, t.cfg.serverOpts...) + svr, err := ConnectServer(ctx, address, t.cfg.serverOpts...) if err != nil { - // ¯\_(ツ)_/¯ - return + return err } t.servers[address] = svr var sub *ServerSubscription sub, err = svr.Subscribe() if err != nil { - // ¯\_(ツ)_/¯ - return + return err } t.wg.Add(1) @@ -369,10 +397,12 @@ func (t *Topology) addServer(address addr.Addr) { } t.wg.Done() }() + + return nil } -func (t *Topology) removeServer(address addr.Addr, server *Server) { - server.Close() +func (t *Topology) removeServer(ctx context.Context, address addr.Addr, server *Server) { + _ = server.Disconnect(ctx) delete(t.servers, address) } diff --git a/internal/testutil/config.go b/internal/testutil/config.go index e5e785d171..8ce5fde21f 100644 --- a/internal/testutil/config.go +++ b/internal/testutil/config.go @@ -70,7 +70,7 @@ func Topology(t *testing.T) *topology.Topology { if err != nil { liveTopologyErr = err } else { - liveTopology.Init() + liveTopology.Connect(context.Background()) s, err := liveTopology.SelectServer(context.Background(), description.WriteSelector()) require.NoError(t, err) diff --git a/mongo/client.go b/mongo/client.go index 69d3f092d7..5d8f90d1d7 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -34,6 +34,19 @@ type Client struct { writeConcern *writeconcern.WriteConcern } +// Connect creates a new Client and then initializes it using the Connect method. +func Connect(ctx context.Context, uri string, opts *ClientOptions) (*Client, error) { + c, err := NewClientWithOptions(uri, opts) + if err != nil { + return nil, err + } + err = c.Connect(ctx) + if err != nil { + return nil, err + } + return c, nil +} + // NewClient creates a new client to connect to a cluster specified by the uri. func NewClient(uri string) (*Client, error) { cs, err := connstring.Parse(uri) @@ -62,6 +75,24 @@ func NewClientFromConnString(cs connstring.ConnString) (*Client, error) { return newClient(cs, nil) } +// Connect initializes the Client by starting background monitoring goroutines. +// This method must be called before a Client can be used. +func (c *Client) Connect(ctx context.Context) error { + return c.topology.Connect(ctx) +} + +// Disconnect closes sockets to the topology referenced by this Client. It will +// shut down any monitoring goroutines, close the idle connection pool, and will +// wait until all the in use connections have been returned to the connection +// pool and closed before returning. If the context expires via cancellation, +// deadline, or timeout before the in use connections have returned, the in use +// connections will be closed, resulting in the failure of any in flight read +// or write operations. If this method returns with no errors, all connections +// associated with this Client have been closed. +func (c *Client) Disconnect(ctx context.Context) error { + return c.topology.Disconnect(ctx) +} + func newClient(cs connstring.ConnString, opts *ClientOptions) (*Client, error) { client := &Client{ connString: cs, @@ -88,7 +119,10 @@ func newClient(cs connstring.ConnString, opts *ClientOptions) (*Client, error) { return nil, err } - topo.Init() + err = topo.Connect(context.Background()) + if err != nil { + return nil, err + } client.topology = topo client.readConcern = readConcernFromConnString(&client.connString) @@ -145,16 +179,16 @@ func writeConcernFromConnString(cs *connstring.ConnString) *writeconcern.WriteCo } // Database returns a handle for a given database. -func (client *Client) Database(name string) *Database { - return newDatabase(client, name) +func (c *Client) Database(name string) *Database { + return newDatabase(c, name) } // ConnectionString returns the connection string of the cluster the client is connected to. -func (client *Client) ConnectionString() string { - return client.connString.Original +func (c *Client) ConnectionString() string { + return c.connString.Original } -func (client *Client) listDatabasesHelper(ctx context.Context, filter interface{}, +func (c *Client) listDatabasesHelper(ctx context.Context, filter interface{}, nameOnly bool) (ListDatabasesResult, error) { f, err := TransformDocument(filter) @@ -172,7 +206,7 @@ func (client *Client) listDatabasesHelper(ctx context.Context, filter interface{ // The spec indicates that we should not run the listDatabase command on a secondary in a // replica set. - res, err := dispatch.ListDatabases(ctx, cmd, client.topology, description.ReadPrefSelector(readpref.Primary())) + res, err := dispatch.ListDatabases(ctx, cmd, c.topology, description.ReadPrefSelector(readpref.Primary())) if err != nil { return ListDatabasesResult{}, err } @@ -180,13 +214,13 @@ func (client *Client) listDatabasesHelper(ctx context.Context, filter interface{ } // ListDatabases returns a ListDatabasesResult. -func (client *Client) ListDatabases(ctx context.Context, filter interface{}) (ListDatabasesResult, error) { - return client.listDatabasesHelper(ctx, filter, false) +func (c *Client) ListDatabases(ctx context.Context, filter interface{}) (ListDatabasesResult, error) { + return c.listDatabasesHelper(ctx, filter, false) } // ListDatabaseNames returns a slice containing the names of all of the databases on the server. -func (client *Client) ListDatabaseNames(ctx context.Context, filter interface{}) ([]string, error) { - res, err := client.listDatabasesHelper(ctx, filter, true) +func (c *Client) ListDatabaseNames(ctx context.Context, filter interface{}) ([]string, error) { + res, err := c.listDatabasesHelper(ctx, filter, true) if err != nil { return nil, err }