Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ch0ronomato authored and Ian Schweer committed Nov 25, 2024
1 parent d14d152 commit d17d4a9
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pytensor.ifelse import ifelse
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus
from pytensor.tensor.type import matrices, matrix, scalar, vector


Expand Down Expand Up @@ -343,3 +343,17 @@ def test_pytorch_OpFromGraph():

f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv])


def test_pytorch_scipy():
x = vector("a", shape=(3,))
out = expit(x)
f = FunctionGraph([x], [out])
compare_pytorch_and_py(f, [np.random.rand(3)])


def test_pytorch_softplus():
x = vector("a", shape=(3,))
out = softplus(x)
f = FunctionGraph([x], [out])
compare_pytorch_and_py(f, [np.random.rand(3)])

0 comments on commit d17d4a9

Please sign in to comment.