From 7fec04564748c7c6654fcd8c106c5930e4023c4d Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 24 Nov 2024 18:47:08 -0800 Subject: [PATCH] Check for scipy in elemwise --- pytensor/link/pytorch/dispatch/elemwise.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 4ae74d9bb6..79ca5beec1 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -1,3 +1,5 @@ +import importlib + import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify @@ -10,8 +12,20 @@ def pytorch_funcify_Elemwise(op, node, **kwargs): scalar_op = op.scalar_op base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) + + def check_special_scipy(func_name): + if "scipy." not in func_name: + return False + loc = func_name.split(".")[1:] + try: + mod = importlib.import_module(".".join(loc[:-1]), "torch") + return getattr(mod, loc[-1], False) + except ImportError: + return False + if hasattr(scalar_op, "nfunc_spec") and ( - hasattr(torch, scalar_op.nfunc_spec[0]) or "scipy." in scalar_op.nfunc_spec[0] + hasattr(torch, scalar_op.nfunc_spec[0]) + or check_special_scipy(scalar_op.nfunc_spec[0]) ): # torch can handle this scalar # broadcast, we'll let it.