Skip to content

Commit

Permalink
Remove test value
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Schweer committed Aug 15, 2024
1 parent 2766457 commit ef9277b
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
Expand Down Expand Up @@ -310,10 +310,9 @@ def test_pytorch_ifelse():

for test_value, cond in [(0.2, 0.5), (0.5, 0.4)]:
a = scalar("a")
a.tag.test_value = np.array(test_value, dtype=config.floatX)
x = ifelse(
a < cond, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals])
)
x_fg = FunctionGraph([a], x)

compare_pytorch_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])
compare_pytorch_and_py(x_fg, np.array(test_value, dtype=config.floatX))

0 comments on commit ef9277b

Please sign in to comment.