diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index edac16bdee..c395cd70ab 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -48,7 +48,7 @@ ) from pytensor.tensor.subtensor import Subtensor, get_idx_list from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes -from pytensor.tensor.type_other import NoneConst +from pytensor.tensor.type_other import NoneConst, NoneTypeT class ShapeFeature(Feature): @@ -974,6 +974,35 @@ def local_reshape_lift(fgraph, node): return [e] +@register_useless +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([SpecifyShape]) +def local_useless_specify_shape(fgraph, node): + """Remove SpecifyShape when the asserted shapes are already encoded in the static type of the input.""" + x, *shape = node.inputs + for static_dim, specified_dim in zip(x.type.shape, shape, strict=True): + if isinstance(specified_dim.type, NoneTypeT): + continue + if static_dim is None: + # There is an unknown static dimension that is being specified + return None + if not ( + isinstance(specified_dim, Constant) and specified_dim.data == static_dim + ): + # The specified dim is either: + # 1. Not constant or + # 2. Constant that does not match the static dim + # Either way, we must keep the SpecifyShape + return None + + # If we arrived here, it means SpecifyShape was already encoded in the static shape + # We don't need it + copy_stack_trace(node.outputs[0], x) + return [x] + + @register_infer_shape @register_useless @register_canonicalize @@ -1189,10 +1218,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node): @register_specialize @node_rewriter([Unbroadcast]) def local_useless_unbroadcast(fgraph, node): - """Remove `Unbroadcast` if it does not actually change the broadcasting pattern. - - TODO: Implement equivalent rewrite for SpecifyShape - """ + """Remove `Unbroadcast` if it does not actually change the broadcasting pattern.""" if isinstance(node.op, Unbroadcast): x = node.inputs[0] if x.type.ndim == node.outputs[0].type.ndim and all( diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index 8392423a58..f4c529a0d2 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -23,6 +23,7 @@ ShapeFeature, local_reshape_to_dimshuffle, local_useless_reshape, + local_useless_specify_shape, ) from pytensor.tensor.shape import ( Reshape, @@ -476,6 +477,30 @@ def test_vector_dim_err(self): shape_feature.same_shape(x, o, 0, 1) +def test_useless_specify_shape(): + x = tensor("x", shape=(None, 5, 3)) + + # We avoid the helper specify_shape that optimizes some (but not all) cases eagerly + ss = SpecifyShape() + + out = ss(x, None, 5, None) + assert isinstance(out.owner.op, SpecifyShape) + ret = local_useless_specify_shape.transform(None, out.owner) + assert ret == [x] + + # SpecifyShape is needed to enfore unknown dim is 3 + out = ss(x, 3, 5, None) + assert isinstance(out.owner.op, SpecifyShape) + ret = local_useless_specify_shape.transform(None, out.owner) + assert ret is None + + # SpecifyShape is needed to raise mismatch between static and specified dim + out = ss(x, None, 5, 4) + assert isinstance(out.owner.op, SpecifyShape) + ret = local_useless_specify_shape.transform(None, out.owner) + assert ret is None + + @pytest.mark.parametrize( "shape", [lscalar(), iscalar()],