From e8f981692de21ce9d804617c12f7d36aed4535f2 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Thu, 24 Nov 2022 15:02:44 +0800 Subject: [PATCH] [bug] MatrixType bug fix: Fix error with nested StructType and MatrixType (#6689) Issue: https://github.com/taichi-dev/taichi/issues/5819 ### Brief Summary 1. Modified `Matrix::fill()` to broadcast `val` into VectorType if `ndim == 1`, and to MatrixType if `ndim == 2` 2. Modified `Struct::fill()` to apply `matrix_op.fill()` in case of `Expr` with TensorType --- python/taichi/lang/matrix.py | 7 +++++- python/taichi/lang/struct.py | 12 ++++++--- taichi/transforms/lower_ast.cpp | 8 ++---- tests/python/test_custom_struct.py | 39 +++++++++++++++++++++++++----- 4 files changed, 49 insertions(+), 17 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index f67b95850df8f..90050709042dd 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1542,7 +1542,12 @@ def fill(self, val): """ if isinstance(val, numbers.Number) or (isinstance(val, expr.Expr) and not val.is_tensor()): - val = list(list(val for _ in range(self.m)) for _ in range(self.n)) + if self.ndim == 2: + val = list( + list(val for _ in range(self.m)) for _ in range(self.n)) + else: + assert self.ndim == 1 + val = list(val for _ in range(self.n)) elif isinstance(val, Matrix): val = val.to_list() else: diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py index 7581b9e2637ff..58c9b370ec22b 100644 --- a/python/taichi/lang/struct.py +++ b/python/taichi/lang/struct.py @@ -223,10 +223,14 @@ def fill(self, val): Args: val (Union[int, float]): Value to fill. """ - def assign_renamed(x, y): - return ops.assign(x, y) - - return self._element_wise_writeback_binary(assign_renamed, val) + for k, v in self.items: + if isinstance(v, impl.Expr) and v.ptr.is_tensor(): + from taichi.lang import matrix_ops # pylint: disable=C0415 + matrix_ops.fill(v, val) + elif isinstance(v, (Struct, Matrix)): + v._element_wise_binary(ops.assign, val) + else: + ops.assign(v, val) def __len__(self): """Get the number of entries in a custom struct""" diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 72805b14118ba..68dcb47b70401 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -396,25 +396,21 @@ class LowerAST : public IRVisitor { auto expr = assign->rhs; auto fctx = make_flatten_ctx(); flatten_rvalue(expr, &fctx); + flatten_lvalue(dest, &fctx); if (dest.is()) { - fctx.push_back( - assign->parent->lookup_var(assign->lhs.cast()->id), - expr->stmt); + fctx.push_back(dest->stmt, expr->stmt); } else if (dest.is()) { auto ix = dest.cast(); - flatten_lvalue(dest, &fctx); if (ix->is_local()) { fctx.push_back(dest->stmt, expr->stmt); } else { fctx.push_back(dest->stmt, expr->stmt); } } else if (dest.is()) { - flatten_lvalue(dest, &fctx); fctx.push_back(dest->stmt, expr->stmt); } else { TI_ASSERT(dest.is() && dest.cast()->is_ptr); - flatten_lvalue(dest, &fctx); fctx.push_back(dest->stmt, expr->stmt); } fctx.stmts.back()->set_tb(assign->tb); diff --git a/tests/python/test_custom_struct.py b/tests/python/test_custom_struct.py index 66bbde14f2599..f6dfa22e26576 100644 --- a/tests/python/test_custom_struct.py +++ b/tests/python/test_custom_struct.py @@ -77,8 +77,7 @@ def run_python_scope(): assert y[i].b == int(1.01 * i) -@test_utils.test() -def test_struct_fill(): +def _test_struct_fill(): n = 32 # also tests implicit cast @@ -114,6 +113,16 @@ def fill_elements(): assert np.allclose(x[i].b.to_numpy(), int(x[i].a)) +@test_utils.test() +def test_struct_fill(): + _test_struct_fill() + + +@test_utils.test(real_matrix=True, real_matrix_scalarize=True) +def test_struct_fill_matrix_scalarize(): + _test_struct_fill() + + @test_utils.test() def test_matrix_type(): n = 32 @@ -142,8 +151,7 @@ def run_python_scope(): assert np.allclose(x[i].to_numpy(), np.array([i + 1, i, i])) -@test_utils.test() -def test_struct_type(): +def _test_struct_type(): n = 32 vec3f = ti.types.vector(3, float) line3f = ti.types.struct(linedir=vec3f, length=float) @@ -204,6 +212,16 @@ def run_python_scope(): assert x[i].line.length == 5.0 +@test_utils.test() +def test_struct_type(): + _test_struct_type() + + +@test_utils.test(real_matrix=True, real_matrix_scalarize=True) +def test_struct_type_matrix_scalarize(): + _test_struct_type() + + @test_utils.test(exclude=ti.cc) def test_dataclass(): # example struct class type @@ -245,8 +263,7 @@ def get_area_field() -> ti.f32: assert np.isclose(get_area_field(), 4.0 * 3.14 * 4.0) -@test_utils.test() -def test_struct_assign(): +def _test_struct_assign(): n = 32 vec3f = ti.types.vector(3, float) line3f = ti.types.struct(linedir=vec3f, length=float) @@ -284,6 +301,16 @@ def run_python_scope(): assert x[i].line.length == i + 0.5 +@test_utils.test() +def test_struct_assign(): + _test_struct_assign() + + +@test_utils.test(real_matrix=True, real_matrix_scalarize=True) +def test_struct_assign_matrix_scalarize(): + _test_struct_assign() + + @test_utils.test() def test_compound_type_implicit_cast(): vec2i = ti.types.vector(2, int)