From 1ee9365b306389c3fcf472d548ad4ec047979ae8 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Wed, 22 Jul 2020 14:29:32 +0900 Subject: [PATCH] [bug] Update mpm_lagrangian_force and fix Matrix constructor (#1545) * [bug] Update mpm_lagrangian_force and fix Matrix constructor * enforce code format (don't skip ci) * add test * add scalar test Co-authored-by: Taichi Gardener --- examples/mpm_lagrangian_forces.py | 20 ++++++++++++-------- python/taichi/lang/matrix.py | 1 + tests/python/test_field.py | 16 ++++++++++++++++ tests/python/test_matrix.py | 16 ++++++++++++++++ 4 files changed, 45 insertions(+), 8 deletions(-) diff --git a/examples/mpm_lagrangian_forces.py b/examples/mpm_lagrangian_forces.py index eb2f47f3a46b3..342cd8f39ea89 100644 --- a/examples/mpm_lagrangian_forces.py +++ b/examples/mpm_lagrangian_forces.py @@ -17,14 +17,18 @@ mu = 1 la = 1 -x = ti.Vector(dim, dt=ti.f32, shape=n_particles, needs_grad=True) -v = ti.Vector(dim, dt=ti.f32, shape=n_particles) -C = ti.Matrix(dim, dim, dt=ti.f32, shape=n_particles) -grid_v = ti.Vector(dim, dt=ti.f32, shape=(n_grid, n_grid)) -grid_m = ti.var(dt=ti.f32, shape=(n_grid, n_grid)) -restT = ti.Matrix(dim, dim, dt=ti.f32, shape=n_particles, needs_grad=True) -total_energy = ti.var(ti.f32, shape=(), needs_grad=True) -vertices = ti.var(ti.i32, shape=(n_elements, 3)) +x = ti.Vector.field(dim, dtype=ti.f32, shape=n_particles, needs_grad=True) +v = ti.Vector.field(dim, dtype=ti.f32, shape=n_particles) +C = ti.Matrix.field(dim, dim, dtype=ti.f32, shape=n_particles) +grid_v = ti.Vector.field(dim, dtype=ti.f32, shape=(n_grid, n_grid)) +grid_m = ti.field(dtype=ti.f32, shape=(n_grid, n_grid)) +restT = ti.Matrix.field(dim, + dim, + dtype=ti.f32, + shape=n_particles, + needs_grad=True) +total_energy = ti.field(dtype=ti.f32, shape=(), needs_grad=True) +vertices = ti.field(dtype=ti.i32, shape=(n_elements, 3)) @ti.func diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index da1b06c2afa0f..9cbd76de85723 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -99,6 +99,7 @@ def __init__(self, self.n = mat.n self.m = mat.m self.entries = mat.entries + self.grad = mat.grad if self.n * self.m > 32: warning( diff --git a/tests/python/test_field.py b/tests/python/test_field.py index 9358c0454229b..c0af4dae55f57 100644 --- a/tests/python/test_field.py +++ b/tests/python/test_field.py @@ -58,3 +58,19 @@ def test_matrix_field(n, m, dtype, shape): assert x.dtype == dtype assert x.n == n assert x.m == m + + +@ti.host_arch_only +def test_field_needs_grad(): + # Just make sure the usage doesn't crash, see https://github.com/taichi-dev/taichi/pull/1545 + n = 8 + m1 = ti.field(ti.f32, n, needs_grad=True) + m2 = ti.field(ti.f32, n, needs_grad=True) + gr = ti.field(ti.f32, n) + + @ti.kernel + def func(): + for i in range(n): + gr[i] = m1.grad[i] + m2.grad[i] + + func() diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 97be2d02d8c03..7f37c798e0806 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -186,3 +186,19 @@ def func(): func() assert np.allclose(m.to_numpy(), np.ones((5, 2, 2), np.int32) * 12) + + +@ti.all_archs +def test_matrix_needs_grad(): + # Just make sure the usage doesn't crash, see https://github.com/taichi-dev/taichi/pull/1545 + n = 8 + m1 = ti.Matrix.field(2, 2, ti.f32, n, needs_grad=True) + m2 = ti.Matrix.field(2, 2, ti.f32, n, needs_grad=True) + gr = ti.Matrix.field(2, 2, ti.f32, n) + + @ti.kernel + def func(): + for i in range(n): + gr[i] = m1.grad[i] + m2.grad[i] + + func()