Skip to content

Commit

Permalink
[Lang] [bug] Allow filling a field with Expr (taichi-dev#6391)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#6318

### Brief Summary

The implementation and tests have been refactored by the way to ease
future maintenance efforts.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and jim19930609 committed Oct 25, 2022
1 parent 65b5311 commit 9090963
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 90 deletions.
19 changes: 3 additions & 16 deletions python/taichi/_kernels.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Iterable

from taichi._funcs import field_fill_taichi_scope
from taichi._lib.utils import get_os_name
from taichi.lang import ops
from taichi.lang._ndrange import ndrange
Expand Down Expand Up @@ -237,20 +236,8 @@ def clear_loss(l: template()):


@kernel
def fill_matrix(mat: template(), vals: template()):
for I in grouped(mat):
for p in static(range(mat.n)):
for q in static(range(mat.m)):
if static(mat[I].ndim == 2):
if static(isinstance(vals[p], Iterable)):
mat[I][p, q] = vals[p][q]
else:
mat[I][p, q] = vals[p]
else:
if static(isinstance(vals[p], Iterable)):
mat[I][p] = vals[p][q]
else:
mat[I][p] = vals[p]
def field_fill_python_scope(F: template(), val: template()):
field_fill_taichi_scope(F, val)


@kernel
Expand Down
32 changes: 11 additions & 21 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,34 +1630,24 @@ def fill(self, val):
"""Fills this matrix field with specified values.
Args:
val (Union[Number, List, Tuple, Matrix]): Values to fill,
val (Union[Number, Expr, List, Tuple, Matrix]): Values to fill,
should have consistent dimension consistent with `self`.
"""
if isinstance(val, numbers.Number):
val = tuple(
[tuple([val for _ in range(self.m)]) for _ in range(self.n)])
elif isinstance(val,
(list, tuple)) and isinstance(val[0], numbers.Number):
assert self.m == 1
val = tuple(val)
elif is_vector(val):
val = tuple([(val(i), ) for i in range(self.n * self.m)])
if isinstance(val, numbers.Number) or (isinstance(val, expr.Expr)
and not val.is_tensor()):
val = list(list(val for _ in range(self.m)) for _ in range(self.n))
elif isinstance(val, Matrix):
val_tuple = []
for i in range(val.n):
row = []
for j in range(val.m):
row.append(val(i, j))
row = tuple(row)
val_tuple.append(row)
val = tuple(val_tuple)
val = val.to_list()
else:
assert isinstance(val, (list, tuple))
val = tuple(tuple(x) if isinstance(x, list) else x for x in val)
assert len(val) == self.n
if self.ndim != 1:
assert len(val[0]) == self.m

if in_python_scope():
from taichi._kernels import fill_matrix # pylint: disable=C0415
fill_matrix(self, val)
from taichi._kernels import \
field_fill_python_scope # pylint: disable=C0415
field_fill_python_scope(self, val)
else:
from taichi._funcs import \
field_fill_taichi_scope # pylint: disable=C0415
Expand Down
20 changes: 0 additions & 20 deletions tests/python/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,26 +304,6 @@ def test_indexing_mat_field_with_np_int():
val[idx][idx, idx]


@test_utils.test(exclude=[ti.cc], debug=True)
def test_field_fill():
x = ti.field(int, shape=(3, 3))
x.fill(2)

y = ti.field(float, shape=(3, 3))
y.fill(2.0)

z = ti.Vector.field(3, float, shape=(3, 3))
z.fill([1, 2, 3])

@ti.kernel
def test():
x.fill(3)
y.fill(3.0)
z.fill([4, 5, 6])

test()


@test_utils.test()
def test_python_for_in():
x = ti.field(int, shape=3)
Expand Down
62 changes: 29 additions & 33 deletions tests/python/test_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,65 @@


@test_utils.test()
def test_fill_scalar():
val = ti.field(ti.i32)
def test_fill_scalar_field():
n = 4
m = 7
val = ti.field(ti.i32, shape=(n, m))

ti.root.dense(ti.ij, (n, m)).place(val)

val.fill(2)
for i in range(n):
for j in range(m):
val[i, j] = i + j * 3
assert val[i, j] == 2

val.fill(2)
@ti.kernel
def fill_in_kernel(v: ti.i32):
val.fill(v)

fill_in_kernel(3)
for i in range(n):
for j in range(m):
assert val[i, j] == 2
assert val[i, j] == 3


@test_utils.test()
def test_fill_matrix_scalar():
val = ti.Matrix.field(2, 3, ti.i32)

def test_fill_matrix_field_with_scalar():
n = 4
m = 7
val = ti.Matrix.field(2, 3, ti.i32, shape=(n, m))

ti.root.dense(ti.ij, (n, m)).place(val)

val.fill(2)
for i in range(n):
for j in range(m):
for p in range(2):
for q in range(3):
val[i, j][p, q] = i + j * 3
assert (val[i, j] == 2).all()

val.fill(2)
@ti.kernel
def fill_in_kernel(v: ti.i32):
val.fill(v)

fill_in_kernel(3)
for i in range(n):
for j in range(m):
for p in range(2):
for q in range(3):
assert val[i, j][p, q] == 2
assert (val[i, j] == 3).all()


@test_utils.test()
def test_fill_matrix_matrix():
val = ti.Matrix.field(2, 3, ti.i32)

def test_fill_matrix_field_with_matrix():
n = 4
m = 7
val = ti.Matrix.field(2, 3, ti.i32, shape=(n, m))

ti.root.dense(ti.ij, (n, m)).place(val)

mat = ti.Matrix([[0, 1, 2], [2, 3, 4]])
val.fill(mat)
for i in range(n):
for j in range(m):
for p in range(2):
for q in range(3):
val[i, j][p, q] = i + j * 3

mat = ti.Matrix([[0, 1, 2], [2, 3, 4]])
assert (val[i, j] == mat).all()

val.fill(mat)
@ti.kernel
def fill_in_kernel(v: ti.types.matrix(2, 3, ti.i32)):
val.fill(v)

mat = ti.Matrix([[4, 5, 6], [6, 7, 8]])
fill_in_kernel(mat)
for i in range(n):
for j in range(m):
for p in range(2):
for q in range(3):
assert val[i, j][p, q] == mat(p, q)
assert (val[i, j] == mat).all()

0 comments on commit 9090963

Please sign in to comment.