diff --git a/python/taichi/_funcs.py b/python/taichi/_funcs.py index 31b692adc9d86..7ce41cdc0679b 100644 --- a/python/taichi/_funcs.py +++ b/python/taichi/_funcs.py @@ -162,7 +162,6 @@ def _svd3d(A, dt, iters=None): Decomposed 3x3 matrices `U`, 'S' and `V`. """ assert A.n == 3 and A.m == 3 - inputs = get_runtime().prog.current_ast_builder().expand_expr([A.ptr]) assert dt in [f32, f64] if iters is None: if dt == f32: @@ -171,10 +170,10 @@ def _svd3d(A, dt, iters=None): iters = 8 if dt == f32: rets = get_runtime().prog.current_ast_builder().sifakis_svd_f32( - *inputs, iters) + A.ptr, iters) else: rets = get_runtime().prog.current_ast_builder().sifakis_svd_f64( - *inputs, iters) + A.ptr, iters) assert len(rets) == 21 U_entries = rets[:9] V_entries = rets[9:18] diff --git a/taichi/math/svd.h b/taichi/math/svd.h index 3385cb8833afb..d0da10beb3da2 100644 --- a/taichi/math/svd.h +++ b/taichi/math/svd.h @@ -42,17 +42,20 @@ std::tuple -sifakis_svd_export(ASTBuilder *ast_builder, - const Expr &a00, - const Expr &a01, - const Expr &a02, - const Expr &a10, - const Expr &a11, - const Expr &a12, - const Expr &a20, - const Expr &a21, - const Expr &a22, - int num_iters) { +sifakis_svd_export(ASTBuilder *ast_builder, const Expr &mat, int num_iters) { + auto expanded_exprs = ast_builder->expand_expr({mat}); + TI_ASSERT(expanded_exprs.size() == 9); + + Expr a00 = expanded_exprs[0]; + Expr a01 = expanded_exprs[1]; + Expr a02 = expanded_exprs[2]; + Expr a10 = expanded_exprs[3]; + Expr a11 = expanded_exprs[4]; + Expr a12 = expanded_exprs[5]; + Expr a20 = expanded_exprs[6]; + Expr a21 = expanded_exprs[7]; + Expr a22 = expanded_exprs[8]; + static_assert(sizeof(Tf) == sizeof(Ti), ""); constexpr Tf Four_Gamma_Squared = 5.82842712474619f; constexpr Tf Sine_Pi_Over_Eight = 0.3826834323650897f;