Skip to content

Commit

Permalink
Check for scipy in elemwise
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Schweer committed Nov 25, 2024
1 parent 806b189 commit 7fec045
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import importlib

import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
Expand All @@ -10,8 +12,20 @@
def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)

def check_special_scipy(func_name):
if "scipy." not in func_name:
return False

Check warning on line 18 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L18

Added line #L18 was not covered by tests
loc = func_name.split(".")[1:]
try:
mod = importlib.import_module(".".join(loc[:-1]), "torch")
return getattr(mod, loc[-1], False)

Check warning on line 22 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L22

Added line #L22 was not covered by tests
except ImportError:
return False

if hasattr(scalar_op, "nfunc_spec") and (
hasattr(torch, scalar_op.nfunc_spec[0]) or "scipy." in scalar_op.nfunc_spec[0]
hasattr(torch, scalar_op.nfunc_spec[0])
or check_special_scipy(scalar_op.nfunc_spec[0])
):
# torch can handle this scalar
# broadcast, we'll let it.
Expand Down

0 comments on commit 7fec045

Please sign in to comment.