Skip to content

Commit

Permalink
[lang] Migrate TensorType expansion for svd from Python code to Front…
Browse files Browse the repository at this point in the history
…end IR (#6972)

Issue: #5819

### Brief Summary
  • Loading branch information
jim19930609 authored Dec 26, 2022
1 parent e06baff commit 99b39d0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
5 changes: 2 additions & 3 deletions python/taichi/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def _svd3d(A, dt, iters=None):
Decomposed 3x3 matrices `U`, 'S' and `V`.
"""
assert A.n == 3 and A.m == 3
inputs = get_runtime().prog.current_ast_builder().expand_expr([A.ptr])
assert dt in [f32, f64]
if iters is None:
if dt == f32:
Expand All @@ -171,10 +170,10 @@ def _svd3d(A, dt, iters=None):
iters = 8
if dt == f32:
rets = get_runtime().prog.current_ast_builder().sifakis_svd_f32(
*inputs, iters)
A.ptr, iters)
else:
rets = get_runtime().prog.current_ast_builder().sifakis_svd_f64(
*inputs, iters)
A.ptr, iters)
assert len(rets) == 21
U_entries = rets[:9]
V_entries = rets[9:18]
Expand Down
25 changes: 14 additions & 11 deletions taichi/math/svd.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,20 @@ std::tuple<Expr,
Expr,
Expr,
Expr>
sifakis_svd_export(ASTBuilder *ast_builder,
const Expr &a00,
const Expr &a01,
const Expr &a02,
const Expr &a10,
const Expr &a11,
const Expr &a12,
const Expr &a20,
const Expr &a21,
const Expr &a22,
int num_iters) {
sifakis_svd_export(ASTBuilder *ast_builder, const Expr &mat, int num_iters) {
auto expanded_exprs = ast_builder->expand_expr({mat});
TI_ASSERT(expanded_exprs.size() == 9);

Expr a00 = expanded_exprs[0];
Expr a01 = expanded_exprs[1];
Expr a02 = expanded_exprs[2];
Expr a10 = expanded_exprs[3];
Expr a11 = expanded_exprs[4];
Expr a12 = expanded_exprs[5];
Expr a20 = expanded_exprs[6];
Expr a21 = expanded_exprs[7];
Expr a22 = expanded_exprs[8];

static_assert(sizeof(Tf) == sizeof(Ti), "");
constexpr Tf Four_Gamma_Squared = 5.82842712474619f;
constexpr Tf Sine_Pi_Over_Eight = 0.3826834323650897f;
Expand Down

0 comments on commit 99b39d0

Please sign in to comment.