Skip to content

Commit

Permalink
Rewrite dots as multiplication without summation
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 27, 2025
1 parent 92ebf60 commit e0cb086
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 15 deletions.
42 changes: 34 additions & 8 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
stack,
switch,
)
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import (
CAReduce,
Elemwise,
Expand Down Expand Up @@ -2726,6 +2726,22 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims))


# Predefine all batched variations of Dot
_inner_prod = Blockwise(
_dot,
signature="(n),(n)->()",
)

_matrix_vec_prod = Blockwise(
_dot,
signature="(m,k),(k)->(m)",
)

_vec_matrix_prod = Blockwise(
_dot,
signature="(k),(k,n)->(n)",
)

_matrix_matrix_matmul = Blockwise(
_dot,
signature="(m,k),(k,n)->(m,n)",
Expand Down Expand Up @@ -2795,14 +2811,24 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None


@_vectorize_node.register(Dot)
def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
def vectorize_node_dot(op, node, batched_x, batched_y):
old_x, old_y = node.inputs
if old_x.type.ndim == 2 and old_y.type.ndim == 2:
# If original input is equivalent to a matrix-matrix product,
# return specialized Matmul Op to avoid unnecessary new Ops.
return matmul(batched_x, batched_y).owner
else:
return vectorize_node_fallback(op, node, batched_x, batched_y)
old_x_ndim = old_x.type.ndim
old_y_ndim = old_y.type.ndim
match (old_x_ndim, old_y_ndim):
case (1, 1):
batch_op = _inner_prod
case (2, 1):
batch_op = _matrix_vec_prod
case (1, 2):
batch_op = _vec_matrix_prod
case (2, 2):
batch_op = _matrix_matrix_matmul
case _:
raise ValueError(
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
)
return batch_op(batched_x, batched_y).owner


def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
Expand Down
65 changes: 65 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
Prod,
Sum,
_conj,
_dot,
_inner_prod,
_matrix_matrix_matmul,
_matrix_vec_prod,
_vec_matrix_prod,
add,
digamma,
dot,
Expand Down Expand Up @@ -242,6 +247,66 @@ def local_batched_matmul_to_core_matmul(fgraph, node):
return None


@register_canonicalize
@register_specialize
@node_rewriter(
[_dot, _inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul]
)
def local_dot_to_mul(fgraph, node):
"""Rewrite dots that correspond to multiplication without summation."""
a, b = node.inputs
a_st_shape = a.type.shape
b_st_shape = b.type.shape
if isinstance(node.op, Dot):
core_a_ndim = a.type.ndim
core_b_ndim = b.type.ndim
else:
# Blockwise variants of Dot
core_a_ndim = len(node.op.inputs_sig[0])
core_b_ndim = len(node.op.inputs_sig[1])

if core_a_ndim > 2 or core_b_ndim > 2:
# Shouldn't happen, but here just in case
return None

if core_b_ndim == 1:
if a_st_shape[-1] == 1 or b_st_shape[-1] == 1:
if core_a_ndim == 1:
# inner product: (..., 1) * (..., 1) -> (...)
# just squeeze the last dimensions of a and b
new_a = a.squeeze(-1)
new_b = b.squeeze(-1)
else:
# matrix vector product: (..., m, 1) * (..., 1) -> (..., m)
# the last dimension b is already aligned for the elemwise multiplication
# after we squeeze the last dimension of a
new_a = a.squeeze(-1)
new_b = b
else:
return None

else:
if a_st_shape[-1] == 1 or b_st_shape[-2] == 1:
if core_a_ndim == 1:
# vector_matrix product: (..., 1) * (..., 1, n) -> (..., n)
# the last dimension of a is already aligned for the elemwise multiplication
# after we squeeze the one to last dimension of b
new_a = a
new_b = b.squeeze(-2)
else:
# matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
# the dimensions of a and b are already aligned for the elemwise multiplication
new_a = a
new_b = b
else:
return None

new_a = copy_stack_trace(a, new_a)
new_b = copy_stack_trace(b, new_b)
new_out = copy_stack_trace(node.out, mul(new_a, new_b))
return [new_out]


def is_inverse_pair(node_op, prev_op, inv_pair):
"""
Given two consecutive operations, check if they are the
Expand Down
10 changes: 8 additions & 2 deletions tests/compile/test_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

import pytensor.tensor as pt
from pytensor.compile import ProfileStats
from pytensor.compile import ProfileStats, get_mode
from pytensor.compile.function import function
from pytensor.configdefaults import config
from pytensor.ifelse import ifelse
Expand All @@ -28,7 +28,10 @@ def test_profiling(self):
x = [fvector(f"val{i}") for i in range(3)]

z = []
z += [pt.outer(x[i], x[i + 1]).sum(axis=1) for i in range(len(x) - 1)]
z += [
pt.dot(x[i][:, None], x[i + 1][None, :]).sum(axis=1)
for i in range(len(x) - 1)
]
z += [x[i] + x[i + 1] for i in range(len(x) - 1)]

p = ProfileStats(False, gpu_checks=False)
Expand All @@ -38,6 +41,9 @@ def test_profiling(self):
else:
m = None

# This test requires an unoptimized outer mul written as a dot
m = get_mode(m).excluding("local_dot_to_mul")

f = function(x, z, profile=p, name="test_profiling", mode=m)

inp = [np.arange(1024, dtype="float32") + 1 for i in range(len(x))]
Expand Down
51 changes: 50 additions & 1 deletion tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, equal_computations
from pytensor.graph import vectorize_graph
from pytensor.graph.basic import Apply, ancestors, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import (
SequentialNodeRewriter,
Expand Down Expand Up @@ -4571,3 +4572,51 @@ def test_log_kv_stabilization():
out.eval({x: 1000.0}, mode=mode),
-1003.2180912984705,
)


@pytest.mark.parametrize(
"a_shape,b_shape",
[
((1,), (1,)),
((3, 1), (1,)),
((1,), (1, 3)),
((3, 1), (1, 3)),
],
)
@pytest.mark.parametrize("batched", (False, True))
def test_local_dot_to_mul(batched, a_shape, b_shape):
a = tensor("a", shape=a_shape)
b = tensor("b", shape=b_shape)

out = dot(a, b)
if batched:
batch_a = tensor("batch_a", shape=(1, 5, *a_shape))
batch_b = tensor("batch_b", shape=(7, 1, *b_shape))
out = vectorize_graph(out, {a: batch_a, b: batch_b})
a = batch_a
b = batch_b

assert (
sum(
isinstance(var.owner.op, (Blockwise | Dot))
for var in ancestors([out])
if var.owner
)
== 1
)

rewritten_out = rewrite_graph(out)
assert rewritten_out.type.shape == out.type.shape
assert not any(
isinstance(var.owner.op, (Blockwise | Dot))
for var in ancestors([rewritten_out])
if var.owner
)

a_test = np.random.normal(size=a.type.shape)
b_test = np.random.normal(size=b.type.shape)
test_mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
out.eval({a: a_test, b: b_test}, mode=test_mode),
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
)
4 changes: 2 additions & 2 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,9 +770,9 @@ def test_alloc_constant_folding(self):
self.allocs,
[
# IncSubtensor1
(some_matrix[:60], 2),
(some_matrix[:60], 1),
# AdvancedIncSubtensor1
(some_matrix[arange(60)], 2),
(some_matrix[arange(60)], 1),
# AdvancedIncSubtensor
(some_matrix[idx, idx], 1),
],
Expand Down
8 changes: 6 additions & 2 deletions tests/tensor/test_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,7 +1723,7 @@ class TestGer(unittest_tools.OptimizationTestMixin):

def setup_method(self):
self.mode = pytensor.compile.get_default_mode().including("fast_run")
self.mode = self.mode.excluding("c_blas", "scipy_blas")
self.mode = self.mode.excluding("c_blas", "scipy_blas", "local_dot_to_mul")
dtype = self.dtype = "float64" # optimization isn't dtype-dependent
self.A = tensor(dtype=dtype, shape=(None, None))
self.a = tensor(dtype=dtype, shape=())
Expand Down Expand Up @@ -1795,7 +1795,11 @@ def test_b_nonconst_does_not_triggers_ger(self):

def test_outer(self):
rng = np.random.default_rng(unittest_tools.fetch_seed())
f = self.function([self.x, self.y], outer(self.x, self.y))
f = self.function(
[self.x, self.y],
# Old outer used to be written like this
pt.dot(self.x[:, None], self.y[None, :]),
)
self.assertFunctionContains(f, self.ger_destructive)
f(
rng.random(5).astype(self.dtype),
Expand Down

0 comments on commit e0cb086

Please sign in to comment.