Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Make dynamic indexing compatible with BLS #3244

Merged
merged 1 commit into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,14 +559,7 @@ def block_local(*args):

Args:
*args (List[Field]): A list of sparse Taichi fields.

Raises:
InvalidOperationError: If the ``dynamic_index`` feature (experimental)
is enabled.
"""
if ti.current_cfg().dynamic_index:
raise InvalidOperationError(
'dynamic_index is not allowed when block_local is turned on.')
for a in args:
for v in a.get_field_members():
_ti_core.insert_snode_access_flag(
Expand Down
4 changes: 2 additions & 2 deletions tests/python/bls_test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def p2g(use_shared: ti.template(), m: ti.template()):
u0 = ti.assume_in_range(u_[0], Im[0], 0, 1)
u1 = ti.assume_in_range(u_[1], Im[1], 0, 1)

u = ti.Vector([u0, u1])
u = ti.Vector([u0, u1], dt=ti.i32)

for offset in ti.static(ti.grouped(ti.ndrange(extend, extend))):
m[u + offset] += scatter_weight
Expand Down Expand Up @@ -230,7 +230,7 @@ def g2p(use_shared: ti.template(), s: ti.template()):
u0 = ti.assume_in_range(u_[0], Im[0], 0, 1)
u1 = ti.assume_in_range(u_[1], Im[1], 0, 1)

u = ti.Vector([u0, u1])
u = ti.Vector([u0, u1], dt=ti.i32)

tot = 0.0

Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_ad_for.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def complex():
assert a.grad[i] == g[i]


@ti.test(require=[ti.extension.adstack, ti.extension.bls], dynamic_index=False)
@ti.test(require=[ti.extension.adstack, ti.extension.bls])
def test_triple_for_loops_bls():
N = 8
M = 3
Expand Down
49 changes: 12 additions & 37 deletions tests/python/test_bls.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,7 @@
import pytest

import taichi as ti


@ti.test(require=ti.extension.bls, dynamic_index=True)
def test_bls_with_dynamic_index():
x, y = ti.field(ti.f32), ti.field(ti.f32)

N = 64
bs = 16

ti.root.pointer(ti.i, N // bs).dense(ti.i, bs).place(x, y)

@ti.kernel
def populate():
for i in range(N):
x[i] = i

@ti.kernel
def call_block_local():
ti.block_local(x)

populate()
with pytest.raises(ti.InvalidOperationError):
call_block_local()


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_simple_1d():
x, y = ti.field(ti.f32), ti.field(ti.f32)

Expand All @@ -53,7 +28,7 @@ def copy():
assert y[i] == i


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_simple_2d():
x, y = ti.field(ti.f32), ti.field(ti.f32)

Expand Down Expand Up @@ -86,57 +61,57 @@ def _test_bls_stencil(*args, **kwargs):
bls_test_template(*args, **kwargs)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_gather_1d_trivial():
# y[i] = x[i]
_test_bls_stencil(1, 128, bs=32, stencil=((0, ), ))


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_gather_1d():
# y[i] = x[i - 1] + x[i]
_test_bls_stencil(1, 128, bs=32, stencil=((-1, ), (0, )))


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_gather_2d():
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=16, stencil=stencil)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_gather_2d_nonsquare():
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=(4, 16), stencil=stencil)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_gather_3d():
stencil = [(-1, -1, -1), (2, 0, 1)]
_test_bls_stencil(3, 64, bs=(4, 8, 16), stencil=stencil)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_scatter_1d_trivial():
# y[i] = x[i]
_test_bls_stencil(1, 128, bs=32, stencil=((0, ), ), scatter=True)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_scatter_1d():
_test_bls_stencil(1, 128, bs=32, stencil=(
(1, ),
(0, ),
), scatter=True)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_scatter_2d():
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
_test_bls_stencil(2, 128, bs=16, stencil=stencil, scatter=True)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_multiple_inputs():
x, y, z, w, w2 = ti.field(ti.i32), ti.field(ti.i32), ti.field(
ti.i32), ti.field(ti.i32), ti.field(ti.i32)
Expand Down Expand Up @@ -171,7 +146,7 @@ def copy(bls: ti.template(), w: ti.template()):
assert w[i, j] == w2[i, j]


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_bls_large_block():
n = 2**10
block_size = 32
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_bls_assume_in_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _test_scattering_two_pointer_levels():
use_offset=False)


@ti.test(require=ti.extension.bls, dynamic_index=False)
@ti.test(require=ti.extension.bls)
def test_gathering():
bls_particle_grid(N=128,
ppc=10,
Expand Down