Skip to content

Commit

Permalink
Set up a generator wrap function (#244)
Browse files Browse the repository at this point in the history
* Set up a generator wrap function

* Add test for wrapGenerator function

* Strengthen tests around generator wrap

* Remove console logs

* Add test to cover console warn

* Add coverage to isAborted method
  • Loading branch information
thekevinscott authored Mar 4, 2022
1 parent 1dfc9e5 commit 1673e23
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 38 deletions.
34 changes: 12 additions & 22 deletions packages/upscalerjs/src/upscale.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
WARNING_PROGRESS_WITHOUT_PATCH_SIZE,
WARNING_UNDEFINED_PADDING,
} from './upscale';
import { wrapGenerator } from './utils';
import * as tensorAsBase from 'tensor-as-base64';
import * as image from './image.generated';
import { IModelDefinition, } from './types';
Expand Down Expand Up @@ -90,8 +91,7 @@ describe('getCopyOfInput', () => {
);
expect(getCopyOfInput(input)).not.toEqual(input);
});

})
});

describe('getConsistentTensorDimensions', () => {
interface IOpts {
Expand Down Expand Up @@ -1050,16 +1050,6 @@ describe('getRowsAndColumns', () => {
});
});

async function wrapGen<T>(gen: AsyncGenerator<T>) {
let { value: result, done } = await gen.next();
while (done === false) {
const genResult = await gen.next();
result = genResult.value;
done = genResult.done;
}
return result;
}

describe('predict', () => {
const origWarn = console.warn;
afterEach(() => {
Expand All @@ -1079,7 +1069,7 @@ describe('predict', () => {
const model = {
predict: jest.fn(() => pred),
} as unknown as tf.LayersModel;
const result = await wrapGen(
const result = await wrapGenerator(
predict(img.expandDims(0), {
}, { model, modelDefinition: { scale: 2, } as IModelDefinition })
);
Expand All @@ -1101,7 +1091,7 @@ describe('predict', () => {
return tf.fill([2, 2, 3,], pixel.dataSync()[0]).expandDims(0);
}),
} as unknown as tf.LayersModel;
const result = await wrapGen(predict(
const result = await wrapGenerator(predict(
img.expandDims(0),
{
patchSize: 1,
Expand Down Expand Up @@ -1158,7 +1148,7 @@ describe('predict', () => {
}),
} as unknown as tf.LayersModel;
const progress = jest.fn();
await wrapGen(
await wrapGenerator(
predict(img, {
patchSize,
padding: 0,
Expand Down Expand Up @@ -1187,7 +1177,7 @@ describe('predict', () => {
}),
} as unknown as tf.LayersModel;
const progress = jest.fn((_1: any, _2: any) => {});
await wrapGen(
await wrapGenerator(
predict(img, {
patchSize,
padding: 0,
Expand Down Expand Up @@ -1255,7 +1245,7 @@ describe('predict', () => {
]);
}
});
await wrapGen(
await wrapGenerator(
predict(img, {
patchSize,
padding: 0,
Expand Down Expand Up @@ -1332,7 +1322,7 @@ describe('predict', () => {
]);
}
});
await wrapGen(
await wrapGenerator(
predict(img, {
patchSize,
padding: 0,
Expand Down Expand Up @@ -1366,7 +1356,7 @@ describe('predict', () => {
.expandDims(0);
}),
} as unknown as tf.LayersModel;
await wrapGen(
await wrapGenerator(
predict(img, {
patchSize,
}, { model, modelDefinition: { scale, } as IModelDefinition })
Expand All @@ -1386,7 +1376,7 @@ describe('predict', () => {
.expandDims(0);
}),
} as unknown as tf.LayersModel;
await wrapGen(
await wrapGenerator(
predict(img, {
progress: () => { },
}, { model, modelDefinition: { scale, } as IModelDefinition })
Expand All @@ -1412,7 +1402,7 @@ describe('upscale', () => {
predict: jest.fn(() => tf.ones([1, 2, 2, 3,])),
} as unknown as tf.LayersModel;
(mockedTensorAsBase as any).default = async() => 'foobarbaz';
const result = await wrapGen(upscale(img, {}, { model, modelDefinition: { scale: 2, } as IModelDefinition, }));
const result = await wrapGenerator(upscale(img, {}, { model, modelDefinition: { scale: 2, } as IModelDefinition, }));
expect(result).toEqual('foobarbaz');
});

Expand All @@ -1433,7 +1423,7 @@ describe('upscale', () => {
predict: jest.fn(() => upscaledTensor),
} as unknown as tf.LayersModel;
(mockedTensorAsBase as any).default = async() => 'foobarbaz';
const result = await wrapGen(upscale(img, { output: 'tensor', }, { model, modelDefinition: { scale: 2, } as IModelDefinition, }));
const result = await wrapGenerator(upscale(img, { output: 'tensor', }, { model, modelDefinition: { scale: 2, } as IModelDefinition, }));
if (typeof result === 'string') {
throw new Error('Unexpected string type');
}
Expand Down
24 changes: 10 additions & 14 deletions packages/upscalerjs/src/upscale.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { tf, } from './dependencies.generated';
import type { UpscaleArgs, IModelDefinition, ProcessFn, ResultFormat, UpscaleResponse, Progress, MultiArgProgress, } from './types';
import { getImageAsTensor, } from './image.generated';
import tensorAsBase64 from 'tensor-as-base64';
import { warn, isTensor, isProgress, isMultiArgTensorProgress, isAborted, } from './utils';
import { wrapGenerator, warn, isTensor, isProgress, isMultiArgTensorProgress, isAborted, } from './utils';
import type { GetImageAsTensorInput, } from './image.generated';

export class AbortError extends Error {
Expand Down Expand Up @@ -384,7 +384,6 @@ export async function* upscale<P extends Progress<O, PO>, O extends ResultFormat
done = genResult.done;
yield upscaledPixels;
}

preprocessedPixels.dispose();

const postprocessedPixels = getProcessedPixels<tf.Tensor3D>(
Expand Down Expand Up @@ -412,16 +411,7 @@ export async function cancellableUpscale<P extends Progress<O, PO>, O extends Re
{ signal, ...args }: UpscaleArgs<P, O, PO>,
internalArgs: UpscaleInternalArgs,
): Promise<UpscaleResponse<O>> {
const gen = upscale(
input,
args,
internalArgs,
);
if (isAborted(signal)) {
throw new AbortError();
}
let result: IteratorResult<YieldedIntermediaryValue, UpscaleResponse<O>>;
for (result = await gen.next(); !result.done; result = await gen.next()) {
const tick = async (result?: YieldedIntermediaryValue) => {
await tf.nextFrame();
if (isAborted(signal)) {
if (isTensor(result)) {
Expand All @@ -430,6 +420,12 @@ export async function cancellableUpscale<P extends Progress<O, PO>, O extends Re
throw new AbortError();
}
}

return result.value;
await tick();
const upscaledPixels = await wrapGenerator(upscale(
input,
args,
internalArgs,
), tick);
await tick();
return upscaledPixels;
}
122 changes: 121 additions & 1 deletion packages/upscalerjs/src/utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,125 @@
import * as tf from '@tensorflow/tfjs';
import { isSingleArgProgress, isMultiArgTensorProgress, isString, isFourDimensionalTensor, isThreeDimensionalTensor, isTensor, } from './utils';
import { wrapGenerator, isSingleArgProgress, isMultiArgTensorProgress, isString, isFourDimensionalTensor, isThreeDimensionalTensor, isTensor, warn, isAborted, } from './utils';

describe('isAborted', () => {
it('handles an undefined signal', () => {
expect(isAborted()).toEqual(false);
});

it('handles a non-aborted signal', () => {
const controller = new AbortController();
expect(isAborted(controller.signal)).toEqual(false);
});

it('handles an aborted signal', () => {
const controller = new AbortController();
controller.abort();
expect(isAborted(controller.signal)).toEqual(true);
});
});

describe('warn', () => {
const origWarn = console.warn;
afterEach(() => {
console.warn = origWarn;
});

it('logs a string to console', () => {
const fn = jest.fn();
console.warn = fn;
warn('foo');
expect(fn).toHaveBeenCalledTimes(1);
expect(fn).toHaveBeenCalledWith('foo');
});

it('logs an array of strings to console', () => {
const fn = jest.fn();
console.warn = fn;
warn([
'foo',
'bar',
'baz'
]);
expect(fn).toHaveBeenCalledTimes(1);
expect(fn).toHaveBeenCalledWith('foo\nbar\nbaz');
});
});

describe('wrapGenerator', () => {
it('wraps a sync generator', async () => {
function* foo() {
yield 'foo';
yield 'bar';
return 'baz';
}

const result = await wrapGenerator(foo())
expect(result).toEqual('baz');
});

it('wraps an async generator', async () => {
async function* foo() {
yield 'foo';
yield 'bar';
return 'baz';
}

const result = await wrapGenerator(foo())
expect(result).toEqual('baz');
});

it('calls a callback function in the generator', async () => {
async function* foo() {
yield 'foo';
yield 'bar';
return 'baz';
}

const callback = jest.fn();

await wrapGenerator(foo(), callback);
expect(callback).toHaveBeenCalledTimes(2);
expect(callback).toHaveBeenCalledWith('foo');
expect(callback).toHaveBeenCalledWith('bar');
expect(callback).not.toHaveBeenCalledWith('baz');
});

it('accepts an async callback function', async () => {
async function* foo() {
yield 'foo';
yield 'bar';
return 'baz';
}

const callback = jest.fn(async () => {});
await wrapGenerator(foo(), callback);
expect(callback).toHaveBeenCalledTimes(2);
expect(callback).toHaveBeenCalledWith('foo');
expect(callback).toHaveBeenCalledWith('bar');
expect(callback).not.toHaveBeenCalledWith('baz');
});

it('should await the async callback function', (done) => {
async function* foo() {
yield 'foo';
yield 'bar';
return 'baz';
}

const wait = () => new Promise(resolve => setTimeout(resolve));
let called = 0;
const callback = jest.fn(async () => {
called++;
await wait();
if (called < 2) {
expect(callback).toHaveBeenCalledTimes(called);
} else if (called === 2) {
done();
}
});
wrapGenerator(foo(), callback);
}, 100);
});

describe('isSingleArgProgress', () => {
it('returns true for function', () => {
Expand Down
21 changes: 20 additions & 1 deletion packages/upscalerjs/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,23 @@ export const isMultiArgTensorProgress = (p: Progress<any, any>, output: ResultFo
return progressOutput === undefined && output === 'tensor' || progressOutput === 'tensor';
}

export const isAborted = (abortSignal?: AbortSignal) => !!abortSignal && abortSignal?.aborted;
export const isAborted = (abortSignal?: AbortSignal) => {
if (abortSignal) {
return abortSignal.aborted;
}
return false;
};

type PostNext<T = unknown> = ((value: T) => (void | Promise<void>));
export async function wrapGenerator<T = unknown, TReturn = any, TNext = unknown>(
gen: Generator<T, TReturn, TNext> | AsyncGenerator<T, TReturn, TNext>,
postNext?: PostNext<T>
): Promise<TReturn> {
let result: IteratorResult<T, TReturn>;
for (result = await gen.next(); !result.done; result = await gen.next()) {
if (postNext) {
await postNext(result.value);
}
}
return <TReturn>result.value;
}

0 comments on commit 1673e23

Please sign in to comment.