From 5c7762841f776d5aec7990985571019a5c518651 Mon Sep 17 00:00:00 2001 From: listerily Date: Fri, 26 May 2023 11:21:30 +0800 Subject: [PATCH] [bug] Fix misbehaviour and assertion error on ti.math.sign --- python/taichi/math/mathimpl.py | 2 +- tests/python/test_unary_ops.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/taichi/math/mathimpl.py b/python/taichi/math/mathimpl.py index 947cea1fb71e9..6bd0b06b01245 100644 --- a/python/taichi/math/mathimpl.py +++ b/python/taichi/math/mathimpl.py @@ -273,7 +273,7 @@ def sign(x): >>> ti.math.sign(x) [-1.000000, 0.000000, 1.000000] """ - return ops.cast((x >= 0.0) - (x <= 0.0), float) + return ops.cast((x >= 0.0), float) - ops.cast((x <= 0.0), float) @func diff --git a/tests/python/test_unary_ops.py b/tests/python/test_unary_ops.py index 574e5e0d914db..6fe014bee872b 100644 --- a/tests/python/test_unary_ops.py +++ b/tests/python/test_unary_ops.py @@ -142,3 +142,14 @@ def test_u32(x: ti.uint32) -> ti.int32: assert test_u32(100) == 3 assert test_u32(1000) == 6 assert test_u32(10000) == 5 + + +@test_utils.test() +def test_sign(): + @ti.kernel + def foo(val: ti.f32) -> ti.f32: + return ti.math.sign(val) + + assert foo(0.5) == 1.0 + assert foo(-0.5) == -1.0 + assert foo(0.0) == 0.0