From 2b603df91c0f13c22460fe436088562f802b34f9 Mon Sep 17 00:00:00 2001 From: Yujie Hui Date: Mon, 20 May 2024 11:48:48 -0700 Subject: [PATCH] aten.hardshrink.default in unary_ops (#3681) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/3681 Implement aten.hardshrink in unary_ops ``` func: hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) ``` Another standalone implementation can be found: D57500203 Reviewed By: SS-JIA Differential Revision: D57573208 fbshipit-source-id: 7939543fddf8e39f6bb4e681c43d386140ae560b --- .../vulkan/runtime/graph/ops/glsl/unary_op.yaml | 2 ++ backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp | 13 +++++++++++++ backends/vulkan/test/op_tests/cases.py | 1 + 3 files changed, 16 insertions(+) diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 6760e70303..0610bffe13 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -22,3 +22,5 @@ unary_op: OPERATOR: sqrt(X) - NAME: tanh OPERATOR: tanh(clamp(X, -15.0, 15.0)) + - NAME: hardshrink + OPERATOR: X * (vec4(greaterThan(X, vec4(A))) + vec4(lessThan(X, vec4(B)))) diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index a21a680894..82063059c7 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -100,6 +100,17 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) { kClampShaderName); \ } +#define DEFINE_HARDSHRINK_FN(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_unary_op_node( \ + graph, \ + args[0], \ + get_val_or_inf(graph, args[1], /*max = */ false), \ + -get_val_or_inf(graph, args[1], /*max = */ true), \ + args[2], \ + "hardshrink"); \ + } + void gelu(ComputeGraph& graph, const std::vector& args) { // args[1] is the `approximate` string // https://fburl.com/code/9omngmyo @@ -116,6 +127,7 @@ DEFINE_ACTIVATION_FN(tanh); DEFINE_CLAMP_FN(clamp); DEFINE_CLAMP_FN(hardtanh); DEFINE_RELU_FN(relu); +DEFINE_HARDSHRINK_FN(hardshrink); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); @@ -127,6 +139,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.sigmoid.default, sigmoid); VK_REGISTER_OP(aten.sqrt.default, sqrt); VK_REGISTER_OP(aten.tanh.default, tanh); + VK_REGISTER_OP(aten.hardshrink.default, hardshrink); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 8bfe409c1f..35b4bbc9b3 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -794,4 +794,5 @@ def get_gelu_inputs(): "aten._log_softmax.default": get_softmax_inputs(), "aten._native_batch_norm_legit_no_training.default": get_native_batch_norm_inputs(), "aten.gelu.default": get_gelu_inputs(), + "aten.hardshrink.default": get_unary_ops_inputs(), }