Skip to content

Commit

Permalink
Allow fill_sink rewrite to accomodate changes in broadcastability
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 28, 2024
1 parent a6255d6 commit fc21336
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
5 changes: 1 addition & 4 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,7 @@ def local_fill_sink(fgraph, node):
# Check if we need to propagate the fill to the new outputs
# It's enough to check the first output, as Elemwise outputs must all have the same shapes
# Note: There are orderings that may require fewer fills.
old_bcast_pattern = node.outputs[0].type.broadcastable
models_iter = iter(models)
while old_bcast_pattern != outputs[0].type.broadcastable:
model = next(models_iter)
for model in models:
# Only apply this model if it would actually do anything
if broadcasted_by(outputs[0], model):
outputs = [fill(model, output) for output in outputs]
Expand Down
18 changes: 17 additions & 1 deletion tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytensor
import pytensor.scalar as ps
import pytensor.tensor as pt
from pytensor import shared
from pytensor import graph_replace, shared
from pytensor.compile import optdb
from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode
Expand Down Expand Up @@ -2010,3 +2010,19 @@ def test_topological_fill_sink_multi_output_client():
[new_out] = fg.outputs
expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x))))
assert equal_computations([new_out], [expected_out])


def test_topological_fill_sink_broadcastable_change():
"""Test rewrite doesn't fail after a graph replacement that provides a broadcastable change."""
a = vector("a", shape=(1,))
b = vector("b", shape=(1,))
zeros = pt.vector("zeros", shape=(None,))
initial_out = pt.full_like(zeros, a) + b

# Make broadcast to zeros irrelevant
out = graph_replace(initial_out, {zeros: pt.zeros((1,))}, strict=False)

fg = FunctionGraph([a, b], [out], copy_inputs=False)
topological_fill_sink.rewrite(fg)
[new_out] = fg.outputs
assert equal_computations([new_out], [a + b])

0 comments on commit fc21336

Please sign in to comment.