diff --git a/doc/extending/op.rst b/doc/extending/op.rst index ddd397dee9..b1585c4ecd 100644 --- a/doc/extending/op.rst +++ b/doc/extending/op.rst @@ -506,4 +506,3 @@ These are the function required to work with :func:`pytensor.gradient.grad`. the outputs) back to their corresponding shapes and return them as the output of the :meth:`Op.R_op` method. - :ref:`List of op with r op support `. diff --git a/doc/library/gradient.rst b/doc/library/gradient.rst deleted file mode 100644 index f823a1c381..0000000000 --- a/doc/library/gradient.rst +++ /dev/null @@ -1,76 +0,0 @@ -.. _libdoc_gradient: - -=========================================== -:mod:`gradient` -- Symbolic Differentiation -=========================================== - -.. module:: gradient - :platform: Unix, Windows - :synopsis: low-level automatic differentiation -.. moduleauthor:: LISA - -.. testsetup:: * - - from pytensor.gradient import * - -Symbolic gradient is usually computed from :func:`gradient.grad`, which offers a -more convenient syntax for the common case of wanting the gradient of some -scalar cost with respect to some input expressions. The :func:`grad_sources_inputs` -function does the underlying work, and is more flexible, but is also more -awkward to use when :func:`gradient.grad` can do the job. - - -Gradient related functions -========================== - -.. automodule:: pytensor.gradient - :members: - -.. _R_op_list: - - -List of Implemented R op -======================== - - -See the :ref:`gradient tutorial ` for the R op documentation. - -list of ops that support R-op: - * with test - * SpecifyShape - * MaxAndArgmax - * Subtensor - * IncSubtensor set_subtensor too - * Alloc - * Dot - * Elemwise - * Sum - * Softmax - * Shape - * Join - * Rebroadcast - * Reshape - * DimShuffle - * Scan [In tests/scan/test_basic.test_rop] - - * without test - * Split - * ARange - * ScalarFromTensor - * AdvancedSubtensor1 - * AdvancedIncSubtensor1 - * AdvancedIncSubtensor - -Partial list of ops without support for R-op: - - * All sparse ops - * All linear algebra ops. - * PermuteRowElements - * AdvancedSubtensor - * TensorDot - * Outer - * Prod - * MulwithoutZeros - * ProdWithoutZeros - * CAReduce(for max,... done for MaxAndArgmax op) - * MaxAndArgmax(only for matrix on axis 0 or 1) diff --git a/doc/library/tensor/basic.rst b/doc/library/tensor/basic.rst index 8d22c1e577..4f087b6788 100644 --- a/doc/library/tensor/basic.rst +++ b/doc/library/tensor/basic.rst @@ -1791,5 +1791,3 @@ Gradient / Differentiation :members: grad :noindex: -See the :ref:`gradient ` page for complete documentation -of the gradient module. diff --git a/doc/tutorial/gradients.rst b/doc/tutorial/gradients.rst index edb38bb018..5f16be3cec 100644 --- a/doc/tutorial/gradients.rst +++ b/doc/tutorial/gradients.rst @@ -86,9 +86,7 @@ of symbolic differentiation). ``i`` of the output list is the gradient of the first argument of `pt.grad` with respect to the ``i``-th element of the list given as second argument. The first argument of `pt.grad` has to be a scalar (a tensor - of size 1). For more information on the semantics of the arguments of - `pt.grad` and details about the implementation, see - :ref:`this` section of the library. + of size 1). Additional information on the inner workings of differentiation may also be found in the more advanced tutorial :ref:`Extending PyTensor`. @@ -204,7 +202,21 @@ you need to do something similar to this: >>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1]) array([ 2., 2.]) -:ref:`List ` of Op that implement Rop. +By default, the R-operator is implemented as a double application of the L_operator +(see `reference `). +In most cases this should be as performant as a specialized implementation of the R-operator. +However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators, +such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator. + +When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing +`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method. + + +>>> JV = pytensor.gradient.Rop(y, W, V, use_op_rop_implementation=True) +>>> f = pytensor.function([W, V, x], JV) +>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1]) +array([ 2., 2.]) + L-operator ---------- @@ -234,7 +246,6 @@ array([[ 0., 0.], as the input parameter, while the result of the R-operator has a shape similar to that of the output. - :ref:`List of op with r op support `. Hessian times a Vector ====================== diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 49baa3bb26..a4a3d1840a 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -340,6 +340,12 @@ def __init__( ``None``, this will be used as the connection_pattern for this :class:`Op`. + .. warning:: + + rop overrides is ignored when `pytensor.gradient.Rop` is called with + `use_op_rop_implementation=False` (default). In this case the Lop + is used twice to obtain a mathematically equivalent Rop. + strict: bool, default False If true, it raises when any variables needed to compute the inner graph are not provided as explici inputs. This can only happen for graphs with @@ -641,7 +647,12 @@ def _build_and_cache_rop_op(self): return rop_overrides eval_points = [inp_t() for inp_t in self.input_types] - fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points) + fn_rop = partial( + Rop, + wrt=inner_inputs, + eval_points=eval_points, + use_op_rop_implementation=True, + ) callable_args = (inner_inputs, eval_points) if rop_overrides is None: diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 13ca943383..04572b29d0 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -142,13 +142,50 @@ def __str__(self): disconnected_type = DisconnectedType() -def Rop( - f: Variable | Sequence[Variable], - wrt: Variable | Sequence[Variable], - eval_points: Variable | Sequence[Variable], +def pushforward_through_pullback( + outputs: Sequence[Variable], + inputs: Sequence[Variable], + tangents: Sequence[Variable], disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise", return_disconnected: Literal["none", "zero", "disconnected"] = "zero", -) -> Variable | None | Sequence[Variable | None]: +) -> Sequence[Variable | None]: + """Compute the pushforward (Rop) through two applications of a pullback (Lop) operation. + + References + ---------- + .. [1] J. Towns, "A new trick for calculating Jacobian vector products", 2017. + Available: https://j-towns.github.io/2017/06/12/A-new-trick.html + + """ + # Cotangents are just auxiliary variables that should be pruned from the final graph, + # but that would require a graph rewrite before the user tries to compile a pytensor function. + # To avoid trouble we use .zeros_like() instead of .type(), which does not create a new root variable. + cotangents = [out.zeros_like(dtype=config.floatX) for out in outputs] # type: ignore + + input_cotangents = Lop( + f=outputs, + wrt=inputs, + eval_points=cotangents, + disconnected_inputs=disconnected_outputs, + return_disconnected="zero", + ) + + return Lop( + f=input_cotangents, # type: ignore + wrt=cotangents, + eval_points=tangents, + disconnected_inputs="ignore", + return_disconnected=return_disconnected, + ) + + +def _rop_legacy( + f: Sequence[Variable], + wrt: Sequence[Variable], + eval_points: Sequence[Variable], + disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise", + return_disconnected: Literal["none", "zero", "disconnected"] = "zero", +) -> Sequence[Variable | None]: """Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`. Mathematically this stands for the Jacobian of `f` right multiplied by the @@ -190,38 +227,6 @@ def Rop( If `f` is a list/tuple, then return a list/tuple with the results. """ - if not isinstance(wrt, list | tuple): - _wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)] - else: - _wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt] - - if not isinstance(eval_points, list | tuple): - _eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)] - else: - _eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points] - - if not isinstance(f, list | tuple): - _f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)] - else: - _f = [pytensor.tensor.as_tensor_variable(x) for x in f] - - if len(_wrt) != len(_eval_points): - raise ValueError("`wrt` must be the same length as `eval_points`.") - - # Check that each element of wrt corresponds to an element - # of eval_points with the same dimensionality. - for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)): - try: - if wrt_elem.type.ndim != eval_point.type.ndim: - raise ValueError( - f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: " - f"{wrt_elem.type.ndim} and {eval_point.type.ndim}" - ) - except AttributeError: - # wrt_elem and eval_point don't always have ndim like random type - # Tensor, Sparse have the ndim attribute - pass - seen_nodes: dict[Apply, Sequence[Variable]] = {} def _traverse(node): @@ -237,8 +242,8 @@ def _traverse(node): # inputs of the node local_eval_points = [] for inp in inputs: - if inp in _wrt: - local_eval_points.append(_eval_points[_wrt.index(inp)]) + if inp in wrt: + local_eval_points.append(eval_points[wrt.index(inp)]) elif inp.owner is None: try: local_eval_points.append(inp.zeros_like()) @@ -292,13 +297,13 @@ def _traverse(node): # end _traverse # Populate the dictionary - for out in _f: + for out in f: _traverse(out.owner) rval: list[Variable | None] = [] - for out in _f: - if out in _wrt: - rval.append(_eval_points[_wrt.index(out)]) + for out in f: + if out in wrt: + rval.append(eval_points[wrt.index(out)]) elif ( seen_nodes.get(out.owner, None) is None or seen_nodes[out.owner][out.owner.outputs.index(out)] is None @@ -337,6 +342,116 @@ def _traverse(node): else: rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)]) + return rval + + +def Rop( + f: Variable | Sequence[Variable], + wrt: Variable | Sequence[Variable], + eval_points: Variable | Sequence[Variable], + disconnected_outputs: Literal["ignore", "warn", "raise"] = "raise", + return_disconnected: Literal["none", "zero", "disconnected"] = "zero", + use_op_rop_implementation: bool = False, +) -> Variable | None | Sequence[Variable | None]: + """Computes the R-operator applied to `f` with respect to `wrt` at `eval_points`. + + Mathematically this stands for the Jacobian of `f` right multiplied by the + `eval_points`. + + By default, the R-operator is implemented as a double application of the L_operator [1]_. + In most cases this should be as performant as a specialized implementation of the R-operator. + However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators, + such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator. + + When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing + `use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method. + + Parameters + ---------- + f + The outputs of the computational graph to which the R-operator is + applied. + wrt + Variables for which the R-operator of `f` is computed. + eval_points + Points at which to evaluate each of the variables in `wrt`. + disconnected_outputs + Defines the behaviour if some of the variables in `f` + have no dependency on any of the variable in `wrt` (or if + all links are non-differentiable). The possible values are: + + - ``'ignore'``: considers that the gradient on these parameters is zero. + - ``'warn'``: consider the gradient zero, and print a warning. + - ``'raise'``: raise `DisconnectedInputError`. + + return_disconnected + - ``'zero'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be + ``wrt[i].zeros_like()``. + - ``'none'`` : If ``wrt[i]`` is disconnected, return value ``i`` will be + ``None`` + - ``'disconnected'`` : returns variables of type `DisconnectedType` + use_op_lop_implementation: bool, default=True + If `True`, we obtain Rop via double application of Lop. + If `False`, the legacy Rop implementation is used. The number of graphs that support this form + is much more restricted, and the generated graphs may be less optimized. + + Returns + ------- + :class:`~pytensor.graph.basic.Variable` or list/tuple of Variables + A symbolic expression such obeying + ``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``, + where the indices in that expression are magic multidimensional + indices that specify both the position within a list and all + coordinates of the tensor elements. + If `f` is a list/tuple, then return a list/tuple with the results. + + References + ---------- + .. [1] J. Towns, "A new trick for calculating Jacobian vector products", 2017. + Available: https://j-towns.github.io/2017/06/12/A-new-trick.html + """ + + if not isinstance(wrt, list | tuple): + _wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)] + else: + _wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt] + + if not isinstance(eval_points, list | tuple): + _eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)] + else: + _eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points] + + if not isinstance(f, list | tuple): + _f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)] + else: + _f = [pytensor.tensor.as_tensor_variable(x) for x in f] + + if len(_wrt) != len(_eval_points): + raise ValueError("`wrt` must be the same length as `eval_points`.") + + # Check that each element of wrt corresponds to an element + # of eval_points with the same dimensionality. + for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points, strict=True)): + try: + if wrt_elem.type.ndim != eval_point.type.ndim: + raise ValueError( + f"Elements {i} of `wrt` and `eval_point` have mismatched dimensionalities: " + f"{wrt_elem.type.ndim} and {eval_point.type.ndim}" + ) + except AttributeError: + # wrt_elem and eval_point don't always have ndim like random type + # Tensor, Sparse have the ndim attribute + pass + + if use_op_rop_implementation: + rval = _rop_legacy( + _f, _wrt, _eval_points, disconnected_outputs, return_disconnected + ) + else: + rval = pushforward_through_pullback( + _f, _wrt, _eval_points, disconnected_outputs, return_disconnected + ) + using_list = isinstance(f, list) using_tuple = isinstance(f, tuple) return as_list_or_tuple(using_list, using_tuple, rval) @@ -348,6 +463,7 @@ def Lop( eval_points: Variable | Sequence[Variable], consider_constant: Sequence[Variable] | None = None, disconnected_inputs: Literal["ignore", "warn", "raise"] = "raise", + return_disconnected: Literal["none", "zero", "disconnected"] = "zero", ) -> Variable | None | Sequence[Variable | None]: """Computes the L-operator applied to `f` with respect to `wrt` at `eval_points`. @@ -404,6 +520,7 @@ def Lop( consider_constant=consider_constant, wrt=_wrt, disconnected_inputs=disconnected_inputs, + return_disconnected=return_disconnected, ) using_list = isinstance(wrt, list) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index a01347ef9c..87cadac8cf 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -72,6 +72,7 @@ from pytensor.graph.features import NoOutputFromInplace from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.replace import clone_replace +from pytensor.graph.type import HasShape from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.link.c.basic import CLinker from pytensor.printing import op_debug_information @@ -2509,13 +2510,25 @@ def compute_all_gradients(known_grads): return rval var_mappings = self.get_oinp_iinp_iout_oout_mappings() - dC_dinps_t = [None for inp in diff_inputs] disconnected_dC_dinps_t = [True for inp in diff_inputs] + + n_mit_mot_outs = info.n_mit_mot_outs + # In the case of mit-mot there can be more inner outputs than outer ones + n_extra_mit_mot_outs = n_mit_mot_outs - info.n_mit_mot + idx_nitsot_out_start = n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot + idx_nitsot_out_end = idx_nitsot_out_start + info.n_nit_sot + + # Create dummy variables for the internal input gradients + states = ( + self.inner_mitmot(self_inputs) + + self.inner_mitsot(self_inputs) + + self.inner_sitsot(self_inputs) + ) dC_dXts = [] Xts = [] for idx, Xt in enumerate(diff_outputs): # We are looking for x[t-1] for a given x[t] - if idx >= info.n_mit_mot_outs: + if idx >= n_mit_mot_outs: Xt_placeholder = safe_new(Xt) Xts.append(Xt_placeholder) @@ -2523,9 +2536,7 @@ def compute_all_gradients(known_grads): # or not. NOTE : This cannot be done by using # "if Xt not in self.inner_nitsot_outs(self_outputs)" because # the exact same variable can be used as multiple outputs. - idx_nitsot_start = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot - idx_nitsot_end = idx_nitsot_start + info.n_nit_sot - if idx < idx_nitsot_start or idx >= idx_nitsot_end: + if idx < idx_nitsot_out_start or idx >= idx_nitsot_out_end: # What we do here is loop through dC_douts and collect all # those that are connected to the specific one and do an # upcast on all of their dtypes to get the dtype for this @@ -2533,12 +2544,6 @@ def compute_all_gradients(known_grads): # specific previous step is defined or not is done somewhere # else. dtypes = [] - states = ( - self.inner_mitmot(self_inputs) - + self.inner_mitsot(self_inputs) - + self.inner_sitsot(self_inputs) - ) - for pos, inp in enumerate(states): if inp in graph_inputs([Xt]): # Get the index of the outer output that to which @@ -2555,35 +2560,43 @@ def compute_all_gradients(known_grads): new_dtype = config.floatX dC_dXt = safe_new(Xt, dtype=new_dtype) else: - if isinstance(dC_douts[idx].type, DisconnectedType): + # nit-sot outputs + # If not disconnected assume the output gradient type is a valid type for the input gradient + if isinstance( + dC_douts[idx - n_extra_mit_mot_outs].type, DisconnectedType + ): continue - dC_dXt = safe_new(dC_douts[idx][0]) + dC_dXt = safe_new(dC_douts[idx - n_extra_mit_mot_outs][0]) dC_dXts.append(dC_dXt) + # Handle cases where the very same variable may be used as different outputs + # TODO: Couldn't we add a view Op to avoid this when building the Scan graph? known_grads = {} dc_dxts_idx = 0 for i in range(len(diff_outputs)): - if i < idx_nitsot_start or i >= idx_nitsot_end: - if diff_outputs[i] in known_grads: - known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] - else: - known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] - dc_dxts_idx += 1 + if not (i < idx_nitsot_out_start or i >= idx_nitsot_out_end) and isinstance( + dC_douts[i - n_extra_mit_mot_outs].type, DisconnectedType + ): + # Special case where we don't have a dC_dXt for disconnected nitsot outputs + continue + + # Just some trouble to avoid a +0 + if diff_outputs[i] in known_grads: + known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] else: - if isinstance(dC_douts[i].type, DisconnectedType): - continue - else: - if diff_outputs[i] in known_grads: - known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] - else: - known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] - dc_dxts_idx += 1 + known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] + dc_dxts_idx += 1 + dC_dinps_t = compute_all_gradients(known_grads) # mask inputs that get no gradients for dx in range(len(dC_dinps_t)): - if not dC_dinps_t[dx]: - dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx]) + if dC_dinps_t[dx] is None: + dC_dinps_t[dx] = ( + pt.zeros_like(diff_inputs[dx]) + if isinstance(diff_inputs[dx].type, HasShape) + else pt.zeros(()) + ) else: disconnected_dC_dinps_t[dx] = False for Xt, Xt_placeholder in zip( @@ -2846,7 +2859,6 @@ def compute_all_gradients(known_grads): for idx in range(info.n_sit_sot): mitmot_inp_taps.append([0, 1]) mitmot_out_taps.append([1]) - through_shared = False if not isinstance(dC_douts[idx + offset].type, DisconnectedType): outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) else: @@ -2883,7 +2895,7 @@ def compute_all_gradients(known_grads): elif through_shared: type_outs.append("through_shared") elif disconnected_dC_dinps_t[ins_pos]: - type_outs.append("disconnected") + type_outs.append("disconnected ") else: type_outs.append("connected") @@ -2934,7 +2946,7 @@ def compute_all_gradients(known_grads): if through_shared: type_outs.append("through_shared") elif disconnected_dC_dinps_t[_p]: - type_outs.append("disconnected") + type_outs.append("disconnected nitsot") else: type_outs.append("connected") @@ -2958,7 +2970,8 @@ def compute_all_gradients(known_grads): else: outer_inp_sitsot.append( pt.zeros( - [grad_steps + 1] + [x.shape[i] for i in range(x.ndim)], + [grad_steps + 1] + + (list(x.shape) if isinstance(x.type, HasShape) else []), dtype=y.dtype, ) ) @@ -3007,9 +3020,7 @@ def compute_all_gradients(known_grads): name=f"grad_of_{self.name}" if self.name else None, allow_gc=self.allow_gc, ) - outputs = local_op(*outer_inputs) - if not isinstance(outputs, list | tuple): - outputs = [outputs] + outputs = local_op(*outer_inputs, return_list=True) # Re-order the gradients correctly gradients = [DisconnectedType()()] @@ -3095,7 +3106,6 @@ def compute_all_gradients(known_grads): ) ) - start = len(gradients) gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)] begin = end @@ -3155,7 +3165,12 @@ def R_op(self, inputs, eval_points): rop_self_outputs = self_outputs if info.n_shared_outs > 0: rop_self_outputs = rop_self_outputs[: -info.n_shared_outs] - rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points) + rop_outs = Rop( + rop_self_outputs, + rop_of_inputs, + inner_eval_points, + use_op_rop_implementation=True, + ) if not isinstance(rop_outs, list | tuple): rop_outs = [rop_outs] # Step 2. Figure out what corresponds to what in the scan diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index b185f686bc..9fa823feb8 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -431,20 +431,25 @@ def L_op(self, inputs, outputs, grads): return (g_x,) def R_op(self, inputs, eval_points): + [x] = inputs if eval_points[0] is None: - return [None, None] - if len(self.axis) != 1: - raise ValueError("R_op supported for max only for one axis!") - if self.axis[0] > 1: - raise ValueError("R_op supported for max only when axis is 0 or 1") + return [None] + axis = tuple(range(x.ndim) if self.axis is None else self.axis) + if isinstance(axis, int): + axis = [axis] + if len(axis) != 1: + raise NotImplementedError("R_op supported for max only for one axis!") + if axis[0] > 1: + raise NotImplementedError("R_op supported for max only when axis is 0 or 1") if inputs[0].ndim != 2: - raise ValueError("R_op supported for max only when input is a matrix") - max_pos = Argmax(self.axis).make_node(*inputs).outputs - # print(eval_points[0].eval()) + raise NotImplementedError( + "R_op supported for max only when input is a matrix" + ) + max_pos = Argmax(self.axis)(*inputs) if self.axis[0] == 0: - return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None] + return [eval_points[0][max_pos, arange(eval_points[0].shape[1])]] else: - return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None] + return [eval_points[0][arange(eval_points[0].shape[0]), max_pos]] class Min(NonZeroDimsCAReduce): diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index 8fc2a529df..ba0257cdda 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -306,7 +306,8 @@ def lop_ov(inps, outs, grads): @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] ) - def test_rop(self, cls_ofg): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_rop(self, cls_ofg, use_op_rop_implementation): a = vector() M = matrix() b = dot(a, M) @@ -315,7 +316,7 @@ def test_rop(self, cls_ofg): W = matrix() y = op_matmul(x, W) du = vector() - dv = Rop(y, x, du) + dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation) fn = function([x, W, du], dv) xval = np.random.random((16,)).astype(config.floatX) Wval = np.random.random((16, 16)).astype(config.floatX) @@ -324,7 +325,8 @@ def test_rop(self, cls_ofg): dvval2 = fn(xval, Wval, duval) np.testing.assert_array_almost_equal(dvval2, dvval, 4) - def test_rop_multiple_outputs(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_rop_multiple_outputs(self, use_op_rop_implementation): a = vector() M = matrix() b = dot(a, M) @@ -339,21 +341,21 @@ def test_rop_multiple_outputs(self): duval = np.random.random((16,)).astype(config.floatX) y = op_matmul(x, W)[0] - dv = Rop(y, x, du) + dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation) fn = function([x, W, du], dv) result_dvval = fn(xval, Wval, duval) expected_dvval = np.dot(duval, Wval) np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4) y = op_matmul(x, W)[1] - dv = Rop(y, x, du) + dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation) fn = function([x, W, du], dv) result_dvval = fn(xval, Wval, duval) expected_dvval = -np.dot(duval, Wval) np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4) y = pt.add(*op_matmul(x, W)) - dv = Rop(y, x, du) + dv = Rop(y, x, du, use_op_rop_implementation=use_op_rop_implementation) fn = function([x, W, du], dv) result_dvval = fn(xval, Wval, duval) expected_dvval = np.zeros_like(np.dot(duval, Wval)) @@ -362,7 +364,16 @@ def test_rop_multiple_outputs(self): @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] ) - def test_rop_override(self, cls_ofg): + @pytest.mark.parametrize( + "use_op_rop_implementation", + [ + True, + pytest.param( + False, marks=pytest.mark.xfail(reason="Custom ROp is ignored") + ), + ], + ) + def test_rop_override(self, cls_ofg, use_op_rop_implementation): x, y = vectors("xy") def ro(inps, epts): @@ -380,7 +391,12 @@ def ro(inps, epts): du, dv = vector("du"), vector("dv") for op in [op_mul, op_mul2]: zz = op_mul(xx, yy) - dw = Rop(zz, [xx, yy], [du, dv]) + dw = Rop( + zz, + [xx, yy], + [du, dv], + use_op_rop_implementation=use_op_rop_implementation, + ) fn = function([xx, yy, du, dv], dw) vals = np.random.random((4, 32)).astype(config.floatX) dwval = fn(*vals) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 9fa893ab27..5d0d220603 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -1922,7 +1922,8 @@ def inner_fn(): fgrad = function([], g_sh) assert fgrad() == 1 - def test_R_op(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_R_op(self, use_op_rop_implementation): seed = utt.fetch_seed() rng = np.random.default_rng(seed) floatX = config.floatX @@ -1957,9 +1958,9 @@ def rnn_fn(_u, _y, _W): eh0 = vector("eh0") eW = matrix("eW") - nwo_u = Rop(o, _u, eu) - nwo_h0 = Rop(o, _h0, eh0) - nwo_W = Rop(o, _W, eW) + nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation) + nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation) + nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation) fn_rop = function( [u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W], on_unused_input="ignore" ) @@ -1992,12 +1993,13 @@ def rnn_fn(_u, _y, _W): vnu, vnh0, vnW = fn_rop(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) tnu, tnh0, tnW = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) - utt.assert_allclose(vnu, tnu, atol=1e-6) - utt.assert_allclose(vnh0, tnh0, atol=1e-6) - utt.assert_allclose(vnW, tnW, atol=1e-6) + np.testing.assert_allclose(vnu, tnu, atol=1e-6) + np.testing.assert_allclose(vnh0, tnh0, atol=1e-6) + np.testing.assert_allclose(vnW, tnW, atol=1e-6) @pytest.mark.slow - def test_R_op_2(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_R_op_2(self, use_op_rop_implementation): seed = utt.fetch_seed() rng = np.random.default_rng(seed) floatX = config.floatX @@ -2040,9 +2042,9 @@ def rnn_fn(_u, _y, _W): eh0 = vector("eh0") eW = matrix("eW") - nwo_u = Rop(o, _u, eu) - nwo_h0 = Rop(o, _h0, eh0) - nwo_W = Rop(o, _W, eW) + nwo_u = Rop(o, _u, eu, use_op_rop_implementation=use_op_rop_implementation) + nwo_h0 = Rop(o, _h0, eh0, use_op_rop_implementation=use_op_rop_implementation) + nwo_W = Rop(o, _W, eW, use_op_rop_implementation=use_op_rop_implementation) fn_rop = function( [u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W, o], on_unused_input="ignore" ) @@ -2074,11 +2076,12 @@ def rnn_fn(_u, _y, _W): ) tnu, tnh0, tnW, tno = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW) - utt.assert_allclose(vnu, tnu, atol=1e-6) - utt.assert_allclose(vnh0, tnh0, atol=1e-6) - utt.assert_allclose(vnW, tnW, atol=2e-6) + np.testing.assert_allclose(vnu, tnu, atol=1e-6) + np.testing.assert_allclose(vnh0, tnh0, atol=1e-6) + np.testing.assert_allclose(vnW, tnW, atol=2e-6) - def test_R_op_mitmot(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_R_op_mitmot(self, use_op_rop_implementation): # this test is a copy paste from the script given by Justin Bayer to # reproduce this bug # We have 2 parameter groups with the following shapes. @@ -2094,13 +2097,10 @@ def test_R_op_mitmot(self): W1 = pars[:3].reshape(W1shape) W2 = pars[3:].reshape(W2shape) - # Define recurrent model. We are using a model where each input is a - # tensor - # of shape (T, B, D) where T is the number of timesteps, B is the - # number of - # sequences iterated over in parallel and D is the dimensionality of - # each - # item at a timestep. + # Define recurrent model. We are using a model where each input + # is a tensor of shape (T, B, D) where T is the number of timesteps, + # B is the number of sequences iterated over in parallel and + # D is the dimensionality of each item at a timestep. inpt = tensor3("inpt") target = tensor3("target") @@ -2115,9 +2115,11 @@ def test_R_op_mitmot(self): transfer = sigmoid hidden_rec, _ = scan( - lambda x, h_tm1: transfer(dot(h_tm1, W2) + x), - sequences=hidden, + lambda x, h_tm1, W2: transfer(dot(h_tm1, W2) + x), + sequences=[hidden], outputs_info=[pt.zeros_like(hidden[0])], + non_sequences=[W2], + strict=True, ) hidden_rec.reshape( @@ -2128,7 +2130,64 @@ def test_R_op_mitmot(self): d_cost_wrt_pars = grad(cost, pars) p = dvector() - Rop(d_cost_wrt_pars, pars, p) + # TODO: We should test something about the Rop! + Rop( + d_cost_wrt_pars, + pars, + p, + use_op_rop_implementation=use_op_rop_implementation, + ) + + def test_second_derivative_disconnected_cost_with_mit_mot(self): + # This test is a regression test for a bug that was revealed + # when we computed the pushforward of a Scan gradient via two applications of pullback + seq = pt.vector("seq", shape=(2,)) + z = pt.scalar("z") + x0 = pt.vector("x0", shape=(2,)) + + # When s is 1 and z is 2, xs[-1] is just a sneaky + # x ** 4 (after two nsteps) + # grad should be 4 * x ** 3 + # and grad of grad should be 12 * x ** 2 + def step(s, xtm2, xtm1, z): + return s * ((xtm2 * 0 + xtm1) ** 2) * (z / 2) + + xs, _ = scan( + step, + sequences=[seq], + outputs_info=[{"initial": x0, "taps": (-2, -1)}], + non_sequences=[z], + n_steps=2, + ) + last_x = xs[-1] + + g_wrt_x0, g_wrt_z, g_wrt_seq = pt.grad(last_x, [x0, z, seq]) + g = g_wrt_x0.sum() + g_wrt_z.sum() * 0 + g_wrt_seq.sum() * 0 + assert g.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 4 + gg = pt.grad(g, wrt=x0).sum() + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 + assert gg.eval({seq: [2, 2], x0: [1, 1], z: 2}) == 96 + + # Leave out z + g_wrt_x0, g_wrt_seq = pt.grad(last_x, [x0, seq]) + g = g_wrt_x0.sum() + g_wrt_seq.sum() * 0 + gg = pt.grad(g, wrt=x0).sum() + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 + assert gg.eval({seq: [2, 2], x0: [1, 1], z: 2}) == 96 + + # Leave out seq + g_wrt_x0, g_wrt_z = pt.grad(last_x, [x0, z]) + g = g_wrt_x0.sum() + g_wrt_z.sum() * 0 + gg = pt.grad(g, wrt=x0).sum() + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2 + + # Leave out z and seq + g_wrt_x0 = pt.grad(last_x, x0) + g = g_wrt_x0.sum() + gg = pt.grad(g, wrt=x0).sum() + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 2}) == 12 + assert gg.eval({seq: [1, 1], x0: [1, 1], z: 1}) == 3 / 2 @pytest.mark.skipif( diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index c9b9afff19..d9a444bf01 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -14,7 +14,7 @@ from pytensor.tensor import swapaxes from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import _allclose, dot, matmul +from pytensor.tensor.math import dot, matmul from pytensor.tensor.nlinalg import ( SVD, Det, @@ -51,9 +51,12 @@ def test_rop_lop(): v = vector("v") y = MatrixInverse()(mx).sum(axis=0) - yv = pytensor.gradient.Rop(y, mx, mv) + yv = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=True) rop_f = function([mx, mv], yv) + yv_via_lop = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=False) + rop_via_lop_f = function([mx, mv], yv_via_lop) + sy, _ = pytensor.scan( lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(), sequences=pt.arange(y.shape[0]), @@ -65,22 +68,16 @@ def test_rop_lop(): vx = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) vv = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) - v1 = rop_f(vx, vv) - v2 = scan_f(vx, vv) - - assert _allclose(v1, v2), f"ROP mismatch: {v1} {v2}" + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(rop_f(vx, vv), v_ref) + np.testing.assert_allclose(rop_via_lop_f(vx, vv), v_ref) - raised = False - try: + with pytest.raises(ValueError): pytensor.gradient.Rop( - pytensor.clone_replace(y, replace={mx: break_op(mx)}), mx, mv - ) - except ValueError: - raised = True - if not raised: - raise Exception( - "Op did not raised an error even though the function" - " is not differentiable" + pytensor.clone_replace(y, replace={mx: break_op(mx)}), + mx, + mv, + use_op_rop_implementation=True, ) vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX) @@ -90,9 +87,9 @@ def test_rop_lop(): sy = pytensor.gradient.grad((v * y).sum(), mx) scan_f = function([mx, v], sy) - v1 = lop_f(vx, vv) - v2 = scan_f(vx, vv) - assert _allclose(v1, v2), f"LOP mismatch: {v1} {v2}" + v_ref = scan_f(vx, vv) + v = lop_f(vx, vv) + np.testing.assert_allclose(v, v_ref) def test_transinv_to_invtrans(): diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index e85b8cfd46..7700d2b14b 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -603,7 +603,7 @@ def test_validation(self): class TestRopLop(RopLopChecker): def test_shape(self): - self.check_nondiff_rop(self.x.shape[0]) + self.check_nondiff_rop(self.x.shape[0], self.x, self.v) def test_specifyshape(self): self.check_rop_lop(specify_shape(self.x, self.in_shape), self.in_shape) diff --git a/tests/test_rop.py b/tests/test_rop.py index 0b9fe41a1e..e8c7810f6c 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -17,7 +17,13 @@ import pytensor import pytensor.tensor as pt from pytensor import function -from pytensor.gradient import Lop, Rop, grad, grad_undefined +from pytensor.gradient import ( + Lop, + NullTypeGradError, + Rop, + grad, + grad_undefined, +) from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.tensor.math import argmax, dot @@ -72,13 +78,13 @@ def setup_method(self): self.mv = matrix("mv") self.mat_in_shape = (5 + self.rng.integers(3), 5 + self.rng.integers(3)) - def check_nondiff_rop(self, y): + def check_nondiff_rop(self, y, x, v): """ If your op is not differentiable(so you can't define Rop) test that an error is raised. """ with pytest.raises(ValueError): - Rop(y, self.x, self.v) + Rop(y, x, v, use_op_rop_implementation=True).dprint() def check_mat_rop_lop(self, y, out_shape): """ @@ -106,8 +112,14 @@ def check_mat_rop_lop(self, y, out_shape): vv = np.asarray( self.rng.uniform(size=self.mat_in_shape), pytensor.config.floatX ) - yv = Rop(y, self.mx, self.mv) + yv = Rop(y, self.mx, self.mv, use_op_rop_implementation=True) rop_f = function([self.mx, self.mv], yv, on_unused_input="ignore") + + yv_through_lop = Rop(y, self.mx, self.mv, use_op_rop_implementation=False) + rop_through_lop_f = function( + [self.mx, self.mv], yv_through_lop, on_unused_input="ignore" + ) + sy, _ = pytensor.scan( lambda i, y, x, v: (grad(y[i], x) * v).sum(), sequences=pt.arange(y.shape[0]), @@ -115,13 +127,14 @@ def check_mat_rop_lop(self, y, out_shape): ) scan_f = function([self.mx, self.mv], sy, on_unused_input="ignore") - v1 = rop_f(vx, vv) - v2 = scan_f(vx, vv) - - assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}" + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(rop_f(vx, vv), v_ref) + np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref) self.check_nondiff_rop( - pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}) + pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}), + self.mx, + self.mv, ) vv = np.asarray(self.rng.uniform(size=out_shape), pytensor.config.floatX) @@ -131,11 +144,11 @@ def check_mat_rop_lop(self, y, out_shape): sy = grad((self.v * y).sum(), self.mx) scan_f = function([self.mx, self.v], sy) - v1 = lop_f(vx, vv) - v2 = scan_f(vx, vv) - assert np.allclose(v1, v2), f"LOP mismatch: {v1} {v2}" + v = lop_f(vx, vv) + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(v, v_ref) - def check_rop_lop(self, y, out_shape): + def check_rop_lop(self, y, out_shape, check_nondiff_rop: bool = True): """ As check_mat_rop_lop, except the input is self.x which is a vector. The output is still a vector. @@ -144,32 +157,32 @@ def check_rop_lop(self, y, out_shape): vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) vv = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) - yv = Rop(y, self.x, self.v) + yv = Rop(y, self.x, self.v, use_op_rop_implementation=True) rop_f = function([self.x, self.v], yv, on_unused_input="ignore") + + yv_through_lop = Rop(y, self.x, self.v, use_op_rop_implementation=False) + rop_through_lop_f = function( + [self.x, self.v], yv_through_lop, on_unused_input="ignore" + ) + J, _ = pytensor.scan( lambda i, y, x: grad(y[i], x), sequences=pt.arange(y.shape[0]), non_sequences=[y, self.x], ) sy = dot(J, self.v) - scan_f = function([self.x, self.v], sy, on_unused_input="ignore") - v1 = rop_f(vx, vv) - v2 = scan_f(vx, vv) - assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}" + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(rop_f(vx, vv), v_ref) + np.testing.assert_allclose(rop_through_lop_f(vx, vv), v_ref) - try: - Rop( + if check_nondiff_rop: + self.check_nondiff_rop( pytensor.clone_replace(y, replace={self.x: break_op(self.x)}), self.x, self.v, ) - except ValueError: - pytest.skip( - "Rop does not handle non-differentiable inputs " - "correctly. Bug exposed by fixing Add.grad method." - ) vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX) vv = np.asarray(self.rng.uniform(size=out_shape), pytensor.config.floatX) @@ -182,22 +195,20 @@ def check_rop_lop(self, y, out_shape): non_sequences=[y, self.x], ) sy = dot(self.v, J) - scan_f = function([self.x, self.v], sy) - v1 = lop_f(vx, vv) - v2 = scan_f(vx, vv) - assert np.allclose(v1, v2), f"LOP mismatch: {v1} {v2}" + v = lop_f(vx, vv) + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(v, v_ref) class TestRopLop(RopLopChecker): def test_max(self): - # self.check_mat_rop_lop(pt_max(self.mx, axis=[0,1])[0], ()) self.check_mat_rop_lop(pt_max(self.mx, axis=0), (self.mat_in_shape[1],)) self.check_mat_rop_lop(pt_max(self.mx, axis=1), (self.mat_in_shape[0],)) def test_argmax(self): - self.check_nondiff_rop(argmax(self.mx, axis=1)) + self.check_nondiff_rop(argmax(self.mx, axis=1), self.mx, self.mv) def test_subtensor(self): self.check_rop_lop(self.x[:4], (4,)) @@ -252,10 +263,14 @@ def test_dot(self): insh = self.in_shape[0] vW = np.asarray(self.rng.uniform(size=(insh, insh)), pytensor.config.floatX) W = pytensor.shared(vW) - self.check_rop_lop(dot(self.x, W), self.in_shape) + # check_nondiff_rop shows an error in the legacy Rop logic + # See: test_Rop_partially_differentiable_paths + self.check_rop_lop(dot(self.x, W), self.in_shape, check_nondiff_rop=False) def test_elemwise0(self): - self.check_rop_lop((self.x + 1) ** 2, self.in_shape) + # check_nondiff_rop shows an error in the legacy Rop logic + # See: test_Rop_partially_differentiable_paths + self.check_rop_lop((self.x + 1) ** 2, self.in_shape, check_nondiff_rop=False) def test_elemwise1(self): self.check_rop_lop(self.x + pt.cast(self.x, "int32"), self.in_shape) @@ -287,18 +302,18 @@ def test_alloc(self): self.mat_in_shape[0] * self.mat_in_shape[1] * self.in_shape[0], ) - def test_invalid_input(self): - success = False - - try: - Rop(0.0, [matrix()], [vector()]) - success = True - except ValueError: - pass - - assert not success + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_invalid_input(self, use_op_rop_implementation): + with pytest.raises(ValueError): + Rop( + 0.0, + [matrix()], + [vector()], + use_op_rop_implementation=use_op_rop_implementation, + ) - def test_multiple_outputs(self): + @pytest.mark.parametrize("use_op_rop_implementation", [True, False]) + def test_multiple_outputs(self, use_op_rop_implementation): m = matrix("m") v = vector("v") m_ = matrix("m_") @@ -309,10 +324,20 @@ def test_multiple_outputs(self): m_val = self.rng.uniform(size=(3, 7)).astype(pytensor.config.floatX) v_val = self.rng.uniform(size=(7,)).astype(pytensor.config.floatX) - rop_out1 = Rop([m, v, m + v], [m, v], [m_, v_]) + rop_out1 = Rop( + [m, v, m + v], + [m, v], + [m_, v_], + use_op_rop_implementation=use_op_rop_implementation, + ) assert isinstance(rop_out1, list) assert len(rop_out1) == 3 - rop_out2 = Rop((m, v, m + v), [m, v], [m_, v_]) + rop_out2 = Rop( + (m, v, m + v), + [m, v], + [m_, v_], + use_op_rop_implementation=use_op_rop_implementation, + ) assert isinstance(rop_out2, tuple) assert len(rop_out2) == 3 @@ -322,12 +347,66 @@ def test_multiple_outputs(self): f = pytensor.function([m, v, m_, v_], all_outs) f(mval, vval, m_val, v_val) - def test_Rop_dot_bug_18Oct2013_Jeremiah(self): + @pytest.mark.parametrize( + "use_op_rop_implementation", + [pytest.param(True, marks=pytest.mark.xfail()), False], + ) + def test_Rop_partially_differentiable_paths(self, use_op_rop_implementation): # This test refers to a bug reported by Jeremiah Lowin on 18th Oct # 2013. The bug consists when through a dot operation there is only # one differentiable path (i.e. there is no gradient wrt to one of # the inputs). x = pt.arange(20.0).reshape([1, 20]) - v = pytensor.shared(np.ones([20])) + v = pytensor.shared(np.ones([20]), name="v") d = dot(x, v).sum() - Rop(grad(d, v), v, v) + + pytensor.dprint(grad(d, v)) + Rop( + grad(d, v), + v, + v, + use_op_rop_implementation=use_op_rop_implementation, + # 2025: This is a tricky case, the gradient of the gradient does not depend on v + # although v still exists in the graph inside a `Second` operator. + # The original test was checking that Rop wouldn't raise an error, but Lop does. + # Since the correct behavior is ambiguous, I let both implementations off the hook. + disconnected_outputs="raise" if use_op_rop_implementation else "ignore", + ) + + # 2025: Here is an unambiguous test for the original commented issue: + x = pt.matrix("x") + y = pt.matrix("y") + out = dot(x, break_op(y)).sum() + # Should not raise an error + Rop( + out, + [x], + [x.type()], + use_op_rop_implementation=use_op_rop_implementation, + disconnected_outputs="raise", + ) + + # More extensive testing shows that the legacy Rop implementation FAILS to raise when + # the cost is linked through strictly non-differentiable paths. + # This is not Dot specific, we would observe the same with any operation where the gradient + # with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...) + out = dot(break_op(x), y).sum() + with pytest.raises((ValueError, NullTypeGradError)): + Rop( + out, + [x], + [x.type()], + use_op_rop_implementation=use_op_rop_implementation, + disconnected_outputs="raise", + ) + + # Only when both paths are non-differentiable is an error correctly raised again. + out = dot(break_op(x), break_op(y)).sum() + with pytest.raises((ValueError, NullTypeGradError)): + Rop( + out, + [x], + [x.type()], + use_op_rop_implementation=use_op_rop_implementation, + disconnected_outputs="raise", + )