From 646a12ad406f9684dc4de32f0bddab6c424e54e9 Mon Sep 17 00:00:00 2001 From: mingrui Date: Fri, 1 Jul 2022 17:58:38 +0800 Subject: [PATCH] move the needs grad check into create field member --- python/taichi/lang/impl.py | 18 ++++++++---------- python/taichi/lang/matrix.py | 21 ++++++++++----------- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index d18ce81ce1ead..6cbb7bbf9b031 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -505,7 +505,7 @@ def __repr__(self): @python_scope -def create_field_member(dtype, name): +def create_field_member(dtype, name, needs_grad): dtype = cook_dtype(dtype) # primal @@ -518,6 +518,7 @@ def create_field_member(dtype, name): x_grad = None x_dual = None + # TODO: replace the name of `_ti_core.needs_grad`, it only checks whether the dtype can be accepted to compute derivatives if _ti_core.needs_grad(dtype): # adjoint x_grad = Expr(get_runtime().prog.make_id_expr("")) @@ -526,14 +527,18 @@ def create_field_member(dtype, name): x_grad.ptr.set_name(name + ".grad") x_grad.ptr.set_is_primal(False) x.ptr.set_adjoint(x_grad.ptr) + if needs_grad: + pytaichi.grad_vars.append(x_grad) # dual x_dual = Expr(get_runtime().prog.make_id_expr("")) - x_dual.declaration_tb = get_traceback(stacklevel=4) x_dual.ptr = _ti_core.global_new(x_dual.ptr, dtype) x_dual.ptr.set_name(name + ".dual") x_dual.ptr.set_is_primal(False) x.ptr.set_dual(x_dual.ptr) + elif needs_grad: + raise TaichiRuntimeError( + f'{dtype} is not supported for field with `needs_grad=True`.') return x, x_grad, x_dual @@ -582,14 +587,7 @@ def field(dtype, shape=None, name="", offset=None, needs_grad=False): assert (offset is None or shape is not None), 'The shape cannot be None when offset is being set' - x, x_grad, x_dual = create_field_member(dtype, name) - - if needs_grad: - if x_grad is None: - raise TaichiRuntimeError( - f'{dtype} is not supported for field with `needs_grad=True`.') - pytaichi.grad_vars.append(x_grad) - + x, x_grad, x_dual = create_field_member(dtype, name, needs_grad) x, x_grad, x_dual = ScalarField(x), ScalarField(x_grad), ScalarField( x_dual) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index f274bf96a81b8..a1992e090a80d 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1095,7 +1095,9 @@ def field(cls, ) == n, f'Please set correct dtype list for Vector. The shape of dtype list should be ({n}, ) instead of {np.shape(dtype)}' for i in range(n): entries.append( - impl.create_field_member(dtype[i], name=name)) + impl.create_field_member(dtype[i], + name=name, + needs_grad=needs_grad)) else: assert len(np.shape(dtype)) == 2 and len(dtype) == n and len( dtype[0] @@ -1103,20 +1105,17 @@ def field(cls, for i in range(n): for j in range(m): entries.append( - impl.create_field_member(dtype[i][j], name=name)) + impl.create_field_member(dtype[i][j], + name=name, + needs_grad=needs_grad)) else: for _ in range(n * m): - entries.append(impl.create_field_member(dtype, name=name)) + entries.append( + impl.create_field_member(dtype, + name=name, + needs_grad=needs_grad)) entries, entries_grad, entries_dual = zip(*entries) - if needs_grad: - for e in entries_grad: - if e is None: - raise RuntimeError( - f'{dtype} is not supported for field with `needs_grad=True`.' - ) - impl.get_runtime().grad_vars.append(e) - entries, entries_grad, entries_dual = MatrixField( entries, n, m), MatrixField(entries_grad, n, m), MatrixField(entries_grad, n, m)