diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 2b9a9208e2e53..f9cf74071d14d 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -12,7 +12,7 @@ import { DataType } from '../wasm-common'; import { getInstance } from '../wasm-factory'; import { createView } from './tensor-view'; -import { TensorId, createTensorManager } from './webnn/tensor-manager'; +import { TensorId, createTensorManager, convertInt64ToInt32 } from './webnn/tensor-manager'; import { configureLogger, LOG_DEBUG } from './log'; /* @@ -288,6 +288,7 @@ export class WebNNBackend { builder: MLGraphBuilder, desc: MLOperandDescriptor, mountedFiles: Map | undefined, + shouldConvertInt64ToInt32: boolean = false, ): MLOperand { // If available, "Module.MountedFiles" is a Map for all preloaded files. if (!mountedFiles) { @@ -323,7 +324,13 @@ export class WebNNBackend { bufferView = new Uint32Array(buffer); break; case 'int64': - bufferView = new BigInt64Array(buffer); + if (shouldConvertInt64ToInt32) { + // Int64 is not supported by current context, use int32 instead. + bufferView = convertInt64ToInt32(new Uint8Array(buffer), false); + desc.dataType = 'int32'; + } else { + bufferView = new BigInt64Array(buffer); + } break; case 'uint64': bufferView = new BigUint64Array(buffer); @@ -340,7 +347,13 @@ export class WebNNBackend { throw new Error(`Unsupported data type: ${desc.dataType} in creating WebNN Constant from external data.`); } - LOG_DEBUG('verbose', () => `[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}}`); + LOG_DEBUG( + 'verbose', + () => + `[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}} ${ + shouldConvertInt64ToInt32 ? '(Note: it was int64 data type and registered to int32 as workaround)' : '' + }`, + ); return builder.constant(desc, bufferView); } @@ -357,6 +370,11 @@ export class WebNNBackend { return inputNames.includes(inputName); } + public isInt64Supported(sessionId: number): boolean { + const context = this.mlContextBySessionId.get(sessionId); + return !!context?.opSupportLimits()['input']['dataTypes'].includes('int64'); + } + public flush(): void { // Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations. } diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index ebdd5069aa089..aa4cec247f9a5 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -88,6 +88,9 @@ const calculateByteLength = (dataType: MLOperandDataType, shape: readonly number class TensorWrapper { // The id of the last session that used this tensor. public sessionId: number; + // This flag is used to indicate whether we should convert data from int64 to int32. + public shouldConvertInt64toInt32: boolean = false; + public isInt64ToInt32Converted: boolean = false; private mlContext: MLContext; private mlTensor: MLTensor; @@ -100,12 +103,15 @@ class TensorWrapper { tensor: MLTensor; dataType: MLOperandDataType; shape: readonly number[]; + shouldConvertInt64toInt32?: boolean; }) { - this.sessionId = descriptor.sessionId; - this.mlContext = descriptor.context; - this.mlTensor = descriptor.tensor; - this.dataType = descriptor.dataType; - this.tensorShape = descriptor.shape; + const { sessionId, context, tensor, dataType, shape, shouldConvertInt64toInt32 = false } = descriptor; + this.sessionId = sessionId; + this.mlContext = context; + this.mlTensor = tensor; + this.dataType = dataType; + this.tensorShape = shape; + this.shouldConvertInt64toInt32 = shouldConvertInt64toInt32; } public get tensor(): MLTensor { @@ -133,13 +139,35 @@ class TensorWrapper { this.mlContext.writeTensor(this.mlTensor, data); } - public async read(): Promise; - public async read(dstBuffer: ArrayBufferView | ArrayBuffer): Promise; - async read(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { - if (dstBuffer) { - return this.mlContext.readTensor(this.mlTensor, dstBuffer); + public async read(shouldConvertInt32ToInt64?: boolean): Promise; + public async read( + shouldConvertInt32ToInt64?: boolean, + dstBuffer?: ArrayBufferView | ArrayBuffer, + ): Promise; + public async read( + shouldConvertInt32ToInt64?: boolean, + dstBuffer?: ArrayBufferView | ArrayBuffer, + ): Promise { + if (shouldConvertInt32ToInt64) { + // This was an int64 data as saved as int32 as workaround, we need to read it as int64. + const data = await this.mlContext.readTensor(this.mlTensor); + const int64Data = convertInt32ToInt64(new Uint8Array(data)); + + if (dstBuffer) { + const targetBuffer = + dstBuffer instanceof ArrayBuffer + ? new Uint8Array(dstBuffer) + : new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength); + targetBuffer.set(int64Data); + return undefined; + } else { + return int64Data.buffer; + } + } else { + return dstBuffer + ? await this.mlContext.readTensor(this.mlTensor, dstBuffer) + : await this.mlContext.readTensor(this.mlTensor); } - return this.mlContext.readTensor(this.mlTensor); } public canReuseTensor(context: MLContext, dataType: MLOperandDataType, shape: readonly number[]): boolean { @@ -150,6 +178,10 @@ class TensorWrapper { this.tensorShape.every((v, i) => v === shape[i]) ); } + + public setIsInt64ToInt32Converted(isConverted: boolean): void { + this.isInt64ToInt32Converted = isConverted; + } } /** @@ -184,6 +216,14 @@ class TensorIdTracker { copyOld: boolean, ): Promise { const context = this.tensorManager.getMLContext(sessionId); + // If the data type is int64 and the context does not support int64, we need to convert it to int32. + const shouldConvertInt64toInt32 = + dataType === 'int64' && !context.opSupportLimits()['input']['dataTypes'].includes('int64'); + if (shouldConvertInt64toInt32) { + dataType = 'int32'; + LOG_DEBUG('verbose', () => `[WebNN] TensorIdTracker.ensureTensor: convert dataType from int64 to int32`); + } + if (this.wrapper) { if (this.wrapper.canReuseTensor(context, dataType, shape)) { return this.wrapper.tensor; @@ -200,9 +240,19 @@ class TensorIdTracker { // eslint-disable-next-line no-bitwise const usage = typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ | MLTensorUsage.WRITE; - this.wrapper = await this.tensorManager.getCachedTensor(sessionId, dataType, shape, usage, true, true); + this.wrapper = await this.tensorManager.getCachedTensor( + sessionId, + dataType, + shape, + usage, + true, + true, + shouldConvertInt64toInt32, + ); if (copyOld && this.activeUpload) { + // We don't need to convert the old int64 data to int32, + // because it has been converted when it was uploaded. this.wrapper.write(this.activeUpload); this.activeUpload = undefined; } @@ -212,6 +262,12 @@ class TensorIdTracker { public upload(data: Uint8Array): void { if (this.wrapper) { + if (this.wrapper.shouldConvertInt64toInt32) { + // Convert int64 to int32. + const new_data = convertInt64ToInt32(data, true); + this.wrapper.setIsInt64ToInt32Converted(true); + data = new_data instanceof Int32Array ? new Uint8Array(new_data.buffer) : new_data; + } if (data.byteLength === this.wrapper.byteLength) { this.wrapper.write(data); return; @@ -230,24 +286,30 @@ class TensorIdTracker { public async download(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { if (this.activeUpload) { + // If this.activeUpload has been converted to int32, we need to convert it back to int64 data. + const dstData = this.wrapper?.isInt64ToInt32Converted + ? convertInt32ToInt64(this.activeUpload) + : this.activeUpload; + if (dstBuffer) { if (dstBuffer instanceof ArrayBuffer) { - new Uint8Array(dstBuffer).set(this.activeUpload); + new Uint8Array(dstBuffer).set(dstData); } else { - new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(this.activeUpload); + new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(dstData); } return; } else { - return this.activeUpload.buffer; + return dstData.buffer; } } if (!this.wrapper) { throw new Error('Tensor has not been created.'); } + if (!dstBuffer) { - return this.wrapper.read(); + return this.wrapper.read(this.wrapper?.shouldConvertInt64toInt32); } - return this.wrapper.read(dstBuffer); + return this.wrapper.read(this.wrapper?.shouldConvertInt64toInt32, dstBuffer); } } @@ -367,6 +429,7 @@ class TensorManagerImpl implements TensorManager { usage: MLTensorUsageFlags | undefined, writable: boolean, readable: boolean, + shouldConvertInt64toInt32: boolean = false, ): Promise { const context = this.getMLContext(sessionId); for (const [index, tensor] of this.freeTensors.entries()) { @@ -386,7 +449,7 @@ class TensorManagerImpl implements TensorManager { writable, readable, }); - return new TensorWrapper({ sessionId, context, tensor, dataType, shape }); + return new TensorWrapper({ sessionId, context, tensor, dataType, shape, shouldConvertInt64toInt32 }); } /** @@ -402,3 +465,50 @@ class TensorManagerImpl implements TensorManager { export const createTensorManager = (...args: ConstructorParameters): TensorManager => new TensorManagerImpl(...args); + +// Convert BigInt64Array buffer data to Int32Array buffer data. +export function convertInt64ToInt32(data: Uint8Array, returnUint8 = true): Uint8Array | Int32Array { + // Make sure it is a multiple of 8 bytes (BigInt64Array). + if (data.byteLength % 8 !== 0) { + throw new Error('Invalid Uint8Array length, must be a multiple of 8 (BigInt).'); + } + + // Convert Uint8Array to BigInt64Array. + const numElements = data.byteLength / 8; + const bigInt64Array = new BigInt64Array(data.buffer, data.byteOffset, numElements); + + // Convert BigInt64Array to Int32Array (same number of elements). + const int32Array = new Int32Array(numElements); + + for (let i = 0; i < numElements; i++) { + const value = bigInt64Array[i]; + + // Check for overflow. + if (value > 2147483647n || value < -2147483648n) { + throw new Error(`Overflow occurred when converting BigInt to Int32 at index ${i}: ${value}`); + } + + int32Array[i] = Number(value); + } + + // Return based on the requested format. + return returnUint8 ? new Uint8Array(int32Array.buffer) : int32Array; +} + +// Convert Int32Array buffer data to BigInt64Array buffer data. +function convertInt32ToInt64(data: Uint8Array): Uint8Array { + // Make sure it is a multiple of 4 bytes (Int32Array). + if (data.byteLength % 4 !== 0) { + throw new Error('Invalid Uint8Array length, must be a multiple of 4 (Int32).'); + } + + // Convert Uint8Array to Int32Array. + const numElements = data.byteLength / 4; + const int32Array = new Int32Array(data.buffer, data.byteOffset, numElements); + + // Convert Int32Array to BigInt64Array (same number of elements). + const bigInt64Array = BigInt64Array.from(int32Array, BigInt); + + // Return BigInt64Array buffer data. + return new Uint8Array(bigInt64Array.buffer); +} diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index c513b2ec2ed8b..0ebff457d5b33 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -415,4 +415,13 @@ interface MLContext { readTensor(sourceTensor: MLTensor): Promise; readTensor(sourceTensor: MLTensor, destinationData: ArrayBufferView|ArrayBuffer): Promise; dispatch(graph: MLGraph, inputs: MLNamedTensor, outputs: MLNamedTensor): void; + opSupportLimits() : MLOpSupportLimits; +} + +interface MLOpSupportLimits { + input: MLSupportLimits; +} + +interface MLSupportLimits { + dataTypes: MLOperandDataType[]; } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 4bccfa76fdda3..497cfc8360d8c 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -820,13 +820,19 @@ export const run = async ( ]); } else if (preferredLocation === 'ml-tensor' && size > 0) { const ensureTensor = wasm.jsepEnsureTensor; - if (!ensureTensor) { + const isInt64Supported = wasm.jsepIsInt64Supported; + if (!ensureTensor || !isInt64Supported) { throw new Error('preferredLocation "ml-tensor" is not supported without using WebNN.'); } const tensorSize = calculateTensorSizeInBytes(dataType, size); if (tensorSize === undefined || !isMLTensorSupportedType(type)) { throw new Error(`Unsupported data type: ${type}`); } + if (type === 'int64' && !isInt64Supported(sessionId)) { + throw new Error( + `preferredLocation "ml-tensor" for int64 output is not supported by current WebNN Context.`, + ); + } // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use // ensureTensor to get/create the MLTensor. In which case, we don't need to copy the data if a new tensor diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index b4871e145f4d7..69563e5e07b1b 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -261,6 +261,7 @@ export declare namespace JSEP { * @param dataLength - specify the external data length. * @param builder - specify the MLGraphBuilder used for constructing the Constant. * @param desc - specify the MLOperandDescriptor of the Constant. + * @param shouldConvertInt64ToInt32 - specify whether to convert int64 to int32. * @returns the WebNN Constant operand for the specified external data. */ jsepRegisterMLConstant( @@ -269,6 +270,7 @@ export declare namespace JSEP { dataLength: number, builder: MLGraphBuilder, desc: MLOperandDescriptor, + shouldConvertInt64ToInt32: boolean, ): MLOperand; /** @@ -291,6 +293,12 @@ export declare namespace JSEP { * @returns the MLTensor ID for the temporary MLTensor. */ jsepCreateTemporaryTensor: (sessionId: number, dataType: DataType, shape: readonly number[]) => Promise; + /** + * [exported from pre-jsep.js] Check if a session's associated WebNN Context supports int64. + * @param sessionId - specify the session ID. + * @returns whether the WebNN Context supports int64. + */ + jsepIsInt64Supported: (sessionId: number) => boolean; } } diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 20f3ffddd2779..47f65cd0b8e85 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -144,9 +144,19 @@ bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& we const std::string_view webnn_data_type = it->second; // Check if WebNN supports the data type. - emscripten::val is_supported = - webnn_supported_data_types.call("includes", emscripten::val(std::string(webnn_data_type))); - return is_supported.as(); + bool is_supported = webnn_supported_data_types.call("includes", + emscripten::val(std::string(webnn_data_type))) + .as(); + + if (webnn_data_type == "int64" && + !is_supported && + webnn_supported_data_types.call("includes", emscripten::val("int32")).as()) { + // Current context doesn't support int64, but int32 is supported. + // We can use int32 as a workaround. + is_supported = true; + } + + return is_supported; } // Check if the input or output data type of ONNX node is supported by the WebNN operator. diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index d61ae1a1f6be7..6814b019f699c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -43,8 +43,12 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val options = emscripten::val::object(); options.set("keepDimensions", keep_dims == 1); - // TODO(Honry): check whether int64 output data type is supported by WebNN opSupportLimits() API. - options.set("outputDataType", "int64"); + std::string output_data_type = "int64"; + if (!model_builder.IsInt64Supported()) { + // Int64 is not supported by current context, use int32 instead. + output_data_type = "int32"; + } + options.set("outputDataType", output_data_type); options.set("label", node.Name()); emscripten::val output = emscripten::val::object(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index 9eacc192d4c02..7c08f3e5045f5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -73,6 +73,10 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); + if (to_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 && !model_builder.IsInt64Supported()) { + // Int64 is not supported by current context, use int32 instead. + operand_type = "int32"; + } emscripten::val output = model_builder.GetBuilder().call("cast", input, emscripten::val(operand_type), options); diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index cbaff79f4fd4f..d633bb96f858e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -159,14 +159,22 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build if (position_ids_is_offset) { // We generate a sequence from 0 to sequence_length and add the offset to it. const std::vector position_ids_range_shape = {1, sequence_length}; - emscripten::val position_ids_range_buffer = emscripten::val::global("BigInt64Array").new_(sequence_length); + std::string typed_array_name = "BigInt64Array"; + int position_ids_data_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; + const bool is_int64_supported = model_builder.IsInt64Supported(); + if (!is_int64_supported) { + // Int64 is not supported by current context, use int32 instead. + typed_array_name = "Int32Array"; + position_ids_data_type = ONNX_NAMESPACE::TensorProto_DataType_INT32; + } + emscripten::val position_ids_range_buffer = emscripten::val::global(typed_array_name.c_str()).new_(sequence_length); for (uint32_t i = 0; i < sequence_length; i++) { - position_ids_range_buffer.set(i, emscripten::val::global("BigInt")(i)); + position_ids_range_buffer.set(i, is_int64_supported ? emscripten::val::global("BigInt")(i) : emscripten::val(i)); } emscripten::val position_ids_range_desc = emscripten::val::object(); position_ids_range_desc.set("shape", emscripten::val::array(position_ids_range_shape)); position_ids_range_desc.set("dimensions", emscripten::val::array(position_ids_range_shape)); - position_ids_range_desc.set("dataType", emscripten::val("int64")); + ORT_RETURN_IF_NOT(SetWebnnDataType(position_ids_range_desc, position_ids_data_type), "Unsupported data type"); emscripten::val position_ids_range = wnn_builder.call( "constant", position_ids_range_desc, position_ids_range_buffer); // Add the offset to the sequence. diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 360c6588898f1..2bea1af896eee 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -29,12 +29,20 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto rank = static_cast(input_shape.size()); emscripten::val desc = emscripten::val::object(); - ORT_RETURN_IF_NOT(SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_INT64), "Unsupported data type"); emscripten::val dims = emscripten::val::array(); dims.call("push", rank); desc.set("dimensions", dims); desc.set("shape", dims); - emscripten::val shape_buffer = emscripten::val::global("BigInt64Array").new_(emscripten::val::array(input_shape)); + int data_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; + std::string typed_array_name = "BigInt64Array"; + if (!model_builder.IsInt64Supported()) { + // Int64 is not supported by current context, use int32 instead. + data_type = ONNX_NAMESPACE::TensorProto_DataType_INT32; + typed_array_name = "Int32Array"; + } + ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); + emscripten::val shape_buffer = + emscripten::val::global(typed_array_name.c_str()).new_(emscripten::val::array(input_shape)); emscripten::val shape_constant = model_builder.GetBuilder().call("constant", desc, shape_buffer); NodeAttrHelper helper(node); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index ace6519a1fc11..afa67fc67cea4 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -34,6 +34,9 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge if (!wnn_builder_.as()) { ORT_THROW("Failed to create WebNN builder."); } + if (wnn_limits["input"]["dataTypes"].call("includes", emscripten::val("int64")).as()) { + is_int64_supported_ = true; + } } Status ModelBuilder::Initialize() { @@ -125,6 +128,10 @@ Status ModelBuilder::RegisterInitializers() { emscripten::val view = emscripten::val::undefined(); std::byte* tensor_ptr = nullptr; + // A flag to indicate if we should convert int64 to int32. + const bool should_convert_int64_to_int32 = !is_int64_supported_ && + data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64; + if (utils::HasExternalData(tensor)) { // Create WebNN Constant from external data. std::basic_string external_file_path; @@ -138,7 +145,8 @@ Status ModelBuilder::RegisterInitializers() { static_cast(data_offset), static_cast(tensor_byte_size), wnn_builder_, - desc); + desc, + should_convert_int64_to_int32); } else { if (tensor.has_raw_data()) { tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); @@ -195,6 +203,25 @@ Status ModelBuilder::RegisterInitializers() { break; } + // If int64 is not supported, convert int64 to int32. + std::vector int32_data(num_elements); + if (should_convert_int64_to_int32) { + try { + std::transform(reinterpret_cast(tensor_ptr), + reinterpret_cast(tensor_ptr) + static_cast(num_elements), + int32_data.begin(), + [](int64_t val) -> int32_t { + return gsl::narrow(val); + }); + LOGS(logger_, VERBOSE) << "Initializer '" << name << "' is converted from int64 to int32."; + } catch (const std::exception& e) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, e.what()); + } + view = emscripten::val{emscripten::typed_memory_view(num_elements, int32_data.data())}; + + desc.set("dataType", emscripten::val("int32")); + } + // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached // buffers in JS side. Simply create a copy to fix it. operand = wnn_builder_.call("constant", desc, view.call("slice")); @@ -203,7 +230,7 @@ Status ModelBuilder::RegisterInitializers() { // TODO: support other type. return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The initializer of graph has unsupported type, name: ", - tensor.name(), " type: ", data_type); + name, " type: ", data_type); } wnn_operands_.insert(std::make_pair(name, operand)); } @@ -259,6 +286,10 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i } if (is_input) { + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 && !is_int64_supported_) { + // Int64 is not supported by current context, use int32 instead. + desc.set("dataType", emscripten::val("int32")); + } wnn_operands_.insert(std::make_pair(name, wnn_builder_.call("input", name, desc))); emscripten::val::module_property("jsepRegisterGraphInput")(name); input_names_.push_back(name); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 4e2d84f481df0..8b2027919f3b2 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -33,6 +33,8 @@ class ModelBuilder { const GraphViewer& GetGraphViewer() const { return graph_viewer_; } InitializedTensorSet GetInitializerTensors(); + bool IsInt64Supported() const { return is_int64_supported_; } + const emscripten::val& GetBuilder() const { return wnn_builder_; } const emscripten::val& GetContext() const { return wnn_context_; } const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); } @@ -71,6 +73,7 @@ class ModelBuilder { emscripten::val wnn_context_ = emscripten::val::undefined(); emscripten::val wnn_builder_ = emscripten::val::undefined(); + bool is_int64_supported_{false}; DataLayout preferred_layout_; WebnnDeviceType wnn_device_type_; emscripten::val wnn_limits_ = emscripten::val::undefined(); diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 0c83e71a921cb..2ac367fc9caa8 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -243,13 +243,27 @@ Module['jsepInit'] = (name, params) => { Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => { return backend['createMLContext'](optionsOrGpuDevice); }; - Module['jsepRegisterMLConstant'] = (externalFilePath, dataOffset, dataLength, builder, desc) => { - return backend['registerMLConstant']( - externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); + Module["jsepRegisterMLConstant"] = ( + externalFilePath, + dataOffset, + dataLength, + builder, + desc, + shouldConvertInt64ToInt32, + ) => { + return backend["registerMLConstant"]( + externalFilePath, + dataOffset, + dataLength, + builder, + desc, + Module.MountedFiles, + shouldConvertInt64ToInt32, + ); }; Module['jsepRegisterGraphInput'] = backend['registerGraphInput'].bind(backend); Module['jsepIsGraphInput'] = backend['isGraphInput'].bind(backend); - Module['jsepCreateTemporaryTensor'] = backend['createTemporaryTensor'].bind(backend); + Module['jsepIsInt64Supported'] = backend['isInt64Supported'].bind(backend); } };