diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 70288029dde42..f52bfff75a6f3 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -300,8 +300,7 @@ def subscript(self, *indices): # ptr.is_global_ptr() will check whether it's an element in the field (which is different from ptr.is_global_var()). if isinstance( self.entries[0], - ti.Expr) and self.entries[0].ptr.is_global_ptr() and ( - ti.cfg.arch == ti.cpu or ti.cfg.arch == ti.gpu): + ti.Expr) and self.entries[0].ptr.is_global_ptr() and ti.is_extension_supported(ti.cfg.arch, ti.extension.dynamic_index): return ti.subscript_with_offset(self.entries[0], (i, j), self.m, True) else: diff --git a/taichi/inc/extensions.inc.h b/taichi/inc/extensions.inc.h index 5c0ba10343346..0dc106864c8a0 100644 --- a/taichi/inc/extensions.inc.h +++ b/taichi/inc/extensions.inc.h @@ -9,3 +9,4 @@ PER_EXTENSION(bls) // Block-local storage PER_EXTENSION(assertion) // Run-time asserts in Taichi kernels PER_EXTENSION(extfunc) // Invoke external functions or backend source PER_EXTENSION(packed) // Shape will not be padded to a power of two +PER_EXTENSION(dynamic_index) // Dynamic index support for both global and local tensors diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index fd0d1c48cb31a..dfa04f51b24d0 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -318,6 +318,9 @@ class GlobalTensorElementStmt : public Stmt { } bool has_global_side_effect() const override { + // After access lowered, activate info will be recorded in SNodeLookupStmt's + // activate for AOS sparse data structure. We don't support SOA sparse data + // structure for now. return false; } diff --git a/taichi/program/extension.cpp b/taichi/program/extension.cpp index 7dbae325f3067..a33b098d3a701 100644 --- a/taichi/program/extension.cpp +++ b/taichi/program/extension.cpp @@ -11,15 +11,17 @@ bool is_extension_supported(Arch arch, Extension ext) { {Arch::x64, {Extension::sparse, Extension::async_mode, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, - Extension::assertion, Extension::extfunc, Extension::packed}}, + Extension::assertion, Extension::extfunc, Extension::packed, + Extension::dynamic_index}}, {Arch::arm64, {Extension::sparse, Extension::async_mode, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, - Extension::assertion, Extension::packed}}, + Extension::assertion, Extension::packed, Extension::dynamic_index}}, {Arch::cuda, {Extension::sparse, Extension::async_mode, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, - Extension::bls, Extension::assertion, Extension::packed}}, + Extension::bls, Extension::assertion, Extension::packed, + Extension::dynamic_index}}, {Arch::metal, {Extension::adstack, Extension::assertion, Extension::quant_basic, Extension::async_mode, Extension::sparse}}, diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 74572077e78b0..c8257cd2a5076 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -161,18 +161,32 @@ def run(): assert np.allclose(r2[None].value.to_numpy(), ops(a, c)) -@ti.test(arch=[ti.cpu, ti.gpu]) +@ti.test(require=ti.extension.dynamic_index) def test_matrix_non_constant_index(): m = ti.Matrix.field(2, 2, ti.i32, 5) + v = ti.Vector.field(10, ti.i32, 5) @ti.kernel - def func(): + def func1(): for i in range(5): for j, k in ti.ndrange(2, 2): - m[i][j, k] = 12 - - func() + m[i][j, k] = j * j + k * k + assert m[1][0, 1] == 1 + assert m[2][1, 0] == 1 + assert m[3][1, 1] == 2 + func1() + assert m[4][0, 1] == 1 + @ti.kernel + def func2(): + for i in range(5): + for j in range(4): + v[i][j * j] = j * j + assert v[1][0] == 0 + assert v[1][1] == 1 + assert v[1][4] == 4 + func2() + assert v[1][9] == 9 @ti.test(arch=ti.cpu) def test_matrix_constant_index():