From 368320774e36b0e5988f8b469731294c94f34b50 Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+VedPatwardhan@users.noreply.github.com> Date: Sat, 22 Oct 2022 08:10:52 +0000 Subject: [PATCH] Multiple additions to execute_with_gradients --- ivy/functional/backends/jax/gradients.py | 64 ++++---- ivy/functional/backends/numpy/gradients.py | 4 +- .../backends/tensorflow/gradients.py | 46 +++--- ivy/functional/backends/torch/gradients.py | 37 +++-- ivy/functional/ivy/gradients.py | 153 +++++++++++------- 5 files changed, 183 insertions(+), 121 deletions(-) diff --git a/ivy/functional/backends/jax/gradients.py b/ivy/functional/backends/jax/gradients.py index 7e9fdda97a15b..90cf2a2178ea9 100644 --- a/ivy/functional/backends/jax/gradients.py +++ b/ivy/functional/backends/jax/gradients.py @@ -13,8 +13,10 @@ # local import ivy from ivy.functional.ivy.gradients import ( - _get_native_arrays_and_indices, - _zero_gradients_to_none_and_to_ivy, + _arrays_to_float_variables, + _get_required_native_variables, + _get_native_variables_and_indices, + _remove_zeros_and_nones, _stop_grad_and_index, ) @@ -49,47 +51,55 @@ def _set_duplicates(xs, duplicate_key_chains): return xs -def _forward_fn(xs, func, duplicate_key_chains): +def _forward_fn(xs, x, xs_grad_idxs, func, duplicate_key_chains): + if xs_grad_idxs is not None: + ivy.set_nest_at_indices(xs, xs_grad_idxs, x) + else: + xs = x if isinstance(xs, ivy.Container): xs = _set_duplicates(xs, duplicate_key_chains) ret = func(xs) - _, arr_values = _get_native_arrays_and_indices(ret) - if isinstance(arr_values, list) and len(arr_values) == 1: - arr_values = arr_values[0] - return arr_values + _, ret_values = _get_native_variables_and_indices(ret) + if isinstance(ret_values, list) and len(ret_values) == 1: + ret_values = ret_values[0] + return ret_values -def execute_with_gradients(func, xs, /, *, retain_grads=False, grad_idxs=None): +def execute_with_gradients( + func, xs, /, *, retain_grads=False, xs_grad_idxs=None, ret_grad_idxs=None +): + xs = _arrays_to_float_variables(xs) func_ret = func(xs) + xs_required = _get_required_native_variables(ivy.copy_nest(xs), xs_grad_idxs) xs = ivy.to_native(xs) - arr_idxs, arr_values = _get_native_arrays_and_indices(func_ret) - - if arr_values is None or (isinstance(arr_values, list) and len(arr_values) == 0): + ret_idxs, ret_values = _get_native_variables_and_indices(func_ret) + if ret_values is None or (isinstance(ret_values, list) and len(ret_values) == 0): return func_ret, {} - if isinstance(arr_values, list) and len(arr_values) == 1: - y = arr_values[0] + if isinstance(ret_values, list) and len(ret_values) == 1: + y = ret_values[0] else: - y = arr_values - + y = ret_values duplicate_key_chains = () if isinstance(xs, ivy.Container): duplicate_key_chains = xs.duplicate_array_keychains() - if isinstance(y, ivy.NativeArray): - grad_fn = jax.grad(lambda x: _forward_fn(x, func, duplicate_key_chains)) - grads = grad_fn(xs) + grad_fn = jax.grad( + lambda x: _forward_fn(xs, x, xs_grad_idxs, func, duplicate_key_chains) + ) + grads = grad_fn(xs_required) else: - grad_fn = jax.jacrev(lambda x: _forward_fn(x, func, duplicate_key_chains)) - grads_ = grad_fn(xs) + grad_fn = jax.jacrev( + lambda x: _forward_fn(xs, x, xs_grad_idxs, func, duplicate_key_chains) + ) + grads_ = grad_fn(xs_required) grads = grads_ - if isinstance(arr_idxs, list) and len(arr_idxs): - grads = {arr_idxs[i]: grad for i, grad in enumerate(grads_)} - + if isinstance(ret_idxs, list) and len(ret_idxs): + grads = {ret_idxs[i]: grad for i, grad in enumerate(grads_)} if isinstance(xs, ivy.Container): grads = _set_duplicates(grads, duplicate_key_chains) - - grads = _zero_gradients_to_none_and_to_ivy(grads) - func_ret, grads = _stop_grad_and_index(func_ret, retain_grads, grads, grad_idxs) + grads = _remove_zeros_and_nones(grads, grads) + func_ret, grads = _stop_grad_and_index(func_ret, retain_grads, grads, ret_grad_idxs) + grads = ivy.to_ivy(grads) return func_ret, grads @@ -99,7 +109,7 @@ def value_and_grad(func): def callback_fn(xs): xs = ivy.nested_map(xs, lambda x: ivy.to_native(x), include_derived=True) ret = jax.value_and_grad(grad_fn)(xs) - ret = _zero_gradients_to_none_and_to_ivy(ret) + ret = _remove_zeros_and_nones(ret, ret) return ret return callback_fn diff --git a/ivy/functional/backends/numpy/gradients.py b/ivy/functional/backends/numpy/gradients.py index 95c7f2264ece8..83bf51ecbe380 100644 --- a/ivy/functional/backends/numpy/gradients.py +++ b/ivy/functional/backends/numpy/gradients.py @@ -23,7 +23,9 @@ def variable_data(x): return x -def execute_with_gradients(func, xs, /, *, retain_grads=False, grad_idxs=None): +def execute_with_gradients( + func, xs, /, *, retain_grads=False, xs_grad_idxs=None, ret_grad_idxs=None +): logging.warning( "NumPy does not support autograd, " "'execute_with_gradients' returns None in place of function gradients." diff --git a/ivy/functional/backends/tensorflow/gradients.py b/ivy/functional/backends/tensorflow/gradients.py index eab0091b20382..a0577a38f740d 100644 --- a/ivy/functional/backends/tensorflow/gradients.py +++ b/ivy/functional/backends/tensorflow/gradients.py @@ -10,8 +10,10 @@ # local import ivy from ivy.functional.ivy.gradients import ( - _get_native_arrays_and_indices, - _zero_gradients_to_none_and_to_ivy, + _arrays_to_float_variables, + _get_required_native_variables, + _get_native_variables_and_indices, + _remove_zeros_and_nones, _stop_grad_and_index, ) @@ -29,23 +31,22 @@ def variable_data(x): return x.value() -def execute_with_gradients(func, xs, /, *, retain_grads=False, grad_idxs=None): +def execute_with_gradients( + func, xs, /, *, retain_grads=False, xs_grad_idxs=None, ret_grad_idxs=None +): + xs = _arrays_to_float_variables(xs) with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape: tape.watch(ivy.to_native(xs)) func_ret = func(xs) - arr_idxs, arr_values = _get_native_arrays_and_indices(func_ret, reshape=False) - - if arr_values is None or (isinstance(arr_values, list) and len(arr_values) == 0): + xs = _get_required_native_variables(xs, xs_grad_idxs) + ret_idxs, ret_values = _get_native_variables_and_indices(func_ret, reshape=False) + if ret_values is None or (isinstance(ret_values, list) and len(ret_values) == 0): return func_ret, {} - if isinstance(arr_values, list) and len(arr_values) == 1: - y = arr_values[0] + if isinstance(ret_values, list) and len(ret_values) == 1: + y = ret_values[0] else: - y = arr_values - - def grad_func(y): - ret = tape.gradient(y, ivy.to_native(xs)) - return ret - + y = ret_values + grad_func = lambda y: tape.gradient(y, ivy.to_native(xs)) if isinstance(y, ivy.NativeArray): grads = ivy.to_ivy(grad_func(y)) else: @@ -57,16 +58,15 @@ def grad_func(y): y = [] else: y = ivy.multi_index_nest(y, array_idxs) - grads_ = ivy.nested_map(y, grad_func, include_derived=True) grads = grads_ - if isinstance(arr_idxs, list) and len(arr_idxs): - grads = {arr_idxs[i]: grad for i, grad in enumerate(grads_)} - - grads = _zero_gradients_to_none_and_to_ivy(grads) - func_ret, grads = _stop_grad_and_index(func_ret, retain_grads, grads, grad_idxs) + if isinstance(ret_idxs, list) and len(ret_idxs): + grads = {ret_idxs[i]: grad for i, grad in enumerate(grads_)} + grads = _remove_zeros_and_nones(grads, grads) + func_ret, grads = _stop_grad_and_index(func_ret, retain_grads, grads, ret_grad_idxs) if not retain_grads: del tape + grads = ivy.to_ivy(grads) return func_ret, grads @@ -84,7 +84,8 @@ def grad_fn(xs): lambda x: ivy.to_ivy(x), include_derived=True, ) - grads = _zero_gradients_to_none_and_to_ivy(grads_) + grads_ = _remove_zeros_and_nones(grads_, grads_) + grads_ = ivy.to_ivy(grads_) grad_idxs = ivy.nested_argwhere(grads_, lambda x: ivy.is_ivy_array(x)) grad_array_vals = list(ivy.multi_index_nest(grads_, grad_idxs)) xs = ivy.to_ivy(xs) @@ -132,6 +133,7 @@ def callback_fn(x_in): x_in = ivy.to_native(ivy.array(x_in)) tape.watch(x_in) y = grad_fn(x_in) - return _zero_gradients_to_none_and_to_ivy(ivy.to_ivy(tape.gradient(y, x_in))) + grad_ = ivy.to_ivy(tape.gradient(y, x_in)) + return _remove_zeros_and_nones(grad_, grad_) return callback_fn diff --git a/ivy/functional/backends/torch/gradients.py b/ivy/functional/backends/torch/gradients.py index 28ebed23d45ec..8598cecf98b0f 100644 --- a/ivy/functional/backends/torch/gradients.py +++ b/ivy/functional/backends/torch/gradients.py @@ -9,8 +9,10 @@ # local import ivy from ivy.functional.ivy.gradients import ( - _get_native_arrays_and_indices, - _zero_gradients_to_none_and_to_ivy, + _arrays_to_float_variables, + _get_required_native_variables, + _get_native_variables_and_indices, + _remove_zeros_and_nones, _stop_grad_and_index, ) @@ -31,7 +33,6 @@ def variable_data(x): def _forward_fn(xs, func): xs = ivy.Container(xs) - print("xs", xs) ret = func(xs) if isinstance(ret, ivy.Array): @@ -50,17 +51,19 @@ def _forward_fn(xs, func): # noinspection PyShadowingNames -def execute_with_gradients(func, xs, /, *, retain_grads=False, grad_idxs=None): +def execute_with_gradients( + func, xs, /, *, retain_grads=False, xs_grad_idxs=None, ret_grad_idxs=None +): + xs = _arrays_to_float_variables(xs) func_ret = func(xs) - xs = ivy.to_native(xs) - arr_idxs, arr_values = _get_native_arrays_and_indices(func_ret) - - if arr_values is None or (isinstance(arr_values, list) and len(arr_values) == 0): + xs = _get_required_native_variables(xs, xs_grad_idxs) + ret_idxs, ret_values = _get_native_variables_and_indices(func_ret) + if ret_values is None or (isinstance(ret_values, list) and len(ret_values) == 0): return func_ret, {} - if isinstance(arr_values, list) and len(arr_values) == 1: - y = arr_values[0] + if isinstance(ret_values, list) and len(ret_values) == 1: + y = ret_values[0] else: - y = arr_values + y = ret_values def grad_func(y): if isinstance(xs, ivy.Container): @@ -102,11 +105,11 @@ def grad_func(y): grad_arr_values = ivy.multi_index_nest(y, grad_arr_idxs) grads_ = [grad_func(torch.clone(arr_value)) for arr_value in grad_arr_values] grads = grads_ - if isinstance(arr_idxs, list) and len(arr_idxs): - grads = {arr_idxs[i]: grad for i, grad in enumerate(grads_)} - - grads = _zero_gradients_to_none_and_to_ivy(grads) - func_ret, grads = _stop_grad_and_index(func_ret, retain_grads, grads, grad_idxs) + if isinstance(ret_idxs, list) and len(ret_idxs): + grads = {ret_idxs[i]: grad for i, grad in enumerate(grads_)} + grads = _remove_zeros_and_nones(grads, grads) + func_ret, grads = _stop_grad_and_index(func_ret, retain_grads, grads, ret_grad_idxs) + grads = ivy.to_ivy(grads) return func_ret, grads @@ -125,7 +128,7 @@ def autograd_fn(x): else ivy.to_native(ivy.zeros_like(ivy.to_ivy(x))) ) grad = ivy.to_ivy(grad) - grad = _zero_gradients_to_none_and_to_ivy(grad) + grad = _remove_zeros_and_nones(grads, grads) return grad grads = ivy.nested_map( diff --git a/ivy/functional/ivy/gradients.py b/ivy/functional/ivy/gradients.py index 570b14bb0f26e..1749bb34c41d5 100644 --- a/ivy/functional/ivy/gradients.py +++ b/ivy/functional/ivy/gradients.py @@ -22,65 +22,103 @@ # ------- # -def _zero_gradients_to_none_and_to_ivy(grads): - if isinstance(grads, ivy.Array): - return ( - None - if ivy.all(ivy.abs(grads).astype("float64") < 1e-10) - else ivy.to_ivy(grads) - ) - else: +def _arrays_to_float_variables(xs): + def map_fn(x): + if ivy.is_array(x, exclusive=True): + if ivy.is_int_dtype(x.dtype): + x = x.astype(ivy.default_float_dtype()) + return ivy.variable(x) + return x - def func(x): - if ivy.is_array(x): - abs_val = ivy.abs(x) - return ivy.all(abs_val.astype("float64") < 1e-10) - return x is None + return ivy.nested_map(xs, map_fn, include_derived=True) - zero_idxs = ivy.nested_argwhere(grads, func) - if ( - not isinstance(zero_idxs, list) - or np.asarray(zero_idxs, dtype="object").size == 0 - ): - return ivy.nested_map(grads, ivy.to_ivy, include_derived=True) - zero_idxs.reverse() - ivy.prune_nest_at_indices(grads, zero_idxs) - return ivy.nested_map(grads, ivy.to_ivy, include_derived=True) +def _get_required_native_variables(xs, xs_grad_idxs): + xs = ivy.to_ivy(xs) + if xs_grad_idxs is not None: + ivy.map_nest_at_indices(xs, xs_grad_idxs, ivy.to_native) + else: + xs = ivy.nested_map(xs, ivy.to_native) -def _get_native_arrays_and_indices(func_ret, reshape=True): def map_fn(x): - if ivy.is_array(x) and ivy.is_variable(x): - x = ivy.to_ivy(x) if ivy.is_native_array(x) else x - if len(x.shape) == 0: - return ivy.to_native(x) - elif x.size == 1: - if reshape: - return ivy.to_native(ivy.reshape(x, [])) - return ivy.to_native(x) + if ivy.is_native_array(x): + return x + return None + + xs = ivy.nested_map(xs, map_fn, include_derived=True) + none_idxs = ivy.nested_argwhere(xs, lambda x: x is None) + if not _check_if_empty(none_idxs): + none_idxs.reverse() + ivy.prune_nest_at_indices(xs, none_idxs) + return xs + + +def _check_if_empty(idxs): + return not isinstance(idxs, list) or np.asarray(idxs, dtype="object").size == 0 + + +def _remove_zeros_and_nones(grads, x, idx=[]): + if ivy.is_array(x): + abs_val = ivy.abs(x) + if ivy.all(abs_val.astype("float64") < 1e-10): + ivy.prune_nest_at_index(grads, idx) + return grads + if x is None: + ivy.prune_nest_at_index(grads, idx) + else: + keys = [k for k in x] + for k in keys: + idx.append(k) + grads = _remove_zeros_and_nones(grads, x[k], idx) + idx.pop() + + keys = [k for k in x] + if len(keys) == 0: + ivy.prune_nest_at_index(grads, idx) + return grads + + +def _idxs_to_str(idxs): + final_idxs = [] + for i in range(len(idxs)): + final_idxs.append([str(x) for x in idxs[i]]) + final_idxs[i] = "_".join(final_idxs[i]) + return final_idxs + + +def _get_native_variables_and_indices(x, reshape=True): + def map_fn(x_): + if ivy.is_array(x_): + x_ = ivy.to_ivy(x_) if ivy.is_native_array(x_) else x_ + if len(x_.shape) == 0: + return ivy.to_native(x_) + if reshape: + if x_.size == 1: + if reshape: + return ivy.to_native(ivy.reshape(x_, [])) + return ivy.to_native(x_) + else: + return ivy.to_ivy(x_) else: - return ivy.to_ivy(x) - return x + return ivy.to_native(x_) + return x_ - if ivy.is_array(func_ret) and ivy.is_variable(func_ret): - return [], map_fn(func_ret) + if ivy.is_array(x): + return [], map_fn(x) - func_ret = ivy.nested_map(func_ret, map_fn, include_derived=True) - arr_idxs = ivy.nested_argwhere(func_ret, lambda x: ivy.is_native_array(x)) - if not isinstance(arr_idxs, list) or np.asarray(arr_idxs, "object").size == 0: - arr_values = [] + x = ivy.nested_map(x, map_fn, include_derived=True) + arr_idxs = ivy.nested_argwhere(x, lambda x: ivy.is_native_array(x)) + if _check_if_empty(arr_idxs): + return arr_idxs, [] else: - arr_values = ivy.multi_index_nest(func_ret, arr_idxs) - for i in range(len(arr_idxs)): - arr_idxs[i] = [str(x) for x in arr_idxs[i]] - arr_idxs[i] = "_".join(arr_idxs[i]) - - return arr_idxs, arr_values + arr_values = ivy.multi_index_nest(x, arr_idxs) + arr_idxs = _idxs_to_str(arr_idxs) + return arr_idxs, arr_values def _stop_grad_and_index(func_ret, retain_grads, grads, grad_idxs): if not retain_grads: - if isinstance(func_ret, ivy.Array): + if ivy.is_array(func_ret): func_ret = ivy.stop_gradient(func_ret) else: func_ret = ivy.nested_map( @@ -89,9 +127,7 @@ def _stop_grad_and_index(func_ret, retain_grads, grads, grad_idxs): include_derived=True, ) if grad_idxs is not None: - for i in range(len(grad_idxs)): - grad_idxs[i] = [str(x) for x in grad_idxs[i]] - grad_idxs[i] = "_".join(grad_idxs[i]) + grad_idxs = _idxs_to_str(grad_idxs) grads = {idx: grads[idx] for idx in grad_idxs} if isinstance(grads, dict): grads = ivy.Container(grads) @@ -470,7 +506,9 @@ def stop_gradient( @inputs_to_ivy_arrays @handle_exceptions -def execute_with_gradients(func, xs, /, *, retain_grads=False, grad_idxs=None): +def execute_with_gradients( + func, xs, /, *, retain_grads=False, xs_grad_idxs=None, ret_grad_idxs=None +): """Call function func with input of xs variables, and return the function result func_ret and the gradients of each output variable w.r.t each input variable, @@ -483,9 +521,12 @@ def execute_with_gradients(func, xs, /, *, retain_grads=False, grad_idxs=None): Variables for which to compute the function gradients with respective to. retain_grads Whether to retain the gradients of the returned values. (Default value = False) - grad_idxs - Indices of the returned arrays for which to return computed gradients If None, - all gradients are returned. (Default value = None) + xs_grad_idxs + Indices of the input arrays to compute gradients with respect to. If None, + gradients are returned with respect to all input arrays. (Default value = None) + ret_grad_idxs + Indices of the returned arrays for which to return computed gradients. If None, + gradients are returned for all returned arrays. (Default value = None) Returns ------- @@ -495,7 +536,11 @@ def execute_with_gradients(func, xs, /, *, retain_grads=False, grad_idxs=None): """ return current_backend(None).execute_with_gradients( - func, xs, retain_grads=retain_grads, grad_idxs=grad_idxs + func, + xs, + retain_grads=retain_grads, + xs_grad_idxs=xs_grad_idxs, + ret_grad_idxs=ret_grad_idxs, )