Skip to content

Commit

Permalink
[lang] Make dynamic indexing compatible with BLS (#3244)
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier authored Oct 22, 2021
1 parent 9c6aa37 commit 48c8c7d
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 48 deletions.
7 changes: 0 additions & 7 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,14 +559,7 @@ def block_local(*args):
Args:
*args (List[Field]): A list of sparse Taichi fields.
Raises:
InvalidOperationError: If the ``dynamic_index`` feature (experimental)
is enabled.
"""
if ti.current_cfg().dynamic_index:
raise InvalidOperationError(
'dynamic_index is not allowed when block_local is turned on.')
for a in args:
for v in a.get_field_members():
_ti_core.insert_snode_access_flag(
Expand Down
4 changes: 2 additions & 2 deletions tests/python/bls_test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def p2g(use_shared: ti.template(), m: ti.template()):
u0 = ti.assume_in_range(u_[0], Im[0], 0, 1)
u1 = ti.assume_in_range(u_[1], Im[1], 0, 1)

u = ti.Vector([u0, u1])
u = ti.Vector([u0, u1], dt=ti.i32)

for offset in ti.static(ti.grouped(ti.ndrange(extend, extend))):
m[u + offset] += scatter_weight
Expand Down Expand Up @@ -230,7 +230,7 @@ def g2p(use_shared: ti.template(), s: ti.template()):
u0 = ti.assume_in_range(u_[0], Im[0], 0, 1)
u1 = ti.assume_in_range(u_[1], Im[1], 0, 1)

u = ti.Vector([u0, u1])
u = ti.Vector([u0, u1], dt=ti.i32)

tot = 0.0

Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_ad_for.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def complex():
assert a.grad[i] == g[i]


@ti.test(require=[ti.extension.adstack, ti.extension.bls], dynamic_index=False)
@ti.test(require=[ti.extension.adstack, ti.extension.bls])
def test_triple_for_loops_bls():
N = 8
M = 3
Expand Down
49 changes: 12 additions & 37 deletions tests/python/test_bls.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,7 @@
import pytest

import taichi as ti


@ti.test(require=ti.extension.bls, dynamic_index=True)
def test_bls_with_dynamic_index():
x, y = ti.field(ti.f32), ti.field(ti.f32)

N = 64
bs = 16

ti.root.pointer(ti.i, N // bs).dense(ti.i, bs).place(x, y)

@ti.kernel
def populate():
for i in range(N):
x[i] = i

@ti.kernel
def call_block_local():
ti.block_local(x)

populate()
with pytest.raises(ti.InvalidOperationError):
call_block_local()


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_simple_1d():
x, y = ti.field(ti.f32), ti.field(ti.f32)

Expand All @@ -53,7 +28,7 @@ def copy():
assert y[i] == i


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_simple_2d():
x, y = ti.field(ti.f32), ti.field(ti.f32)

Expand Down Expand Up @@ -86,57 +61,57 @@ def _test_bls_stencil(*args, **kwargs):
bls_test_template(*args, **kwargs)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_gather_1d_trivial():
# y[i] = x[i]
_test_bls_stencil(1, 128, bs=32, stencil=((0, ), ))


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_gather_1d():
# y[i] = x[i - 1] + x[i]
_test_bls_stencil(1, 128, bs=32, stencil=((-1, ), (0, )))


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_gather_2d():
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=16, stencil=stencil)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_gather_2d_nonsquare():
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=(4, 16), stencil=stencil)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_gather_3d():
stencil = [(-1, -1, -1), (2, 0, 1)]
_test_bls_stencil(3, 64, bs=(4, 8, 16), stencil=stencil)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_scatter_1d_trivial():
# y[i] = x[i]
_test_bls_stencil(1, 128, bs=32, stencil=((0, ), ), scatter=True)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_scatter_1d():
_test_bls_stencil(1, 128, bs=32, stencil=(
(1, ),
(0, ),
), scatter=True)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_scatter_2d():
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=16, stencil=stencil, scatter=True)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_multiple_inputs():
x, y, z, w, w2 = ti.field(ti.i32), ti.field(ti.i32), ti.field(
ti.i32), ti.field(ti.i32), ti.field(ti.i32)
Expand Down Expand Up @@ -171,7 +146,7 @@ def copy(bls: ti.template(), w: ti.template()):
assert w[i, j] == w2[i, j]


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_bls_large_block():
n = 2**10
block_size = 32
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_bls_assume_in_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _test_scattering_two_pointer_levels():
use_offset=False)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_gathering():
bls_particle_grid(N=128,
ppc=10,
Expand Down

0 comments on commit 48c8c7d

Please sign in to comment.