Skip to content

Commit

Permalink
Keep stack trace in random_make_inplace
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 8, 2024
1 parent e4606f1 commit ff8aece
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 4 additions & 1 deletion pytensor/tensor/random/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def random_make_inplace(fgraph, node):
props = op._props_dict()
props["inplace"] = True
new_op = type(op)(**props)
return new_op.make_node(*node.inputs).outputs
new_outputs = new_op.make_node(*node.inputs).outputs
for old_out, new_out in zip(node.outputs, new_outputs):
copy_stack_trace(old_out, new_out)
return new_outputs

return False

Expand Down
4 changes: 2 additions & 2 deletions tests/tensor/random/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytensor.compile.mode import Mode
from pytensor.graph.basic import Constant, Variable, ancestors
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter, check_stack_trace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import constant
from pytensor.tensor.elemwise import DimShuffle
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_inplace_rewrites(rv_op):
for a, b in zip(new_op.dist_params(new_node), op.dist_params(node))
)
assert np.array_equal(new_op.size_param(new_node).data, op.size_param(node).data)

assert check_stack_trace(f)

@config.change_flags(compute_test_value="raise")
@pytest.mark.parametrize(
Expand Down

0 comments on commit ff8aece

Please sign in to comment.