Skip to content

Commit

Permalink
Fixed PollingController to be a mixin so it can be used with both V1 …
Browse files Browse the repository at this point in the history
…and V2 controllers
  • Loading branch information
shanejonas committed Sep 27, 2023
1 parent ff9e2b3 commit ddf7c70
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 135 deletions.
21 changes: 6 additions & 15 deletions packages/polling-controller/src/PollingController.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ControllerMessenger } from '@metamask/base-controller';

import type { PollingCompleteType } from './PollingController';
import PollingController from './PollingController';
import { PollingController } from './PollingController';

const TICK_TIME = 1000;

Expand All @@ -27,7 +27,6 @@ describe('PollingController', () => {
metadata: {},
name: 'PollingController',
state: { foo: 'bar' },
pollingIntervalLength: TICK_TIME,
});
controller.start('mainnet');
jest.advanceTimersByTime(TICK_TIME);
Expand All @@ -48,7 +47,6 @@ describe('PollingController', () => {
metadata: {},
name: 'PollingController',
state: { foo: 'bar' },
pollingIntervalLength: TICK_TIME,
});
const pollingToken = controller.start('mainnet');
jest.advanceTimersByTime(TICK_TIME);
Expand All @@ -69,7 +67,6 @@ describe('PollingController', () => {
metadata: {},
name: 'PollingController',
state: { foo: 'bar' },
pollingIntervalLength: TICK_TIME,
});
const pollingToken1 = controller.start('mainnet');
controller.start('mainnet');
Expand All @@ -93,7 +90,6 @@ describe('PollingController', () => {
metadata: {},
name: 'PollingController',
state: { foo: 'bar' },
pollingIntervalLength: TICK_TIME,
});
controller.start('mainnet');
expect(() => {
Expand All @@ -113,7 +109,6 @@ describe('PollingController', () => {
metadata: {},
name: 'PollingController',
state: { foo: 'bar' },
pollingIntervalLength: TICK_TIME,
});
controller.start('mainnet');
expect(() => {
Expand All @@ -136,7 +131,6 @@ describe('PollingController', () => {
metadata: {},
name: 'PollingController',
state: { foo: 'bar' },
pollingIntervalLength: TICK_TIME,
});
controller.start('mainnet');
jest.advanceTimersByTime(TICK_TIME);
Expand All @@ -158,7 +152,6 @@ describe('PollingController', () => {
metadata: {},
name: 'PollingController',
state: { foo: 'bar' },
pollingIntervalLength: TICK_TIME,
});
controller.start('mainnet');
controller.start('mainnet');
Expand All @@ -182,20 +175,19 @@ describe('PollingController', () => {
PollingCompleteType<typeof name>
>();

mockMessenger.subscribe(`${name}:pollingComplete`, pollingComplete);

Check failure on line 178 in packages/polling-controller/src/PollingController.test.ts

View workflow job for this annotation

GitHub Actions / Lint, build, and test / Lint (20.x)

Delete `⏎`
const controller = new MyGasFeeController({
messenger: mockMessenger,
metadata: {},
name,
state: { foo: 'bar' },
pollingIntervalLength: TICK_TIME,
});
controller.onPollingComplete(pollingComplete);
const pollingToken = controller.start('mainnet');
controller.stop(pollingToken);
expect(pollingComplete).toHaveBeenCalledTimes(1);
});
it('should poll at the interval length passed via the constructor', async () => {
it('should poll at the interval length when set via setIntervalLength', async () => {
jest.useFakeTimers();

class MyGasFeeController extends PollingController<any, any, any> {
Expand All @@ -208,8 +200,8 @@ describe('PollingController', () => {
metadata: {},
name: 'PollingController',
state: { foo: 'bar' },
pollingIntervalLength: TICK_TIME * 3,
});
controller.setIntervalLength(TICK_TIME * 3);
controller.start('mainnet');
jest.advanceTimersByTime(TICK_TIME);
await Promise.resolve();
Expand Down Expand Up @@ -238,7 +230,6 @@ describe('PollingController', () => {
metadata: {},
name: 'PollingController',
state: { foo: 'bar' },
pollingIntervalLength: TICK_TIME,
});
controller.start('mainnet');
controller.start('rinkeby');
Expand All @@ -259,7 +250,7 @@ describe('PollingController', () => {
controller.stopAll();
});

it('should poll multiple networkClientIds at the interval length passed via the constructor', async () => {
it('should poll multiple networkClientIds when setting interval length', async () => {
jest.useFakeTimers();

class MyGasFeeController extends PollingController<any, any, any> {
Expand All @@ -272,8 +263,8 @@ describe('PollingController', () => {
metadata: {},
name: 'PollingController',
state: { foo: 'bar' },
pollingIntervalLength: TICK_TIME * 2,
});
controller.setIntervalLength(TICK_TIME * 2);
controller.start('mainnet');
jest.advanceTimersByTime(TICK_TIME);
await Promise.resolve();
Expand Down
226 changes: 106 additions & 120 deletions packages/polling-controller/src/PollingController.ts
Original file line number Diff line number Diff line change
@@ -1,150 +1,136 @@
import { BaseControllerV2 } from '@metamask/base-controller';
import type {
RestrictedControllerMessenger,
StateMetadata,
} from '@metamask/base-controller';
import { BaseController, BaseControllerV2 } from '@metamask/base-controller';
import type { NetworkClientId } from '@metamask/network-controller';
import type { Json } from '@metamask/utils';
import { v4 as random } from 'uuid';

export type PollingCompleteType<N extends string> = {
type: `${N}:pollingComplete`;
payload: [string];
};

type Constructor = new (...args: any[]) => {};

/**
* PollingController is an abstract class that implements the polling
* functionality for a controller. It is meant to be extended by a controller
* that needs to poll for data by networkClientId.
* PollingControllerMixin
*
* @param Base - The base class to mix onto.
* @returns The mixin.
*/
export default abstract class PollingController<
Name extends string,
State extends Record<string, Json>,
messenger extends RestrictedControllerMessenger<
Name,
any,
PollingCompleteType<Name> | any,
string,
string
>,
> extends BaseControllerV2<Name, State, messenger> {
readonly #intervalLength: number;
function PollingControllerMixin<TBase extends Constructor>(Base: TBase) {
/**
* PollingController is an abstract class that implements the polling
* functionality for a controller. It is meant to be extended by a controller
* that needs to poll for data by networkClientId.
*
*/
abstract class PollingControllerBase extends Base {
readonly #networkClientIdTokensMap: Map<NetworkClientId, Set<string>> =
new Map();

private readonly networkClientIdTokensMap: Map<NetworkClientId, Set<string>> =
new Map();
readonly #intervalIds: Record<NetworkClientId, NodeJS.Timeout> = {};

private readonly intervalIds: Record<NetworkClientId, NodeJS.Timeout> = {};
#callbacks: Set<(networkClientId: NetworkClientId) => void> = new Set();

constructor({
name,
state,
messenger,
metadata,
pollingIntervalLength,
}: {
name: Name;
state: State;
metadata: StateMetadata<State>;
messenger: messenger;
pollingIntervalLength: number;
}) {
super({
name,
state,
messenger,
metadata,
});
#intervalLength = 1000;

if (!pollingIntervalLength) {
throw new Error('pollingIntervalLength required for PollingController');
getIntervalLength() {
return this.#intervalLength;
}

this.#intervalLength = pollingIntervalLength;
}
setIntervalLength(length: number) {
this.#intervalLength = length;
}

/**
* Starts polling for a networkClientId
*
* @param networkClientId - The networkClientId to start polling for
* @returns void
*/
start(networkClientId: NetworkClientId) {
const innerPollToken = random();
if (this.networkClientIdTokensMap.has(networkClientId)) {
const set = this.networkClientIdTokensMap.get(networkClientId);
set?.add(innerPollToken);
} else {
const set = new Set<string>();
set.add(innerPollToken);
this.networkClientIdTokensMap.set(networkClientId, set);
/**
* Starts polling for a networkClientId
*
* @param networkClientId - The networkClientId to start polling for
* @returns void
*/
start(networkClientId: NetworkClientId) {
const innerPollToken = random();
if (this.#networkClientIdTokensMap.has(networkClientId)) {
const set = this.#networkClientIdTokensMap.get(networkClientId);
set?.add(innerPollToken);
} else {
const set = new Set<string>();
set.add(innerPollToken);
this.#networkClientIdTokensMap.set(networkClientId, set);
}
this.#poll(networkClientId);
return innerPollToken;
}
this.#poll(networkClientId);
return innerPollToken;
}

/**
* Stops polling for all networkClientIds
*/
stopAll() {
this.networkClientIdTokensMap.forEach((tokens, _networkClientId) => {
tokens.forEach((token) => {
this.stop(token);
/**
* Stops polling for all networkClientIds
*/
stopAll() {
this.#networkClientIdTokensMap.forEach((tokens, _networkClientId) => {
tokens.forEach((token) => {
this.stop(token);
});
});
});
}

/**
* Stops polling for a networkClientId
*
* @param pollingToken - The polling token to stop polling for
*/
stop(pollingToken: string) {
if (!pollingToken) {
throw new Error('pollingToken required');
}
let found = false;
this.networkClientIdTokensMap.forEach((tokens, networkClientId) => {
if (tokens.has(pollingToken)) {
found = true;
this.networkClientIdTokensMap
.get(networkClientId)
?.delete(pollingToken);
if (this.networkClientIdTokensMap.get(networkClientId)?.size === 0) {
clearTimeout(this.intervalIds[networkClientId]);
delete this.intervalIds[networkClientId];
this.networkClientIdTokensMap.delete(networkClientId);
this.messagingSystem.publish(
`${this.name}:pollingComplete`,
networkClientId,
);

/**
* Stops polling for a networkClientId
*
* @param pollingToken - The polling token to stop polling for
*/
stop(pollingToken: string) {
if (!pollingToken) {
throw new Error('pollingToken required');
}
let found = false;
this.#networkClientIdTokensMap.forEach((tokens, networkClientId) => {
if (tokens.has(pollingToken)) {
found = true;
this.#networkClientIdTokensMap
.get(networkClientId)
?.delete(pollingToken);
if (this.#networkClientIdTokensMap.get(networkClientId)?.size === 0) {
clearTimeout(this.#intervalIds[networkClientId]);
delete this.#intervalIds[networkClientId];
this.#networkClientIdTokensMap.delete(networkClientId);
this.#callbacks.forEach((callback) => {
callback(networkClientId);
});
this.#callbacks.clear();
}
}
});
if (!found) {
throw new Error('pollingToken not found');
}
});
if (!found) {
throw new Error('pollingToken not found');
}
}

/**
* Executes the poll for a networkClientId
*
* @param networkClientId - The networkClientId to execute the poll for
*/
abstract executePoll(networkClientId: NetworkClientId): Promise<void>;
/**
* Executes the poll for a networkClientId
*
* @param networkClientId - The networkClientId to execute the poll for
*/
abstract executePoll(networkClientId: NetworkClientId): Promise<void>;

#poll(networkClientId: NetworkClientId) {
if (this.intervalIds[networkClientId]) {
clearTimeout(this.intervalIds[networkClientId]);
delete this.intervalIds[networkClientId];
}
this.intervalIds[networkClientId] = setTimeout(async () => {
try {
await this.executePoll(networkClientId);
} catch (error) {
console.error(error);
#poll(networkClientId: NetworkClientId) {
if (this.#intervalIds[networkClientId]) {
clearTimeout(this.#intervalIds[networkClientId]);
delete this.#intervalIds[networkClientId];
}
this.#poll(networkClientId);
}, this.#intervalLength);
this.#intervalIds[networkClientId] = setTimeout(async () => {
try {
await this.executePoll(networkClientId);
} catch (error) {
console.error(error);
}
this.#poll(networkClientId);
}, this.#intervalLength);
}

onPollingComplete(callback: (networkClientId: NetworkClientId) => void) {
this.#callbacks.add(callback);
}
}
return PollingControllerBase;
}

export const PollingController = PollingControllerMixin(BaseControllerV2);
export const PollingControllerV1 = PollingControllerMixin(BaseController);

0 comments on commit ddf7c70

Please sign in to comment.