From 5c3afe1c97b6f73b301a4234c852a2dabc57e5e2 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Fri, 2 Dec 2022 14:43:33 +0800 Subject: [PATCH] [bug] MatrixType bug fix: Fix error with BLS (#6664) Issue: https://github.com/taichi-dev/taichi/issues/5819 ### Brief Summary Fixed a set of issues to make BLS tests work. 1. Modified GroupedNDRange generator to directly yield `Expr with TensorType` instead of `_IntermediateMatrix` when `real_matrix=True` 2. Added support for `rescale_index()` to handle `Expr with TensorType` 3. Added scalarization for `indices` of SNode ops Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yi Xu --- python/taichi/lang/_ndrange.py | 10 +++- python/taichi/lang/snode.py | 64 ++++++++++++++++---- tests/python/test_bls_assume_in_range.py | 75 ++++++++++++++++++++---- 3 files changed, 126 insertions(+), 23 deletions(-) diff --git a/python/taichi/lang/_ndrange.py b/python/taichi/lang/_ndrange.py index f414bf5e75cab..e24d531652d41 100644 --- a/python/taichi/lang/_ndrange.py +++ b/python/taichi/lang/_ndrange.py @@ -1,10 +1,11 @@ import collections.abc import numpy as np -from taichi.lang import ops +from taichi.lang import impl, ops from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.expr import Expr -from taichi.lang.matrix import _IntermediateMatrix +from taichi.lang.matrix import _IntermediateMatrix, make_matrix +from taichi.types import primitive_types from taichi.types.utils import is_integral @@ -144,7 +145,10 @@ def __init__(self, r): def __iter__(self): for ind in self.r: - yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1) + if impl.current_cfg().real_matrix: + yield make_matrix(list(ind), dt=primitive_types.i32) + else: + yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1) __all__ = ['ndrange'] diff --git a/python/taichi/lang/snode.py b/python/taichi/lang/snode.py index 00eda7bfbe356..440fb72d3596a 100644 --- a/python/taichi/lang/snode.py +++ b/python/taichi/lang/snode.py @@ -1,3 +1,4 @@ +import functools import numbers import warnings @@ -7,6 +8,34 @@ from taichi.lang.util import get_traceback +def _get_expanded_indices(indices): + if isinstance(indices, matrix.Matrix): + indices = indices.entries + elif isinstance(indices, expr.Expr) and indices.is_tensor(): + indices = [ + expr.Expr(x) + for x in impl.get_runtime().prog.current_ast_builder().expand_expr( + [indices.ptr]) + ] + return indices + + +def _expand_indices(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + # indices is the second argument to ti.append, ti.activate, ... + if len(args) > 1: + args = list(args) + args[1] = _get_expanded_indices(args[1]) + else: + assert "indices" in kwargs.keys() + kwargs["indices"] = _get_expanded_indices(kwargs["indices"]) + + return func(*args, **kwargs) + + return wrapper + + class SNode: """A Python-side SNode wrapper. @@ -357,25 +386,35 @@ def rescale_index(a, b, I): Returns: Ib (:class:`~taichi.Vector`): rescaled grouped loop index """ + assert isinstance( a, (Field, SNode)), "The first argument must be a field or an SNode" assert isinstance( b, (Field, SNode)), "The second argument must be a field or an SNode" if isinstance(I, list): - I = matrix.Vector(I) + n = len(I) else: assert isinstance( - I, matrix.Matrix - ), "The third argument must be an index (list or ti.Vector)" - entries = [I(i) for i in range(I.n)] - for n in range(min(I.n, min(len(a.shape), len(b.shape)))): - if a.shape[n] > b.shape[n]: - entries[n] = I(n) // (a.shape[n] // b.shape[n]) - if a.shape[n] < b.shape[n]: - entries[n] = I(n) * (b.shape[n] // a.shape[n]) - return matrix.Vector(entries) + I, (expr.Expr, matrix.Matrix) + ), "The third argument must be an index (list, ti.Vector, or Expr with TensorType)" + n = I.n + + from taichi.lang.kernel_impl import pyfunc # pylint: disable=C0415 + + @pyfunc + def _rescale_index(): + result = matrix.Vector([I[i] for i in range(n)]) + for i in impl.static(range(min(n, min(len(a.shape), len(b.shape))))): + if a.shape[i] > b.shape[i]: + result[i] = I[i] // (a.shape[i] // b.shape[i]) + if a.shape[i] < b.shape[i]: + result[i] = I[i] * (b.shape[i] // a.shape[i]) + return result + + return _rescale_index() +@_expand_indices def append(node, indices, val): """Append a value `val` to a SNode `node` at index `indices`. @@ -392,6 +431,7 @@ def append(node, indices, val): return a +@_expand_indices def is_active(node, indices): """Explicitly query whether a cell in a SNode `node` at location `indices` is active or not. @@ -408,6 +448,7 @@ def is_active(node, indices): expr.make_expr_group(indices))) +@_expand_indices def activate(node, indices): """Explicitly activate a cell of `node` at location `indices`. @@ -419,6 +460,7 @@ def activate(node, indices): node._snode.ptr, expr.make_expr_group(indices)) +@_expand_indices def deactivate(node, indices): """Explicitly deactivate a cell of `node` at location `indices`. @@ -433,6 +475,7 @@ def deactivate(node, indices): node._snode.ptr, expr.make_expr_group(indices)) +@_expand_indices def length(node, indices): """Return the length of the dynamic SNode `node` at index `indices`. @@ -448,6 +491,7 @@ def length(node, indices): expr.make_expr_group(indices))) +@_expand_indices def get_addr(f, indices): """Query the memory address (on CUDA/x64) of field `f` at index `indices`. diff --git a/tests/python/test_bls_assume_in_range.py b/tests/python/test_bls_assume_in_range.py index c7dc0cbe3ceab..0e98c659a677f 100644 --- a/tests/python/test_bls_assume_in_range.py +++ b/tests/python/test_bls_assume_in_range.py @@ -4,8 +4,7 @@ from .bls_test_template import bls_particle_grid -@test_utils.test(require=ti.extension.bls) -def test_scattering(): +def _test_scattering(): bls_particle_grid(N=128, ppc=10, block_size=8, @@ -13,8 +12,7 @@ def test_scattering(): use_offset=False) -@test_utils.test(require=ti.extension.bls) -def test_scattering_offset(): +def _test_scattering_offset(): bls_particle_grid(N=128, ppc=10, block_size=8, @@ -22,8 +20,7 @@ def test_scattering_offset(): use_offset=True) -@test_utils.test(require=ti.extension.bls) -def test_scattering_two_pointer_levels(): +def _test_scattering_two_pointer_levels(): bls_particle_grid(N=128, ppc=10, block_size=8, @@ -32,8 +29,7 @@ def test_scattering_two_pointer_levels(): use_offset=False) -@test_utils.test(require=ti.extension.bls) -def test_gathering(): +def _test_gathering(): bls_particle_grid(N=128, ppc=10, block_size=8, @@ -41,8 +37,7 @@ def test_gathering(): use_offset=False) -@test_utils.test(require=ti.extension.bls) -def test_gathering_offset(): +def _test_gathering_offset(): bls_particle_grid(N=128, ppc=10, block_size=8, @@ -50,4 +45,64 @@ def test_gathering_offset(): use_offset=True) +@test_utils.test(require=ti.extension.bls) +def test_gathering(): + _test_gathering() + + +@test_utils.test(require=ti.extension.bls) +def test_gathering_offset(): + _test_gathering_offset() + + +@test_utils.test(require=ti.extension.bls) +def test_scattering_two_pointer_levels(): + _test_scattering_two_pointer_levels() + + +@test_utils.test(require=ti.extension.bls) +def test_scattering(): + _test_scattering() + + +@test_utils.test(require=ti.extension.bls) +def test_scattering_offset(): + _test_scattering_offset() + + +@test_utils.test(require=ti.extension.bls, + real_matrix=True, + real_matrix_scalarize=True) +def test_gathering_matrix_scalarize(): + _test_gathering() + + +@test_utils.test(require=ti.extension.bls, + real_matrix=True, + real_matrix_scalarize=True) +def test_gathering_offset_matrix_scalarize(): + _test_gathering_offset() + + +@test_utils.test(require=ti.extension.bls, + real_matrix=True, + real_matrix_scalarize=True) +def test_scattering_matrix_scalarize(): + _test_scattering() + + +@test_utils.test(require=ti.extension.bls, + real_matrix=True, + real_matrix_scalarize=True) +def test_scattering_offset_matrix_scalarize(): + _test_scattering_offset() + + +@test_utils.test(require=ti.extension.bls, + real_matrix=True, + real_matrix_scalarize=True) +def test_scattering_two_pointer_levels_matrix_scalarize(): + _test_scattering_two_pointer_levels() + + # TODO: debug mode behavior of assume_in_range