From 840522430b4a0169e5b48f13a5d5df33f4a02c37 Mon Sep 17 00:00:00 2001 From: archibate <1931127624@qq.com> Date: Sun, 5 Jul 2020 15:22:33 +0800 Subject: [PATCH] [test] [std] Add matrix SSA violation regression test --- tests/python/test_linalg.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/python/test_linalg.py b/tests/python/test_linalg.py index c977ebf33a450..c0da5cad6cf0f 100644 --- a/tests/python/test_linalg.py +++ b/tests/python/test_linalg.py @@ -2,6 +2,7 @@ import numpy as np from taichi import approx import pytest +import math @ti.all_archs @@ -65,6 +66,25 @@ def init(): assert aNormalized[None][2] == approx(3.0 * invSqrt14) +@ti.all_archs +def test_matrix_ssa(): + a = ti.Vector(2, ti.f32, ()) + b = ti.Matrix(2, 2, ti.f32, ()) + + @ti.kernel + def func(): + a[None] = a[None].normalized() + b[None] = b[None].transpose() + + inv_sqrt2 = 1 / math.sqrt(2) + + a[None] = [1, 1] + b[None] = [[1, 2], [3, 4]] + func() + assert a[None].value == ti.Vector([inv_sqrt2, inv_sqrt2]) + assert b[None].value == ti.Matrix([[1, 3], [2, 4]]) + + @ti.all_archs def test_cross(): a = ti.Vector(3, dt=ti.f32) @@ -243,8 +263,6 @@ def test_mat_inverse(): @ti.all_archs def test_matrix_factories(): - import math - a = ti.Vector.var(3, dt=ti.i32, shape=3) b = ti.Matrix.var(2, 2, dt=ti.f32, shape=2) c = ti.Matrix.var(2, 3, dt=ti.f32, shape=2)