From 0e855d0876dffe0fc413e0b015971d5c22370e66 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Fri, 1 Nov 2024 21:47:53 -0700 Subject: [PATCH 1/5] Allow for scipy module resolution --- pytensor/link/pytorch/dispatch/scalar.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index a977c6d4b2..c1ff6d0e53 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -1,3 +1,5 @@ +import importlib + import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify @@ -19,9 +21,14 @@ def pytorch_funcify_ScalarOp(op, node, **kwargs): if nfunc_spec is None: raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}") - func_name = nfunc_spec[0] + func_name = nfunc_spec[0].replace("scipy.", "") - pytorch_func = getattr(torch, func_name) + if "." in func_name: + loc = func_name.split(".") + mod = importlib.import_module(".".join(["torch", *loc[:-1]])) + pytorch_func = getattr(mod, loc[-1]) + else: + pytorch_func = getattr(torch, func_name) if len(node.inputs) > op.nfunc_spec[1]: # Some Scalar Ops accept multiple number of inputs, behaving as a variadic function, From 24360b61d45b083bfb2dca3a95ff220b00c02461 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Fri, 1 Nov 2024 21:48:30 -0700 Subject: [PATCH 2/5] Add softplus --- pytensor/link/pytorch/dispatch/scalar.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index c1ff6d0e53..1416e58f55 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -7,6 +7,7 @@ Cast, ScalarOp, ) +from pytensor.scalar.math import Softplus @pytorch_funcify.register(ScalarOp) @@ -56,3 +57,8 @@ def cast(x): return x.to(dtype=dtype) return cast + + +@pytorch_funcify.register(Softplus) +def pytorch_funcify_Softplus(op, node, **kwargs): + return torch.nn.Softplus() From ab6cb21f07f72cccf4d1ee7684c28ca4a8415ef1 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Wed, 6 Nov 2024 20:43:09 -0800 Subject: [PATCH 3/5] Add tests --- tests/link/pytorch/test_basic.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 25827d23f9..83249d021b 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -17,7 +17,7 @@ from pytensor.ifelse import ifelse from pytensor.link.pytorch.linker import PytorchLinker from pytensor.raise_op import CheckAndRaise -from pytensor.tensor import alloc, arange, as_tensor, empty, eye +from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus from pytensor.tensor.type import matrices, matrix, scalar, vector @@ -374,3 +374,17 @@ def inner_fn(x): f = function([x], out, mode="PYTORCH") f(torch.ones(3)) assert "inner_fn" not in dir(m), "function call reference leaked" + + +def test_pytorch_scipy(): + x = vector("a", shape=(3,)) + out = expit(x) + f = FunctionGraph([x], [out]) + compare_pytorch_and_py(f, [np.random.rand(3)]) + + +def test_pytorch_softplus(): + x = vector("a", shape=(3,)) + out = softplus(x) + f = FunctionGraph([x], [out]) + compare_pytorch_and_py(f, [np.random.rand(3)]) From 806b189c7cf2f848111a59844d12bf6bcfcbb140 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 24 Nov 2024 18:03:14 -0800 Subject: [PATCH 4/5] Allow scipy scalar handling --- pytensor/link/pytorch/dispatch/elemwise.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 72f97af1fa..4ae74d9bb6 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -10,13 +10,15 @@ def pytorch_funcify_Elemwise(op, node, **kwargs): scalar_op = op.scalar_op base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) - - if hasattr(scalar_op, "nfunc_spec") and hasattr(torch, scalar_op.nfunc_spec[0]): + if hasattr(scalar_op, "nfunc_spec") and ( + hasattr(torch, scalar_op.nfunc_spec[0]) or "scipy." in scalar_op.nfunc_spec[0] + ): # torch can handle this scalar # broadcast, we'll let it. def elemwise_fn(*inputs): Elemwise._check_runtime_broadcast(node, inputs) return base_fn(*inputs) + else: def elemwise_fn(*inputs): From 7fec04564748c7c6654fcd8c106c5930e4023c4d Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 24 Nov 2024 18:47:08 -0800 Subject: [PATCH 5/5] Check for scipy in elemwise --- pytensor/link/pytorch/dispatch/elemwise.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 4ae74d9bb6..79ca5beec1 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -1,3 +1,5 @@ +import importlib + import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify @@ -10,8 +12,20 @@ def pytorch_funcify_Elemwise(op, node, **kwargs): scalar_op = op.scalar_op base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) + + def check_special_scipy(func_name): + if "scipy." not in func_name: + return False + loc = func_name.split(".")[1:] + try: + mod = importlib.import_module(".".join(loc[:-1]), "torch") + return getattr(mod, loc[-1], False) + except ImportError: + return False + if hasattr(scalar_op, "nfunc_spec") and ( - hasattr(torch, scalar_op.nfunc_spec[0]) or "scipy." in scalar_op.nfunc_spec[0] + hasattr(torch, scalar_op.nfunc_spec[0]) + or check_special_scipy(scalar_op.nfunc_spec[0]) ): # torch can handle this scalar # broadcast, we'll let it.