Skip to content

Commit

Permalink
add ellipsis to einmix
Browse files Browse the repository at this point in the history
  • Loading branch information
arogozhnikov committed Dec 20, 2024
1 parent 731d17d commit a230557
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 19 deletions.
89 changes: 70 additions & 19 deletions einops/layers/_einmix.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, List, Optional, Dict

from einops import EinopsError
from einops.parsing import ParsedExpression
from einops.parsing import ParsedExpression, _ellipsis
import warnings
import string
from ..einops import _product
Expand Down Expand Up @@ -71,9 +71,13 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona
set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}),
"Unrecognized identifiers on the right side of EinMix {}",
)

if left.has_ellipsis or right.has_ellipsis or weight.has_ellipsis:
raise EinopsError("Ellipsis is not supported in EinMix (right now)")
if weight.has_ellipsis:
raise EinopsError("Ellipsis is not supported in weight, as its shape should be fully specified")
if left.has_ellipsis or right.has_ellipsis:
if not (left.has_ellipsis and right.has_ellipsis):
raise EinopsError(f"Ellipsis in EinMix should be on both sides, {pattern}")
if left.has_ellipsis_parenthesized:
raise EinopsError(f"Ellipsis on left side can't be in parenthesis, got {pattern}")
if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]):
raise EinopsError("Anonymous axes (numbers) are not allowed in EinMix")
if "(" in weight_shape or ")" in weight_shape:
Expand All @@ -86,16 +90,18 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona
names: List[str] = []
for group in left.composition:
names += group
names = [name if name != _ellipsis else "..." for name in names]
composition = " ".join(names)
pre_reshape_pattern = f"{left_pattern}->{composition}"
pre_reshape_pattern = f"{left_pattern}-> {composition}"
pre_reshape_lengths = {name: length for name, length in axes_lengths.items() if name in names}

if any(len(group) != 1 for group in right.composition):
if any(len(group) != 1 for group in right.composition) or right.has_ellipsis_parenthesized:
names = []
for group in right.composition:
names += group
names = [name if name != _ellipsis else "..." for name in names]
composition = " ".join(names)
post_reshape_pattern = f"{composition}->{right_pattern}"
post_reshape_pattern = f"{composition} ->{right_pattern}"

self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {})

Expand All @@ -116,22 +122,36 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona
# single output element is a combination of fan_in input elements
_fan_in = _product([axes_lengths[axis] for (axis,) in weight.composition if axis not in right.identifiers])
if bias_shape is not None:
# maybe I should put ellipsis in the beginning for simplicity?
if not isinstance(bias_shape, str):
raise EinopsError("bias shape should be string specifying which axes bias depends on")
bias = ParsedExpression(bias_shape)
_report_axes(set.difference(bias.identifiers, right.identifiers), "Bias axes {} not present in output")
_report_axes(
set.difference(bias.identifiers, right.identifiers),
"Bias axes {} not present in output",
)
_report_axes(
set.difference(bias.identifiers, set(axes_lengths)),
"Sizes not provided for bias axes {}",
)

_bias_shape = []
used_non_trivial_size = False
for axes in right.composition:
for axis in axes:
if axis in bias.identifiers:
_bias_shape.append(axes_lengths[axis])
else:
_bias_shape.append(1)
if axes == _ellipsis:
if used_non_trivial_size:
raise EinopsError("all bias dimensions should go after ellipsis in the output")
else:
# handles ellipsis correctly
for axis in axes:
if axis == _ellipsis:
if used_non_trivial_size:
raise EinopsError("all bias dimensions should go after ellipsis in the output")
elif axis in bias.identifiers:
_bias_shape.append(axes_lengths[axis])
used_non_trivial_size = True
else:
_bias_shape.append(1)
else:
_bias_shape = None

Expand All @@ -142,15 +162,26 @@ def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optiona
# rewrite einsum expression with single-letter latin identifiers so that
# expression will be understood by any framework
mapped_identifiers = {*left.identifiers, *right.identifiers, *weight.identifiers}
if _ellipsis in mapped_identifiers:
mapped_identifiers.remove(_ellipsis)
mapped_identifiers = list(sorted(mapped_identifiers))
mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapped_identifiers)}

def write_flat(axes: list):
return "".join(mapping2letters[axis] for axis in axes)
mapping2letters[_ellipsis] = "..." # preserve ellipsis

def write_flat_remapped(axes: ParsedExpression):
result = []
for composed_axis in axes.composition:
if isinstance(composed_axis, list):
result.extend([mapping2letters[axis] for axis in composed_axis])
else:
assert composed_axis == _ellipsis
result.append("...")
return "".join(result)

self.einsum_pattern: str = "{},{}->{}".format(
write_flat(left.flat_axes_order()),
write_flat(weight.flat_axes_order()),
write_flat(right.flat_axes_order()),
write_flat_remapped(left),
write_flat_remapped(weight),
write_flat_remapped(right),
)

def _create_rearrange_layers(
Expand All @@ -174,3 +205,23 @@ def __repr__(self):
for axis, length in self.axes_lengths.items():
params += ", {}={}".format(axis, length)
return "{}({})".format(self.__class__.__name__, params)


class _EinmixDebugger(_EinmixMixin):
"""Used only to test mixin"""

def _create_rearrange_layers(
self,
pre_reshape_pattern: Optional[str],
pre_reshape_lengths: Optional[Dict],
post_reshape_pattern: Optional[str],
post_reshape_lengths: Optional[Dict],
):
self.pre_reshape_pattern = pre_reshape_pattern
self.pre_reshape_lengths = pre_reshape_lengths
self.post_reshape_pattern = post_reshape_pattern
self.post_reshape_lengths = post_reshape_lengths

def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
self.saved_weight_shape = weight_shape
self.saved_bias_shape = bias_shape
90 changes: 90 additions & 0 deletions einops/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,93 @@ def eval_at_point(params):
# check serialization
fbytes = flax.serialization.to_bytes(params)
_loaded = flax.serialization.from_bytes(params, fbytes)


def test_einmix_decomposition():
"""
Testing that einmix correctly decomposes into smaller transformations.
"""
from einops.layers._einmix import _EinmixDebugger

mixin1 = _EinmixDebugger(
"a b c d e -> e d c b a",
weight_shape="d a b",
d=2, a=3, b=5,
) # fmt: off
assert mixin1.pre_reshape_pattern is None
assert mixin1.post_reshape_pattern is None
assert mixin1.einsum_pattern == "abcde,dab->edcba"
assert mixin1.saved_weight_shape == [2, 3, 5]
assert mixin1.saved_bias_shape is None

mixin2 = _EinmixDebugger(
"a b c d e -> e d c b a",
weight_shape="d a b",
bias_shape="a b c d e",
a=1, b=2, c=3, d=4, e=5,
) # fmt: off
assert mixin2.pre_reshape_pattern is None
assert mixin2.post_reshape_pattern is None
assert mixin2.einsum_pattern == "abcde,dab->edcba"
assert mixin2.saved_weight_shape == [4, 1, 2]
assert mixin2.saved_bias_shape == [5, 4, 3, 2, 1]

mixin3 = _EinmixDebugger(
"... -> ...",
weight_shape="",
bias_shape="",
) # fmt: off
assert mixin3.pre_reshape_pattern is None
assert mixin3.post_reshape_pattern is None
assert mixin3.einsum_pattern == "...,->..."
assert mixin3.saved_weight_shape == []
assert mixin3.saved_bias_shape == []

mixin4 = _EinmixDebugger(
"b a ... -> b c ...",
weight_shape="b a c",
a=1, b=2, c=3,
) # fmt: off
assert mixin4.pre_reshape_pattern is None
assert mixin4.post_reshape_pattern is None
assert mixin4.einsum_pattern == "ba...,bac->bc..."
assert mixin4.saved_weight_shape == [2, 1, 3]
assert mixin4.saved_bias_shape is None

mixin5 = _EinmixDebugger(
"(b a) ... -> b c (...)",
weight_shape="b a c",
a=1, b=2, c=3,
) # fmt: off
assert mixin5.pre_reshape_pattern == "(b a) ... -> b a ..."
assert mixin5.pre_reshape_lengths == dict(a=1, b=2)
assert mixin5.post_reshape_pattern == "b c ... -> b c (...)"
assert mixin5.einsum_pattern == "ba...,bac->bc..."
assert mixin5.saved_weight_shape == [2, 1, 3]
assert mixin5.saved_bias_shape is None

mixin6 = _EinmixDebugger(
"b ... (a c) -> b ... (a d)",
weight_shape="c d",
bias_shape="a d",
a=1, c=3, d=4,
) # fmt: off
assert mixin6.pre_reshape_pattern == "b ... (a c) -> b ... a c"
assert mixin6.pre_reshape_lengths == dict(a=1, c=3)
assert mixin6.post_reshape_pattern == "b ... a d -> b ... (a d)"
assert mixin6.einsum_pattern == "b...ac,cd->b...ad"
assert mixin6.saved_weight_shape == [3, 4]
assert mixin6.saved_bias_shape == [1, 1, 4] # (b) a d, ellipsis does not participate

mixin7 = _EinmixDebugger(
"a ... (b c) -> a (... d b)",
weight_shape="c d b",
bias_shape="d b",
b=2, c=3, d=4,
) # fmt: off
assert mixin7.pre_reshape_pattern == "a ... (b c) -> a ... b c"
assert mixin7.pre_reshape_lengths == dict(b=2, c=3)
assert mixin7.post_reshape_pattern == "a ... d b -> a (... d b)"
assert mixin7.einsum_pattern == "a...bc,cdb->a...db"
assert mixin7.saved_weight_shape == [3, 4, 2]
assert mixin7.saved_bias_shape == [1, 4, 2] # (a) d b, ellipsis does not participate

0 comments on commit a230557

Please sign in to comment.