Skip to content

Commit

Permalink
move the needs grad check into create field member
Browse files Browse the repository at this point in the history
  • Loading branch information
erizmr committed Jul 1, 2022
1 parent fb76054 commit 646a12a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
18 changes: 8 additions & 10 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def __repr__(self):


@python_scope
def create_field_member(dtype, name):
def create_field_member(dtype, name, needs_grad):
dtype = cook_dtype(dtype)

# primal
Expand All @@ -518,6 +518,7 @@ def create_field_member(dtype, name):

x_grad = None
x_dual = None
# TODO: replace the name of `_ti_core.needs_grad`, it only checks whether the dtype can be accepted to compute derivatives
if _ti_core.needs_grad(dtype):
# adjoint
x_grad = Expr(get_runtime().prog.make_id_expr(""))
Expand All @@ -526,14 +527,18 @@ def create_field_member(dtype, name):
x_grad.ptr.set_name(name + ".grad")
x_grad.ptr.set_is_primal(False)
x.ptr.set_adjoint(x_grad.ptr)
if needs_grad:
pytaichi.grad_vars.append(x_grad)

# dual
x_dual = Expr(get_runtime().prog.make_id_expr(""))
x_dual.declaration_tb = get_traceback(stacklevel=4)
x_dual.ptr = _ti_core.global_new(x_dual.ptr, dtype)
x_dual.ptr.set_name(name + ".dual")
x_dual.ptr.set_is_primal(False)
x.ptr.set_dual(x_dual.ptr)
elif needs_grad:
raise TaichiRuntimeError(
f'{dtype} is not supported for field with `needs_grad=True`.')

return x, x_grad, x_dual

Expand Down Expand Up @@ -582,14 +587,7 @@ def field(dtype, shape=None, name="", offset=None, needs_grad=False):
assert (offset is None or shape
is not None), 'The shape cannot be None when offset is being set'

x, x_grad, x_dual = create_field_member(dtype, name)

if needs_grad:
if x_grad is None:
raise TaichiRuntimeError(
f'{dtype} is not supported for field with `needs_grad=True`.')
pytaichi.grad_vars.append(x_grad)

x, x_grad, x_dual = create_field_member(dtype, name, needs_grad)
x, x_grad, x_dual = ScalarField(x), ScalarField(x_grad), ScalarField(
x_dual)

Expand Down
21 changes: 10 additions & 11 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,28 +1095,27 @@ def field(cls,
) == n, f'Please set correct dtype list for Vector. The shape of dtype list should be ({n}, ) instead of {np.shape(dtype)}'
for i in range(n):
entries.append(
impl.create_field_member(dtype[i], name=name))
impl.create_field_member(dtype[i],
name=name,
needs_grad=needs_grad))
else:
assert len(np.shape(dtype)) == 2 and len(dtype) == n and len(
dtype[0]
) == m, f'Please set correct dtype list for Matrix. The shape of dtype list should be ({n}, {m}) instead of {np.shape(dtype)}'
for i in range(n):
for j in range(m):
entries.append(
impl.create_field_member(dtype[i][j], name=name))
impl.create_field_member(dtype[i][j],
name=name,
needs_grad=needs_grad))
else:
for _ in range(n * m):
entries.append(impl.create_field_member(dtype, name=name))
entries.append(
impl.create_field_member(dtype,
name=name,
needs_grad=needs_grad))
entries, entries_grad, entries_dual = zip(*entries)

if needs_grad:
for e in entries_grad:
if e is None:
raise RuntimeError(
f'{dtype} is not supported for field with `needs_grad=True`.'
)
impl.get_runtime().grad_vars.append(e)

entries, entries_grad, entries_dual = MatrixField(
entries, n, m), MatrixField(entries_grad, n,
m), MatrixField(entries_grad, n, m)
Expand Down

0 comments on commit 646a12a

Please sign in to comment.