Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explained variance based scheduler #104

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/slicegpt/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,21 +438,22 @@ class SlicingConfig:
do_slice_head: bool = False
parallel_blocks: bool = False

# use dict[int, int] instead of list[int] to allow for arbitrary order updates and default dicts
# both sequential and parallel blocks case
embedding_dimensions: dict[int, int] = field(default_factory=dict)

attention_input_dimensions: dict[int, int] = field(default_factory=dict)
attention_output_dimensions: dict[int, int] = field(default_factory=dict)
mlp_output_dimensions: dict[int, int] = field(default_factory=dict)

# the 2nd path for the sequential blocks case
attention_output_dimensions: dict[int, int] = field(default_factory=dict)
mlp_input_dimensions: dict[int, int] = field(default_factory=dict)
mlp_output_dimensions: dict[int, int] = field(default_factory=dict)

head_dimension: int | None = None

const_dimension: int | None = None # to be able to load models without config, sliced with const sparsity
# used when loading models sliced with const sparsity that are missing a json config
const_dimension: int | None = None

@staticmethod
def from_dict(d: dict) -> 'SlicingConfig':
def from_dict(d: dict) -> SlicingConfig:
"""Return a SliceConfig object constructed from the provided dictionary."""

def convert_dict_keys_to_int(d: Any) -> Any:
Expand All @@ -470,7 +471,7 @@ def convert_dict_keys_to_int(d: Any) -> Any:
return SlicingConfig(**convert_dict_keys_to_int(d))

@staticmethod
def from_json_string(json_str: str) -> 'SlicingConfig':
def from_json_string(json_str: str) -> SlicingConfig:
"""Return a SliceConfig object constructed from the provided JSON string."""
return SlicingConfig.from_dict(json.loads(json_str))

Expand All @@ -485,6 +486,6 @@ def to_json_string(self) -> str:
"""Return a JSON representation of this object."""
return json.dumps(self.to_dict())

def clone(self) -> 'SlicingConfig':
def clone(self) -> SlicingConfig:
"""Return a clone of this object."""
return copy.deepcopy(self)
9 changes: 7 additions & 2 deletions src/slicegpt/rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def rotate_and_slice_sequential(

# rotate and slice embeddings
eig_val, Q = pca_calc(inps, ignore_masks)
slicing_scheduler.set_embedding_eigenvalues(eig_val.detach().cpu().tolist())
Q = Q.to(device=config.device)
if final_orientation == 'random':
R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_embedding_dimensions()[0])
Expand Down Expand Up @@ -193,6 +194,7 @@ def rotate_and_slice_sequential(

mlp_ln_inputs, _ = get_signals(layer_adapter, args, kwargs)
eig_val, Q = pca_calc(mlp_ln_inputs, ignore_masks)
slicing_scheduler.set_mlp_eigenvalues(idx, eig_val.detach().cpu().tolist())
Q = Q.to(device=config.device, dtype=torch.float64)
if final_orientation == 'random':
R = random_orthogonal_upper_left(
Expand Down Expand Up @@ -224,6 +226,7 @@ def rotate_and_slice_sequential(
# with slicing between Attention and mlp.
_, inps = get_signals(layer_adapter, args, kwargs)
eig_val, Q = pca_calc(inps, ignore_masks)
slicing_scheduler.set_attention_eigenvalues(idx, eig_val.detach().cpu().tolist())
if final_orientation == 'random':
R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_mlp_output_dimension(idx))
Q = Q @ R.to(Q.device)
Expand Down Expand Up @@ -279,7 +282,8 @@ def rotate_and_slice_parallel(
slicing_scheduler.setup(hidden_size=model_adapter.hidden_size, layers_num=len(layers), parallel_blocks=True)

# rotate and slice embeddings
_, Q = pca_calc(inps, ignore_masks)
eig_val, Q = pca_calc(inps, ignore_masks)
slicing_scheduler.set_embedding_eigenvalues(eig_val.detach().cpu().tolist())
Q = Q.to(device=config.device)
if final_orientation == 'random':
R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_embedding_dimensions()[0])
Expand Down Expand Up @@ -322,7 +326,8 @@ def rotate_and_slice_parallel(
outputs.append(out)

inps = outputs
_, Q = pca_calc(inps, ignore_masks)
eig_val, Q = pca_calc(inps, ignore_masks)
slicing_scheduler.set_attention_eigenvalues(idx, eig_val.detach().cpu().tolist())

if final_orientation == 'random':
R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_mlp_output_dimension(idx))
Expand Down
135 changes: 106 additions & 29 deletions src/slicegpt/slicing_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Callable, final

import numpy as np

from slicegpt.model_adapter import SlicingConfig


Expand All @@ -20,6 +27,11 @@ def __init__(self, *, do_slice_head: bool = False):
self.slicing_conf: SlicingConfig = SlicingConfig()
self.slicing_conf.do_slice_head = do_slice_head

# eigenvalues obtained from PCA
self.embedding_eigenvalues: list[float] = []
self.attention_eigenvalues: dict[int, list[float]] = {}
self.mlp_eigenvalues: dict[int, list[float]] = {}

@property
def do_slice_head(self) -> bool:
"""Return whether to slice the head."""
Expand Down Expand Up @@ -49,17 +61,23 @@ def setup(self, *, hidden_size: int, layers_num: int, parallel_blocks: bool) ->
@final
def get_embedding_dimensions(self) -> dict[int, int]:
"""Return the input embedding dimensions."""
val = self._get_input_embedding_dimensions()
if self.slicing_conf.embedding_dimensions:
return self.slicing_conf.embedding_dimensions

val = self._get_embedding_dimensions()
self.slicing_conf.embedding_dimensions = val
return val

@abstractmethod
def _get_input_embedding_dimensions(self) -> dict[int, int]:
def _get_embedding_dimensions(self) -> dict[int, int]:
raise NotImplementedError

@final
def get_attention_input_dimension(self, idx: int) -> int:
"""Return the attention input dimension for the specified layer index."""
if idx in self.slicing_conf.attention_input_dimensions:
return self.slicing_conf.attention_input_dimensions[idx]

val = self._get_attention_input_dimension(idx)
self.slicing_conf.attention_input_dimensions[idx] = val
return val
Expand All @@ -69,12 +87,30 @@ def _get_attention_input_dimension(self, idx: int) -> int:
raise NotImplementedError

@final
def get_attention_output_dimension(self, idx, match_head_dim: bool) -> int:
def get_mlp_output_dimension(self, idx: int) -> int:
"""Return the mlp output dimension for the specified layer index."""
if idx in self.slicing_conf.mlp_output_dimensions:
return self.slicing_conf.mlp_output_dimensions[idx]

use_head_dim = idx == self.layers_num - 1
val = self._get_mlp_output_dimension(idx) if not use_head_dim else self.get_head_dimension()
self.slicing_conf.mlp_output_dimensions[idx] = val
return val

@abstractmethod
def _get_mlp_output_dimension(self, idx: int) -> int:
raise NotImplementedError

@final
def get_attention_output_dimension(self, idx, match_head_dim: bool | None = None) -> int:
"""Return the attention output dimension for the specified layer index."""
if self.parallel_blocks:
return self.get_mlp_output_dimension(idx)

use_head_dim = match_head_dim and idx == self.layers_num - 1
if idx in self.slicing_conf.attention_output_dimensions:
return self.slicing_conf.attention_output_dimensions[idx]

use_head_dim = idx == self.layers_num - 1 and match_head_dim
val = self._get_attention_output_dimension(idx) if not use_head_dim else self.get_head_dimension()
self.slicing_conf.attention_output_dimensions[idx] = val
return val
Expand All @@ -89,6 +125,9 @@ def get_mlp_input_dimension(self, idx: int) -> int:
if self.parallel_blocks:
return self.get_attention_input_dimension(idx)

if idx in self.slicing_conf.mlp_input_dimensions:
return self.slicing_conf.mlp_input_dimensions[idx]

val = self._get_mlp_input_dimension(idx)
self.slicing_conf.mlp_input_dimensions[idx] = val
return val
Expand All @@ -97,21 +136,12 @@ def get_mlp_input_dimension(self, idx: int) -> int:
def _get_mlp_input_dimension(self, idx: int) -> int:
raise NotImplementedError

@final
def get_mlp_output_dimension(self, idx: int) -> int:
"""Return the mlp output dimension for the specified layer index."""
use_head_dim = idx == self.layers_num - 1
val = self._get_mlp_output_dimension(idx) if not use_head_dim else self.get_head_dimension()
self.slicing_conf.mlp_output_dimensions[idx] = val
return val

@abstractmethod
def _get_mlp_output_dimension(self, idx: int) -> int:
raise NotImplementedError

@final
def get_head_dimension(self) -> int:
"""Return the LM head dimension."""
if self.slicing_conf.head_dimension is not None:
return self.slicing_conf.head_dimension

val = self._get_head_dimension() if self.slicing_conf.do_slice_head else self.hidden_size
self.slicing_conf.head_dimension = val
return val
Expand All @@ -120,6 +150,18 @@ def get_head_dimension(self) -> int:
def _get_head_dimension(self) -> int:
raise NotImplementedError

def set_embedding_eigenvalues(self, eigenvalues: list[float]) -> None:
"""Set the eigenvalues of the embeddings PCA."""
self.embedding_eigenvalues = eigenvalues

def set_attention_eigenvalues(self, idx: int, eigenvalues: list[float]) -> None:
"""Set the eigenvalues of the attention layer PCA."""
self.attention_eigenvalues[idx] = eigenvalues

def set_mlp_eigenvalues(self, idx: int, eigenvalues: list[float]) -> None:
"""Set the eigenvalues of the MLP layer PCA."""
self.mlp_eigenvalues[idx] = eigenvalues


class ConfigSlicingScheduler(SlicingScheduler):
"""Slicing scheduler that returns the dimensions specified in the config."""
Expand All @@ -128,21 +170,21 @@ def __init__(self, config: SlicingConfig):
super().__init__()
self.slicing_conf = config

def _get_input_embedding_dimensions(self) -> dict[int, int]:
def _get_embedding_dimensions(self) -> dict[int, int]:
return self.slicing_conf.embedding_dimensions

def _get_attention_input_dimension(self, idx: int) -> int:
return self.slicing_conf.attention_input_dimensions[idx]

def _get_mlp_output_dimension(self, idx: int) -> int:
return self.slicing_conf.mlp_output_dimensions[idx]

def _get_attention_output_dimension(self, idx: int) -> int:
return self.slicing_conf.attention_output_dimensions[idx]

def _get_mlp_input_dimension(self, idx: int) -> int:
return self.slicing_conf.mlp_input_dimensions[idx]

def _get_mlp_output_dimension(self, idx: int) -> int:
return self.slicing_conf.mlp_output_dimensions[idx]

def _get_head_dimension(self) -> int:
return self.slicing_conf.head_dimension

Expand All @@ -154,19 +196,19 @@ def __init__(self, dimension: int, *, do_slice_head: bool = False):
super().__init__(do_slice_head=do_slice_head)
self.dimension: int = dimension

def _get_input_embedding_dimensions(self) -> dict[int, int]:
def _get_embedding_dimensions(self) -> dict[int, int]:
return defaultdict(lambda: self.dimension)

def _get_attention_input_dimension(self, idx: int) -> int:
return self.dimension

def _get_attention_output_dimension(self, idx: int) -> int:
def _get_mlp_output_dimension(self, idx: int) -> int:
return self.dimension

def _get_mlp_input_dimension(self, idx: int) -> int:
def _get_attention_output_dimension(self, idx: int) -> int:
return self.dimension

def _get_mlp_output_dimension(self, idx: int) -> int:
def _get_mlp_input_dimension(self, idx: int) -> int:
return self.dimension

def _get_head_dimension(self) -> int:
Expand All @@ -186,13 +228,13 @@ def __init__(self, *, do_slice_head: bool = False):
def _get_attention_input_dimension(self, idx: int) -> int:
# return the input embedding dimension when at the first attn layer inputs
if idx == 0:
return self._get_input_embedding_dimensions()[0] # all dimensions are the same there
return self.get_embedding_dimensions()[0] # all dimensions are the same there

return self._get_mlp_output_dimension(idx - 1)
return self.get_mlp_output_dimension(idx - 1)

@final
def _get_mlp_input_dimension(self, idx: int) -> int:
return self._get_attention_output_dimension(idx)
return self.get_attention_output_dimension(idx)


class FunctionSlicingScheduler(ForwardSlicingScheduler):
Expand Down Expand Up @@ -222,7 +264,7 @@ def _get_layer_dimension(self, idx: int, is_attn_layer: bool = False) -> int:
val -= val % self.round_interval
return val

def _get_input_embedding_dimensions(self) -> dict[int, int]:
def _get_embedding_dimensions(self) -> dict[int, int]:
return defaultdict(lambda: self._get_layer_dimension(0))

def _get_attention_output_dimension(self, idx: int) -> int:
Expand All @@ -242,7 +284,7 @@ def create_linear(
attn_end: float | None = None,
round_interval: int = 1,
do_slice_head: bool = False,
) -> 'FunctionSlicingScheduler':
) -> FunctionSlicingScheduler:
"""Create a linear slicing scheduler, mainly as an example for testing."""

def linear(start: float, end: float) -> Callable[[float], float]:
Expand All @@ -259,3 +301,38 @@ def linear_sparsity_func(location: float) -> float:
round_interval=round_interval,
do_slice_head=do_slice_head,
)


class ExplainedVarianceSlicingScheduler(ForwardSlicingScheduler):
"""A slicing scheduler that applies sparsity based on the explained variance from the PCA."""

def __init__(
self,
*,
uev_threshold: float,
round_interval: int = 1,
do_slice_head: bool = False,
):
super().__init__(do_slice_head=do_slice_head)
self.uev_threshold: float = uev_threshold
self.round_interval: int = round_interval

def _get_layer_dimension(self, eigen_vals: list[float], plot: bool = False) -> int:
eigen_vals = np.array(eigen_vals)
cum_var = np.cumsum(np.array(eigen_vals)) / np.sum(eigen_vals)
dim = np.argmax(cum_var > 1 - self.uev_threshold)
dim -= dim % self.round_interval
dim = int(dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth having the round_interval parameter passed to slicing schedulers in general, to ensure they always return dims that are rounded accordingly? I guess the constant and function schedulers could ask for constants/functions that take the rounding interval into account already, but what about this scheduler?

return dim

def _get_embedding_dimensions(self) -> dict[int, int]:
return defaultdict(lambda: self._get_layer_dimension(self.embedding_eigenvalues))

def _get_attention_output_dimension(self, idx: int) -> int:
return self._get_layer_dimension(self.mlp_eigenvalues[idx])

def _get_mlp_output_dimension(self, idx: int) -> int:
return self._get_layer_dimension(self.attention_eigenvalues[idx])

def _get_head_dimension(self) -> int:
return self.get_attention_output_dimension(self.layers_num - 1)