diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 0bed304c29..1f7a58af49 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -1,15 +1,19 @@ +from pytensor import Variable from pytensor.compile.mode import optdb from pytensor.graph import Constant, node_rewriter from pytensor.graph.replace import vectorize_node from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import Dot from pytensor.tensor.rewriting.basic import ( register_canonicalize, register_specialize, register_stabilize, ) +from pytensor.tensor.rewriting.uncanonicalize import local_dimshuffle_alloc +from pytensor.tensor.shape import Reshape from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor @@ -70,7 +74,7 @@ def local_eager_useless_unbatched_blockwise(fgraph, node): Dot | Alloc | ARange | Subtensor | AdvancedSubtensor | AdvancedIncSubtensor, ): # Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize - # These other Ops can't always be trivially vectored at runtime, + # These other Ops can't always be trivially vectorized at runtime, # since their inputs may imply non-rectangular shapes. return local_useless_unbatched_blockwise.fn(fgraph, node) @@ -86,6 +90,18 @@ def _squeeze_left(x, stop_at_dim: int | None = None): return x.squeeze(axis=tuple(range(squeeze_ndim))) +def alloc_or_expand_dims_of_alloc(var: Variable) -> bool: + return var.owner and ( + isinstance(var.owner.op, Alloc) + or ( + isinstance(var.owner.op, DimShuffle) + and var.owner.inputs[0].owner + and isinstance(var.owner.inputs[0].owner.op, Alloc) + ) + ) + + +@register_canonicalize("shape_unsafe") @register_specialize("shape_unsafe") @node_rewriter([Blockwise]) def local_blockwise_alloc(fgraph, node): @@ -97,19 +113,25 @@ def local_blockwise_alloc(fgraph, node): BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector) """ - if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner): - return None - op: Blockwise = node.op # type: ignore batch_ndim = op.batch_ndim(node) if not batch_ndim: return None + if not any(alloc_or_expand_dims_of_alloc(var) for var in node.inputs): + return None + new_inputs = [] batch_shapes = [] can_push_any_alloc = False for inp, inp_sig in zip(node.inputs, op.inputs_sig): + if inp.owner and isinstance(inp.owner.op, DimShuffle): + # Convert DimShuffle of Alloc to Alloc + new_inp = local_dimshuffle_alloc.transform(None, inp.owner) + if new_inp: + [inp] = new_inp + if inp.owner and isinstance(inp.owner.op, Alloc): # Push batch dims from Alloc value, *shape = inp.owner.inputs @@ -167,17 +189,15 @@ def local_blockwise_alloc(fgraph, node): missing_ndim = old_out_type.ndim - new_out_type.ndim batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim] for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples + if old_out_type.broadcastable[i]: + continue for batch_dim in batch_dims: if batch_dim == 1: continue + batch_shape[i] = batch_dim if isinstance(batch_dim, Constant): # Give preference to Constants - batch_shape[i] = batch_dim break - elif old_out_type.broadcastable[i]: - # Only use non Constant shapes if absolutely necessary - # Otherwise, we use the shape of the non-alloc output - batch_shape[i] = batch_dim copy_stack_trace(node.outputs, new_outs) new_outs = [ @@ -190,3 +210,29 @@ def local_blockwise_alloc(fgraph, node): ] copy_stack_trace(node.outputs, new_outs) return new_outs + + +@register_canonicalize +@register_specialize +@node_rewriter([Blockwise]) +def local_blockwise_reshape(fgraph, node): + """Rewrite away square Blockwise reshapes. + + Reshape is tricky to vectorize eagerly, because a graph like + `x.reshape([x.shape[0] * x.shape[1], -1])` has many operations + that must be vectorized before we arrize at the reshape operation. + + For the square Reshape case, we must wait for all the intemediate + operations to be lifted as Allocs + """ + if not isinstance(node.op.core_op, Reshape): + return None + + x, output_shape = node.inputs + batch_ndim = node.op.batch_ndim(node) + if all(output_shape.type.broadcastable[:batch_ndim]): + batched_shape = x.shape[:batch_ndim] + core_reshape = _squeeze_left(output_shape, batch_ndim) + new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)]) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index edac16bdee..b9e7502156 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -32,7 +32,6 @@ from pytensor.tensor.rewriting.basic import ( register_canonicalize, register_specialize, - register_stabilize, register_useless, topo_constant_folding, ) @@ -749,51 +748,43 @@ def apply(self, fgraph): pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) -def local_reshape_chain(op): - @node_rewriter([op]) - def f(fgraph, node): - """ - Reshape(Reshape(shape1),shape2) -> Reshape(shape2) - - """ - if not check_chain(node, op, op): - return False - - # TODO: this can permit a failing program to run by eliminating - # the lower reshape - rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) - - # Copy over stacktrace from previous output node, as any error - # in new computational graph would have been caused by last op - # in the old computational graph. - copy_stack_trace(node.outputs, rval) - - # It might happen that the desired output of this node has a - # broadcastable pattern that does not match that of 'rval'. This is - # when originally, we were able to figure out that one of the - # dimensions of the reshape is one, but some other transformation - # replaced the shape by one for which this cannot be guessed. - # We should try to figure out why we lost the information about this - # constant value... but in the meantime, better not apply this - # rewrite. - if rval.type.ndim == node.outputs[0].type.ndim and all( - s1 == s2 - for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape) - if s1 == 1 or s2 == 1 - ): - return [rval] - else: - return False - - return f +@register_canonicalize("shape_unsafe") +@register_specialize("shape_unsafe") +@node_rewriter([Reshape]) +def local_reshape_chain(fgraph, node): + """ + Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2) + """ + if not check_chain(node, Reshape, Reshape): + return False -register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain") + rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) + + # Copy over stacktrace from previous output node, as any error + # in new computational graph would have been caused by last op + # in the old computational graph. + copy_stack_trace(node.outputs, rval) + + # It might happen that the desired output of this node has a + # broadcastable pattern that does not match that of 'rval'. This is + # when originally, we were able to figure out that one of the + # dimensions of the reshape is one, but some other transformation + # replaced the shape by one for which this cannot be guessed. + # We should try to figure out why we lost the information about this + # constant value... but in the meantime, better not apply this + # rewrite. + if rval.type.ndim == node.outputs[0].type.ndim and all( + s1 == s2 + for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape) + if s1 == 1 or s2 == 1 + ): + return [rval] -@register_useless -@register_canonicalize -@register_stabilize +@register_useless("shape_unsafe") +@register_canonicalize("shape_unsafe") +@register_specialize("shape_unsafe") @node_rewriter([Reshape]) def local_useless_reshape(fgraph, node): """Remove two kinds of useless `Reshape`. @@ -802,24 +793,17 @@ def local_useless_reshape(fgraph, node): - Remove `Reshape` when reshaping to the shape of the input. """ - inp = node.inputs[0] - output = node.outputs[0] - output_shape = node.inputs[1] + inp, output_shape = node.inputs + [output] = node.outputs if inp.type.ndim != output.type.ndim: return False # Simple case: both input and output have a single dimension. - # TODO FIXME XXX: This could hide errors if the user provides inconsistent - # shapes. if ( inp.type.ndim == 1 and output.type.ndim == 1 - and all( - s1 == s2 - for s1, s2 in zip(inp.type.shape, output.type.shape) - if s1 == 1 or s2 == 1 - ) + and inp.type.broadcastable == output.type.broadcastable ): return [inp] @@ -832,8 +816,15 @@ def local_useless_reshape(fgraph, node): # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for # broadcastable and constant dimensions - if output_shape.owner and isinstance(output_shape.owner.op, MakeVector): - output_shape_is = output_shape.owner.inputs + if isinstance(output_shape, Constant) or ( + output_shape.owner and isinstance(output_shape.owner.op, MakeVector) + ): + if isinstance(output_shape, Constant): + output_shape_is = [ + as_tensor_variable(dim, ndim=0) for dim in output_shape.data + ] + else: + output_shape_is = output_shape.owner.inputs shape_feature = getattr(fgraph, "shape_feature", None) @@ -865,9 +856,9 @@ def local_useless_reshape(fgraph, node): shape_match[dim] = True continue - # Match 1 if input.type.shape[dim] == 1 + # Match constant if input.type.shape[dim] == constant cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) - if inp.type.shape[dim] == 1 and cst_outshp_i == 1: + if inp.type.shape[dim] == cst_outshp_i: shape_match[dim] = True continue @@ -881,17 +872,18 @@ def local_useless_reshape(fgraph, node): if shape_feature: inpshp_i = shape_feature.get_shape(inp, dim) if inpshp_i == outshp_i or ( - extract_constant(inpshp_i, only_process_constants=1) - == extract_constant(outshp_i, only_process_constants=1) + extract_constant(inpshp_i, only_process_constants=True) + == extract_constant(outshp_i, only_process_constants=True) ): shape_match[dim] = True continue - if all(shape_match) and nb_m1 <= 1: + if nb_m1 <= 1 and all(shape_match): + return [inp] + + if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1): return [inp] - # TODO later: if all the shapes except one match, we may want to - # consider it useless as well, like we do in the 1-dim case. return False @@ -910,9 +902,8 @@ def local_reshape_to_dimshuffle(fgraph, node): -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n))) """ op = node.op - inp = node.inputs[0] - output = node.outputs[0] - output_shape = node.inputs[1] + inp, output_shape = node.inputs + [output] = node.outputs dimshuffle_new_order = [] new_output_shape = [] @@ -944,7 +935,7 @@ def local_reshape_to_dimshuffle(fgraph, node): @register_canonicalize -@register_stabilize +@register_specialize @node_rewriter([Reshape]) def local_reshape_lift(fgraph, node): """ diff --git a/tests/tensor/rewriting/test_blockwise.py b/tests/tensor/rewriting/test_blockwise.py index d5ea6e2b4e..6653734ee1 100644 --- a/tests/tensor/rewriting/test_blockwise.py +++ b/tests/tensor/rewriting/test_blockwise.py @@ -1,7 +1,9 @@ from functools import partial -from pytensor import function -from pytensor.graph import FunctionGraph, rewrite_graph +import numpy as np + +from pytensor import Mode, function +from pytensor.graph import FunctionGraph, rewrite_graph, vectorize_graph from pytensor.graph.basic import equal_computations from pytensor.scalar import log as scalar_log from pytensor.tensor import add, alloc, matrix, tensor, tensor3 @@ -9,6 +11,7 @@ from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.nlinalg import MatrixPinv from pytensor.tensor.rewriting.blockwise import local_useless_blockwise +from pytensor.tensor.shape import Reshape def test_useless_blockwise_of_elemwise(): @@ -118,3 +121,27 @@ def test_blockwise_alloc(): out = vector_add(x, alloc(y, 5)) expected_out = out assert equal([rewrite(out)], [expected_out]) + + +def test_blockwise_reshape(): + x = tensor("x", shape=(None, None, None)) + y = x.reshape([x.shape[0] * x.shape[1], -1]) + + new_x = tensor("x", shape=(None, None, None, None)) + new_y = vectorize_graph(y, {x: new_x}) + assert not isinstance(new_y.owner.op, Reshape) + assert isinstance(new_y.owner.op, Blockwise) and isinstance( + new_y.owner.op.core_op, Reshape + ) + + rewritten_y = rewrite_graph( + new_y, include=("canonicalize", "specialize"), clone=True + ) + assert isinstance(rewritten_y.owner.op, Reshape) + + no_rewrites = Mode(linker="py", optimizer=None) + test_x = np.arange(5 * 4 * 3 * 2).reshape(5, 4, 3, 2) + np.testing.assert_allclose( + new_y.eval({"x": test_x}, mode=no_rewrites), + rewritten_y.eval({"x": test_x}, mode=no_rewrites), + ) diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index 8392423a58..4a2c83d33f 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -336,6 +336,52 @@ def test_m1(self): topo = f2.maker.fgraph.toposort() assert not any(isinstance(n.op, Reshape) for n in topo) + def test_constant_shape(self): + # Where reshape is a constant that matches the shape + x = matrix(shape=(2, 3)) + shape = pt.as_tensor(np.array([2, 3])) + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is x + + x = matrix(shape=(2, 3)) + shape = pt.as_tensor(np.array([-1, 3])) + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is x + + x = matrix(shape=(None, 3)) + shape = pt.as_tensor(np.array([-1, 3])) + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is x + + x = matrix(shape=(None, 3)) + shape = pt.as_tensor(np.array([2, 3])) + out = reshape(x, shape) + new_out = rewrite_graph(out) + # This could be rewritten as a specify_shape(x, (2, 3)) + assert new_out is not x + + x = matrix(shape=(2, 3)) + shape = pt.as_tensor(np.array([3, 2])) + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is not x + + def test_all_but_one_match(self): + x = matrix(shape=(None, None)) + shape = [x.shape[0], 3] + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert equal_computations([new_out], [specify_shape(x, (None, 3))]) + + # Rewrite does not apply if there's also a -1 + shape = [-1, 3] + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is out + class TestLocalReshapeToDimshuffle: def setup_method(self): diff --git a/tests/tensor/test_einsum.py b/tests/tensor/test_einsum.py index bb543cccca..cb36289936 100644 --- a/tests/tensor/test_einsum.py +++ b/tests/tensor/test_einsum.py @@ -4,9 +4,30 @@ import numpy as np import pytest +import pytensor import pytensor.tensor as pt from pytensor import Mode, function +from pytensor.graph import FunctionGraph +from pytensor.graph.op import HasInnerGraph +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum +from pytensor.tensor.shape import Reshape + + +def assert_no_blockwise_in_graph(fgraph: FunctionGraph, core_op=None) -> None: + for node in fgraph.apply_nodes: + if isinstance(node.op, Blockwise): + if core_op is None: + raise AssertionError + assert not isinstance(node.op.core_op, core_op) + + if isinstance(node.op, HasInnerGraph): + # InnerGraph Ops can be rewritten without modifying the original fgraph + if hasattr(node.op, "_fn"): + inner_fgraph = node.op._fn.maker.fgraph + else: + inner_fgraph = node.op.fgraph + assert_no_blockwise_in_graph(inner_fgraph, core_op=core_op) def test_iota(): @@ -41,9 +62,7 @@ def test_delta(): def test_general_dot(): - mode = Mode(linker="py", optimizer=None) rng = np.random.default_rng(45) - signature = "(l0,a0,a1,l1),(a1,r0,r1,a0)->(l0,l1,r0,r1)" tensordot_axes = [(-3, -2), (-1, -4)] @@ -53,21 +72,17 @@ def test_general_dot(): y = pt.tensor("y", shape=(4, 13, 5, 7, 11)) out = _general_dot((x, y), tensordot_axes, [(0, 1), (0,)]) - # FIXME: Not a satisfactory graph! - # import pytensor - # fn = pytensor.function([x, y], out) - # print() - # pytensor.dprint(fn, print_type=True) - - x_test = rng.normal(size=x.type.shape) - y_test = rng.normal(size=y.type.shape) + fn = pytensor.function([x, y], out) + # fn.dprint(print_type=True) + assert_no_blockwise_in_graph(fn.maker.fgraph, Reshape) np_batched_tensordot = np.vectorize( partial(np.tensordot, axes=tensordot_axes), signature=signature ) - + x_test = rng.normal(size=x.type.shape) + y_test = rng.normal(size=y.type.shape) np.testing.assert_allclose( - out.eval({x: x_test, y: y_test}, mode=mode), + fn(x_test, y_test), np_batched_tensordot(x_test, y_test), ) @@ -121,10 +136,9 @@ def test_einsum_signatures(static_shape_known, signature): fn = function(operands, out) pt_out = fn(*test_values) - # import pytensor - # print(); pytensor.dprint(fn, print_type=True) + # print(); fn.dprint(print_type=True) - # assert out.type.shape == np_out.shape # Reshape operations lose static shape + assert_no_blockwise_in_graph(fn.maker.fgraph) np.testing.assert_allclose(pt_out, np_out)