diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 6760e70303..0610bffe13 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -22,3 +22,5 @@ unary_op: OPERATOR: sqrt(X) - NAME: tanh OPERATOR: tanh(clamp(X, -15.0, 15.0)) + - NAME: hardshrink + OPERATOR: X * (vec4(greaterThan(X, vec4(A))) + vec4(lessThan(X, vec4(B)))) diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index a21a680894..82063059c7 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -100,6 +100,17 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) { kClampShaderName); \ } +#define DEFINE_HARDSHRINK_FN(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_unary_op_node( \ + graph, \ + args[0], \ + get_val_or_inf(graph, args[1], /*max = */ false), \ + -get_val_or_inf(graph, args[1], /*max = */ true), \ + args[2], \ + "hardshrink"); \ + } + void gelu(ComputeGraph& graph, const std::vector& args) { // args[1] is the `approximate` string // https://fburl.com/code/9omngmyo @@ -116,6 +127,7 @@ DEFINE_ACTIVATION_FN(tanh); DEFINE_CLAMP_FN(clamp); DEFINE_CLAMP_FN(hardtanh); DEFINE_RELU_FN(relu); +DEFINE_HARDSHRINK_FN(hardshrink); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); @@ -127,6 +139,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.sigmoid.default, sigmoid); VK_REGISTER_OP(aten.sqrt.default, sqrt); VK_REGISTER_OP(aten.tanh.default, tanh); + VK_REGISTER_OP(aten.hardshrink.default, hardshrink); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 8bfe409c1f..35b4bbc9b3 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -794,4 +794,5 @@ def get_gelu_inputs(): "aten._log_softmax.default": get_softmax_inputs(), "aten._native_batch_norm_legit_no_training.default": get_native_batch_norm_inputs(), "aten.gelu.default": get_gelu_inputs(), + "aten.hardshrink.default": get_unary_ops_inputs(), }