From 909096354b2bd312816022d7a0d75d2c909038c3 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Fri, 21 Oct 2022 13:02:52 +0800 Subject: [PATCH] [Lang] [bug] Allow filling a field with Expr (#6391) Issue: #6318 ### Brief Summary The implementation and tests have been refactored by the way to ease future maintenance efforts. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/_kernels.py | 19 ++--------- python/taichi/lang/matrix.py | 32 +++++++------------ tests/python/test_field.py | 20 ------------ tests/python/test_fill.py | 62 +++++++++++++++++------------------- 4 files changed, 43 insertions(+), 90 deletions(-) diff --git a/python/taichi/_kernels.py b/python/taichi/_kernels.py index 09abf7db9bf64..cd7808c5c20ae 100644 --- a/python/taichi/_kernels.py +++ b/python/taichi/_kernels.py @@ -1,5 +1,4 @@ -from typing import Iterable - +from taichi._funcs import field_fill_taichi_scope from taichi._lib.utils import get_os_name from taichi.lang import ops from taichi.lang._ndrange import ndrange @@ -237,20 +236,8 @@ def clear_loss(l: template()): @kernel -def fill_matrix(mat: template(), vals: template()): - for I in grouped(mat): - for p in static(range(mat.n)): - for q in static(range(mat.m)): - if static(mat[I].ndim == 2): - if static(isinstance(vals[p], Iterable)): - mat[I][p, q] = vals[p][q] - else: - mat[I][p, q] = vals[p] - else: - if static(isinstance(vals[p], Iterable)): - mat[I][p] = vals[p][q] - else: - mat[I][p] = vals[p] +def field_fill_python_scope(F: template(), val: template()): + field_fill_taichi_scope(F, val) @kernel diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index c1e148ef00584..57f82e29c450c 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1630,34 +1630,24 @@ def fill(self, val): """Fills this matrix field with specified values. Args: - val (Union[Number, List, Tuple, Matrix]): Values to fill, + val (Union[Number, Expr, List, Tuple, Matrix]): Values to fill, should have consistent dimension consistent with `self`. """ - if isinstance(val, numbers.Number): - val = tuple( - [tuple([val for _ in range(self.m)]) for _ in range(self.n)]) - elif isinstance(val, - (list, tuple)) and isinstance(val[0], numbers.Number): - assert self.m == 1 - val = tuple(val) - elif is_vector(val): - val = tuple([(val(i), ) for i in range(self.n * self.m)]) + if isinstance(val, numbers.Number) or (isinstance(val, expr.Expr) + and not val.is_tensor()): + val = list(list(val for _ in range(self.m)) for _ in range(self.n)) elif isinstance(val, Matrix): - val_tuple = [] - for i in range(val.n): - row = [] - for j in range(val.m): - row.append(val(i, j)) - row = tuple(row) - val_tuple.append(row) - val = tuple(val_tuple) + val = val.to_list() + else: + assert isinstance(val, (list, tuple)) + val = tuple(tuple(x) if isinstance(x, list) else x for x in val) assert len(val) == self.n if self.ndim != 1: assert len(val[0]) == self.m - if in_python_scope(): - from taichi._kernels import fill_matrix # pylint: disable=C0415 - fill_matrix(self, val) + from taichi._kernels import \ + field_fill_python_scope # pylint: disable=C0415 + field_fill_python_scope(self, val) else: from taichi._funcs import \ field_fill_taichi_scope # pylint: disable=C0415 diff --git a/tests/python/test_field.py b/tests/python/test_field.py index b7af5e70c21b4..d8e934329e686 100644 --- a/tests/python/test_field.py +++ b/tests/python/test_field.py @@ -304,26 +304,6 @@ def test_indexing_mat_field_with_np_int(): val[idx][idx, idx] -@test_utils.test(exclude=[ti.cc], debug=True) -def test_field_fill(): - x = ti.field(int, shape=(3, 3)) - x.fill(2) - - y = ti.field(float, shape=(3, 3)) - y.fill(2.0) - - z = ti.Vector.field(3, float, shape=(3, 3)) - z.fill([1, 2, 3]) - - @ti.kernel - def test(): - x.fill(3) - y.fill(3.0) - z.fill([4, 5, 6]) - - test() - - @test_utils.test() def test_python_for_in(): x = ti.field(int, shape=3) diff --git a/tests/python/test_fill.py b/tests/python/test_fill.py index 5692fe38860f4..9a7f1b1dada23 100644 --- a/tests/python/test_fill.py +++ b/tests/python/test_fill.py @@ -3,69 +3,65 @@ @test_utils.test() -def test_fill_scalar(): - val = ti.field(ti.i32) +def test_fill_scalar_field(): n = 4 m = 7 + val = ti.field(ti.i32, shape=(n, m)) - ti.root.dense(ti.ij, (n, m)).place(val) - + val.fill(2) for i in range(n): for j in range(m): - val[i, j] = i + j * 3 + assert val[i, j] == 2 - val.fill(2) + @ti.kernel + def fill_in_kernel(v: ti.i32): + val.fill(v) + fill_in_kernel(3) for i in range(n): for j in range(m): - assert val[i, j] == 2 + assert val[i, j] == 3 @test_utils.test() -def test_fill_matrix_scalar(): - val = ti.Matrix.field(2, 3, ti.i32) - +def test_fill_matrix_field_with_scalar(): n = 4 m = 7 + val = ti.Matrix.field(2, 3, ti.i32, shape=(n, m)) - ti.root.dense(ti.ij, (n, m)).place(val) - + val.fill(2) for i in range(n): for j in range(m): - for p in range(2): - for q in range(3): - val[i, j][p, q] = i + j * 3 + assert (val[i, j] == 2).all() - val.fill(2) + @ti.kernel + def fill_in_kernel(v: ti.i32): + val.fill(v) + fill_in_kernel(3) for i in range(n): for j in range(m): - for p in range(2): - for q in range(3): - assert val[i, j][p, q] == 2 + assert (val[i, j] == 3).all() @test_utils.test() -def test_fill_matrix_matrix(): - val = ti.Matrix.field(2, 3, ti.i32) - +def test_fill_matrix_field_with_matrix(): n = 4 m = 7 + val = ti.Matrix.field(2, 3, ti.i32, shape=(n, m)) - ti.root.dense(ti.ij, (n, m)).place(val) - + mat = ti.Matrix([[0, 1, 2], [2, 3, 4]]) + val.fill(mat) for i in range(n): for j in range(m): - for p in range(2): - for q in range(3): - val[i, j][p, q] = i + j * 3 - - mat = ti.Matrix([[0, 1, 2], [2, 3, 4]]) + assert (val[i, j] == mat).all() - val.fill(mat) + @ti.kernel + def fill_in_kernel(v: ti.types.matrix(2, 3, ti.i32)): + val.fill(v) + mat = ti.Matrix([[4, 5, 6], [6, 7, 8]]) + fill_in_kernel(mat) for i in range(n): for j in range(m): - for p in range(2): - for q in range(3): - assert val[i, j][p, q] == mat(p, q) + assert (val[i, j] == mat).all()