Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove useless SpecifyShape #885

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions pytensor/tensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)
from pytensor.tensor.subtensor import Subtensor, get_idx_list
from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.type_other import NoneConst, NoneTypeT


class ShapeFeature(Feature):
Expand Down Expand Up @@ -974,6 +974,35 @@ def local_reshape_lift(fgraph, node):
return [e]


@register_useless
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([SpecifyShape])
def local_useless_specify_shape(fgraph, node):
"""Remove SpecifyShape when the asserted shapes are already encoded in the static type of the input."""
x, *shape = node.inputs
for static_dim, specified_dim in zip(x.type.shape, shape, strict=True):
if isinstance(specified_dim.type, NoneTypeT):
continue
if static_dim is None:
# There is an unknown static dimension that is being specified
return None
if not (
isinstance(specified_dim, Constant) and specified_dim.data == static_dim
):
# The specified dim is either:
# 1. Not constant or
# 2. Constant that does not match the static dim
# Either way, we must keep the SpecifyShape
return None

# If we arrived here, it means SpecifyShape was already encoded in the static shape
# We don't need it
copy_stack_trace(node.outputs[0], x)
return [x]


@register_infer_shape
@register_useless
@register_canonicalize
Expand Down Expand Up @@ -1189,10 +1218,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
@register_specialize
@node_rewriter([Unbroadcast])
def local_useless_unbroadcast(fgraph, node):
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern.

TODO: Implement equivalent rewrite for SpecifyShape
"""
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern."""
if isinstance(node.op, Unbroadcast):
x = node.inputs[0]
if x.type.ndim == node.outputs[0].type.ndim and all(
Expand Down
25 changes: 25 additions & 0 deletions tests/tensor/rewriting/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ShapeFeature,
local_reshape_to_dimshuffle,
local_useless_reshape,
local_useless_specify_shape,
)
from pytensor.tensor.shape import (
Reshape,
Expand Down Expand Up @@ -476,6 +477,30 @@ def test_vector_dim_err(self):
shape_feature.same_shape(x, o, 0, 1)


def test_useless_specify_shape():
x = tensor("x", shape=(None, 5, 3))

# We avoid the helper specify_shape that optimizes some (but not all) cases eagerly
ss = SpecifyShape()

out = ss(x, None, 5, None)
assert isinstance(out.owner.op, SpecifyShape)
ret = local_useless_specify_shape.transform(None, out.owner)
assert ret == [x]

# SpecifyShape is needed to enfore unknown dim is 3
out = ss(x, 3, 5, None)
assert isinstance(out.owner.op, SpecifyShape)
ret = local_useless_specify_shape.transform(None, out.owner)
assert ret is None

# SpecifyShape is needed to raise mismatch between static and specified dim
out = ss(x, None, 5, 4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beyond the scope of this PR, but SpecifyShape could just eagerly raise in this case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there's also a config flag to raise during rewrites when shapes are found to be incompatible in the shape feature. But I don't see why not do it all the time.

Can you open an issue?

assert isinstance(out.owner.op, SpecifyShape)
ret = local_useless_specify_shape.transform(None, out.owner)
assert ret is None


@pytest.mark.parametrize(
"shape",
[lscalar(), iscalar()],
Expand Down
Loading