Skip to content

Commit

Permalink
add extension and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
squarefk committed Jul 29, 2021
1 parent 9985bac commit 848db37
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 10 deletions.
3 changes: 1 addition & 2 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,7 @@ def subscript(self, *indices):
# ptr.is_global_ptr() will check whether it's an element in the field (which is different from ptr.is_global_var()).
if isinstance(
self.entries[0],
ti.Expr) and self.entries[0].ptr.is_global_ptr() and (
ti.cfg.arch == ti.cpu or ti.cfg.arch == ti.gpu):
ti.Expr) and self.entries[0].ptr.is_global_ptr() and ti.is_extension_supported(ti.cfg.arch, ti.extension.dynamic_index):
return ti.subscript_with_offset(self.entries[0], (i, j),
self.m, True)
else:
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/extensions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ PER_EXTENSION(bls) // Block-local storage
PER_EXTENSION(assertion) // Run-time asserts in Taichi kernels
PER_EXTENSION(extfunc) // Invoke external functions or backend source
PER_EXTENSION(packed) // Shape will not be padded to a power of two
PER_EXTENSION(dynamic_index) // Dynamic index support for both global and local tensors
3 changes: 3 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ class GlobalTensorElementStmt : public Stmt {
}

bool has_global_side_effect() const override {
// After access lowered, activate info will be recorded in SNodeLookupStmt's
// activate for AOS sparse data structure. We don't support SOA sparse data
// structure for now.
return false;
}

Expand Down
8 changes: 5 additions & 3 deletions taichi/program/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ bool is_extension_supported(Arch arch, Extension ext) {
{Arch::x64,
{Extension::sparse, Extension::async_mode, Extension::quant,
Extension::quant_basic, Extension::data64, Extension::adstack,
Extension::assertion, Extension::extfunc, Extension::packed}},
Extension::assertion, Extension::extfunc, Extension::packed,
Extension::dynamic_index}},
{Arch::arm64,
{Extension::sparse, Extension::async_mode, Extension::quant,
Extension::quant_basic, Extension::data64, Extension::adstack,
Extension::assertion, Extension::packed}},
Extension::assertion, Extension::packed, Extension::dynamic_index}},
{Arch::cuda,
{Extension::sparse, Extension::async_mode, Extension::quant,
Extension::quant_basic, Extension::data64, Extension::adstack,
Extension::bls, Extension::assertion, Extension::packed}},
Extension::bls, Extension::assertion, Extension::packed,
Extension::dynamic_index}},
{Arch::metal,
{Extension::adstack, Extension::assertion, Extension::quant_basic,
Extension::async_mode, Extension::sparse}},
Expand Down
24 changes: 19 additions & 5 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,32 @@ def run():
assert np.allclose(r2[None].value.to_numpy(), ops(a, c))


@ti.test(arch=[ti.cpu, ti.gpu])
@ti.test(require=ti.extension.dynamic_index)
def test_matrix_non_constant_index():
m = ti.Matrix.field(2, 2, ti.i32, 5)
v = ti.Vector.field(10, ti.i32, 5)

@ti.kernel
def func():
def func1():
for i in range(5):
for j, k in ti.ndrange(2, 2):
m[i][j, k] = 12

func()
m[i][j, k] = j * j + k * k
assert m[1][0, 1] == 1
assert m[2][1, 0] == 1
assert m[3][1, 1] == 2
func1()
assert m[4][0, 1] == 1

@ti.kernel
def func2():
for i in range(5):
for j in range(4):
v[i][j * j] = j * j
assert v[1][0] == 0
assert v[1][1] == 1
assert v[1][4] == 4
func2()
assert v[1][9] == 9

@ti.test(arch=ti.cpu)
def test_matrix_constant_index():
Expand Down

0 comments on commit 848db37

Please sign in to comment.