Skip to content

Commit

Permalink
Refactor upscale function into a generator
Browse files Browse the repository at this point in the history
  • Loading branch information
thekevinscott committed Mar 3, 2022
1 parent 151d9e5 commit 559a50d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
9 changes: 5 additions & 4 deletions packages/upscalerjs/src/upscale.test.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { tf } from './dependencies.generated';
import upscale, {
import {
predict,
getRowsAndColumns,
getTensorDimensions,
getCopyOfInput,
getProcessedPixels,
concatTensors,
upscale,
WARNING_PROGRESS_WITHOUT_PATCH_SIZE,
WARNING_UNDEFINED_PADDING,
} from './upscale';
Expand Down Expand Up @@ -1408,7 +1409,7 @@ describe('upscale', () => {
predict: jest.fn(() => tf.ones([1, 2, 2, 3,])),
} as unknown as tf.LayersModel;
(mockedTensorAsBase as any).default = async() => 'foobarbaz';
const result = await upscale(model, img, { scale: 2, } as IModelDefinition);
const result = await wrapGen(upscale(model, img, { scale: 2, } as IModelDefinition));
expect(result).toEqual('foobarbaz');
});

Expand All @@ -1429,9 +1430,9 @@ describe('upscale', () => {
predict: jest.fn(() => upscaledTensor),
} as unknown as tf.LayersModel;
(mockedTensorAsBase as any).default = async() => 'foobarbaz';
const result = await upscale(model, img, { scale: 2, } as IModelDefinition, {
const result = await wrapGen(upscale(model, img, { scale: 2, } as IModelDefinition, {
output: 'tensor',
});
}));
if (typeof result === 'string') {
throw new Error('Unexpected string type');
}
Expand Down
27 changes: 24 additions & 3 deletions packages/upscalerjs/src/upscale.ts
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,12 @@ export function getProcessedPixels<T extends tf.Tensor3D | tf.Tensor4D>(
// what input is in which format
export const getCopyOfInput = (input: GetImageAsTensorInput) => isTensor(input) ? input.clone() : input;

async function upscale<P extends Progress<O, PO>, O extends ReturnType = 'src', PO extends ReturnType = undefined>(
export async function* upscale<P extends Progress<O, PO>, O extends ReturnType = 'src', PO extends ReturnType = undefined>(
model: tf.LayersModel,
input: GetImageAsTensorInput,
modelDefinition: IModelDefinition,
options: IUpscaleOptions<P, O, PO> = {},
): Promise<UpscaleResponse<O>> {
): AsyncGenerator<UpscaleResponse<O>> {
const parsedInput = getCopyOfInput(input);
const startingPixels = await getImageAsTensor(parsedInput);

Expand Down Expand Up @@ -391,4 +391,25 @@ async function upscale<P extends Progress<O, PO>, O extends ReturnType = 'src',
return <UpscaleResponse<O>>base64Src;
};

export default upscale;
async function wrappedUpscale<P extends Progress<O, PO>, O extends ReturnType = 'src', PO extends ReturnType = undefined>(
model: tf.LayersModel,
input: GetImageAsTensorInput,
modelDefinition: IModelDefinition,
options: IUpscaleOptions<P, O, PO> = {},
): Promise<UpscaleResponse<O>> {
const gen = upscale(
model,
input,
modelDefinition,
options,
);
let { value: upscaledPixels, done } = await gen.next();
while (done === false) {
const genResult = await gen.next();
upscaledPixels = genResult.value;
done = genResult.done;
}
return upscaledPixels;
}

export default wrappedUpscale;
4 changes: 2 additions & 2 deletions packages/upscalerjs/src/upscaler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
} from './types';
import loadModel, { getModelDefinitions, } from './loadModel';
import warmup from './warmup';
import upscale from './upscale';
import upscaleImage from './upscale';
import type { GetImageAsTensorInput, } from './image.generated';

class Upscaler {
Expand Down Expand Up @@ -42,7 +42,7 @@ class Upscaler {
options: IUpscaleOptions<P, O, PO> = {},
) => {
const { model, modelDefinition, } = await this._model;
return upscale(model, image, modelDefinition, options);
return upscaleImage(model, image, modelDefinition, options);
};

getModelDefinitions = async () => {
Expand Down

0 comments on commit 559a50d

Please sign in to comment.