From bfaaa0ce6283073634f2869a93252b0fd741b731 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 30 Apr 2024 14:05:31 +0200 Subject: [PATCH] Keep stack trace in random_make_inplace --- pytensor/tensor/random/rewriting/basic.py | 5 ++++- tests/tensor/random/rewriting/test_basic.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index c6cfda443f..58d2668e79 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -50,7 +50,10 @@ def random_make_inplace(fgraph, node): props = op._props_dict() props["inplace"] = True new_op = type(op)(**props) - return new_op.make_node(*node.inputs).outputs + new_outputs = new_op.make_node(*node.inputs).outputs + for old_out, new_out in zip(node.outputs, new_outputs): + copy_stack_trace(old_out, new_out) + return new_outputs return False diff --git a/tests/tensor/random/rewriting/test_basic.py b/tests/tensor/random/rewriting/test_basic.py index 66fe5287d2..f342d5b81c 100644 --- a/tests/tensor/random/rewriting/test_basic.py +++ b/tests/tensor/random/rewriting/test_basic.py @@ -9,7 +9,7 @@ from pytensor.compile.mode import Mode from pytensor.graph.basic import Constant, Variable, ancestors from pytensor.graph.fg import FunctionGraph -from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter +from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter, check_stack_trace from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.tensor import constant from pytensor.tensor.elemwise import DimShuffle @@ -143,6 +143,7 @@ def test_inplace_rewrites(rv_op): for a, b in zip(new_op.dist_params(new_node), op.dist_params(node)) ) assert np.array_equal(new_op.size_param(new_node).data, op.size_param(node).data) + assert check_stack_trace(f) @config.change_flags(compute_test_value="raise")