diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 798c8bd681f3e..cfc5a4c3867f0 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -153,6 +153,10 @@ def build_assign_unpack(ctx, node_target, values, is_static_assign): # Unpack: a, b, c = ti.Vector([1., 2., 3.]) if isinstance(values, impl.Expr) and values.ptr.is_tensor(): + if len(values.get_shape()) > 1: + raise ValueError( + 'Matrices with more than one columns cannot be unpacked') + values = ctx.ast_builder.expand_expr([values.ptr]) if len(values) == 1: values = values[0] diff --git a/tests/python/test_tuple_assign.py b/tests/python/test_tuple_assign.py index a4dc7e317d3c2..83e547887b57a 100644 --- a/tests/python/test_tuple_assign.py +++ b/tests/python/test_tuple_assign.py @@ -207,8 +207,7 @@ def func(): func() -@test_utils.test(arch=get_host_arch_list()) -def test_unpack_mismatch_matrix(): +def _test_unpack_mismatch_matrix(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) c = ti.field(ti.f32, ()) @@ -223,6 +222,18 @@ def func(): func() +@test_utils.test(arch=get_host_arch_list()) +def test_unpack_mismatch_matrix(): + _test_unpack_mismatch_matrix() + + +@test_utils.test(arch=get_host_arch_list(), + real_matrix=True, + real_matrix_scalarize=True) +def test_unpack_mismatch_matrix_scalarize(): + _test_unpack_mismatch_matrix() + + @test_utils.test(arch=get_host_arch_list()) def test_unpack_from_shape(): a = ti.field(ti.f32, ())