Skip to content

Commit

Permalink
[bug] MatrixType bug fix: Fix error with BLS (#6664)
Browse files Browse the repository at this point in the history
Issue: #5819

### Brief Summary
Fixed a set of issues to make BLS tests work.
1. Modified GroupedNDRange generator to directly yield `Expr with
TensorType` instead of `_IntermediateMatrix` when `real_matrix=True`
2. Added support for `rescale_index()` to handle `Expr with TensorType`
3. Added scalarization for `indices` of SNode ops

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yi Xu <xy_xuyi@foxmail.com>
  • Loading branch information
3 people authored Dec 2, 2022
1 parent fa2433d commit 5c3afe1
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 23 deletions.
10 changes: 7 additions & 3 deletions python/taichi/lang/_ndrange.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import collections.abc

import numpy as np
from taichi.lang import ops
from taichi.lang import impl, ops
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
from taichi.lang.expr import Expr
from taichi.lang.matrix import _IntermediateMatrix
from taichi.lang.matrix import _IntermediateMatrix, make_matrix
from taichi.types import primitive_types
from taichi.types.utils import is_integral


Expand Down Expand Up @@ -144,7 +145,10 @@ def __init__(self, r):

def __iter__(self):
for ind in self.r:
yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1)
if impl.current_cfg().real_matrix:
yield make_matrix(list(ind), dt=primitive_types.i32)
else:
yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1)


__all__ = ['ndrange']
64 changes: 54 additions & 10 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import numbers
import warnings

Expand All @@ -7,6 +8,34 @@
from taichi.lang.util import get_traceback


def _get_expanded_indices(indices):
if isinstance(indices, matrix.Matrix):
indices = indices.entries
elif isinstance(indices, expr.Expr) and indices.is_tensor():
indices = [
expr.Expr(x)
for x in impl.get_runtime().prog.current_ast_builder().expand_expr(
[indices.ptr])
]
return indices


def _expand_indices(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# indices is the second argument to ti.append, ti.activate, ...
if len(args) > 1:
args = list(args)
args[1] = _get_expanded_indices(args[1])
else:
assert "indices" in kwargs.keys()
kwargs["indices"] = _get_expanded_indices(kwargs["indices"])

return func(*args, **kwargs)

return wrapper


class SNode:
"""A Python-side SNode wrapper.
Expand Down Expand Up @@ -357,25 +386,35 @@ def rescale_index(a, b, I):
Returns:
Ib (:class:`~taichi.Vector`): rescaled grouped loop index
"""

assert isinstance(
a, (Field, SNode)), "The first argument must be a field or an SNode"
assert isinstance(
b, (Field, SNode)), "The second argument must be a field or an SNode"
if isinstance(I, list):
I = matrix.Vector(I)
n = len(I)
else:
assert isinstance(
I, matrix.Matrix
), "The third argument must be an index (list or ti.Vector)"
entries = [I(i) for i in range(I.n)]
for n in range(min(I.n, min(len(a.shape), len(b.shape)))):
if a.shape[n] > b.shape[n]:
entries[n] = I(n) // (a.shape[n] // b.shape[n])
if a.shape[n] < b.shape[n]:
entries[n] = I(n) * (b.shape[n] // a.shape[n])
return matrix.Vector(entries)
I, (expr.Expr, matrix.Matrix)
), "The third argument must be an index (list, ti.Vector, or Expr with TensorType)"
n = I.n

from taichi.lang.kernel_impl import pyfunc # pylint: disable=C0415

@pyfunc
def _rescale_index():
result = matrix.Vector([I[i] for i in range(n)])
for i in impl.static(range(min(n, min(len(a.shape), len(b.shape))))):
if a.shape[i] > b.shape[i]:
result[i] = I[i] // (a.shape[i] // b.shape[i])
if a.shape[i] < b.shape[i]:
result[i] = I[i] * (b.shape[i] // a.shape[i])
return result

return _rescale_index()


@_expand_indices
def append(node, indices, val):
"""Append a value `val` to a SNode `node` at index `indices`.
Expand All @@ -392,6 +431,7 @@ def append(node, indices, val):
return a


@_expand_indices
def is_active(node, indices):
"""Explicitly query whether a cell in a SNode `node` at location
`indices` is active or not.
Expand All @@ -408,6 +448,7 @@ def is_active(node, indices):
expr.make_expr_group(indices)))


@_expand_indices
def activate(node, indices):
"""Explicitly activate a cell of `node` at location `indices`.
Expand All @@ -419,6 +460,7 @@ def activate(node, indices):
node._snode.ptr, expr.make_expr_group(indices))


@_expand_indices
def deactivate(node, indices):
"""Explicitly deactivate a cell of `node` at location `indices`.
Expand All @@ -433,6 +475,7 @@ def deactivate(node, indices):
node._snode.ptr, expr.make_expr_group(indices))


@_expand_indices
def length(node, indices):
"""Return the length of the dynamic SNode `node` at index `indices`.
Expand All @@ -448,6 +491,7 @@ def length(node, indices):
expr.make_expr_group(indices)))


@_expand_indices
def get_addr(f, indices):
"""Query the memory address (on CUDA/x64) of field `f` at index `indices`.
Expand Down
75 changes: 65 additions & 10 deletions tests/python/test_bls_assume_in_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,23 @@
from .bls_test_template import bls_particle_grid


@test_utils.test(require=ti.extension.bls)
def test_scattering():
def _test_scattering():
bls_particle_grid(N=128,
ppc=10,
block_size=8,
scatter=True,
use_offset=False)


@test_utils.test(require=ti.extension.bls)
def test_scattering_offset():
def _test_scattering_offset():
bls_particle_grid(N=128,
ppc=10,
block_size=8,
scatter=True,
use_offset=True)


@test_utils.test(require=ti.extension.bls)
def test_scattering_two_pointer_levels():
def _test_scattering_two_pointer_levels():
bls_particle_grid(N=128,
ppc=10,
block_size=8,
Expand All @@ -32,22 +29,80 @@ def test_scattering_two_pointer_levels():
use_offset=False)


@test_utils.test(require=ti.extension.bls)
def test_gathering():
def _test_gathering():
bls_particle_grid(N=128,
ppc=10,
block_size=8,
scatter=False,
use_offset=False)


@test_utils.test(require=ti.extension.bls)
def test_gathering_offset():
def _test_gathering_offset():
bls_particle_grid(N=128,
ppc=10,
block_size=8,
scatter=False,
use_offset=True)


@test_utils.test(require=ti.extension.bls)
def test_gathering():
_test_gathering()


@test_utils.test(require=ti.extension.bls)
def test_gathering_offset():
_test_gathering_offset()


@test_utils.test(require=ti.extension.bls)
def test_scattering_two_pointer_levels():
_test_scattering_two_pointer_levels()


@test_utils.test(require=ti.extension.bls)
def test_scattering():
_test_scattering()


@test_utils.test(require=ti.extension.bls)
def test_scattering_offset():
_test_scattering_offset()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gathering_matrix_scalarize():
_test_gathering()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gathering_offset_matrix_scalarize():
_test_gathering_offset()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_scattering_matrix_scalarize():
_test_scattering()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_scattering_offset_matrix_scalarize():
_test_scattering_offset()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_scattering_two_pointer_levels_matrix_scalarize():
_test_scattering_two_pointer_levels()


# TODO: debug mode behavior of assume_in_range

0 comments on commit 5c3afe1

Please sign in to comment.