From 10fcb9c547523b6dbc85131f4c07017839c9232b Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Sun, 27 Sep 2020 16:40:08 +0900 Subject: [PATCH] [perf] Wrap clear loss/grad into a kernel --- python/taichi/lang/__init__.py | 6 ++++-- python/taichi/lang/meta.py | 8 ++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index 3eb0855ef8da4..5414c47e8098d 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -273,8 +273,10 @@ def Tape(loss, clear_gradients=True): ' for all fields that are required by autodiff.') if clear_gradients: clear_all_gradients() - loss[None] = 0 - loss.grad[None] = 1 + + from .meta import clear_loss + clear_loss(loss) + return runtime.get_tape(loss) diff --git a/python/taichi/lang/meta.py b/python/taichi/lang/meta.py index d314161c03ccb..243c27153f8bc 100644 --- a/python/taichi/lang/meta.py +++ b/python/taichi/lang/meta.py @@ -82,6 +82,14 @@ def clear_gradients(vars: ti.template()): ti.Expr(s)[I] = 0 +@ti.kernel +def clear_loss(l: ti.template()): + # Using SNode writers would result in a forced sync, therefore we wrap these + # writes into a kernel. + l[None] = 0 + l.grad[None] = 1 + + @ti.kernel def fill_matrix(mat: ti.template(), vals: ti.template()): for I in ti.grouped(mat):