Skip to content

Commit

Permalink
[WebNN] Better int64 integration
Browse files Browse the repository at this point in the history
This PR adds some workarounds to enable int64 support for some WebNN
backends which don't support int64 data type.

- Do not fallback ops that are specifically due to the int64 limitation.
- Convert all int64 initializer and input values to int32 and handle
 potential overflow errors.
- Register all int64 model intputs and outputs as int32 ml-tensor.
- Handle ONNX ops that need intputs or outputs conversion between int64
and int32. e.g. ArgMax, ArgMin, Cast, etc.
- Convert int64 output data back to int32.
- Disallow int64 outputs as 'ml-tensor' preferredOutputLocation.

Fixed microsoft#21401
  • Loading branch information
Honry committed Feb 26, 2025
1 parent e46c0d8 commit 331f768
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 38 deletions.
24 changes: 21 additions & 3 deletions js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { DataType } from '../wasm-common';
import { getInstance } from '../wasm-factory';

import { createView } from './tensor-view';
import { TensorId, createTensorManager } from './webnn/tensor-manager';
import { TensorId, createTensorManager, convertInt64ToInt32 } from './webnn/tensor-manager';
import { configureLogger, LOG_DEBUG } from './log';

/*
Expand Down Expand Up @@ -288,6 +288,7 @@ export class WebNNBackend {
builder: MLGraphBuilder,
desc: MLOperandDescriptor,
mountedFiles: Map<string, Uint8Array> | undefined,
shouldConvertInt64ToInt32: boolean = false,
): MLOperand {
// If available, "Module.MountedFiles" is a Map for all preloaded files.
if (!mountedFiles) {
Expand Down Expand Up @@ -323,7 +324,13 @@ export class WebNNBackend {
bufferView = new Uint32Array(buffer);
break;
case 'int64':
bufferView = new BigInt64Array(buffer);
if (shouldConvertInt64ToInt32) {
// Int64 is not supported by current context, use int32 instead.
bufferView = convertInt64ToInt32(new Uint8Array(buffer), false);
desc.dataType = 'int32';
} else {
bufferView = new BigInt64Array(buffer);
}
break;
case 'uint64':
bufferView = new BigUint64Array(buffer);
Expand All @@ -340,7 +347,13 @@ export class WebNNBackend {
throw new Error(`Unsupported data type: ${desc.dataType} in creating WebNN Constant from external data.`);
}

LOG_DEBUG('verbose', () => `[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}}`);
LOG_DEBUG(
'verbose',
() =>
`[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}} ${
shouldConvertInt64ToInt32 ? '(Note: it was int64 data type and registered to int32 as workaround)' : ''
}`,
);

return builder.constant(desc, bufferView);
}
Expand All @@ -357,6 +370,11 @@ export class WebNNBackend {
return inputNames.includes(inputName);
}

public isInt64Supported(sessionId: number): boolean {
const context = this.mlContextBySessionId.get(sessionId);
return !!context?.opSupportLimits()['input']['dataTypes'].includes('int64');
}

public flush(): void {
// Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations.
}
Expand Down
146 changes: 128 additions & 18 deletions js/web/lib/wasm/jsep/webnn/tensor-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ const calculateByteLength = (dataType: MLOperandDataType, shape: readonly number
class TensorWrapper {
// The id of the last session that used this tensor.
public sessionId: number;
// This flag is used to indicate whether we should convert data from int64 to int32.
public shouldConvertInt64toInt32: boolean = false;
public isInt64ToInt32Converted: boolean = false;

private mlContext: MLContext;
private mlTensor: MLTensor;
Expand All @@ -100,12 +103,15 @@ class TensorWrapper {
tensor: MLTensor;
dataType: MLOperandDataType;
shape: readonly number[];
shouldConvertInt64toInt32?: boolean;
}) {
this.sessionId = descriptor.sessionId;
this.mlContext = descriptor.context;
this.mlTensor = descriptor.tensor;
this.dataType = descriptor.dataType;
this.tensorShape = descriptor.shape;
const { sessionId, context, tensor, dataType, shape, shouldConvertInt64toInt32 = false } = descriptor;
this.sessionId = sessionId;
this.mlContext = context;
this.mlTensor = tensor;
this.dataType = dataType;
this.tensorShape = shape;
this.shouldConvertInt64toInt32 = shouldConvertInt64toInt32;
}

public get tensor(): MLTensor {
Expand Down Expand Up @@ -133,13 +139,35 @@ class TensorWrapper {
this.mlContext.writeTensor(this.mlTensor, data);
}

public async read(): Promise<ArrayBuffer>;
public async read(dstBuffer: ArrayBufferView | ArrayBuffer): Promise<undefined>;
async read(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise<ArrayBuffer | undefined> {
if (dstBuffer) {
return this.mlContext.readTensor(this.mlTensor, dstBuffer);
public async read(shouldConvertInt32ToInt64?: boolean): Promise<ArrayBuffer>;
public async read(
shouldConvertInt32ToInt64?: boolean,
dstBuffer?: ArrayBufferView | ArrayBuffer,
): Promise<ArrayBuffer | undefined>;
public async read(
shouldConvertInt32ToInt64?: boolean,
dstBuffer?: ArrayBufferView | ArrayBuffer,
): Promise<ArrayBuffer | undefined> {
if (shouldConvertInt32ToInt64) {
// This was an int64 data as saved as int32 as workaround, we need to read it as int64.
const data = await this.mlContext.readTensor(this.mlTensor);
const int64Data = convertInt32ToInt64(new Uint8Array(data));

if (dstBuffer) {
const targetBuffer =
dstBuffer instanceof ArrayBuffer
? new Uint8Array(dstBuffer)
: new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength);
targetBuffer.set(int64Data);
return undefined;
} else {
return int64Data.buffer;
}
} else {
return dstBuffer
? await this.mlContext.readTensor(this.mlTensor, dstBuffer)
: await this.mlContext.readTensor(this.mlTensor);
}
return this.mlContext.readTensor(this.mlTensor);
}

public canReuseTensor(context: MLContext, dataType: MLOperandDataType, shape: readonly number[]): boolean {
Expand All @@ -150,6 +178,10 @@ class TensorWrapper {
this.tensorShape.every((v, i) => v === shape[i])
);
}

public setIsInt64ToInt32Converted(isConverted: boolean): void {
this.isInt64ToInt32Converted = isConverted;
}
}

/**
Expand Down Expand Up @@ -184,6 +216,14 @@ class TensorIdTracker {
copyOld: boolean,
): Promise<MLTensor> {
const context = this.tensorManager.getMLContext(sessionId);
// If the data type is int64 and the context does not support int64, we need to convert it to int32.
const shouldConvertInt64toInt32 =
dataType === 'int64' && !context.opSupportLimits()['input']['dataTypes'].includes('int64');
if (shouldConvertInt64toInt32) {
dataType = 'int32';
LOG_DEBUG('verbose', () => `[WebNN] TensorIdTracker.ensureTensor: convert dataType from int64 to int32`);
}

if (this.wrapper) {
if (this.wrapper.canReuseTensor(context, dataType, shape)) {
return this.wrapper.tensor;
Expand All @@ -200,9 +240,19 @@ class TensorIdTracker {

// eslint-disable-next-line no-bitwise
const usage = typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ | MLTensorUsage.WRITE;
this.wrapper = await this.tensorManager.getCachedTensor(sessionId, dataType, shape, usage, true, true);
this.wrapper = await this.tensorManager.getCachedTensor(
sessionId,
dataType,
shape,
usage,
true,
true,
shouldConvertInt64toInt32,
);

if (copyOld && this.activeUpload) {
// We don't need to convert the old int64 data to int32,
// because it has been converted when it was uploaded.
this.wrapper.write(this.activeUpload);
this.activeUpload = undefined;
}
Expand All @@ -212,6 +262,12 @@ class TensorIdTracker {

public upload(data: Uint8Array): void {
if (this.wrapper) {
if (this.wrapper.shouldConvertInt64toInt32) {
// Convert int64 to int32.
const new_data = convertInt64ToInt32(data, true);
this.wrapper.setIsInt64ToInt32Converted(true);
data = new_data instanceof Int32Array ? new Uint8Array(new_data.buffer) : new_data;
}
if (data.byteLength === this.wrapper.byteLength) {
this.wrapper.write(data);
return;
Expand All @@ -230,24 +286,30 @@ class TensorIdTracker {

public async download(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise<ArrayBuffer | undefined> {
if (this.activeUpload) {
// If this.activeUpload has been converted to int32, we need to convert it back to int64 data.
const dstData = this.wrapper?.isInt64ToInt32Converted
? convertInt32ToInt64(this.activeUpload)
: this.activeUpload;

if (dstBuffer) {
if (dstBuffer instanceof ArrayBuffer) {
new Uint8Array(dstBuffer).set(this.activeUpload);
new Uint8Array(dstBuffer).set(dstData);
} else {
new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(this.activeUpload);
new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(dstData);
}
return;
} else {
return this.activeUpload.buffer;
return dstData.buffer;
}
}
if (!this.wrapper) {
throw new Error('Tensor has not been created.');
}

if (!dstBuffer) {
return this.wrapper.read();
return this.wrapper.read(this.wrapper?.shouldConvertInt64toInt32);
}
return this.wrapper.read(dstBuffer);
return this.wrapper.read(this.wrapper?.shouldConvertInt64toInt32, dstBuffer);
}
}

Expand Down Expand Up @@ -367,6 +429,7 @@ class TensorManagerImpl implements TensorManager {
usage: MLTensorUsageFlags | undefined,
writable: boolean,
readable: boolean,
shouldConvertInt64toInt32: boolean = false,
): Promise<TensorWrapper> {
const context = this.getMLContext(sessionId);
for (const [index, tensor] of this.freeTensors.entries()) {
Expand All @@ -386,7 +449,7 @@ class TensorManagerImpl implements TensorManager {
writable,
readable,
});
return new TensorWrapper({ sessionId, context, tensor, dataType, shape });
return new TensorWrapper({ sessionId, context, tensor, dataType, shape, shouldConvertInt64toInt32 });
}

/**
Expand All @@ -402,3 +465,50 @@ class TensorManagerImpl implements TensorManager {

export const createTensorManager = (...args: ConstructorParameters<typeof TensorManagerImpl>): TensorManager =>
new TensorManagerImpl(...args);

// Convert BigInt64Array buffer data to Int32Array buffer data.
export function convertInt64ToInt32(data: Uint8Array, returnUint8 = true): Uint8Array | Int32Array {
// Make sure it is a multiple of 8 bytes (BigInt64Array).
if (data.byteLength % 8 !== 0) {
throw new Error('Invalid Uint8Array length, must be a multiple of 8 (BigInt).');
}

// Convert Uint8Array to BigInt64Array.
const numElements = data.byteLength / 8;
const bigInt64Array = new BigInt64Array(data.buffer, data.byteOffset, numElements);

// Convert BigInt64Array to Int32Array (same number of elements).
const int32Array = new Int32Array(numElements);

for (let i = 0; i < numElements; i++) {
const value = bigInt64Array[i];

// Check for overflow.
if (value > 2147483647n || value < -2147483648n) {
throw new Error(`Overflow occurred when converting BigInt to Int32 at index ${i}: ${value}`);
}

int32Array[i] = Number(value);
}

// Return based on the requested format.
return returnUint8 ? new Uint8Array(int32Array.buffer) : int32Array;
}

// Convert Int32Array buffer data to BigInt64Array buffer data.
function convertInt32ToInt64(data: Uint8Array): Uint8Array {
// Make sure it is a multiple of 4 bytes (Int32Array).
if (data.byteLength % 4 !== 0) {
throw new Error('Invalid Uint8Array length, must be a multiple of 4 (Int32).');
}

// Convert Uint8Array to Int32Array.
const numElements = data.byteLength / 4;
const int32Array = new Int32Array(data.buffer, data.byteOffset, numElements);

// Convert Int32Array to BigInt64Array (same number of elements).
const bigInt64Array = BigInt64Array.from(int32Array, BigInt);

// Return BigInt64Array buffer data.
return new Uint8Array(bigInt64Array.buffer);
}
9 changes: 9 additions & 0 deletions js/web/lib/wasm/jsep/webnn/webnn.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -415,4 +415,13 @@ interface MLContext {
readTensor(sourceTensor: MLTensor): Promise<ArrayBuffer>;
readTensor(sourceTensor: MLTensor, destinationData: ArrayBufferView|ArrayBuffer): Promise<undefined>;
dispatch(graph: MLGraph, inputs: MLNamedTensor, outputs: MLNamedTensor): void;
opSupportLimits() : MLOpSupportLimits;
}

interface MLOpSupportLimits {
input: MLSupportLimits;
}

interface MLSupportLimits {
dataTypes: MLOperandDataType[];
}
8 changes: 7 additions & 1 deletion js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -820,13 +820,19 @@ export const run = async (
]);
} else if (preferredLocation === 'ml-tensor' && size > 0) {
const ensureTensor = wasm.jsepEnsureTensor;
if (!ensureTensor) {
const isInt64Supported = wasm.jsepIsInt64Supported;
if (!ensureTensor || !isInt64Supported) {
throw new Error('preferredLocation "ml-tensor" is not supported without using WebNN.');
}
const tensorSize = calculateTensorSizeInBytes(dataType, size);
if (tensorSize === undefined || !isMLTensorSupportedType(type)) {
throw new Error(`Unsupported data type: ${type}`);
}
if (type === 'int64' && !isInt64Supported(sessionId)) {
throw new Error(
`preferredLocation "ml-tensor" for int64 output is not supported by current WebNN Context.`,
);
}

// If the graph has been partitioned, the output tensor may have not been created. For this reason, we use
// ensureTensor to get/create the MLTensor. In which case, we don't need to copy the data if a new tensor
Expand Down
8 changes: 8 additions & 0 deletions js/web/lib/wasm/wasm-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ export declare namespace JSEP {
* @param dataLength - specify the external data length.
* @param builder - specify the MLGraphBuilder used for constructing the Constant.
* @param desc - specify the MLOperandDescriptor of the Constant.
* @param shouldConvertInt64ToInt32 - specify whether to convert int64 to int32.
* @returns the WebNN Constant operand for the specified external data.
*/
jsepRegisterMLConstant(
Expand All @@ -269,6 +270,7 @@ export declare namespace JSEP {
dataLength: number,
builder: MLGraphBuilder,
desc: MLOperandDescriptor,
shouldConvertInt64ToInt32: boolean,
): MLOperand;

/**
Expand All @@ -291,6 +293,12 @@ export declare namespace JSEP {
* @returns the MLTensor ID for the temporary MLTensor.
*/
jsepCreateTemporaryTensor: (sessionId: number, dataType: DataType, shape: readonly number[]) => Promise<number>;
/**
* [exported from pre-jsep.js] Check if a session's associated WebNN Context supports int64.
* @param sessionId - specify the session ID.
* @returns whether the WebNN Context supports int64.
*/
jsepIsInt64Supported: (sessionId: number) => boolean;
}
}

Expand Down
16 changes: 13 additions & 3 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,19 @@ bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& we
const std::string_view webnn_data_type = it->second;

// Check if WebNN supports the data type.
emscripten::val is_supported =
webnn_supported_data_types.call<emscripten::val>("includes", emscripten::val(std::string(webnn_data_type)));
return is_supported.as<bool>();
bool is_supported = webnn_supported_data_types.call<emscripten::val>("includes",
emscripten::val(std::string(webnn_data_type)))
.as<bool>();

if (webnn_data_type == "int64" &&
!is_supported &&
webnn_supported_data_types.call<emscripten::val>("includes", emscripten::val("int32")).as<bool>()) {
// Current context doesn't support int64, but int32 is supported.
// We can use int32 as a workaround.
is_supported = true;
}

return is_supported;
}

// Check if the input or output data type of ONNX node is supported by the WebNN operator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

emscripten::val options = emscripten::val::object();
options.set("keepDimensions", keep_dims == 1);
// TODO(Honry): check whether int64 output data type is supported by WebNN opSupportLimits() API.
options.set("outputDataType", "int64");
std::string output_data_type = "int64";
if (!model_builder.IsInt64Supported()) {
// Int64 is not supported by current context, use int32 instead.
output_data_type = "int32";
}
options.set("outputDataType", output_data_type);
options.set("label", node.Name());
emscripten::val output = emscripten::val::object();

Expand Down
Loading

0 comments on commit 331f768

Please sign in to comment.