diff --git a/packages/polling-controller/src/PollingController.test.ts b/packages/polling-controller/src/PollingController.test.ts index b6a9344de5..0a095c4547 100644 --- a/packages/polling-controller/src/PollingController.test.ts +++ b/packages/polling-controller/src/PollingController.test.ts @@ -1,4 +1,5 @@ import { ControllerMessenger } from '@metamask/base-controller'; +import EventEmitter from 'events'; import { useFakeTimers } from 'sinon'; import { advanceTime } from '../../../tests/helpers'; @@ -13,9 +14,22 @@ const createExecutePollMock = () => { return executePollMock; }; +class MyGasFeeController extends PollingController { + _executePoll = createExecutePollMock(); +} + describe('PollingController', () => { let clock: sinon.SinonFakeTimers; + let mockMessenger: any; + let controller: any; beforeEach(() => { + mockMessenger = new ControllerMessenger(); + controller = new MyGasFeeController({ + messenger: mockMessenger, + metadata: {}, + name: 'PollingController', + state: { foo: 'bar' }, + }); clock = useFakeTimers(); }); afterEach(() => { @@ -23,17 +37,6 @@ describe('PollingController', () => { }); describe('start', () => { it('should start polling if not polling', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); controller.startPollingByNetworkClientId('mainnet'); await advanceTime({ clock, duration: 0 }); expect(controller._executePoll).toHaveBeenCalledTimes(1); @@ -44,17 +47,6 @@ describe('PollingController', () => { }); describe('stop', () => { it('should stop polling when called with a valid polling that was the only active pollingToken for a given networkClient', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); const pollingToken = controller.startPollingByNetworkClientId('mainnet'); await advanceTime({ clock, duration: 0 }); expect(controller._executePoll).toHaveBeenCalledTimes(1); @@ -65,17 +57,6 @@ describe('PollingController', () => { controller.stopAllPolling(); }); it('should not stop polling if called with one of multiple active polling tokens for a given networkClient', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); const pollingToken1 = controller.startPollingByNetworkClientId('mainnet'); await advanceTime({ clock, duration: 0 }); @@ -88,17 +69,6 @@ describe('PollingController', () => { controller.stopAllPolling(); }); it('should error if no pollingToken is passed', () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); controller.startPollingByNetworkClientId('mainnet'); expect(() => { controller.stopPollingByPollingToken(undefined as unknown as any); @@ -106,17 +76,6 @@ describe('PollingController', () => { controller.stopAllPolling(); }); it('should error if no matching pollingToken is found', () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); controller.startPollingByNetworkClientId('mainnet'); expect(() => { controller.stopPollingByPollingToken('potato'); @@ -124,19 +83,20 @@ describe('PollingController', () => { controller.stopAllPolling(); }); }); - describe('startPollingByNetworkClientId', () => { - it('should call _executePoll immediately and on interval if polling', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, + describe('setIntervalLength', () => { + it('should set getNetworkClientById (if previously set by setPollWithBlockTracker) to undefined when setting interval length', async () => { + controller.setPollWithBlockTracker(() => { + throw new Error('should not be called'); }); + expect(controller.getPollingWithBlockTracker()).toBe(true); + controller.setIntervalLength(1000); + expect(controller.getPollingWithBlockTracker()).toBe(false); + }); + }); + + describe('startPollingByNetworkClientId', () => { + it('should call _executePoll immediately and on interval if polling', async () => { controller.startPollingByNetworkClientId('mainnet'); await advanceTime({ clock, duration: 0 }); expect(controller._executePoll).toHaveBeenCalledTimes(1); @@ -144,17 +104,6 @@ describe('PollingController', () => { expect(controller._executePoll).toHaveBeenCalledTimes(3); }); it('should call _executePoll immediately once and continue calling _executePoll on interval when start is called again with the same networkClientId', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); controller.startPollingByNetworkClientId('mainnet'); await advanceTime({ clock, duration: 0 }); @@ -169,36 +118,12 @@ describe('PollingController', () => { }); it('should publish "pollingComplete" when stop is called', async () => { const pollingComplete: any = jest.fn(); - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const name = 'PollingController'; - - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name, - state: { foo: 'bar' }, - }); controller.onPollingCompleteByNetworkClientId('mainnet', pollingComplete); const pollingToken = controller.startPollingByNetworkClientId('mainnet'); controller.stopPollingByPollingToken(pollingToken); expect(pollingComplete).toHaveBeenCalledTimes(1); }); it('should poll at the interval length when set via setIntervalLength', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); controller.setIntervalLength(TICK_TIME); controller.startPollingByNetworkClientId('mainnet'); await advanceTime({ clock, duration: 0 }); @@ -211,17 +136,7 @@ describe('PollingController', () => { expect(controller._executePoll).toHaveBeenCalledTimes(2); }); it('should start and stop polling sessions for different networkClientIds with the same options', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); + controller.setIntervalLength(TICK_TIME); const pollToken1 = controller.startPollingByNetworkClientId('mainnet', { address: '0x1', }); @@ -263,17 +178,6 @@ describe('PollingController', () => { }); describe('multiple networkClientIds', () => { it('should poll for each networkClientId', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); controller.startPollingByNetworkClientId('mainnet'); await advanceTime({ clock, duration: 0 }); @@ -306,17 +210,6 @@ describe('PollingController', () => { }); it('should poll multiple networkClientIds when setting interval length', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); controller.setIntervalLength(TICK_TIME * 2); controller.startPollingByNetworkClientId('mainnet'); await advanceTime({ clock, duration: 0 }); @@ -368,19 +261,262 @@ describe('PollingController', () => { ['sepolia', {}], ]); }); + + describe('PollingControllerOnly', () => { + it('can be extended from and constructed', async () => { + class MyClass extends PollingControllerOnly { + _executePoll = createExecutePollMock(); + } + const c = new MyClass(); + expect(c._executePoll).toBeDefined(); + expect(c.getIntervalLength).toBeDefined(); + expect(c.setIntervalLength).toBeDefined(); + expect(c.stopAllPolling).toBeDefined(); + expect(c.startPollingByNetworkClientId).toBeDefined(); + expect(c.stopPollingByPollingToken).toBeDefined(); + }); + }); }); - describe('PollingControllerOnly', () => { - it('can be extended from and constructed', async () => { - class MyClass extends PollingControllerOnly { - _executePoll = createExecutePollMock(); + + describe('Polling on block times', () => { + class TestBlockTracker extends EventEmitter { + private latestBlockNumber: number; + + public interval: number; + + constructor({ interval } = { interval: 1000 }) { + super(); + this.latestBlockNumber = 0; + this.interval = interval; + this.start(interval); } - const c = new MyClass(); - expect(c._executePoll).toBeDefined(); - expect(c.getIntervalLength).toBeDefined(); - expect(c.setIntervalLength).toBeDefined(); - expect(c.stopAllPolling).toBeDefined(); - expect(c.startPollingByNetworkClientId).toBeDefined(); - expect(c.stopPollingByPollingToken).toBeDefined(); + + private start(interval: number) { + setInterval(() => { + this.latestBlockNumber += 1; + this.emit('latest', this.latestBlockNumber); + }, interval); + } + } + + let getNetworkClientById: jest.Mock; + let mainnetBlockTracker: TestBlockTracker; + let goerliBlockTracker: TestBlockTracker; + let sepoliaBlockTracker: TestBlockTracker; + beforeEach(() => { + mainnetBlockTracker = new TestBlockTracker({ interval: 5 }); + goerliBlockTracker = new TestBlockTracker({ interval: 10 }); + sepoliaBlockTracker = new TestBlockTracker({ interval: 15 }); + + getNetworkClientById = jest.fn().mockImplementation((networkClientId) => { + switch (networkClientId) { + case 'mainnet': + return { + blockTracker: mainnetBlockTracker, + }; + case 'goerli': + return { + blockTracker: goerliBlockTracker, + }; + case 'sepolia': + return { + blockTracker: sepoliaBlockTracker, + }; + default: + throw new Error(`Unknown networkClientId: ${networkClientId}`); + } + }); + }); + describe('setPollWithBlockTracker', () => { + it('should set the interval length to undefined', () => { + controller.setPollWithBlockTracker(getNetworkClientById); + + expect(controller.getIntervalLength()).toBeUndefined(); + }); + }); + + describe('startPollingByNetworkClientId', () => { + it('should start polling for the specified networkClientId', async () => { + controller.setPollWithBlockTracker(getNetworkClientById); + + controller.startPollingByNetworkClientId('mainnet'); + + expect(getNetworkClientById).toHaveBeenCalledWith('mainnet'); + + await advanceTime({ clock, duration: 5 }); + + expect(controller._executePoll).toHaveBeenCalledTimes(1); + + await advanceTime({ clock, duration: 1 }); + + expect(controller._executePoll).toHaveBeenCalledTimes(1); + + await advanceTime({ clock, duration: 4 }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + expect.arrayContaining(['mainnet', {}]), + expect.arrayContaining(['mainnet', {}]), + ]); + + controller.stopAllPolling(); + }); + + it('should poll on new block intervals for each networkClientId', async () => { + controller.setPollWithBlockTracker(getNetworkClientById); + + controller.startPollingByNetworkClientId('mainnet'); + controller.startPollingByNetworkClientId('goerli'); + await advanceTime({ clock, duration: 5 }); + + expect(controller._executePoll).toHaveBeenCalledTimes(1); + expect(controller._executePoll).toHaveBeenCalledWith('mainnet', {}, 1); + + await advanceTime({ clock, duration: 5 }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['goerli', {}, 1], + ]); + + await advanceTime({ clock, duration: 5 }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['goerli', {}, 1], + ['mainnet', {}, 3], + ]); + + // 15ms have passed + // Start polling for sepolia, 15ms interval + controller.startPollingByNetworkClientId('sepolia'); + + await advanceTime({ clock, duration: 15 }); + + // at 30ms, 6 blocks have passed for mainnet (every 5ms), 3 for goerli (every 10ms), and 2 for sepolia (every 15ms) + // Didn't start listening to sepolia until 15ms had passed, so we only call executePoll on the 2nd block + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['goerli', {}, 1], + ['mainnet', {}, 3], + ['mainnet', {}, 4], + ['goerli', {}, 2], + ['mainnet', {}, 5], + ['mainnet', {}, 6], + ['goerli', {}, 3], + ['sepolia', {}, 2], + ]); + + controller.stopAllPolling(); + }); + }); + + describe('stopPollingByPollingToken', () => { + it('should should stop polling when all polling tokens for a networkClientId are deleted', async () => { + controller.setPollWithBlockTracker(getNetworkClientById); + + const pollingToken1 = + controller.startPollingByNetworkClientId('mainnet'); + + await advanceTime({ clock, duration: 5 }); + + expect(controller._executePoll).toHaveBeenCalledTimes(1); + expect(controller._executePoll).toHaveBeenCalledWith('mainnet', {}, 1); + + const pollingToken2 = + controller.startPollingByNetworkClientId('mainnet'); + await advanceTime({ clock, duration: 5 }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ]); + + controller.stopPollingByPollingToken(pollingToken1); + + await advanceTime({ clock, duration: 5 }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['mainnet', {}, 3], + ]); + + controller.stopPollingByPollingToken(pollingToken2); + + await advanceTime({ clock, duration: 15 }); + + // no further polling should occur + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['mainnet', {}, 3], + ]); + }); + + it('should should stop polling for one networkClientId when all polling tokens for that networkClientId are deleted, without stopping polling for networkClientIds with active pollingTokens', async () => { + controller.setPollWithBlockTracker(getNetworkClientById); + + const pollingToken1 = + controller.startPollingByNetworkClientId('mainnet'); + + await advanceTime({ clock, duration: 5 }); + + expect(controller._executePoll).toHaveBeenCalledWith('mainnet', {}, 1); + + const pollingToken2 = + controller.startPollingByNetworkClientId('mainnet'); + + await advanceTime({ clock, duration: 5 }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ]); + + controller.startPollingByNetworkClientId('goerli'); + await advanceTime({ clock, duration: 5 }); + + // 3 blocks have passed for mainnet, 1 for goerli but we only started listening to goerli after 5ms + // so the next block will come at 20ms and be the 2nd block for goerli + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['mainnet', {}, 3], + ]); + + controller.stopPollingByPollingToken(pollingToken1); + + await advanceTime({ clock, duration: 5 }); + + // 20ms have passed, 4 blocks for mainnet, 2 for goerli + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['mainnet', {}, 3], + ['mainnet', {}, 4], + ['goerli', {}, 2], + ]); + + controller.stopPollingByPollingToken(pollingToken2); + + await advanceTime({ clock, duration: 20 }); + + // no further polling for mainnet should occur + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['mainnet', {}, 3], + ['mainnet', {}, 4], + ['goerli', {}, 2], + ['goerli', {}, 3], + ['goerli', {}, 4], + ]); + + controller.stopAllPolling(); + }); }); }); }); diff --git a/packages/polling-controller/src/PollingController.ts b/packages/polling-controller/src/PollingController.ts index 0327a7806e..4442391316 100644 --- a/packages/polling-controller/src/PollingController.ts +++ b/packages/polling-controller/src/PollingController.ts @@ -1,5 +1,8 @@ import { BaseController, BaseControllerV1 } from '@metamask/base-controller'; -import type { NetworkClientId } from '@metamask/network-controller'; +import type { + NetworkClientId, + NetworkClient, +} from '@metamask/network-controller'; import type { Json } from '@metamask/utils'; import stringify from 'fast-json-stable-stringify'; import { v4 as random } from 'uuid'; @@ -38,17 +41,30 @@ function PollingControllerMixin(Base: TBase) { readonly #intervalIds: Record = {}; + readonly #activeListeners: Record< + PollingTokenSetId, + (options: Json) => Promise + > = {}; + #callbacks: Map< NetworkClientId, Set<(networkClientId: NetworkClientId) => void> > = new Map(); - #intervalLength = 1000; + #intervalLength: number | undefined = 1000; + + #getNetworkClientById: + | ((networkClientId: NetworkClientId) => NetworkClient) + | undefined; getIntervalLength() { return this.#intervalLength; } + getPollingWithBlockTracker() { + return this.#getNetworkClientById !== undefined; + } + /** * Sets the length of the polling interval * @@ -56,6 +72,22 @@ function PollingControllerMixin(Base: TBase) { */ setIntervalLength(length: number) { this.#intervalLength = length; + + // setting and using an interval is mutually exclusive with polling on new blocks + this.#getNetworkClientById = undefined; + } + + setPollWithBlockTracker( + getNetworkClientById: (networkClientId: NetworkClientId) => NetworkClient, + ) { + if (!getNetworkClientById) { + throw new Error('getNetworkClientById callback required'); + } + + this.#getNetworkClientById = getNetworkClientById; + + // using block times is mutually exclusive with polling on a static interval + this.#intervalLength = undefined; } /** @@ -111,8 +143,27 @@ function PollingControllerMixin(Base: TBase) { found = true; tokenSet.delete(pollingToken); if (tokenSet.size === 0) { - clearTimeout(this.#intervalIds[key]); - delete this.#intervalIds[key]; + // if applicable stop polling on a static interval + if (this.#intervalIds[key]) { + clearTimeout(this.#intervalIds[key]); + delete this.#intervalIds[key]; + } else if ( + // if applicable stop listening for new blocks + this.#getNetworkClientById !== undefined && + this.#activeListeners[key] + ) { + const [networkClientId] = key.split(':'); + const { blockTracker } = + this.#getNetworkClientById(networkClientId); + if (blockTracker) { + blockTracker.removeListener( + 'latest', + this.#activeListeners[key], + ); + } + delete this.#activeListeners[key]; + } + this.#pollingTokenSets.delete(key); this.#callbacks.get(key)?.forEach((callback) => { callback(key); @@ -139,6 +190,32 @@ function PollingControllerMixin(Base: TBase) { #poll(networkClientId: NetworkClientId, options: Json) { const key = getKey(networkClientId, options); + + // if #getNetworkClientById is defined, we want to poll on new blocks + if (this.#getNetworkClientById !== undefined) { + // if we're already listening for new blocks for this key, don't add another listener + if (this.#activeListeners[key]) { + return; + } + const blockTracker = + this.#getNetworkClientById(networkClientId)?.blockTracker; + + if (blockTracker) { + const updateOnNewBlock = this._executePoll.bind( + this, + networkClientId, + options, + ); + blockTracker.addListener('latest', updateOnNewBlock); + this.#activeListeners[key] = updateOnNewBlock; + return; + } + + throw new Error(` + Unable to retrieve blockTracker for networkClientId ${networkClientId} `); + } + + // if we're not polling on new blocks, use setTimeout const interval = this.#intervalIds[key]; if (interval) { clearTimeout(interval);