From 43b95b05120caaf471647daa3b3763944daa0ccc Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 31 Jan 2024 10:28:03 +0800 Subject: [PATCH] [js/webgpu] Support capture and replay for jsep (#18989) This PR expands the graph capture capability to JS EP, which is similar to #16081. But for JS EP, we don't use the CUDA Graph, instead, we records all gpu commands and replay them, which removes most of the cpu overhead to avoid the the situation that gpu waiting for cpu. mobilenetv2-12 becomes 3.7ms from 6ms on NV 3090 and becomes 3.38ms from 4.58ms on Intel A770. All limitations are similar with CUDA EP: 1. Models with control-flow ops (i.e. If, Loop and Scan ops) are not supported. 2. Usage of graph capture is limited to models where-in all ops in the model can be partitioned to the JS EP or CPU EP and no memory copy between them. 3. Shapes of inputs/outputs cannot change across inference calls. 4. IObinding is required. The usage is like below: Method 1: specify outputs buffers explicitly. ``` const sessionOptions = { executionProviders: [ { name: "webgpu", }, ], enableGraphCapture: true, }; const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions); // prepare the inputBuffer/outputBuffer ... ... const feeds = { 'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims }) }; const fetches = { 'output': ort.Tensor.fromGpuBuffer(outputBuffer, { dataType: 'float32', dims: [1, 1000] }) }; let results = await session.run(feeds, fetches); // The first run will begin to capture the graph. // update inputBuffer content ... ... results = = await session.run(feeds, fetches); // The 2ed run and after will directly call replay to execute the graph. ... ... session.release(); ``` Method 2: Don't specify outputs buffers explicitly. Internally, when graph capture is enabled, it will set all outputs location to 'gpu-buffer'. ``` const sessionOptions = { executionProviders: [ { name: "webgpu", }, ], enableGraphCapture: true, }; const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions); // prepare the inputBuffer ... ... const feeds = { 'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims }) }; let results = await session.run(feeds); // The first run will begin to capture the graph. // update inputBuffer content ... ... results = = await session.run(feeds); // The 2ed run and after will directly call replay to execute the graph. ... ... session.release(); --- js/common/lib/inference-session.ts | 8 +- js/web/lib/wasm/binding/ort-wasm.d.ts | 25 ++- js/web/lib/wasm/jsep/backend-webgpu.ts | 100 ++++++++++- js/web/lib/wasm/jsep/init.ts | 8 +- .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 74 ++++++-- .../lib/wasm/jsep/webgpu/program-manager.ts | 15 +- js/web/lib/wasm/jsep/webgpu/types.ts | 2 + js/web/lib/wasm/session-options.ts | 12 ++ js/web/lib/wasm/wasm-core-impl.ts | 166 +++++++++++------- .../providers/js/js_execution_provider.cc | 49 +++++- .../core/providers/js/js_execution_provider.h | 18 +- .../core/providers/js/js_provider_factory.cc | 11 +- .../js/js_provider_factory_creator.h | 4 +- onnxruntime/core/session/inference_session.cc | 66 ++++--- .../core/session/provider_registration.cc | 2 +- onnxruntime/wasm/js_internal_api.js | 15 +- 16 files changed, 439 insertions(+), 136 deletions(-) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 1221b52cd4985..4f85c3b46e253 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -111,7 +111,7 @@ export declare namespace InferenceSession { optimizedModelFilePath?: string; /** - * Wether enable profiling. + * Whether enable profiling. * * This setting is a placeholder for a future use. */ @@ -154,6 +154,12 @@ export declare namespace InferenceSession { */ preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation}; + /** + * Whether enable graph capture. + * This setting is available only in ONNXRuntime Web for WebGPU EP. + */ + enableGraphCapture?: boolean; + /** * Store configurations for a session. See * https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/ diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 24d7062c85fcb..5dd715191c830 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -13,6 +13,9 @@ export declare namespace JSEP { type ReleaseKernelFunction = (kernel: number) => void; type RunFunction = (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; + type CaptureBeginFunction = () => void; + type CaptureEndFunction = () => void; + type ReplayFunction = () => void; } export interface OrtWasmModule extends EmscriptenModule { @@ -128,7 +131,8 @@ export interface OrtWasmModule extends EmscriptenModule { jsepInit? (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, - releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; + releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction, + captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void; /** * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). @@ -158,12 +162,6 @@ export interface OrtWasmModule extends EmscriptenModule { * @returns the GPU data ID for the registered GPU buffer. */ jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; - /** - * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. - * - * @param sessionId - specify the session ID. - */ - jsepUnregisterBuffers?: (sessionId: number) => void; /** * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. * @@ -183,9 +181,18 @@ export interface OrtWasmModule extends EmscriptenModule { (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; /** - * [exported from js_internal_api.js] Called when InferenceSession.run started. + * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before + * _OrtRun[WithBinding]() is called. + * @param sessionId - specify the session ID. + */ + jsepOnRunStart: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is + * called. + * @param sessionId - specify the session ID. + * @returns */ - jsepOnRunStart: () => void; + jsepOnReleaseSession: (sessionId: number) => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a48fe99570abf..e1faecfc046e3 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -10,7 +10,14 @@ import {createView, TensorView} from './tensor-view'; import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; -import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, TimestampQuery} from './webgpu/types'; +import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; + +interface CommandInfo { + readonly kernelId: number; + readonly computePipeline: GPUComputePipeline; + readonly bindGroup: GPUBindGroup; + readonly dispatchGroup: [number, number, number]; +} interface KernelInfo { readonly kernelType: string; @@ -103,6 +110,13 @@ export class WebGpuBackend { */ programManager: ProgramManager; + /** + * representing the session ID of which is currently being run. + * `null` means no session is being run. + * only valid when session.run is executed. + */ + currentSessionId: number|null = null; + /** * representing the kernel ID of which is currently being computed (CPU code perspective). * `null` means no kernel is being computed. @@ -155,6 +169,16 @@ export class WebGpuBackend { queryType: TimestampQuery; env: Env; + sessionStatus: SessionState = 'default'; + /** + * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session. + */ + capturedCommandList: Map = new Map(); + + /** + * a SessionID -> PendingKernelInfo[] mapping for profiling. + */ + private capturedPendingKernels: Map = new Map(); /** * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. @@ -228,6 +252,7 @@ export class WebGpuBackend { getComputePassEncoder(): GPUComputePassEncoder { if (!this.computePassEncoder) { + const commandEncoder = this.getCommandEncoder(); const computePassDescriptor: GPUComputePassDescriptor = {}; if (this.queryType === 'at-passes') { @@ -238,7 +263,7 @@ export class WebGpuBackend { }; } - this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor); + this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor); } return this.computePassEncoder; } @@ -494,7 +519,7 @@ export class WebGpuBackend { () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); - if (this.queryType !== 'none') { + if (this.queryType !== 'none' || this.sessionStatus === 'capturing') { const pendingKernelInfo: PendingKernelInfo = { kernelId: this.currentKernelId!, programName: artifact.programInfo.name, @@ -502,6 +527,9 @@ export class WebGpuBackend { outputTensorViews, }; this.pendingKernels.push(pendingKernelInfo); + + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + sessionPendingKernels!.push(pendingKernelInfo); } this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding); @@ -672,7 +700,71 @@ export class WebGpuBackend { } } } - onRunStart(): void { + + captureBegin(): void { + LOG_DEBUG('info', 'captureBegin'); + let sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); + let sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + if (!sessionCommandList) { + sessionCommandList = []; + this.capturedCommandList.set(this.currentSessionId!, sessionCommandList); + sessionPendingKernels = []; + this.capturedPendingKernels.set(this.currentSessionId!, sessionPendingKernels); + } + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'capturing'; + } + captureEnd(): void { + LOG_DEBUG('info', 'captureEnd'); + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'default'; + } + replay(): void { + LOG_DEBUG('info', 'replay'); + this.sessionStatus = 'replaying'; + const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + const length = sessionCommandList!.length; + this.pendingKernels = []; + for (let i = 0; i < length; i++) { + const computePassEncoder = this.getComputePassEncoder(); + const command = sessionCommandList![i]; + this.writeTimestamp(this.pendingDispatchNumber * 2); + computePassEncoder.setPipeline(command.computePipeline); + computePassEncoder.setBindGroup(0, command.bindGroup); + computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); + this.pendingDispatchNumber++; + if (this.queryType !== 'none') { + this.pendingKernels.push(sessionPendingKernels![i]); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') { + this.endComputePass(); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber) { + this.flush(); + } + } + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'default'; + } + + onReleaseSession(sessionId: number): void { + this.unregisterBuffers(sessionId); + if (this.capturedCommandList.has(sessionId)) { + this.capturedCommandList.delete(sessionId); + } + if (this.capturedPendingKernels.has(sessionId)) { + this.capturedPendingKernels.delete(sessionId); + } + this.gpuDataManager.onReleaseSession(sessionId); + } + + onRunStart(sessionId: number): void { + this.currentSessionId = sessionId; this.setQueryType(); } } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index f1794d71579bf..786ae41646554 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -201,5 +201,11 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); return backend.computeKernel(kernel, context, errors); - }); + }, + // jsepCaptureBegin + () => backend.captureBegin(), + // jsepCaptureEnd + () => backend.captureEnd(), + // jsepReplay + () => backend.replay()); }; diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 6f3d9a52d9f5d..c17bd1e1477ec 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -60,9 +60,15 @@ export interface GpuDataManager { unregisterExternalBuffer(buffer: GPUBuffer): void; /** - * destroy all gpu buffers. Call this when the session.release is called. + * destroy all gpu buffers. */ dispose(): void; + + /** + * release session related data. + * @param sessionId - specify the session ID. + */ + onReleaseSession(sessionId: number): void; } interface StorageCacheValue { @@ -139,6 +145,10 @@ class GpuDataManagerImpl implements GpuDataManager { // The external buffers registered users for IO Binding. private externalBuffers: Map; + // The pendingBuffers for capture graph. + // a SessionID -> GPUBuffer[] mapping. + private capturedPendingBuffers: Map; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); @@ -146,6 +156,7 @@ class GpuDataManagerImpl implements GpuDataManager { this.buffersForUploadingPending = []; this.buffersPending = []; this.externalBuffers = new Map(); + this.capturedPendingBuffers = new Map(); } upload(id: GpuDataId, data: Uint8Array): void { @@ -220,6 +231,9 @@ class GpuDataManagerImpl implements GpuDataManager { () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ id}, buffer is the same, skip.`); return id; + } else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) { + throw new Error(`Registering a different external buffer under graph capture mode is not supported yet. + Please use the previous external buffer!`); } this.externalBuffers.delete(previousBuffer); } else { @@ -312,20 +326,39 @@ class GpuDataManagerImpl implements GpuDataManager { buffer.destroy(); } this.buffersForUploadingPending = []; - for (const buffer of this.buffersPending) { - // eslint-disable-next-line no-bitwise - if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { - // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. - this.freeBuffers.get(buffer.size)!.push(buffer); + + if (this.buffersPending.length === 0) { + return; + } + + if (this.backend.sessionStatus === 'default') { + for (const buffer of this.buffersPending) { // eslint-disable-next-line no-bitwise - } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { - // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. - this.freeUniformBuffers.get(buffer.size)!.push(buffer); - } else { - buffer.destroy(); + if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { + // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. + this.freeBuffers.get(buffer.size)!.push(buffer); + // eslint-disable-next-line no-bitwise + } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { + // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. + this.freeUniformBuffers.get(buffer.size)!.push(buffer); + } else { + buffer.destroy(); + } + } + this.buffersPending = []; + } else { + // Don't release intermediate tensors in non-default mode. + // TODO: reuse the storage buffers in non-default mode. + let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!); + if (!capturedBuffers) { + capturedBuffers = []; + this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers); } + for (const buffer of this.buffersPending) { + capturedBuffers.push(buffer); + } + this.buffersPending = []; } - this.buffersPending = []; } dispose() { @@ -344,9 +377,26 @@ class GpuDataManagerImpl implements GpuDataManager { storage.gpuData.buffer.destroy(); }); + this.capturedPendingBuffers.forEach((buffers) => { + buffers.forEach(buffer => { + buffer.destroy(); + }); + }); this.storageCache = new Map(); this.freeBuffers = new Map(); this.freeUniformBuffers = new Map(); + this.capturedPendingBuffers = new Map(); + } + + onReleaseSession(sessionId: number) { + // release the captured pending buffers. + const pendingBuffers = this.capturedPendingBuffers.get(sessionId); + if (pendingBuffers) { + pendingBuffers.forEach(buffer => { + buffer.destroy(); + }); + this.capturedPendingBuffers.delete(sessionId); + } } } diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 72eb9713e26a8..9d05f607f817f 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -38,7 +38,6 @@ export class ProgramManager { const device = this.backend.device; const computePassEncoder = this.backend.getComputePassEncoder(); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); - computePassEncoder.setPipeline(buildArtifact.computePipeline); const entries = []; for (const input of inputs) { entries.push({binding: entries.length, resource: {buffer: input.buffer}}); @@ -51,8 +50,20 @@ export class ProgramManager { } const bindGroup = device.createBindGroup( {layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name}); - computePassEncoder.setBindGroup(0, bindGroup); + if (this.backend.sessionStatus === 'capturing') { + const commandInfo = { + kernelId: this.backend.currentKernelId!, + computePipeline: buildArtifact.computePipeline, + bindGroup, + dispatchGroup + }; + const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); + sessionCommandList!.push(commandInfo); + } + + computePassEncoder.setPipeline(buildArtifact.computePipeline); + computePassEncoder.setBindGroup(0, bindGroup); computePassEncoder.dispatchWorkgroups(...dispatchGroup); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++; diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 789ac70a6913a..a34b6190b7244 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -5,6 +5,8 @@ import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; +export type SessionState = 'default'|'capturing'|'replaying'; + export enum GpuDataType { default = 0, upload = 1, diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 41ab2d52ca209..48eac57494726 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -168,6 +168,18 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); } + if (sessionOptions.enableGraphCapture !== undefined) { + if (typeof sessionOptions.enableGraphCapture !== 'boolean') { + throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`); + } + const keyDataOffset = allocWasmString('enableGraphCapture', allocs); + const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs); + if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError( + `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`); + } + } + if (sessionOptions.freeDimensionOverrides) { for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) { if (typeof name !== 'string') { diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 046336dc9cac0..37b9ed6a1002f 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -139,7 +139,7 @@ type IOBindingState = { */ type SessionMetadata = [ inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], - bindingState: IOBindingState|null + bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBound: boolean ]; const activeSessions = new Map(); @@ -235,6 +235,8 @@ export const createSession = async( const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); + const enableGraphCapture = !!options?.enableGraphCapture; + const inputNames = []; const outputNames = []; const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; @@ -256,12 +258,20 @@ export const createSession = async( outputNames.push(nameString); if (!BUILD_DEFS.DISABLE_WEBGPU) { + if (enableGraphCapture && options?.preferredOutputLocation === undefined) { + outputPreferredLocations.push('gpu-buffer'); + continue; + } const location = typeof options?.preferredOutputLocation === 'string' ? options.preferredOutputLocation : options?.preferredOutputLocation?.[nameString] ?? 'cpu'; if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { throw new Error(`Not supported preferred output location: ${location}.`); } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${ + location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`); + } outputPreferredLocations.push(location); } } @@ -281,7 +291,9 @@ export const createSession = async( }; } - activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); + activeSessions.set( + sessionHandle, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, enableGraphCapture, false]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -313,13 +325,16 @@ export const releaseSession = (sessionId: number): void => { if (!session) { throw new Error(`cannot release session. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session; if (ioBindingState) { + if (enableGraphCapture) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); + } wasm._OrtReleaseBinding(ioBindingState.handle); } - wasm.jsepUnregisterBuffers?.(sessionId); + wasm.jsepOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -328,70 +343,75 @@ export const releaseSession = (sessionId: number): void => { }; export const prepareInputOutputTensor = - (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): - void => { - if (!tensor) { - tensorHandles.push(0); - return; - } + (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number, + enableGraphCapture = false): void => { + if (!tensor) { + tensorHandles.push(0); + return; + } - const wasm = getInstance(); + const wasm = getInstance(); - const dataType = tensor[0]; - const dims = tensor[1]; - const location = tensor[3]; + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; - let rawData: number; - let dataByteLength: number; + let rawData: number; + let dataByteLength: number; - if (dataType === 'string' && location === 'gpu-buffer') { - throw new Error('String tensor is not supported on GPU.'); - } + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } - if (location === 'gpu-buffer') { - const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; - const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; - dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; - rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); - } else { - const data = tensor[2]; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - let dataIndex = rawData / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); - } - } else { - dataByteLength = data.byteLength; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); - } - } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error( + `External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`); + } - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, - dataLocationStringToEnum(location)); - if (tensor === 0) { - checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; + + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); } - tensorHandles.push(tensor); - } finally { - wasm.stackRestore(stack); + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); } - }; + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(location)); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } + }; /** * perform inference run @@ -404,7 +424,12 @@ export const run = async( if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const sessionHandle = session[0]; + const inputNamesUTF8Encoded = session[1]; + const outputNamesUTF8Encoded = session[2]; + const ioBindingState = session[3]; + const enableGraphCapture = session[4]; + const inputOutputBound = session[5]; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -427,13 +452,15 @@ export const run = async( // create input tensors for (let i = 0; i < inputCount; i++) { - prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + prepareInputOutputTensor( + inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], enableGraphCapture); } // create output tensors for (let i = 0; i < outputCount; i++) { prepareInputOutputTensor( - outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i], + enableGraphCapture); } let inputValuesIndex = inputValuesOffset / 4; @@ -449,7 +476,7 @@ export const run = async( wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState && !inputOutputBound) { const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; if (inputNamesUTF8Encoded.length !== inputCount) { @@ -486,9 +513,12 @@ export const run = async( } } } + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]); } - wasm.jsepOnRunStart?.(); + wasm.jsepOnRunStart?.(sessionHandle); let errorCode: number; if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( @@ -595,10 +625,12 @@ export const run = async( } } - if (ioBindingState) { + if (ioBindingState && !enableGraphCapture) { wasm._OrtClearBoundOutputs(ioBindingState.handle); + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, false]); } - return output; } finally { wasm.stackRestore(beforeRunStack); diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index af9658271d210..308e1c7d952d9 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -3,6 +3,7 @@ #include "js_execution_provider.h" +#include #include #include #include @@ -681,9 +682,13 @@ std::unique_ptr RegisterKernels() { using namespace js; -JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info) +JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options) : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true}, preferred_data_layout_{info.data_layout} { + if (session_options) { + enable_graph_capture_ = session_options->config_options.GetConfigOrDefault("enableGraphCapture", "false") == "true"; + LOGS_DEFAULT(VERBOSE) << "Graph capture enable: " << enable_graph_capture_; + } } std::vector JsExecutionProvider::CreatePreferredAllocators() { @@ -751,4 +756,46 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer JsExecutionProvider::~JsExecutionProvider() { } +Status JsExecutionProvider::OnRunStart() { + if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; + EM_ASM({ Module.jsepCaptureBegin(); }); + } + return Status::OK(); +} + +Status JsExecutionProvider::OnRunEnd(bool sync_stream) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { + if (IsGraphCaptureAllowed()) { + EM_ASM({ Module.jsepCaptureEnd(); }); + is_graph_captured_ = true; + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + + return Status::OK(); +} + +bool JsExecutionProvider::IsGraphCaptureEnabled() const { + return enable_graph_capture_; +} + +bool JsExecutionProvider::IsGraphCaptured() const { + return is_graph_captured_; +} + +Status JsExecutionProvider::ReplayGraph() { + ORT_ENFORCE(IsGraphCaptured()); + EM_ASM({ Module.jsepReplay(); }); + return Status::OK(); +} + +bool JsExecutionProvider::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +} + +void JsExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 39d43498c0717..91a3256ec2bd5 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -5,6 +5,7 @@ #pragma once #include "core/framework/execution_provider.h" +#include "core/framework/session_options.h" #include "core/graph/constants.h" #include "core/providers/providers.h" @@ -38,7 +39,7 @@ struct JsExecutionProviderInfo { class JsExecutionProvider : public IExecutionProvider { public: - JsExecutionProvider(const JsExecutionProviderInfo& info); + JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options); ~JsExecutionProvider() override; std::vector> GetCapability( @@ -57,7 +58,22 @@ class JsExecutionProvider : public IExecutionProvider { bool ConcurrentRunSupported() const override { return false; } std::vector CreatePreferredAllocators() override; + + Status OnRunStart() override; + Status OnRunEnd(bool sync_stream) override; + + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured() const override; + Status ReplayGraph() override; + + private: + bool IsGraphCaptureAllowed() const; + void IncrementRegularRunCountBeforeGraphCapture(); DataLayout preferred_data_layout_; + bool enable_graph_capture_ = false; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_provider_factory.cc b/onnxruntime/core/providers/js/js_provider_factory.cc index 5b7329a87cf6a..cbdf99f702150 100644 --- a/onnxruntime/core/providers/js/js_provider_factory.cc +++ b/onnxruntime/core/providers/js/js_provider_factory.cc @@ -10,21 +10,22 @@ namespace onnxruntime { struct JsProviderFactory : IExecutionProviderFactory { - JsProviderFactory(const ProviderOptions& provider_options) - : info_{provider_options} { + JsProviderFactory(const ProviderOptions& provider_options, const SessionOptions* session_options) + : info_{provider_options}, session_options_(session_options) { } std::unique_ptr CreateProvider() override { - return std::make_unique(info_); + return std::make_unique(info_, session_options_); } private: JsExecutionProviderInfo info_; + const SessionOptions* session_options_; }; std::shared_ptr JsProviderFactoryCreator::Create( - const ProviderOptions& provider_options) { - return std::make_shared(provider_options); + const ProviderOptions& provider_options, const SessionOptions* session_options) { + return std::make_shared(provider_options, session_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_provider_factory_creator.h b/onnxruntime/core/providers/js/js_provider_factory_creator.h index dbabe255c2d7b..510b0fb4248ca 100644 --- a/onnxruntime/core/providers/js/js_provider_factory_creator.h +++ b/onnxruntime/core/providers/js/js_provider_factory_creator.h @@ -9,9 +9,11 @@ #include "core/providers/providers.h" namespace onnxruntime { +struct SessionOptions; struct JsProviderFactoryCreator { - static std::shared_ptr Create(const ProviderOptions& provider_options); + static std::shared_ptr Create(const ProviderOptions& provider_options, + const SessionOptions* session_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c8fc812fe1238..94a750940f6ef 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -146,28 +146,29 @@ static bool HasMemcpyNodes(const Graph& graph) { return false; } -static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { - bool nodes_on_cpu_and_cuda_eps_only = true; +static bool AreAllComputeNodesAssignedToCudaOrJsEp(const Graph& graph) { + bool nodes_on_cpu_and_cuda_and_js_eps_only = true; for (const auto& node : graph.Nodes()) { const auto& node_provider = node.GetExecutionProviderType(); // Empty node provider means CPU EP if (!node_provider.empty() && - node_provider != kCudaExecutionProvider && + !(node_provider == kCudaExecutionProvider || + node_provider == kJsExecutionProvider) && node_provider != kCpuExecutionProvider) { - nodes_on_cpu_and_cuda_eps_only = false; + nodes_on_cpu_and_cuda_and_js_eps_only = false; break; } } - // If we see nodes assigned to EPs other than CPU or CUDA + // If we see nodes assigned to EPs other than CPU, or CUDA/JS // (or) if there are Memcpy nodes, then all compute nodes have - // not been parititoned to the CUDA EP. + // not been parititoned to the CUDA/JS EP. // We allow CPU EPs to show up in the EP list as long as thre is no Memcpy // involved as shape subgraphs will be forced onto CPU and these will not have // Memcpy nodes involved. - return nodes_on_cpu_and_cuda_eps_only && !HasMemcpyNodes(graph); + return nodes_on_cpu_and_cuda_and_js_eps_only && !HasMemcpyNodes(graph); } static bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) { @@ -1761,7 +1762,7 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); - // Currently CUDA graph is only considered by CUDA EP and TRT EP. + // Currently graph capture is only considered by CUDA EP, TRT EP and JS EP. // // Check for CUDA EP: // If the CUDA EP is part of the providers list for this session AND @@ -1774,47 +1775,62 @@ common::Status InferenceSession::Initialize() { // The TRT EP is configured to do a graph capture AND // All the graph nodes have been assigned to the TRT EP, // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). - std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider}; + // + // Check for JS EP: + // If the JS EP is part of the providers list for this session AND + // The JS EP is configured to do a graph capture AND + // All the "compute" graph nodes have been assigned to the JS EP, + // Then the JS EP is cached for triggering a ReplayGraph() in Run(). + // + std::vector graph_support_ep_list = { + onnxruntime::kTensorrtExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kJsExecutionProvider}; - for (auto& it : cuda_graph_support_ep_list) { + for (auto& it : graph_support_ep_list) { auto* target_ep = execution_providers_.Get(it); if (target_ep && target_ep->IsGraphCaptureEnabled()) { - // CUDA Graphs can't work with control flow nodes + // Graphs capture can't work with control flow nodes if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << "as the model has control flow nodes which can't be supported by CUDA Graphs."; + LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " + << "as the model has control flow nodes which can't be supported by " + << target_ep->Type(); ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - "as the model has control flow nodes which can't be supported by CUDA Graphs.")); + "This session cannot use the graph capture feature as requested by the user " + "as the model has control flow nodes which can't be supported by" + + target_ep->Type())); } - if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0) { - // Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes + if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || + strcmp(target_ep->Type().c_str(), onnxruntime::kJsExecutionProvider) == 0) { + // Ensure that all nodes have been partitioned to CUDA/JS or CPU EP && there are no memcpy nodes // The reasoning behind this logic is that certain shape nodes will be forced onto CPU // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP // which is all we care about. - if (!AreAllComputeNodesAssignedToCudaEp(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as all compute graph nodes have not been partitioned to the CUDA EP."; + if (!AreAllComputeNodesAssignedToCudaOrJsEp(graph)) { + LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " + << " as all compute graph nodes have not been partitioned to the " + << target_ep->Type(); ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all compute graph nodes have not been partitioned to the CUDA EP.")); + "This session cannot use the graph capture feature as requested by the user " + " as all compute graph nodes have not been partitioned to the " + + target_ep->Type())); } // Log a warning for the user to know that there are shape subgraphs that will execute on CPU if (HasShapeSubgraphNodes(graph)) { LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " - << "Use the CUDA Graph feature with caution. " + << "Use the graph capture feature with caution. " << "As long as the intermediate shapes produced in the model " - << "using the representative input used to capture the CUDA graph, " + << "using the representative input used to capture the graph, " << "will match the shapes produced in the model for other inputs " << "of the same shape as the representative input (common case), " - << "it is safe to use the CUDA Graph feature."; + << "it is safe to use the graph capture feature."; } } else { // Following code path is for TRT EP currently. diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index ac059bfd00668..de5d93ef0a434 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -149,7 +149,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, if (options->value.config_options.TryGetConfigEntry("preferredLayout", preferred_layout)) { provider_options["preferred_layout"] = preferred_layout; } - options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options)); + options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); #endif diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 7e9c0a6f99c32..cbc60c70b57aa 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -24,7 +24,7 @@ Module['unmountExternalData'] = () => { /** * init JSEP */ -Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel) => { +Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel, captureBegin, captureEnd, replay) => { Module.jsepBackend = backend; Module.jsepAlloc = alloc; Module.jsepFree = free; @@ -33,6 +33,9 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module.jsepCreateKernel = createKernel; Module.jsepReleaseKernel = releaseKernel; Module.jsepRunKernel = runKernel; + Module.jsepCaptureBegin = captureBegin; + Module.jsepCaptureEnd = captureEnd; + Module.jsepReplay = replay; // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) // It removes some overhead in cwarp() and ccall() that we don't need. @@ -181,16 +184,16 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { return backend['registerBuffer'](sessionId, index, buffer, size); }; - Module['jsepUnregisterBuffers'] = sessionId => { - backend['unregisterBuffers'](sessionId); - }; Module['jsepGetBuffer'] = (dataId) => { return backend['getBuffer'](dataId); }; Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; - Module['jsepOnRunStart'] = () => { - return backend['onRunStart'](); + Module['jsepOnReleaseSession'] = sessionId => { + backend['onReleaseSession'](sessionId); + }; + Module['jsepOnRunStart'] = sessionId => { + return backend['onRunStart'](sessionId); }; };