Skip to content

Commit

Permalink
Separate image.ts into separate files and add it to the generation sc…
Browse files Browse the repository at this point in the history
…ript (#192)
  • Loading branch information
Kevin Scott authored Feb 3, 2022
1 parent 6972db9 commit 66c041a
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 9 deletions.
1 change: 1 addition & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ ignore:
- "scripts"
- "test"
- "packages/test-scaffolding"
- "*.generated.ts"
4 changes: 4 additions & 0 deletions packages/upscalerjs/jestconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
"testEnvironment": "node",
"setupFiles": ["./jest.setup.ts", "jest-canvas-mock"],
"collectCoverage": true,
"coveragePathIgnorePatterns": [
"node_modules",
".generated.ts"
],
"testRegex": "(/__tests__/.*|(\\.|/)(test|spec))\\.(jsx?|tsx?)$",
"moduleFileExtensions": ["ts", "tsx", "js", "jsx", "json", "node"]
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { JSDOM } from 'jsdom';
import { getImageAsPixels, getUnknownError, getInvalidTensorError } from './image';
import { getImageAsPixels, getUnknownError, getInvalidTensorError } from './image.browser';
import * as tf from '@tensorflow/tfjs';
jest.mock('@tensorflow/tfjs', () => {
const tf = jest.requireActual('@tensorflow/tfjs');
Expand Down
File renamed without changes.
65 changes: 65 additions & 0 deletions packages/upscalerjs/src/image.generated.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { tf } from './dependencies.generated';
import { isHTMLImageElement, isString, isFourDimensionalTensor, isThreeDimensionalTensor, isTensor } from './utils';

export const getUnknownError = (input: any) => new Error(
[
`Unknown input provided to loadImage that cannot be processed: ${JSON.stringify(input)}`,
`Can only handle a string pointing to a valid image resource, an HTMLImageElement element,`,
`or a 3 or 4 rank tensor.`,
].join(' '),
);

export const getInvalidTensorError = (input: tf.Tensor) => new Error(
[
`Unsupported dimensions for incoming pixels: ${input.shape.length}.`,
'Only 3 or 4 rank tensors are supported.',
].join(' '),
);

export const getImageAsPixels = async (
pixels: string | HTMLImageElement | tf.Tensor,
): Promise<{
tensor: tf.Tensor4D;
type: 'string' | 'HTMLImageElement' | 'tensor';
}> => {
if (isString(pixels)) {
const img = await new Promise<HTMLImageElement>((resolve, reject) => {
const img = new Image();
img.src = pixels;
img.crossOrigin = 'anonymous';
img.onload = () => resolve(img);
img.onerror = reject;
});
return {
tensor: tf.browser.fromPixels(img).expandDims(0) as tf.Tensor4D,
type: 'string',
};
}

if (isHTMLImageElement(pixels)) {
return {
tensor: tf.browser.fromPixels(pixels).expandDims(0) as tf.Tensor4D,
type: 'HTMLImageElement',
};
}

if (isTensor(pixels)) {
if (isFourDimensionalTensor(pixels)) {
return {
tensor: pixels,
type: 'tensor',
};
}

if (isThreeDimensionalTensor(pixels)) {
return {
tensor: pixels.expandDims(0) as tf.Tensor4D,
type: 'tensor',
};
}

throw getInvalidTensorError(pixels);
}

throw getUnknownError(pixels);
};
65 changes: 65 additions & 0 deletions packages/upscalerjs/src/image.node.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { tf } from './dependencies.generated';
import { isHTMLImageElement, isString, isFourDimensionalTensor, isThreeDimensionalTensor, isTensor } from './utils';

export const getUnknownError = (input: any) => new Error(
[
`Unknown input provided to loadImage that cannot be processed: ${JSON.stringify(input)}`,
`Can only handle a string pointing to a valid image resource, an HTMLImageElement element,`,
`or a 3 or 4 rank tensor.`,
].join(' '),
);

export const getInvalidTensorError = (input: tf.Tensor) => new Error(
[
`Unsupported dimensions for incoming pixels: ${input.shape.length}.`,
'Only 3 or 4 rank tensors are supported.',
].join(' '),
);

export const getImageAsPixels = async (
pixels: string | HTMLImageElement | tf.Tensor,
): Promise<{
tensor: tf.Tensor4D;
type: 'string' | 'HTMLImageElement' | 'tensor';
}> => {
if (isString(pixels)) {
const img = await new Promise<HTMLImageElement>((resolve, reject) => {
const img = new Image();
img.src = pixels;
img.crossOrigin = 'anonymous';
img.onload = () => resolve(img);
img.onerror = reject;
});
return {
tensor: tf.browser.fromPixels(img).expandDims(0) as tf.Tensor4D,
type: 'string',
};
}

if (isHTMLImageElement(pixels)) {
return {
tensor: tf.browser.fromPixels(pixels).expandDims(0) as tf.Tensor4D,
type: 'HTMLImageElement',
};
}

if (isTensor(pixels)) {
if (isFourDimensionalTensor(pixels)) {
return {
tensor: pixels,
type: 'tensor',
};
}

if (isThreeDimensionalTensor(pixels)) {
return {
tensor: pixels.expandDims(0) as tf.Tensor4D,
type: 'tensor',
};
}

throw getInvalidTensorError(pixels);
}

throw getUnknownError(pixels);
};
6 changes: 3 additions & 3 deletions packages/upscalerjs/src/upscale.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import upscale, {
getRowsAndColumns,
getTensorDimensions,
} from './upscale';
jest.mock('./image');
jest.mock('tensor-as-base64');
import * as tensorAsBase from 'tensor-as-base64';
import * as image from './image';
import * as image from './image.generated';
import { IModelDefinition } from './types';
jest.mock('./image.generated');
jest.mock('tensor-as-base64');

describe('getConsistentTensorDimensions', () => {
interface IOpts {
Expand Down
2 changes: 1 addition & 1 deletion packages/upscalerjs/src/upscale.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { tf } from './dependencies.generated';
import { IUpscaleOptions, IModelDefinition, ProcessFn } from './types';
import { getImageAsPixels } from './image';
import { getImageAsPixels } from './image.generated';
import tensorAsBase64 from 'tensor-as-base64';
import { warn } from './utils';

Expand Down
22 changes: 18 additions & 4 deletions scripts/scaffold-platform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,29 @@ const getAdditionalDependencies = (platform: Platform): Array<string> => {
];
}

const SRC = path.resolve(__dirname, `../packages/upscalerjs/src`);

const platform = getPlatform(process.argv.pop());
const dependency = getDependency(platform);

const writeFile = (filename: string, content: Array<string>) => {
const outputPath = path.resolve(__dirname, `../packages/upscalerjs/src/${filename}`);
fs.writeFileSync(outputPath, `${content.map(l => l.trim()).join('\n')}\n`);
const writeFile = (filename: string, content: string) => {
const outputPath = path.resolve(SRC, filename);
fs.writeFileSync(outputPath, content);
};

writeFile('./dependencies.generated.ts', [
const writeLines = (filename: string, content: Array<string>) => writeFile(filename, `${content.map(l => l.trim()).join('\n')}\n`);

writeLines('./dependencies.generated.ts', [
`export * as tf from '${dependency}';`,
...getAdditionalDependencies(platform),
]);

const getImagePath = (platform: Platform) => {
if (platform === 'browser') {
return `image.browser.ts`;
}

return `image.node.ts`;
}

writeFile('./image.generated.ts', fs.readFileSync(path.resolve(SRC, getImagePath(platform))));

0 comments on commit 66c041a

Please sign in to comment.