From fc21336a67334241d1c9a59383084ce453988d4e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 27 May 2024 17:54:08 +0200 Subject: [PATCH] Allow fill_sink rewrite to accomodate changes in broadcastability --- pytensor/tensor/rewriting/basic.py | 5 +---- tests/tensor/rewriting/test_basic.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index c24012705d..710f764bc5 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -351,10 +351,7 @@ def local_fill_sink(fgraph, node): # Check if we need to propagate the fill to the new outputs # It's enough to check the first output, as Elemwise outputs must all have the same shapes # Note: There are orderings that may require fewer fills. - old_bcast_pattern = node.outputs[0].type.broadcastable - models_iter = iter(models) - while old_bcast_pattern != outputs[0].type.broadcastable: - model = next(models_iter) + for model in models: # Only apply this model if it would actually do anything if broadcasted_by(outputs[0], model): outputs = [fill(model, output) for output in outputs] diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 366e09ed4a..4ff773dbb8 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -6,7 +6,7 @@ import pytensor import pytensor.scalar as ps import pytensor.tensor as pt -from pytensor import shared +from pytensor import graph_replace, shared from pytensor.compile import optdb from pytensor.compile.function import function from pytensor.compile.mode import get_default_mode, get_mode @@ -2010,3 +2010,19 @@ def test_topological_fill_sink_multi_output_client(): [new_out] = fg.outputs expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x)))) assert equal_computations([new_out], [expected_out]) + + +def test_topological_fill_sink_broadcastable_change(): + """Test rewrite doesn't fail after a graph replacement that provides a broadcastable change.""" + a = vector("a", shape=(1,)) + b = vector("b", shape=(1,)) + zeros = pt.vector("zeros", shape=(None,)) + initial_out = pt.full_like(zeros, a) + b + + # Make broadcast to zeros irrelevant + out = graph_replace(initial_out, {zeros: pt.zeros((1,))}, strict=False) + + fg = FunctionGraph([a, b], [out], copy_inputs=False) + topological_fill_sink.rewrite(fg) + [new_out] = fg.outputs + assert equal_computations([new_out], [a + b])