From 48557fa773c0e728940d0080904a7a579363ccfd Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Thu, 11 Jan 2024 15:00:16 +0800 Subject: [PATCH] [js/webgpu] Add HardSigmoid support --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 1 + js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 20 +++++++++++++++++++ js/web/test/suite-test-list.jsonc | 6 +++--- .../providers/js/js_execution_provider.cc | 2 ++ .../core/providers/js/operators/unary.cc | 3 +++ 6 files changed, 30 insertions(+), 3 deletions(-) diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 2f510308d9306..2557971eb4ded 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -52,6 +52,7 @@ Do not modify directly.* | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| HardSigmoid | ai.onnx(6+) | | | If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 90e02da986b8f..cc504093ca0d7 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], + ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index a25e7fe4229b4..82311d72e58b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); }; +export interface HardSigmoidAttributes extends AttributeWithCacheKey { + readonly alpha: number; + readonly beta: number; +} + +export const parseHardSigmoidAttributes = (attributes: Record): HardSigmoidAttributes => + createAttributeWithCacheKey(attributes as { + alpha: number; + beta: number; + }); + +export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'HardSigmoid', + a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${ + attributes.beta})))`, + undefined, attributes.cacheKey)); +}; + export const sin = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin')); }; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 033b3b3f4b0f5..373b3c645df57 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -597,9 +597,9 @@ // // "test_hardmax_example", // // "test_hardmax_negative_axis", // // "test_hardmax_one_hot", - // // "test_hardsigmoid_default", - // // "test_hardsigmoid_example", - // // "test_hardsigmoid", + "test_hardsigmoid_default", + "test_hardsigmoid_example", + "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", "test_if", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index c2ff2ebc39e13..af9658271d210 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -98,6 +98,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Erf); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, HardSigmoid); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Log); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Log); @@ -392,6 +393,7 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO(13, Erf), KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), KERNEL_CREATE_INFO(13, Sigmoid), + KERNEL_CREATE_INFO(6, HardSigmoid), KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), KERNEL_CREATE_INFO(13, Log), diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index 78563d30b0136..9082527e3a8d7 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -77,6 +77,9 @@ JSEP_KERNEL_IMPL(Sigmoid, Sigmoid) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid) JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(HardSigmoid, HardSigmoid, alpha, 0.2, beta, 0.5) +JSEP_ELEMENTWISE_KERNEL(HardSigmoid, 6, HardSigmoid) + JSEP_KERNEL_IMPL(Log, Log) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log) JSEP_ELEMENTWISE_KERNEL(Log, 13, Log)