Skip to content

Commit

Permalink
Do not reject PatternNodeRewriter due unrelated multiple clients
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 30, 2024
1 parent 2143d85 commit 1e9b9c5
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 17 deletions.
17 changes: 9 additions & 8 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,14 +1616,6 @@ def transform(self, fgraph, node, get_nodes=True):
from etuples.core import ExpressionTuple
from unification import reify, unify

# TODO: We shouldn't need to iterate like this.
if not self.allow_multiple_clients and any(
len(fgraph.clients.get(v)) > 1
for v in vars_between(fgraph.inputs, node.outputs)
if v not in fgraph.inputs
):
return False

if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(fgraph, node):
if real_node == "output":
Expand All @@ -1648,6 +1640,15 @@ def transform(self, fgraph, node, get_nodes=True):
if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx

if not self.allow_multiple_clients:
input_vars = list(s.values())
if any(
len(fgraph.clients[v]) > 1
for v in vars_between(input_vars, node.inputs)
if v not in input_vars
):
return False

if ret.owner:
if not (
len(node.outputs) == len(ret.owner.outputs)
Expand Down
81 changes: 72 additions & 9 deletions tests/graph/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
raise AssertionError()


def OpKeyPatternNodeRewriter(p1, p2, ign=False):
return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
def OpKeyPatternNodeRewriter(p1, p2, allow_multiple_clients=False, ign=False):
return OpKeyGraphRewriter(
PatternNodeRewriter(p1, p2, allow_multiple_clients=allow_multiple_clients),
ignore_newtrees=ign,
)


def WalkingPatternNodeRewriter(p1, p2, ign=True):
Expand Down Expand Up @@ -207,13 +210,70 @@ def constraint(r):
assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"

def test_allow_multiple_clients(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e0 = op1(x, y)
# `e0` has multiple clients (i.e. the `op4` and `op3` nodes)
e = op3(op4(e0), e0)
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op4, (op1, "x", "y")), (op3, "x", "y")).rewrite(g)
assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
x, y, z = inputs = MyVariable("x"), MyVariable("y"), MyVariable("z")
w = op1(x, y)
# `w` has multiple clients (i.e. the `op4` and `op3` nodes)
e = op3(op4(w), w)

# By default, allow_multiple_clients is False
# So the replacement should fail
outputs = [e]
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
(op4, (op1, "x", "y")),
(op3, "x", "y"),
).rewrite(g)
assert equal_computations(g.outputs, outputs)

# Now it should be fine
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
(op4, (op1, "x", "y")),
(op3, "x", "y"),
allow_multiple_clients=True,
).rewrite(g)
assert equal_computations(g.outputs, [op3(op3(x, y), w)])

# The fact that the inputs of the pattern have multiple clients should not matter
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
(op3, (op4, "w"), "w"),
(op3, "w", "w"),
allow_multiple_clients=False,
).rewrite(g)
assert equal_computations(g.outputs, [op3(w, w)])

# The fact that are multiple clients above the inputs of the pattern should not matter
v = op4(e)
e1 = op4(v)
e2 = op1(x, x) # Irrelevant reuse of x that should not block rewrite either
e3 = op1(v, v) # Relevant reuse of v that should block rewrite

outputs = [e1, e2]
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
(op4, (op4, "e")),
"e",
allow_multiple_clients=False,
).rewrite(g)
assert equal_computations(g.outputs, [e, e2])

outputs = [e1, e3]
g = FunctionGraph([x, y, z], outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
(op4, (op4, "e")),
"e",
allow_multiple_clients=False,
).rewrite(g)
assert equal_computations(g.outputs, outputs)

g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
(op4, (op4, "e")),
"e",
allow_multiple_clients=True,
).rewrite(g)
assert equal_computations(g.outputs, [e, e3])

def test_eq(self):
# replacing the whole graph
Expand All @@ -226,6 +286,9 @@ def test_eq(self):
str_g = str(g)
assert str_g == "FunctionGraph(Op4(z, y))"

def test_unrelated_variables_with_multiple_clients(self):
"""Test the rewrite still applies w"""


def KeyedSubstitutionNodeRewriter(op1, op2):
return OpKeyGraphRewriter(SubstitutionNodeRewriter(op1, op2))
Expand Down

0 comments on commit 1e9b9c5

Please sign in to comment.