Skip to content

Commit

Permalink
Split and inverse
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 12, 2024
1 parent 231a977 commit 7ce2a5d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
9 changes: 9 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Eye,
Join,
MakeVector,
Split,
TensorFromScalar,
)

Expand Down Expand Up @@ -185,3 +186,11 @@ def tensorfromscalar(x):
return torch.as_tensor(x)

return tensorfromscalar


@pytorch_funcify.register(Split)
def pytorch_funcify_Split(op, node, **kwargs):
def inner_fn(x, dim, split_amounts):
return x.split(split_amounts.tolist(), dim=dim.item())

return inner_fn
6 changes: 6 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
Cast,
Invert,
ScalarOp,
)
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Softplus


@pytorch_funcify.register(Invert)
def pytorch_funcify_invert(op, node, **kwargs):
return torch.bitwise_not


@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.
Expand Down

0 comments on commit 7ce2a5d

Please sign in to comment.