Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autodiff] Check not placed field.grad when needs_grad = True #5295

Merged
merged 7 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def __init__(self, kernels=None):
self.inside_kernel = False
self.current_kernel = None
self.global_vars = []
self.grad_vars = []
self.matrix_fields = []
self.default_fp = f32
self.default_ip = i32
Expand Down Expand Up @@ -300,6 +301,22 @@ def _check_field_not_placed(self):
f'{bar}Please consider specifying a shape for them. E.g.,' +
'\n\n x = ti.field(float, shape=(2, 3))')

def _check_grad_field_not_placed(self):
not_placed = set()
for _var in self.grad_vars:
if _var.ptr.snode() is None:
not_placed.add(self._get_tb(_var))

if len(not_placed):
bar = '=' * 44 + '\n'
raise RuntimeError(
f'These field(s) requrie `needs_grad=True`, however their grad field(s) are not placed:\n{bar}'
+ f'{bar}'.join(not_placed) +
f'{bar}Please consider place the grad field(s). E.g.,' +
'\n\n ti.root.dense(ti.i, 1).place(x.grad)' +
'\n\n Or specify a shape for the field(s). E.g.,' +
'\n\n x = ti.field(float, shape=(2, 3), needs_grad=True)')

def _check_matrix_field_member_shape(self):
for _field in self.matrix_fields:
shapes = [
Expand All @@ -321,9 +338,11 @@ def materialize(self):
self.materialized = True

self._check_field_not_placed()
self._check_grad_field_not_placed()
self._check_matrix_field_member_shape()
self._calc_matrix_field_dynamic_index_stride()
self.global_vars = []
self.grad_vars = []
self.matrix_fields = []

def _register_signal_handlers(self):
Expand Down Expand Up @@ -486,7 +505,7 @@ def __repr__(self):


@python_scope
def create_field_member(dtype, name):
def create_field_member(dtype, name, needs_grad):
dtype = cook_dtype(dtype)

# primal
Expand All @@ -499,20 +518,27 @@ def create_field_member(dtype, name):

x_grad = None
x_dual = None
# TODO: replace the name of `_ti_core.needs_grad`, it only checks whether the dtype can be accepted to compute derivatives
if _ti_core.needs_grad(dtype):
# adjoint
x_grad = Expr(get_runtime().prog.make_id_expr(""))
x_grad.declaration_tb = get_traceback(stacklevel=4)
x_grad.ptr = _ti_core.global_new(x_grad.ptr, dtype)
x_grad.ptr.set_name(name + ".grad")
x_grad.ptr.set_is_primal(False)
x.ptr.set_adjoint(x_grad.ptr)
if needs_grad:
pytaichi.grad_vars.append(x_grad)

# dual
x_dual = Expr(get_runtime().prog.make_id_expr(""))
x_dual.ptr = _ti_core.global_new(x_dual.ptr, dtype)
x_dual.ptr.set_name(name + ".dual")
x_dual.ptr.set_is_primal(False)
x.ptr.set_dual(x_dual.ptr)
elif needs_grad:
raise TaichiRuntimeError(
f'{dtype} is not supported for field with `needs_grad=True`.')

return x, x_grad, x_dual

Expand All @@ -533,7 +559,7 @@ def field(dtype, shape=None, name="", offset=None, needs_grad=False):
shape (Union[int, tuple[int]], optional): shape of the field.
name (str, optional): name of the field.
offset (Union[int, tuple[int]], optional): offset of the field domain.
needs_grad (bool, optional): whether this field participates in autodiff
needs_grad (bool, optional): whether this field participates in autodiff (reverse mode)
and thus needs an adjoint field to store the gradients.

Example::
Expand Down Expand Up @@ -561,9 +587,10 @@ def field(dtype, shape=None, name="", offset=None, needs_grad=False):
assert (offset is None or shape
is not None), 'The shape cannot be None when offset is being set'

x, x_grad, x_dual = create_field_member(dtype, name)
x, x_grad, x_dual = create_field_member(dtype, name, needs_grad)
x, x_grad, x_dual = ScalarField(x), ScalarField(x_grad), ScalarField(
x_dual)

x._set_grad(x_grad)
x._set_dual(x_dual)

Expand Down
18 changes: 14 additions & 4 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ def field(cls,
name (string, optional): The custom name of the field.
offset (Union[int, tuple of int], optional): The coordinate offset
of all elements in a field.
needs_grad (bool, optional): Whether the Matrix need gradients.
needs_grad (bool, optional): Whether the Matrix need grad field (reverse mode autodiff).
layout (Layout, optional): The field layout, either Array Of
Structure (AOS) or Structure Of Array (SOA).

Expand All @@ -1095,24 +1095,34 @@ def field(cls,
) == n, f'Please set correct dtype list for Vector. The shape of dtype list should be ({n}, ) instead of {np.shape(dtype)}'
for i in range(n):
entries.append(
impl.create_field_member(dtype[i], name=name))
impl.create_field_member(dtype[i],
name=name,
needs_grad=needs_grad))
else:
assert len(np.shape(dtype)) == 2 and len(dtype) == n and len(
dtype[0]
) == m, f'Please set correct dtype list for Matrix. The shape of dtype list should be ({n}, {m}) instead of {np.shape(dtype)}'
for i in range(n):
for j in range(m):
entries.append(
impl.create_field_member(dtype[i][j], name=name))
impl.create_field_member(dtype[i][j],
name=name,
needs_grad=needs_grad))
else:
for _ in range(n * m):
entries.append(impl.create_field_member(dtype, name=name))
entries.append(
impl.create_field_member(dtype,
name=name,
needs_grad=needs_grad))
entries, entries_grad, entries_dual = zip(*entries)

entries, entries_grad, entries_dual = MatrixField(
entries, n, m), MatrixField(entries_grad, n,
m), MatrixField(entries_grad, n, m)

entries._set_grad(entries_grad)
entries._set_dual(entries_dual)

impl.get_runtime().matrix_fields.append(entries)

if shape is None:
Expand Down
14 changes: 12 additions & 2 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def field(cls,
offset (Tuple[int]): offset of the indices of the created field.
For example if `offset=(-10, -10)` the indices of the field
will start at `(-10, -10)`, not `(0, 0)`.
needs_grad (bool): enabling gradient field or not.
needs_grad (bool): enabling grad field (reverse mode autodiff) or not.
layout: AOS or SOA.

Example:
Expand Down Expand Up @@ -356,6 +356,7 @@ def field(cls,
grads = tuple(e.grad for e in field_dict.values())
impl.root.dense(impl.index_nd(dim),
shape).place(*grads, offset=offset)

return StructField(field_dict, methods, name=name)


Expand Down Expand Up @@ -387,11 +388,20 @@ class StructField(Field):
to each struct instance in the field.
name (string, optional): The custom name of the field.
"""
def __init__(self, field_dict, struct_methods, name=None):
def __init__(self, field_dict, struct_methods, name=None, is_primal=True):
# will not call Field initializer
self.field_dict = field_dict
self.struct_methods = struct_methods
self.name = name
self.grad = None
if is_primal:
grad_field_dict = {}
for k, v in self.field_dict.items():
grad_field_dict[k] = v.grad
self.grad = StructField(grad_field_dict,
struct_methods,
name + ".grad",
is_primal=False)
self._register_fields()

@property
Expand Down
28 changes: 28 additions & 0 deletions tests/python/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,34 @@ def func():
func()


@test_utils.test()
def test_field_needs_grad_dtype():
with pytest.raises(
RuntimeError,
match=r".* is not supported for field with `needs_grad=True`."):
a = ti.field(int, shape=1, needs_grad=True)
with pytest.raises(
RuntimeError,
match=r".* is not supported for field with `needs_grad=True`."):
b = ti.Vector.field(3, int, shape=1, needs_grad=True)
with pytest.raises(
RuntimeError,
match=r".* is not supported for field with `needs_grad=True`."):
c = ti.Matrix.field(2, 3, int, shape=1, needs_grad=True)
with pytest.raises(
RuntimeError,
match=r".* is not supported for field with `needs_grad=True`."):
d = ti.Struct.field(
{
"pos": ti.types.vector(3, int),
"vel": ti.types.vector(3, float),
"acc": ti.types.vector(3, float),
"mass": ti.f32,
},
shape=1,
needs_grad=True)


@pytest.mark.parametrize('dtype', [ti.f32, ti.f64])
def test_default_fp(dtype):
ti.init(default_fp=dtype)
Expand Down
34 changes: 34 additions & 0 deletions tests/python/test_fields_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,37 @@ def test_field_initialize_zero():
fb1.dense(ti.i, 1).place(b)
d = fb1.finalize()
assert b[0] == 0


@test_utils.test(exclude=[ti.opengl, ti.cc])
def test_field_builder_place_grad():
@ti.kernel
def mul(arr: ti.template(), out: ti.template()):
for i in arr:
out[i] = arr[i] * 2.0

@ti.kernel
def calc_loss(arr: ti.template(), loss: ti.template()):
for i in arr:
loss[None] += arr[i]

arr = ti.field(ti.f32, needs_grad=True)
fb0 = ti.FieldsBuilder()
fb0.dense(ti.i, 10).place(arr, arr.grad)
snode0 = fb0.finalize()
out = ti.field(ti.f32)
fb1 = ti.FieldsBuilder()
fb1.dense(ti.i, 10).place(out, out.grad)
snode1 = fb1.finalize()
loss = ti.field(ti.f32)
fb2 = ti.FieldsBuilder()
fb2.place(loss, loss.grad)
snode2 = fb2.finalize()
arr.fill(1.0)
mul(arr, out)
calc_loss(out, loss)
loss.grad[None] = 1.0
calc_loss.grad(out, loss)
mul.grad(arr, out)
for i in range(10):
assert arr.grad[i] == 2.0
75 changes: 75 additions & 0 deletions tests/python/test_materialize_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,81 @@ def foo():
foo()


@test_utils.test()
def test_check_grad_field_not_placed():
a = ti.field(ti.f32, needs_grad=True)
ti.root.dense(ti.i, 1).place(a)

@ti.kernel
def foo():
pass

with pytest.raises(
RuntimeError,
match=
r"These field\(s\) requrie `needs_grad=True`, however their grad field\(s\) are not placed.*"
):
foo()


@test_utils.test()
def test_check_grad_vector_field_not_placed():
b = ti.Vector.field(3, ti.f32, needs_grad=True)
ti.root.dense(ti.i, 1).place(b)

@ti.kernel
def foo():
pass

with pytest.raises(
RuntimeError,
match=
r"These field\(s\) requrie `needs_grad=True`, however their grad field\(s\) are not placed.*"
):
foo()


@test_utils.test()
def test_check_grad_matrix_field_not_placed():
c = ti.Matrix.field(2, 3, ti.f32, needs_grad=True)
ti.root.dense(ti.i, 1).place(c)

@ti.kernel
def foo():
pass

with pytest.raises(
RuntimeError,
match=
r"These field\(s\) requrie `needs_grad=True`, however their grad field\(s\) are not placed.*"
):
foo()


@test_utils.test()
def test_check_grad_struct_field_not_placed():
d = ti.Struct.field(
{
"pos": ti.types.vector(3, float),
"vel": ti.types.vector(3, float),
"acc": ti.types.vector(3, float),
"mass": ti.f32,
},
needs_grad=True)
ti.root.dense(ti.i, 1).place(d)

@ti.kernel
def foo():
pass

with pytest.raises(
RuntimeError,
match=
r"These field\(s\) requrie `needs_grad=True`, however their grad field\(s\) are not placed.*"
):
foo()


@test_utils.test()
def test_check_matrix_field_member_shape():
a = ti.Matrix.field(2, 2, ti.i32)
Expand Down