diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index c35f5e5391..af585e3535 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -154,8 +154,8 @@ export class Connection extends TypedEventEmitter { address: string; socketTimeoutMS: number; monitorCommands: boolean; + /** Indicates that the connection (including underlying TCP socket) has been closed. */ closed: boolean; - destroyed: boolean; lastHelloMS?: number; serverApi?: ServerApi; helloOk?: boolean; @@ -204,7 +204,6 @@ export class Connection extends TypedEventEmitter { this.monitorCommands = options.monitorCommands; this.serverApi = options.serverApi; this.closed = false; - this.destroyed = false; this[kHello] = null; this[kClusterTime] = null; @@ -297,10 +296,7 @@ export class Connection extends TypedEventEmitter { if (this.closed) { return; } - - this[kStream].destroy(error); - - this.closed = true; + this.destroy({ force: false }); for (const op of this[kQueue].values()) { op.cb(error); @@ -314,8 +310,7 @@ export class Connection extends TypedEventEmitter { if (this.closed) { return; } - - this.closed = true; + this.destroy({ force: false }); const message = `connection ${this.id} to ${this.address} closed`; for (const op of this[kQueue].values()) { @@ -332,9 +327,7 @@ export class Connection extends TypedEventEmitter { } this[kDelayedTimeoutId] = setTimeout(() => { - this[kStream].destroy(); - - this.closed = true; + this.destroy({ force: false }); const message = `connection ${this.id} to ${this.address} timed out`; const beforeHandshake = this.hello == null; @@ -447,31 +440,23 @@ export class Connection extends TypedEventEmitter { this.removeAllListeners(Connection.PINNED); this.removeAllListeners(Connection.UNPINNED); - if (this[kStream] == null || this.destroyed) { - this.destroyed = true; - if (typeof callback === 'function') { - callback(); - } - - return; - } + this[kMessageStream].destroy(); + this.closed = true; if (options.force) { this[kStream].destroy(); - this.destroyed = true; - if (typeof callback === 'function') { - callback(); + if (callback) { + return process.nextTick(callback); } - - return; } - this[kStream].end(() => { - this.destroyed = true; - if (typeof callback === 'function') { - callback(); + if (!this[kStream].writableEnded) { + this[kStream].end(callback); + } else { + if (callback) { + return process.nextTick(callback); } - }); + } } command( diff --git a/test/unit/cmap/connection.test.ts b/test/unit/cmap/connection.test.ts index c1801cf9e6..c510f1521d 100644 --- a/test/unit/cmap/connection.test.ts +++ b/test/unit/cmap/connection.test.ts @@ -31,6 +31,7 @@ const connectionOptionsDefaults = { /** The absolute minimum socket API needed by Connection as of writing this test */ class FakeSocket extends EventEmitter { + writableEnded: boolean; address() { // is never called } @@ -39,6 +40,14 @@ class FakeSocket extends EventEmitter { } destroy() { // is called, has no side effects + this.writableEnded = true; + } + end(cb) { + this.writableEnded = true; + // nextTick to simulate I/O delay + if (typeof cb === 'function') { + process.nextTick(cb); + } } get remoteAddress() { return 'iLoveJavaScript'; @@ -48,6 +57,20 @@ class FakeSocket extends EventEmitter { } } +class InputStream extends Readable { + writableEnded: boolean; + constructor(options?) { + super(options); + } + + end(cb) { + this.writableEnded = true; + if (typeof cb === 'function') { + process.nextTick(cb); + } + } +} + describe('new Connection()', function () { let server; after(() => mock.cleanup()); @@ -106,7 +129,7 @@ describe('new Connection()', function () { expect(err).to.be.instanceOf(MongoNetworkTimeoutError); expect(result).to.not.exist; - expect(conn).property('stream').property('destroyed', true); + expect(conn).property('stream').property('writableEnded', true); done(); }); @@ -175,7 +198,7 @@ describe('new Connection()', function () { context('when multiple hellos exist on the stream', function () { let callbackSpy; - const inputStream = new Readable(); + const inputStream = new InputStream(); const document = { ok: 1 }; const last = { isWritablePrimary: true }; @@ -394,7 +417,7 @@ describe('new Connection()', function () { connection = sinon.spy(new Connection(driverSocket, connectionOptionsDefaults)); const messageStreamSymbol = getSymbolFrom(connection, 'messageStream'); kDelayedTimeoutId = getSymbolFrom(connection, 'delayedTimeoutId'); - messageStream = connection[messageStreamSymbol]; + messageStream = sinon.spy(connection[messageStreamSymbol]); }); afterEach(() => { @@ -407,13 +430,15 @@ describe('new Connection()', function () { driverSocket.emit('timeout'); expect(connection.onTimeout).to.have.been.calledOnce; + expect(connection.destroy).to.not.have.been.called; expect(connection).to.have.property(kDelayedTimeoutId).that.is.instanceOf(NodeJSTimeoutClass); expect(connection).to.have.property('closed', false); - expect(driverSocket.destroy).to.not.have.been.called; + expect(driverSocket.end).to.not.have.been.called; clock.tick(1); - expect(driverSocket.destroy).to.have.been.calledOnce; + expect(driverSocket.end).to.have.been.calledOnce; + expect(connection.destroy).to.have.been.calledOnce; expect(connection).to.have.property('closed', true); }); @@ -438,6 +463,88 @@ describe('new Connection()', function () { expect(connection).to.have.property('closed', false); expect(connection).to.have.property(kDelayedTimeoutId, null); }); + + it('destroys the message stream and socket', () => { + expect(connection).to.have.property(kDelayedTimeoutId, null); + + driverSocket.emit('timeout'); + + clock.tick(1); + + expect(connection.onTimeout).to.have.been.calledOnce; + expect(connection).to.have.property(kDelayedTimeoutId).that.is.instanceOf(NodeJSTimeoutClass); + + expect(messageStream.destroy).to.have.been.calledOnce; + expect(driverSocket.destroy).to.not.have.been.called; + expect(driverSocket.end).to.have.been.calledOnce; + }); + }); + + describe('onError()', () => { + let connection: sinon.SinonSpiedInstance; + let clock: sinon.SinonFakeTimers; + let timerSandbox: sinon.SinonFakeTimers; + let driverSocket: sinon.SinonSpiedInstance; + let messageStream: MessageStream; + beforeEach(() => { + timerSandbox = createTimerSandbox(); + clock = sinon.useFakeTimers(); + driverSocket = sinon.spy(new FakeSocket()); + // @ts-expect-error: driverSocket does not fully satisfy the stream type, but that's okay + connection = sinon.spy(new Connection(driverSocket, connectionOptionsDefaults)); + const messageStreamSymbol = getSymbolFrom(connection, 'messageStream'); + messageStream = sinon.spy(connection[messageStreamSymbol]); + }); + + afterEach(() => { + timerSandbox.restore(); + clock.restore(); + }); + + it('destroys the message stream and socket', () => { + messageStream.emit('error'); + clock.tick(1); + expect(connection.onError).to.have.been.calledOnce; + connection.destroy({ force: false }); + clock.tick(1); + expect(messageStream.destroy).to.have.been.called; + expect(driverSocket.destroy).to.not.have.been.called; + expect(driverSocket.end).to.have.been.calledOnce; + }); + }); + + describe('onClose()', () => { + let connection: sinon.SinonSpiedInstance; + let clock: sinon.SinonFakeTimers; + let timerSandbox: sinon.SinonFakeTimers; + let driverSocket: sinon.SinonSpiedInstance; + let messageStream: MessageStream; + beforeEach(() => { + timerSandbox = createTimerSandbox(); + clock = sinon.useFakeTimers(); + + driverSocket = sinon.spy(new FakeSocket()); + // @ts-expect-error: driverSocket does not fully satisfy the stream type, but that's okay + connection = sinon.spy(new Connection(driverSocket, connectionOptionsDefaults)); + const messageStreamSymbol = getSymbolFrom(connection, 'messageStream'); + messageStream = sinon.spy(connection[messageStreamSymbol]); + }); + + afterEach(() => { + timerSandbox.restore(); + clock.restore(); + }); + + it('destroys the message stream and socket', () => { + driverSocket.emit('close'); + clock.tick(1); + expect(connection.onClose).to.have.been.calledOnce; + connection.destroy({ force: false }); + clock.tick(1); + expect(messageStream.destroy).to.have.been.called; + expect(driverSocket.destroy).to.not.have.been.called; + expect(driverSocket.end).to.have.been.calledOnce; + }); }); describe('.hasSessionSupport', function () { @@ -491,4 +598,96 @@ describe('new Connection()', function () { }); }); }); + + describe('destroy()', () => { + let connection: sinon.SinonSpiedInstance; + let clock: sinon.SinonFakeTimers; + let timerSandbox: sinon.SinonFakeTimers; + let driverSocket: sinon.SinonSpiedInstance; + let messageStream: MessageStream; + beforeEach(() => { + timerSandbox = createTimerSandbox(); + clock = sinon.useFakeTimers(); + + driverSocket = sinon.spy(new FakeSocket()); + // @ts-expect-error: driverSocket does not fully satisfy the stream type, but that's okay + connection = sinon.spy(new Connection(driverSocket, connectionOptionsDefaults)); + const messageStreamSymbol = getSymbolFrom(connection, 'messageStream'); + messageStream = sinon.spy(connection[messageStreamSymbol]); + }); + + afterEach(() => { + timerSandbox.restore(); + clock.restore(); + }); + + context('when options.force == true', function () { + it('calls stream.destroy', () => { + connection.destroy({ force: true }); + clock.tick(1); + expect(driverSocket.destroy).to.have.been.calledOnce; + }); + + it('does not call stream.end', () => { + connection.destroy({ force: true }); + clock.tick(1); + expect(driverSocket.end).to.not.have.been.called; + }); + + it('destroys the tcp socket', () => { + connection.destroy({ force: true }); + clock.tick(1); + expect(driverSocket.destroy).to.have.been.calledOnce; + }); + + it('destroys the messageStream', () => { + connection.destroy({ force: true }); + clock.tick(1); + expect(messageStream.destroy).to.have.been.calledOnce; + }); + + it('calls stream.destroy whenever destroy is called ', () => { + connection.destroy({ force: true }); + connection.destroy({ force: true }); + connection.destroy({ force: true }); + clock.tick(1); + expect(driverSocket.destroy).to.have.been.calledThrice; + }); + }); + + context('when options.force == false', function () { + it('calls stream.end', () => { + connection.destroy({ force: false }); + clock.tick(1); + expect(driverSocket.end).to.have.been.calledOnce; + }); + + it('does not call stream.destroy', () => { + connection.destroy({ force: false }); + clock.tick(1); + expect(driverSocket.destroy).to.not.have.been.called; + }); + + it('ends the tcp socket', () => { + connection.destroy({ force: false }); + clock.tick(1); + expect(driverSocket.end).to.have.been.calledOnce; + }); + + it('destroys the messageStream', () => { + connection.destroy({ force: false }); + clock.tick(1); + expect(messageStream.destroy).to.have.been.calledOnce; + }); + + it('calls stream.end exactly once when destroy is called multiple times', () => { + connection.destroy({ force: false }); + connection.destroy({ force: false }); + connection.destroy({ force: false }); + connection.destroy({ force: false }); + clock.tick(1); + expect(driverSocket.end).to.have.been.calledOnce; + }); + }); + }); });