Skip to content

Commit

Permalink
aten.hardshrink.default in unary_ops (pytorch#3681)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3681

Implement aten.hardshrink in unary_ops
```
func: hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)
```

Another standalone implementation can be found: D57500203

Reviewed By: SS-JIA

Differential Revision: D57573208

fbshipit-source-id: 7939543fddf8e39f6bb4e681c43d386140ae560b
  • Loading branch information
Yujie Hui authored and facebook-github-bot committed May 20, 2024
1 parent 788676f commit 2b603df
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 0 deletions.
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ unary_op:
OPERATOR: sqrt(X)
- NAME: tanh
OPERATOR: tanh(clamp(X, -15.0, 15.0))
- NAME: hardshrink
OPERATOR: X * (vec4(greaterThan(X, vec4(A))) + vec4(lessThan(X, vec4(B))))
13 changes: 13 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
kClampShaderName); \
}

#define DEFINE_HARDSHRINK_FN(op_name) \
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
return add_unary_op_node( \
graph, \
args[0], \
get_val_or_inf(graph, args[1], /*max = */ false), \
-get_val_or_inf(graph, args[1], /*max = */ true), \
args[2], \
"hardshrink"); \
}

void gelu(ComputeGraph& graph, const std::vector<ValueRef>& args) {
// args[1] is the `approximate` string
// https://fburl.com/code/9omngmyo
Expand All @@ -116,6 +127,7 @@ DEFINE_ACTIVATION_FN(tanh);
DEFINE_CLAMP_FN(clamp);
DEFINE_CLAMP_FN(hardtanh);
DEFINE_RELU_FN(relu);
DEFINE_HARDSHRINK_FN(hardshrink);

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.abs.default, abs);
Expand All @@ -127,6 +139,7 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
VK_REGISTER_OP(aten.sqrt.default, sqrt);
VK_REGISTER_OP(aten.tanh.default, tanh);
VK_REGISTER_OP(aten.hardshrink.default, hardshrink);
}

} // namespace vkcompute
1 change: 1 addition & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,4 +794,5 @@ def get_gelu_inputs():
"aten._log_softmax.default": get_softmax_inputs(),
"aten._native_batch_norm_legit_no_training.default": get_native_batch_norm_inputs(),
"aten.gelu.default": get_gelu_inputs(),
"aten.hardshrink.default": get_unary_ops_inputs(),
}

0 comments on commit 2b603df

Please sign in to comment.