diff --git a/packages/upscalerjs/src/upscale.test.ts b/packages/upscalerjs/src/upscale.test.ts index b2e63bc28..4b6fc1db0 100644 --- a/packages/upscalerjs/src/upscale.test.ts +++ b/packages/upscalerjs/src/upscale.test.ts @@ -12,6 +12,7 @@ import { WARNING_PROGRESS_WITHOUT_PATCH_SIZE, WARNING_UNDEFINED_PADDING, } from './upscale'; +import { wrapGenerator } from './utils'; import * as tensorAsBase from 'tensor-as-base64'; import * as image from './image.generated'; import { IModelDefinition, } from './types'; @@ -90,8 +91,7 @@ describe('getCopyOfInput', () => { ); expect(getCopyOfInput(input)).not.toEqual(input); }); - -}) +}); describe('getConsistentTensorDimensions', () => { interface IOpts { @@ -1050,16 +1050,6 @@ describe('getRowsAndColumns', () => { }); }); -async function wrapGen(gen: AsyncGenerator) { - let { value: result, done } = await gen.next(); - while (done === false) { - const genResult = await gen.next(); - result = genResult.value; - done = genResult.done; - } - return result; -} - describe('predict', () => { const origWarn = console.warn; afterEach(() => { @@ -1079,7 +1069,7 @@ describe('predict', () => { const model = { predict: jest.fn(() => pred), } as unknown as tf.LayersModel; - const result = await wrapGen( + const result = await wrapGenerator( predict(img.expandDims(0), { }, { model, modelDefinition: { scale: 2, } as IModelDefinition }) ); @@ -1101,7 +1091,7 @@ describe('predict', () => { return tf.fill([2, 2, 3,], pixel.dataSync()[0]).expandDims(0); }), } as unknown as tf.LayersModel; - const result = await wrapGen(predict( + const result = await wrapGenerator(predict( img.expandDims(0), { patchSize: 1, @@ -1158,7 +1148,7 @@ describe('predict', () => { }), } as unknown as tf.LayersModel; const progress = jest.fn(); - await wrapGen( + await wrapGenerator( predict(img, { patchSize, padding: 0, @@ -1187,7 +1177,7 @@ describe('predict', () => { }), } as unknown as tf.LayersModel; const progress = jest.fn((_1: any, _2: any) => {}); - await wrapGen( + await wrapGenerator( predict(img, { patchSize, padding: 0, @@ -1255,7 +1245,7 @@ describe('predict', () => { ]); } }); - await wrapGen( + await wrapGenerator( predict(img, { patchSize, padding: 0, @@ -1332,7 +1322,7 @@ describe('predict', () => { ]); } }); - await wrapGen( + await wrapGenerator( predict(img, { patchSize, padding: 0, @@ -1366,7 +1356,7 @@ describe('predict', () => { .expandDims(0); }), } as unknown as tf.LayersModel; - await wrapGen( + await wrapGenerator( predict(img, { patchSize, }, { model, modelDefinition: { scale, } as IModelDefinition }) @@ -1386,7 +1376,7 @@ describe('predict', () => { .expandDims(0); }), } as unknown as tf.LayersModel; - await wrapGen( + await wrapGenerator( predict(img, { progress: () => { }, }, { model, modelDefinition: { scale, } as IModelDefinition }) @@ -1412,7 +1402,7 @@ describe('upscale', () => { predict: jest.fn(() => tf.ones([1, 2, 2, 3,])), } as unknown as tf.LayersModel; (mockedTensorAsBase as any).default = async() => 'foobarbaz'; - const result = await wrapGen(upscale(img, {}, { model, modelDefinition: { scale: 2, } as IModelDefinition, })); + const result = await wrapGenerator(upscale(img, {}, { model, modelDefinition: { scale: 2, } as IModelDefinition, })); expect(result).toEqual('foobarbaz'); }); @@ -1433,7 +1423,7 @@ describe('upscale', () => { predict: jest.fn(() => upscaledTensor), } as unknown as tf.LayersModel; (mockedTensorAsBase as any).default = async() => 'foobarbaz'; - const result = await wrapGen(upscale(img, { output: 'tensor', }, { model, modelDefinition: { scale: 2, } as IModelDefinition, })); + const result = await wrapGenerator(upscale(img, { output: 'tensor', }, { model, modelDefinition: { scale: 2, } as IModelDefinition, })); if (typeof result === 'string') { throw new Error('Unexpected string type'); } diff --git a/packages/upscalerjs/src/upscale.ts b/packages/upscalerjs/src/upscale.ts index ae647a9e9..08153e9b5 100644 --- a/packages/upscalerjs/src/upscale.ts +++ b/packages/upscalerjs/src/upscale.ts @@ -2,7 +2,7 @@ import { tf, } from './dependencies.generated'; import type { UpscaleArgs, IModelDefinition, ProcessFn, ResultFormat, UpscaleResponse, Progress, MultiArgProgress, } from './types'; import { getImageAsTensor, } from './image.generated'; import tensorAsBase64 from 'tensor-as-base64'; -import { warn, isTensor, isProgress, isMultiArgTensorProgress, isAborted, } from './utils'; +import { wrapGenerator, warn, isTensor, isProgress, isMultiArgTensorProgress, isAborted, } from './utils'; import type { GetImageAsTensorInput, } from './image.generated'; export class AbortError extends Error { @@ -384,7 +384,6 @@ export async function* upscale

, O extends ResultFormat done = genResult.done; yield upscaledPixels; } - preprocessedPixels.dispose(); const postprocessedPixels = getProcessedPixels( @@ -412,16 +411,7 @@ export async function cancellableUpscale

, O extends Re { signal, ...args }: UpscaleArgs, internalArgs: UpscaleInternalArgs, ): Promise> { - const gen = upscale( - input, - args, - internalArgs, - ); - if (isAborted(signal)) { - throw new AbortError(); - } - let result: IteratorResult>; - for (result = await gen.next(); !result.done; result = await gen.next()) { + const tick = async (result?: YieldedIntermediaryValue) => { await tf.nextFrame(); if (isAborted(signal)) { if (isTensor(result)) { @@ -430,6 +420,12 @@ export async function cancellableUpscale

, O extends Re throw new AbortError(); } } - - return result.value; + await tick(); + const upscaledPixels = await wrapGenerator(upscale( + input, + args, + internalArgs, + ), tick); + await tick(); + return upscaledPixels; } diff --git a/packages/upscalerjs/src/utils.test.ts b/packages/upscalerjs/src/utils.test.ts index db7a56d8e..fc596df2a 100644 --- a/packages/upscalerjs/src/utils.test.ts +++ b/packages/upscalerjs/src/utils.test.ts @@ -1,5 +1,125 @@ import * as tf from '@tensorflow/tfjs'; -import { isSingleArgProgress, isMultiArgTensorProgress, isString, isFourDimensionalTensor, isThreeDimensionalTensor, isTensor, } from './utils'; +import { wrapGenerator, isSingleArgProgress, isMultiArgTensorProgress, isString, isFourDimensionalTensor, isThreeDimensionalTensor, isTensor, warn, isAborted, } from './utils'; + +describe('isAborted', () => { + it('handles an undefined signal', () => { + expect(isAborted()).toEqual(false); + }); + + it('handles a non-aborted signal', () => { + const controller = new AbortController(); + expect(isAborted(controller.signal)).toEqual(false); + }); + + it('handles an aborted signal', () => { + const controller = new AbortController(); + controller.abort(); + expect(isAborted(controller.signal)).toEqual(true); + }); +}); + +describe('warn', () => { + const origWarn = console.warn; + afterEach(() => { + console.warn = origWarn; + }); + + it('logs a string to console', () => { + const fn = jest.fn(); + console.warn = fn; + warn('foo'); + expect(fn).toHaveBeenCalledTimes(1); + expect(fn).toHaveBeenCalledWith('foo'); + }); + + it('logs an array of strings to console', () => { + const fn = jest.fn(); + console.warn = fn; + warn([ + 'foo', + 'bar', + 'baz' + ]); + expect(fn).toHaveBeenCalledTimes(1); + expect(fn).toHaveBeenCalledWith('foo\nbar\nbaz'); + }); +}); + +describe('wrapGenerator', () => { + it('wraps a sync generator', async () => { + function* foo() { + yield 'foo'; + yield 'bar'; + return 'baz'; + } + + const result = await wrapGenerator(foo()) + expect(result).toEqual('baz'); + }); + + it('wraps an async generator', async () => { + async function* foo() { + yield 'foo'; + yield 'bar'; + return 'baz'; + } + + const result = await wrapGenerator(foo()) + expect(result).toEqual('baz'); + }); + + it('calls a callback function in the generator', async () => { + async function* foo() { + yield 'foo'; + yield 'bar'; + return 'baz'; + } + + const callback = jest.fn(); + + await wrapGenerator(foo(), callback); + expect(callback).toHaveBeenCalledTimes(2); + expect(callback).toHaveBeenCalledWith('foo'); + expect(callback).toHaveBeenCalledWith('bar'); + expect(callback).not.toHaveBeenCalledWith('baz'); + }); + + it('accepts an async callback function', async () => { + async function* foo() { + yield 'foo'; + yield 'bar'; + return 'baz'; + } + + const callback = jest.fn(async () => {}); + await wrapGenerator(foo(), callback); + expect(callback).toHaveBeenCalledTimes(2); + expect(callback).toHaveBeenCalledWith('foo'); + expect(callback).toHaveBeenCalledWith('bar'); + expect(callback).not.toHaveBeenCalledWith('baz'); + }); + + it('should await the async callback function', (done) => { + async function* foo() { + yield 'foo'; + yield 'bar'; + return 'baz'; + } + + const wait = () => new Promise(resolve => setTimeout(resolve)); + let called = 0; + const callback = jest.fn(async () => { + called++; + await wait(); + if (called < 2) { + expect(callback).toHaveBeenCalledTimes(called); + } else if (called === 2) { + done(); + } + }); + wrapGenerator(foo(), callback); + }, 100); +}); describe('isSingleArgProgress', () => { it('returns true for function', () => { diff --git a/packages/upscalerjs/src/utils.ts b/packages/upscalerjs/src/utils.ts index 33e1745bb..15621c788 100644 --- a/packages/upscalerjs/src/utils.ts +++ b/packages/upscalerjs/src/utils.ts @@ -50,4 +50,23 @@ export const isMultiArgTensorProgress = (p: Progress, output: ResultFo return progressOutput === undefined && output === 'tensor' || progressOutput === 'tensor'; } -export const isAborted = (abortSignal?: AbortSignal) => !!abortSignal && abortSignal?.aborted; +export const isAborted = (abortSignal?: AbortSignal) => { + if (abortSignal) { + return abortSignal.aborted; + } + return false; +}; + +type PostNext = ((value: T) => (void | Promise)); +export async function wrapGenerator( + gen: Generator | AsyncGenerator, + postNext?: PostNext +): Promise { + let result: IteratorResult; + for (result = await gen.next(); !result.done; result = await gen.next()) { + if (postNext) { + await postNext(result.value); + } + } + return result.value; +}