Skip to content

Commit

Permalink
Add support for using the TensorFlow.js WebGPU backend
Browse files Browse the repository at this point in the history
  • Loading branch information
reillyeon committed Nov 11, 2023
1 parent 67f3db6 commit 7efa458
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 21 deletions.
25 changes: 17 additions & 8 deletions common/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,14 @@ export function getMedianValue(array) {
// Set tf.js backend based WebNN's 'MLDeviceType' option
export async function setPolyfillBackend(device) {
// Simulate WebNN's device selection using various tf.js backends.
// MLDeviceType: ['default', 'gpu', 'cpu']
// 'default' or 'gpu': tfjs-backend-webgl, 'cpu': tfjs-backend-wasm
if (!device) device = 'gpu';
// MLDeviceType: ['default', 'webgl', 'webgpu', 'cpu']
// 'default' or 'webgl': tfjs-backend-webgl, 'webgpu': tfjs-backend-webgpu,
// 'cpu': tfjs-backend-wasm
if (!device) device = 'webgpu';
// Use 'webgl' by default for better performance.
// Note: 'wasm' backend may run failed on some samples since
// some ops aren't supported on 'wasm' backend at present
const backend = device === 'cpu' ? 'wasm' : 'webgl';
const backend = device === 'cpu' ? 'wasm' : device;
const context = await navigator.ml.createContext();
const tf = context.tf;
if (tf) {
Expand All @@ -221,8 +222,8 @@ export async function setPolyfillBackend(device) {
throw new Error(`Failed to set tf.js backend ${backend}.`);
}
await tf.ready();
let backendInfo = backend == 'wasm' ? 'WASM' : 'WebGL';
if (backendInfo == 'WASM') {
let backendInfo = tf.getBackend();
if (backendInfo == 'wasm') {
const hasSimd = tf.env().features['WASM_HAS_SIMD_SUPPORT'];
const hasThreads = tf.env().features['WASM_HAS_MULTITHREAD_SUPPORT'];
if (hasThreads && hasSimd) {
Expand All @@ -239,6 +240,13 @@ export async function setPolyfillBackend(device) {
`WebNN-polyfill</a> with tf.js ${tf.version_core} ` +
`<b>${backendInfo}</b> backend.`, 'info');
}
switch (device) {
case 'webgl':
case 'webgpu':
return 'gpu';
default:
return 'cpu';
}
}

// Get url params
Expand Down Expand Up @@ -304,7 +312,7 @@ export async function setBackend(backend, device) {
// Create WebNN-polyfill script
await loadScript(webnnPolyfillUrl, webnnPolyfillId);
}
await setPolyfillBackend(device);
return await setPolyfillBackend(device);
} else if (backend === 'webnn') {
// For Electron
if (isElectron()) {
Expand All @@ -326,8 +334,9 @@ export async function setBackend(backend, device) {
addAlert(`WebNN is not supported!`, 'warning');
}
}
return device;
} else {
addAlert(`Unknow backend: ${backend}`, 'warning');
addAlert(`Unknown backend: ${backend}`, 'warning');
}
}

Expand Down
5 changes: 4 additions & 1 deletion face_recognition/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 4 additions & 1 deletion facial_landmark_detection/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 4 additions & 1 deletion image_classification/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 4 additions & 1 deletion lenet/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 4 additions & 1 deletion nsnet2/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
4 changes: 2 additions & 2 deletions nsnet2/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ browseButton.onclick = () => {

export async function main() {
try {
const [backend, deviceType] =
let [backend, deviceType] =
$('input[name="backend"]:checked').attr('id').split('_');
await setBackend(backend, deviceType);
deviceType = await setBackend(backend, deviceType);
// Handle frames parameter.
const searchParams = new URLSearchParams(location.search);
let frames = parseInt(searchParams.get('frames'));
Expand Down
5 changes: 4 additions & 1 deletion object_detection/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 4 additions & 1 deletion rnnoise/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 3 additions & 2 deletions rnnoise/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,14 @@ export async function main() {
try {
const [backend, deviceType] =
$('input[name="backend"]:checked').attr('id').split('_');
await utils.setBackend(backend, deviceType);
const contextOptions = {};
contextOptions['deviceType'] =
await utils.setBackend(backend, deviceType);
modelInfo.innerHTML = '';
await log(modelInfo, `Creating RNNoise with input shape ` +
`[${batchSize} (batch_size) x 100 (frames) x 42].`, true);
await log(modelInfo, '- Loading model...');
const powerPreference = utils.getUrlParams()[1];
const contextOptions = {deviceType};
if (powerPreference) {
contextOptions['powerPreference'] = powerPreference;
}
Expand Down
5 changes: 4 additions & 1 deletion semantic_segmentation/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 4 additions & 1 deletion style_transfer/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down

0 comments on commit 7efa458

Please sign in to comment.