Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement unconditional constant_folding rewrite #1068

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
NodeRewriter,
RemovalNodeRewriter,
Rewriter,
Expand Down Expand Up @@ -1101,10 +1102,7 @@ def local_useless_split(fgraph, node):


@node_rewriter(None)
def constant_folding(fgraph, node):
if not node.op.do_constant_folding(fgraph, node):
return False

def unconditional_constant_folding(fgraph, node):
if not all(isinstance(inp, Constant) for inp in node.inputs):
return False

Expand Down Expand Up @@ -1151,6 +1149,23 @@ def constant_folding(fgraph, node):
return rval


topo_unconditional_constant_folding = in2out(
unconditional_constant_folding,
ignore_newtrees=True,
name="topo_unconditional_constant_folding",
# Not all Ops have a perform method, so we ignore failures to constant_fold
failure_callback=NodeProcessingGraphRewriter.warn_ignore,
)


@node_rewriter(None)
def constant_folding(fgraph, node):
if not node.op.do_constant_folding(fgraph, node):
return False

return unconditional_constant_folding.transform(fgraph, node)


topo_constant_folding = in2out(
constant_folding, ignore_newtrees=True, name="topo_constant_folding"
)
Expand Down
132 changes: 86 additions & 46 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations
from pytensor.graph import Op
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
Expand All @@ -29,6 +30,7 @@
TensorFromScalar,
as_tensor,
cast,
constant,
join,
tile,
)
Expand Down Expand Up @@ -65,6 +67,8 @@
local_merge_alloc,
local_useless_alloc,
local_useless_elemwise,
topo_constant_folding,
topo_unconditional_constant_folding,
topological_fill_sink,
)
from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot
Expand Down Expand Up @@ -742,56 +746,92 @@ def test_upcast(self):
) or (len(topo) > 1)


def test_constant_folding():
# Test that constant folding get registered at fast_compile
# An error removed that registration during the registration.
x = dvector()
mode = get_mode("FAST_COMPILE").excluding("fusion")
f = function([x], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2

# Test that we do not crash when constant folding elemwise scalar
# as they should not generate c code.
class TestConstantFolding:
def test_constant_folding(self):
# Test that constant folding get registered at fast_compile
# An error removed that registration during the registration.
x = dvector()
mode = get_mode("FAST_COMPILE").excluding("fusion")
f = function([x], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2

x = pt.constant(3)
assert x.ndim == 0
mode = get_mode("FAST_COMPILE").excluding("fusion")
f = function([], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert all(isinstance(n.op, DeepCopyOp) for n in topo)
# Test that we do not crash when constant folding elemwise scalar
# as they should not generate c code.

x = pt.constant(3)
assert x.ndim == 0
mode = get_mode("FAST_COMPILE").excluding("fusion")
f = function([], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert all(isinstance(n.op, DeepCopyOp) for n in topo)

@pytest.mark.xfail(
reason="PyTensor rewrites constants before stabilization. "
"This breaks stabilization rewrites in some cases. See #504.",
raises=AssertionError,
)
def test_constant_get_stabilized():
# Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites.
# This caused some stabilization rewrites to not be activated and that
# caused inf values to appear when they should not.
@pytest.mark.xfail(
reason="PyTensor rewrites constants before stabilization. "
"This breaks stabilization rewrites in some cases. See #504.",
raises=AssertionError,
)
def test_constant_get_stabilized(self):
# Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites.
# This caused some stabilization rewrites to not be activated and that
# caused inf values to appear when they should not.

# We can't simply move the `constant_folding` rewrite to
# specialize since this will break other rewrites. We will need to
# partially duplicate some canonicalize rewrites to fix this issue.
# We can't simply move the `constant_folding` rewrite to
# specialize since this will break other rewrites. We will need to
# partially duplicate some canonicalize rewrites to fix this issue.

x2 = scalar()
y2 = log(1 + exp(x2))
mode = get_default_mode()
mode.check_isfinite = False
f2 = function([x2], y2, mode=mode)

assert len(f2.maker.fgraph.toposort()) == 1
assert f2.maker.fgraph.toposort()[0].op == softplus
assert f2(800) == 800

x = pt.as_tensor_variable(800)
y = log(1 + exp(x))
f = function([], y, mode=mode)
# When this error is fixed, the following line should be ok.
assert f() == 800, f()
x2 = scalar()
y2 = log(1 + exp(x2))
mode = get_default_mode()
mode.check_isfinite = False
f2 = function([x2], y2, mode=mode)

assert len(f2.maker.fgraph.toposort()) == 1
assert f2.maker.fgraph.toposort()[0].op == softplus
assert f2(800) == 800

x = pt.as_tensor_variable(800)
y = log(1 + exp(x))
f = function([], y, mode=mode)
# When this error is fixed, the following line should be ok.
assert f() == 800, f()

def test_unconditional(self):
x = pt.alloc(np.e, *(3, 5))
fg = FunctionGraph(outputs=[x], clone=False)

# Default constant folding doesn't apply to Alloc used as outputs
topo_constant_folding.apply(fg)
assert not isinstance(fg.outputs[0], Constant)

# Unconditional constant folding does apply
topo_unconditional_constant_folding.apply(fg)
assert isinstance(fg.outputs[0], Constant)
np.testing.assert_allclose(fg.outputs[0].data, np.full((3, 5), np.e))

def test_unconditional_no_perform_method(self):
"""Test that errors are caught when the Op does not have a perform method."""

class OpNoPerform(Op):
itypes = [scalar(dtype="float64").type]
otypes = [scalar(dtype="float64").type]

def perform(self, *args, **kwargs):
raise NotImplementedError("This Op cannot be evaluated")

x = constant(np.array(5.0))
out = OpNoPerform()(x)

fg = FunctionGraph(outputs=[out], clone=False)
# Default constant_folding will raise
with pytest.raises(NotImplementedError):
topo_constant_folding.apply(fg)

# Unconditional constant folding will be silent
topo_unconditional_constant_folding.apply(fg)
assert not isinstance(fg.outputs[0], Constant)
assert isinstance(fg.outputs[0].owner.op, OpNoPerform)


class TestLocalSwitchSink:
Expand Down
Loading