Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Fix bug in Softmax #17665

Merged
merged 2 commits into from
Sep 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,18 @@
const auto input_size = input_shape.size();
// WebNN Softmax only support 2d input shape, reshape input to 2d.
if (input_size != 2) {
int32_t new_shape_0 = SafeInt<int32_t>(input_shape.data()[0]);
for (size_t i = 1; i < input_size - 1; i++) {
new_shape_0 *= input_shape.data()[i];
}
emscripten::val new_shape = emscripten::val::array();
new_shape.call<void>("push", new_shape_0);
new_shape.call<void>("push", static_cast<int32_t>(input_shape.back()));
NodeAttrHelper helper(node);
int32_t axis = helper.Get("axis", 1);
guschmue marked this conversation as resolved.
Show resolved Hide resolved
if (node.SinceVersion() >= 13)
// Opset 13 has default value -1.
axis = helper.Get("axis", -1);
// Coerce the input into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}].
axis = static_cast<int32_t>(HandleNegativeAxis(axis, input_size));
int32_t first_dim = static_cast<int32_t>(std::reduce(input_shape.begin(), input_shape.begin() + axis,
1, std::multiplies<int64_t>()));
int32_t second_dim = static_cast<int32_t>(std::reduce(input_shape.begin() + axis, input_shape.end(),
1, std::multiplies<int64_t>()));
emscripten::val new_shape = emscripten::val::array(std::vector<int32_t>{first_dim, second_dim});
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, new_shape);
}
output = model_builder.GetBuilder().call<emscripten::val>("softmax", input);
Expand Down Expand Up @@ -76,9 +81,10 @@
return false;
}
NodeAttrHelper helper(node);
const int32_t axis = helper.Get("axis", 1);
// WebNN softmax only support input axis 1
if (axis != 1 && axis != -1) {
const int64_t axis = helper.Get("axis", 1);
// WebNN softmax only support reshape for the last axis or version before 13.
// TODO: support opset 13 by composing into: Exp(input) / ReduceSum(Exp(input), axis=axis, keepdims=1).

Check warning on line 86 in onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc:86: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
if (axis != -1 && axis != input_shape.size() - 1 && node.SinceVersion() >= 13) {
LOGS(logger, VERBOSE) << "SoftMax only support axis 1 or -1, input axis: " << axis;
return false;
}
Expand Down
Loading