From 08fbd11825889773fe8d5d3c91bdd473cade7f3d Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Thu, 1 Feb 2024 16:54:50 +0800 Subject: [PATCH 1/3] [WIP][js/webgpu] Add LeakyRelu activation for fusedConv This PR 1) adds LeakyRelu activation for fusedConv; 2) makes vec4 value work with float32 uniforms attributes. --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 2 +- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 2 +- .../ops/3rd-party/matmul_packed_webgpu.ts | 3 +- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 8 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 47 +++--- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 5 +- js/web/test/data/ops/fused-conv.jsonc | 144 ++++++++++++++++++ 7 files changed, 185 insertions(+), 26 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index bc39bd94e3072..cc3b419c754be 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -130,7 +130,7 @@ const conv2dCommonSnippet = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); - const applyActivation = getActivationSnippet(attributes, resType); + const applyActivation = getActivationSnippet(attributes, resType, dataType); const userCode = ` fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index d18f8586dd071..9768ea50e9c66 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -130,7 +130,7 @@ const conv2dTransposeCommonSnippet = return ${type}(0.0); `; - const applyActivation = getActivationSnippet(attributes, type); + const applyActivation = getActivationSnippet(attributes, type, 'f32'); const userCode = ` fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index d9a8d59f731de..915f1c4320f32 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -481,7 +481,8 @@ export const createMatmulProgramInfo = const uniforms: UniformsArrayType = [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; appendActivationUniforms(activationAttributes, uniforms); - const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); const declareFunctions = matMulReadWriteFnSource( components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], isChannelsLast); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 3c2c3cc4e046c..0084fd063ed1f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from './fuse-utils'; @@ -47,7 +47,8 @@ export const createGroupedConvProgramInfo = const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShape.length); - const applyActivation = getActivationSnippet(attributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); const x = inputVariable('x', inputs[0].dataType, xShape.length); const w = inputVariable('w', inputs[1].dataType, wShape.length); const inputVars = [x, w]; @@ -140,7 +141,8 @@ export const createGroupedConvVectorizeProgramInfo = const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const applyActivation = getActivationSnippet(attributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); const x = inputVariable('x', inputs[0].dataType, xShape.length, components); const w = inputVariable('w', inputs[1].dataType, wShape.length, components); const inputVars = [x, w]; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 60067c014613b..4ef3bd718019c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -15,24 +15,28 @@ export interface InternalActivationAttributes { readonly beta?: number; } -export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => { - switch (attributes.activation) { - case 'Relu': - return `value = max(value, ${valueType}(0.0));`; - case 'Sigmoid': - return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; - case 'Clip': - return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`; - case 'HardSigmoid': - return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${valueType}(uniforms.alpha) * value + ${ - valueType}(uniforms.beta)));`; - case '': - return ''; - // TODO: adding other activations that can be fused. - default: - throw new Error(`Unsupported activation ${attributes.activation}`); - } -}; +export const getActivationSnippet = + (attributes: InternalActivationAttributes, valueType: string, baseType: string): string => { + switch (attributes.activation) { + case 'Relu': + return `value = max(value, ${valueType}(0.0));`; + case 'Sigmoid': + return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; + case 'Clip': + return `value = clamp(value, ${valueType}(${baseType}(uniforms.clip_min)), ${valueType}(${ + baseType}(uniforms.clip_max)));`; + case 'HardSigmoid': + return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${baseType}(uniforms.alpha) * value + ${ + baseType}(uniforms.beta)));`; + case 'LeakyRelu': + return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`; + case '': + return ''; + // TODO: adding other activations that can be fused. + default: + throw new Error(`Unsupported activation ${attributes.activation}`); + } + }; export const appendActivationUniformsData = (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { @@ -42,6 +46,8 @@ export const appendActivationUniformsData = } else if (attributes.activation === 'HardSigmoid') { programUniform.push( {type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!}); + } else if (attributes.activation === 'LeakyRelu') { + programUniform.push({type: DataType.float, data: attributes.alpha!}); } }; @@ -50,6 +56,8 @@ export const appendActivationUniforms = (attributes: InternalActivationAttribute uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); } else if (attributes.activation === 'HardSigmoid') { uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'}); + } else if (attributes.activation === 'LeakyRelu') { + uniforms.push({name: 'alpha', type: 'f32'}); } }; @@ -62,6 +70,9 @@ export const parseInternalActivationAttributes = } else if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; return {activation, clipMax, clipMin}; + } else if (activation === 'LeakyRelu') { + const [alpha] = attributes?.activation_params as [number] || [0.01]; + return {activation, alpha}; } return {activation}; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index b263451b99134..0bc0f0a7eb49e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -7,7 +7,7 @@ import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common'; +import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = @@ -47,7 +47,8 @@ export const createNaiveMatmulProgramInfo = const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); const b = inputVariable('b', inputs[1].dataType, bShape.length, components); const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); const inputVariables = [a, b]; let processBias = ''; if (hasBias) { diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index c734d6db9b92a..1e5678059e8a6 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -286,5 +286,149 @@ ] } ] + }, + { + "name": "fused group-conv with LeakyRelu", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9,-6,51,47,-170,-10,251,229,847,889,973,1015], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC group-conv with LeakyRelu", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-162,63,-158,33,281,85,105,337,455,177,515,609], + "dims": [1, 2, 2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused conv with LeakyRelu", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-540,-860,390,430], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv with LeakyRelu", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-540,-860,390,430], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] } ] From ebeb628c70049fe664f042330aae01aee753cd35 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Fri, 2 Feb 2024 08:59:08 +0800 Subject: [PATCH 2/3] fix format errors --- js/web/test/data/ops/fused-conv.jsonc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index 1e5678059e8a6..6a10e3b96a26a 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -317,7 +317,7 @@ ], "outputs": [ { - "data": [9,-6,51,47,-170,-10,251,229,847,889,973,1015], + "data": [9, -6, 51, 47, -170, -10, 251, 229, 847, 889, 973, 1015], "dims": [1, 3, 2, 2], "type": "float32" } @@ -355,7 +355,7 @@ ], "outputs": [ { - "data": [-162,63,-158,33,281,85,105,337,455,177,515,609], + "data": [-162, 63, -158, 33, 281, 85, 105, 337, 455, 177, 515, 609], "dims": [1, 2, 2, 3], "type": "float32" } @@ -389,7 +389,7 @@ ], "outputs": [ { - "data": [-540,-860,390,430], + "data": [-540, -860, 390, 430], "dims": [1, 1, 2, 2], "type": "float32" } @@ -423,7 +423,7 @@ ], "outputs": [ { - "data": [-540,-860,390,430], + "data": [-540, -860, 390, 430], "dims": [1, 2, 2, 1], "type": "float32" } From 71dd18bc32995d984d812066cff6603f4345f6b4 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Fri, 2 Feb 2024 09:02:53 +0800 Subject: [PATCH 3/3] address comments --- .../wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 9768ea50e9c66..d18f8586dd071 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -130,7 +130,7 @@ const conv2dTransposeCommonSnippet = return ${type}(0.0); `; - const applyActivation = getActivationSnippet(attributes, type, 'f32'); + const applyActivation = getActivationSnippet(attributes, type); const userCode = ` fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 4ef3bd718019c..6e66abacf3471 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -16,7 +16,7 @@ export interface InternalActivationAttributes { } export const getActivationSnippet = - (attributes: InternalActivationAttributes, valueType: string, baseType: string): string => { + (attributes: InternalActivationAttributes, valueType: string, baseType = 'f32'): string => { switch (attributes.activation) { case 'Relu': return `value = max(value, ${valueType}(0.0));`;