Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple additions to execute_with_gradients #6006

Merged
merged 1 commit into from
Oct 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions ivy/functional/backends/jax/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# local
import ivy
from ivy.functional.ivy.gradients import (
_get_native_arrays_and_indices,
_zero_gradients_to_none_and_to_ivy,
_arrays_to_float_variables,
_get_required_native_variables,
_get_native_variables_and_indices,
_remove_zeros_and_nones,
_stop_grad_and_index,
)

Expand Down Expand Up @@ -49,47 +51,55 @@ def _set_duplicates(xs, duplicate_key_chains):
return xs


def _forward_fn(xs, func, duplicate_key_chains):
def _forward_fn(xs, x, xs_grad_idxs, func, duplicate_key_chains):
if xs_grad_idxs is not None:
ivy.set_nest_at_indices(xs, xs_grad_idxs, x)
else:
xs = x
if isinstance(xs, ivy.Container):
xs = _set_duplicates(xs, duplicate_key_chains)
ret = func(xs)
_, arr_values = _get_native_arrays_and_indices(ret)
if isinstance(arr_values, list) and len(arr_values) == 1:
arr_values = arr_values[0]
return arr_values
_, ret_values = _get_native_variables_and_indices(ret)
if isinstance(ret_values, list) and len(ret_values) == 1:
ret_values = ret_values[0]
return ret_values


def execute_with_gradients(func, xs, /, *, retain_grads=False, grad_idxs=None):
def execute_with_gradients(
func, xs, /, *, retain_grads=False, xs_grad_idxs=None, ret_grad_idxs=None
):
xs = _arrays_to_float_variables(xs)
func_ret = func(xs)
xs_required = _get_required_native_variables(ivy.copy_nest(xs), xs_grad_idxs)
xs = ivy.to_native(xs)
arr_idxs, arr_values = _get_native_arrays_and_indices(func_ret)

if arr_values is None or (isinstance(arr_values, list) and len(arr_values) == 0):
ret_idxs, ret_values = _get_native_variables_and_indices(func_ret)
if ret_values is None or (isinstance(ret_values, list) and len(ret_values) == 0):
return func_ret, {}
if isinstance(arr_values, list) and len(arr_values) == 1:
y = arr_values[0]
if isinstance(ret_values, list) and len(ret_values) == 1:
y = ret_values[0]
else:
y = arr_values

y = ret_values
duplicate_key_chains = ()
if isinstance(xs, ivy.Container):
duplicate_key_chains = xs.duplicate_array_keychains()

if isinstance(y, ivy.NativeArray):
grad_fn = jax.grad(lambda x: _forward_fn(x, func, duplicate_key_chains))
grads = grad_fn(xs)
grad_fn = jax.grad(
lambda x: _forward_fn(xs, x, xs_grad_idxs, func, duplicate_key_chains)
)
grads = grad_fn(xs_required)
else:
grad_fn = jax.jacrev(lambda x: _forward_fn(x, func, duplicate_key_chains))
grads_ = grad_fn(xs)
grad_fn = jax.jacrev(
lambda x: _forward_fn(xs, x, xs_grad_idxs, func, duplicate_key_chains)
)
grads_ = grad_fn(xs_required)
grads = grads_
if isinstance(arr_idxs, list) and len(arr_idxs):
grads = {arr_idxs[i]: grad for i, grad in enumerate(grads_)}

if isinstance(ret_idxs, list) and len(ret_idxs):
grads = {ret_idxs[i]: grad for i, grad in enumerate(grads_)}
if isinstance(xs, ivy.Container):
grads = _set_duplicates(grads, duplicate_key_chains)

grads = _zero_gradients_to_none_and_to_ivy(grads)
func_ret, grads = _stop_grad_and_index(func_ret, retain_grads, grads, grad_idxs)
grads = _remove_zeros_and_nones(grads, grads)
func_ret, grads = _stop_grad_and_index(func_ret, retain_grads, grads, ret_grad_idxs)
grads = ivy.to_ivy(grads)
return func_ret, grads


Expand All @@ -99,7 +109,7 @@ def value_and_grad(func):
def callback_fn(xs):
xs = ivy.nested_map(xs, lambda x: ivy.to_native(x), include_derived=True)
ret = jax.value_and_grad(grad_fn)(xs)
ret = _zero_gradients_to_none_and_to_ivy(ret)
ret = _remove_zeros_and_nones(ret, ret)
return ret

return callback_fn
Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/numpy/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def variable_data(x):
return x


def execute_with_gradients(func, xs, /, *, retain_grads=False, grad_idxs=None):
def execute_with_gradients(
func, xs, /, *, retain_grads=False, xs_grad_idxs=None, ret_grad_idxs=None
):
logging.warning(
"NumPy does not support autograd, "
"'execute_with_gradients' returns None in place of function gradients."
Expand Down
46 changes: 24 additions & 22 deletions ivy/functional/backends/tensorflow/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
# local
import ivy
from ivy.functional.ivy.gradients import (
_get_native_arrays_and_indices,
_zero_gradients_to_none_and_to_ivy,
_arrays_to_float_variables,
_get_required_native_variables,
_get_native_variables_and_indices,
_remove_zeros_and_nones,
_stop_grad_and_index,
)

Expand All @@ -29,23 +31,22 @@ def variable_data(x):
return x.value()


def execute_with_gradients(func, xs, /, *, retain_grads=False, grad_idxs=None):
def execute_with_gradients(
func, xs, /, *, retain_grads=False, xs_grad_idxs=None, ret_grad_idxs=None
):
xs = _arrays_to_float_variables(xs)
with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:
tape.watch(ivy.to_native(xs))
func_ret = func(xs)
arr_idxs, arr_values = _get_native_arrays_and_indices(func_ret, reshape=False)

if arr_values is None or (isinstance(arr_values, list) and len(arr_values) == 0):
xs = _get_required_native_variables(xs, xs_grad_idxs)
ret_idxs, ret_values = _get_native_variables_and_indices(func_ret, reshape=False)
if ret_values is None or (isinstance(ret_values, list) and len(ret_values) == 0):
return func_ret, {}
if isinstance(arr_values, list) and len(arr_values) == 1:
y = arr_values[0]
if isinstance(ret_values, list) and len(ret_values) == 1:
y = ret_values[0]
else:
y = arr_values

def grad_func(y):
ret = tape.gradient(y, ivy.to_native(xs))
return ret

y = ret_values
grad_func = lambda y: tape.gradient(y, ivy.to_native(xs))
if isinstance(y, ivy.NativeArray):
grads = ivy.to_ivy(grad_func(y))
else:
Expand All @@ -57,16 +58,15 @@ def grad_func(y):
y = []
else:
y = ivy.multi_index_nest(y, array_idxs)

grads_ = ivy.nested_map(y, grad_func, include_derived=True)
grads = grads_
if isinstance(arr_idxs, list) and len(arr_idxs):
grads = {arr_idxs[i]: grad for i, grad in enumerate(grads_)}

grads = _zero_gradients_to_none_and_to_ivy(grads)
func_ret, grads = _stop_grad_and_index(func_ret, retain_grads, grads, grad_idxs)
if isinstance(ret_idxs, list) and len(ret_idxs):
grads = {ret_idxs[i]: grad for i, grad in enumerate(grads_)}
grads = _remove_zeros_and_nones(grads, grads)
func_ret, grads = _stop_grad_and_index(func_ret, retain_grads, grads, ret_grad_idxs)
if not retain_grads:
del tape
grads = ivy.to_ivy(grads)
return func_ret, grads


Expand All @@ -84,7 +84,8 @@ def grad_fn(xs):
lambda x: ivy.to_ivy(x),
include_derived=True,
)
grads = _zero_gradients_to_none_and_to_ivy(grads_)
grads_ = _remove_zeros_and_nones(grads_, grads_)
grads_ = ivy.to_ivy(grads_)
grad_idxs = ivy.nested_argwhere(grads_, lambda x: ivy.is_ivy_array(x))
grad_array_vals = list(ivy.multi_index_nest(grads_, grad_idxs))
xs = ivy.to_ivy(xs)
Expand Down Expand Up @@ -132,6 +133,7 @@ def callback_fn(x_in):
x_in = ivy.to_native(ivy.array(x_in))
tape.watch(x_in)
y = grad_fn(x_in)
return _zero_gradients_to_none_and_to_ivy(ivy.to_ivy(tape.gradient(y, x_in)))
grad_ = ivy.to_ivy(tape.gradient(y, x_in))
return _remove_zeros_and_nones(grad_, grad_)

return callback_fn
37 changes: 20 additions & 17 deletions ivy/functional/backends/torch/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
# local
import ivy
from ivy.functional.ivy.gradients import (
_get_native_arrays_and_indices,
_zero_gradients_to_none_and_to_ivy,
_arrays_to_float_variables,
_get_required_native_variables,
_get_native_variables_and_indices,
_remove_zeros_and_nones,
_stop_grad_and_index,
)

Expand All @@ -31,7 +33,6 @@ def variable_data(x):

def _forward_fn(xs, func):
xs = ivy.Container(xs)
print("xs", xs)
ret = func(xs)

if isinstance(ret, ivy.Array):
Expand All @@ -50,17 +51,19 @@ def _forward_fn(xs, func):


# noinspection PyShadowingNames
def execute_with_gradients(func, xs, /, *, retain_grads=False, grad_idxs=None):
def execute_with_gradients(
func, xs, /, *, retain_grads=False, xs_grad_idxs=None, ret_grad_idxs=None
):
xs = _arrays_to_float_variables(xs)
func_ret = func(xs)
xs = ivy.to_native(xs)
arr_idxs, arr_values = _get_native_arrays_and_indices(func_ret)

if arr_values is None or (isinstance(arr_values, list) and len(arr_values) == 0):
xs = _get_required_native_variables(xs, xs_grad_idxs)
ret_idxs, ret_values = _get_native_variables_and_indices(func_ret)
if ret_values is None or (isinstance(ret_values, list) and len(ret_values) == 0):
return func_ret, {}
if isinstance(arr_values, list) and len(arr_values) == 1:
y = arr_values[0]
if isinstance(ret_values, list) and len(ret_values) == 1:
y = ret_values[0]
else:
y = arr_values
y = ret_values

def grad_func(y):
if isinstance(xs, ivy.Container):
Expand Down Expand Up @@ -102,11 +105,11 @@ def grad_func(y):
grad_arr_values = ivy.multi_index_nest(y, grad_arr_idxs)
grads_ = [grad_func(torch.clone(arr_value)) for arr_value in grad_arr_values]
grads = grads_
if isinstance(arr_idxs, list) and len(arr_idxs):
grads = {arr_idxs[i]: grad for i, grad in enumerate(grads_)}

grads = _zero_gradients_to_none_and_to_ivy(grads)
func_ret, grads = _stop_grad_and_index(func_ret, retain_grads, grads, grad_idxs)
if isinstance(ret_idxs, list) and len(ret_idxs):
grads = {ret_idxs[i]: grad for i, grad in enumerate(grads_)}
grads = _remove_zeros_and_nones(grads, grads)
func_ret, grads = _stop_grad_and_index(func_ret, retain_grads, grads, ret_grad_idxs)
grads = ivy.to_ivy(grads)
return func_ret, grads


Expand All @@ -125,7 +128,7 @@ def autograd_fn(x):
else ivy.to_native(ivy.zeros_like(ivy.to_ivy(x)))
)
grad = ivy.to_ivy(grad)
grad = _zero_gradients_to_none_and_to_ivy(grad)
grad = _remove_zeros_and_nones(grads, grads)
return grad

grads = ivy.nested_map(
Expand Down
Loading