Skip to content

Commit

Permalink
Bug 1902166 [wpt PR 46725] - WebNN: Support axis for softmax operator…
Browse files Browse the repository at this point in the history
…, a=testonly

Automatic update from web-platform-tests
WebNN: Support axis for softmax operator

This CL adds axis parameter into the IDL and mojo definitions of softmax
operator [1]. It also updates the backends implementation to support the
new axis parameter.

In addition, the corresponding tests have also been updated.

[1] webmachinelearning/webnn#649

Bug: 338094927
Change-Id: Ib08ecbba61c27c94256953a952357eeda80241e6
Cq-Include-Trybots: luci.chromium.try​:win11-blink-rel,mac14.arm64-blink-rel,mac14-blink-rel
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5495877
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Commit-Queue: Bin Miao <bin.miao@intel.com>
Reviewed-by: Austin Sullivan <asully@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1314418}

--

wpt-commits: 06bcaad19c8cdb96751ff3d80698533e85044a1d
wpt-pr: 46725
  • Loading branch information
miaobin authored and moz-wptsync-bot committed Jun 18, 2024
1 parent ef6402e commit d46e697
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@

// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-softmax

runWebNNConformanceTests('softmax', buildOperationWithSingleInput);
runWebNNConformanceTests('softmax', buildSoftmax);
87 changes: 87 additions & 0 deletions testing/web-platform/tests/webnn/resources/test_data/softmax.json
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,93 @@
],
"type": "float32"
}
},
{
"name": "softmax float32 3D constant tensor",
"inputs": {
"x": {
"shape": [1, 3, 4],
"data": [
0.4301910996437073,
0.5471914410591125,
-1.1637765169143677,
0.18390046060085297,
0.583903968334198,
0.17356790602207184,
0.5397239923477173,
-0.9535139799118042,
-0.5920282602310181,
-0.17344485223293304,
0.14395014941692352,
-0.37920907139778137
],
"type": "float32",
"constant": true
}
},
"axis": 1,
"expected": {
"name": "output",
"shape": [1, 3, 4],
"data": [
0.39589041471481323,
0.45983806252479553,
0.09812675416469574,
0.529077410697937,
0.4616699814796448,
0.31647709012031555,
0.5390242338180542,
0.16964708268642426,
0.142439603805542,
0.22368484735488892,
0.36284899711608887,
0.3012755215167999
],
"type": "float32"
}
},
{
"name": "softmax float32 4D tensor",
"inputs": {
"x": {
"shape": [3, 4, 1, 1],
"data": [
0.4301910996437073,
0.5471914410591125,
-1.1637765169143677,
0.18390046060085297,
0.583903968334198,
0.17356790602207184,
0.5397239923477173,
-0.9535139799118042,
-0.5920282602310181,
-0.17344485223293304,
0.14395014941692352,
-0.37920907139778137
],
"type": "float32"
}
},
"axis": 1,
"expected": {
"name": "output",
"shape": [3, 4, 1, 1],
"data": [
0.3216537833213806,
0.3615773916244507,
0.06533370912075043,
0.25143513083457947,
0.35271573066711426,
0.23400123417377472,
0.33747196197509766,
0.07581108063459396,
0.17110128700733185,
0.26004093885421753,
0.3571779429912567,
0.2116798311471939
],
"type": "float32"
}
}
]
}
14 changes: 14 additions & 0 deletions testing/web-platform/tests/webnn/resources/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,20 @@ const buildSlice = (operationName, builder, resources) => {
return namedOutputOperand;
};

const buildSoftmax = (operationName, builder, resources) => {
// MLOperand softmax(MLOperand input, [EnforceRange] unsigned long axis);
const namedOutputOperand = {};
const inputOperand = createSingleInputOperand(builder, resources);
if (resources.axis !== undefined) {
// invoke builder.softmax(input, axis)
namedOutputOperand[resources.expected.name] = builder[operationName](inputOperand, resources.axis);
} else {
// invoke builder.softmax(input)
namedOutputOperand[resources.expected.name] = builder[operationName](inputOperand);
}
return namedOutputOperand;
};

const buildSplit = (operationName, builder, resources) => {
// sequence<MLOperand> split(MLOperand input,
// (unsigned long or sequence<unsigned long>) splits,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,97 @@

'use strict';

validateInputFromAnotherBuilder('softmax');
const tests_without_axis = [
{
name: '[softmax] Test building Softmax with float32 input without axis.',
input: { dataType: 'float32', dimensions: [4, 3] },
output: { dataType: 'float32', dimensions: [4, 3] }
},
{
name: '[softmax] Test building Softmax with float16 input without axis.',
input: { dataType: 'float16', dimensions: [3, 5] },
output: { dataType: 'float16', dimensions: [3, 5] }
},
{
name: '[softmax] Throw if the input is not a non-floating point data.',
input: { dataType: 'int32', dimensions: [3, 2] }
},
{
name: '[softmax] Throw if the input dimensions is not 2.',
input: { dataType: 'float32', dimensions: [1, 4, 3] }
}
];

tests_without_axis.forEach(test =>
promise_test(async t => {
let input = builder.input(
`input`, { dataType: test.input.dataType, dimensions: test.input.dimensions }
);
if (test.output) {
const output = builder.softmax(input);
assert_equals(output.dataType(), test.output.dataType);
assert_array_equals(output.shape(), test.output.dimensions);
} else {
assert_throws_js(TypeError, () => builder.softmax(input));
}
}, test.name)
);

multi_builder_test(async (t, builder, otherBuilder) => {
const operandDescriptor = { dataType: 'float32', dimensions: [2, 3] };
const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor);

assert_throws_js(
TypeError,
() => builder.softmax(inputFromOtherBuilder));
}, '[softmax without axis] throw if any input is from another builder');

const tests = [
{
name: '[softmax] Test building Softmax with float32 input.',
input: { dataType: 'float32', dimensions: [4, 4, 3] },
axis: 1,
output: { dataType: 'float32', dimensions: [4, 4, 3] }
},
{
name: '[softmax] Test building Softmax with float16 input.',
input: { dataType: 'float16', dimensions: [3, 1, 5, 2] },
axis: 2,
output: { dataType: 'float16', dimensions: [3, 1, 5, 2] }
},
{
name: '[softmax] Throw if the input is not a non-floating-point data.',
input: { dataType: 'int32', dimensions: [3, 1, 5, 2] },
axis: 3
},
{
name: '[softmax] Throw if the axis is greater than input rank - 1.',
input: { dataType: 'float16', dimensions: [3, 1, 5, 2] },
axis: 4
}
];

tests.forEach(test =>
promise_test(async t => {
let input = builder.input(
`input`, { dataType: test.input.dataType, dimensions: test.input.dimensions }
);
if (test.output) {
const output = builder.softmax(input, test.axis);
assert_equals(output.dataType(), test.output.dataType);
assert_array_equals(output.shape(), test.output.dimensions);
} else {
assert_throws_js(TypeError, () => builder.softmax(input, test.axis));
}
}, test.name)
);

multi_builder_test(async (t, builder, otherBuilder) => {
const operandDescriptor = { dataType: 'float32', dimensions: [1, 2, 3] };
const inputFromOtherBuilder = otherBuilder.input('input', operandDescriptor);
const axis = 1;

assert_throws_js(
TypeError,
() => builder.softmax(inputFromOtherBuilder, axis));
}, '[softmax] throw if any input is from another builder');

0 comments on commit d46e697

Please sign in to comment.