Skip to content

Commit

Permalink
attribute speed and memory optimizations (#245)
Browse files Browse the repository at this point in the history
* Speed and memory optimizations

* fix sequence_scores remap from filtered

* Refactored get_sequences_from_batched_steps

* Add check for stack dimension

* Bump dev version, update tutorial

* Bump ruff style to py39
  • Loading branch information
gsarti authored Jan 17, 2024
1 parent 7503576 commit f434192
Show file tree
Hide file tree
Showing 45 changed files with 1,518 additions and 1,413 deletions.
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
author = "The Inseq Team"

# The short X.Y version
version = "0.5"
version = "0.6"
# The full version, including alpha/beta/rc tags
release = "0.5.0"
release = "0.6.0.dev0"


# Prefix link to point to master, comment this during version release and uncomment below line
Expand Down
8 changes: 3 additions & 5 deletions examples/inseq_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
"source": [
"%%capture\n",
"# Run in Colab to install local packages\n",
"%pip install bitsandbytes accelerate\n",
"%pip install git+https://github.com/huggingface/transformers.git\n",
"%pip install git+https://github.com/inseq-team/inseq.git"
"%pip install bitsandbytes accelerate transformers inseq"
]
},
{
Expand All @@ -28,7 +26,7 @@
"source": [
"This tutorial showcases how to use the [inseq](https://github.com/inseq-team/inseq) library for various interpretability analyses, with a focus on advanced use cases such as aggregation and contrastive attribution. The tutorial was adapted from the [LCL'23 tutorial on Interpretability for NLP](https://github.com/gsarti/lcl23-xnlm-lab) with the goal of updating it whenever new functionalities or breaking changes are introduced.\n",
"\n",
"⚠️ **IMPORTANT** ⚠️ : `inseq` is a very new library and under active development. Current results were obtained using the latest development versions on June 30, 2023. If you find any issue, or you are not able to reproduce the results shown here, we'd love if you could open an issue on [the inseq Github repository](https://github.com/inseq-team/inseq) so that we could update the tutorial accordingly! 🙂\n",
"⚠️ **IMPORTANT** ⚠️ : `inseq` is a very new library and under active development. Current results were obtained using the latest inseq release. If you find any issue, or you are not able to reproduce the results shown here, we'd love if you could open an issue on [the inseq Github repository](https://github.com/inseq-team/inseq) so that we could update the tutorial accordingly! 🙂\n",
"\n",
"# Introduction: Feature Attribution for NLP\n",
"\n",
Expand Down Expand Up @@ -86,7 +84,7 @@
"\n",
"[Inseq](https://github.com/inseq-team/inseq) ([Sarti et al., 2023](https://arxiv.org/abs/2302.13942)) is a toolkit based on the [🤗 Transformers](https//github.com/huggingface/transformers) and [Captum](https://captum.ai/docs/introduction) libraries for intepreting language generation models using feature attribution methods. Inseq allows you to analyze the behavior of a language generation model by computing the importance of each input token for each token in the generated output using the various categories of attribution methods like those described in the previous section (use `inseq.list_feature_attribution_methods()` to list all available methods). You can refer to the [Inseq paper](https://arxiv.org/abs/2302.13942) for an overview of the tool and some usage examples.\n",
"\n",
"The current version of the library (v0.5.0, December 2023) supports all [`AutoModelForSeq2SeqLM`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForSeq2SeqLM) (among others, [T5](https://huggingface.co/docs/transformers/model_doc/t5), [Bart](https://huggingface.co/docs/transformers/model_doc/bart) and all >1000 [MarianNMT](https://huggingface.co/docs/transformers/model_doc/marian) MT models) and [AutoModelForCausalLM](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM) (among others, [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2), [Bloom](https://huggingface.co/docs/transformers/model_doc/bloom) and [LLaMa](https://huggingface.co/docs/transformers/model_doc/llama)), including advanced tools for custom attribution targets and post-processing of attribution matrices.\n",
"Inseq supports all [`AutoModelForSeq2SeqLM`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForSeq2SeqLM) (among others, [T5](https://huggingface.co/docs/transformers/model_doc/t5), [Bart](https://huggingface.co/docs/transformers/model_doc/bart) and all >1000 [MarianNMT](https://huggingface.co/docs/transformers/model_doc/marian) MT models) and [AutoModelForCausalLM](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM) (among others, [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2), [Bloom](https://huggingface.co/docs/transformers/model_doc/bloom) and [LLaMa](https://huggingface.co/docs/transformers/model_doc/llama)), including advanced tools for custom attribution targets and post-processing of attribution matrices.\n",
"\n",
"The following code is a \"Hello world\" equivalent in Inseq: in three lines of code, an English-to-Italian machine translation model is loaded alongside an attribution method, attribution is performed, and results are visualized:"
]
Expand Down
5 changes: 3 additions & 2 deletions inseq/attr/attribution_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
"""Decorators for attribution methods."""

import logging
from collections.abc import Sequence
from functools import wraps
from typing import Any, Callable, List, Optional, Sequence
from typing import Any, Callable, Optional

from ..data.data_utils import TensorWrapper

Expand Down Expand Up @@ -55,7 +56,7 @@ def batched(f: Callable[..., Any]) -> Callable[..., Any]:

@wraps(f)
def batched_wrapper(self, *args, batch_size: Optional[int] = None, **kwargs):
def get_batched(bs: Optional[int], seq: Sequence[Any]) -> List[List[Any]]:
def get_batched(bs: Optional[int], seq: Sequence[Any]) -> list[list[Any]]:
if isinstance(seq, str):
seq = [seq]
if isinstance(seq, list):
Expand Down
35 changes: 15 additions & 20 deletions inseq/attr/feat/attribution_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import logging
import math
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from ...utils import extract_signature_args, get_aligned_idx
from ...utils.typing import (
OneOrMoreAttributionSequences,
OneOrMoreIdSequences,
OneOrMoreTokenSequences,
SingleScorePerStepTensor,
StepAttributionTensor,
TextInput,
TokenWithId,
)
Expand Down Expand Up @@ -51,7 +50,7 @@ def check_attribute_positions(
max_length: int,
attr_pos_start: Optional[int] = None,
attr_pos_end: Optional[int] = None,
) -> Tuple[int, int]:
) -> tuple[int, int]:
r"""Checks whether the combination of start/end positions for attribution is valid.
Args:
Expand Down Expand Up @@ -90,8 +89,8 @@ def join_token_ids(
tokens: OneOrMoreTokenSequences,
ids: OneOrMoreIdSequences,
contrast_tokens: Optional[OneOrMoreTokenSequences] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
) -> List[TokenWithId]:
contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None,
) -> list[TokenWithId]:
"""Joins tokens and ids into a list of TokenWithId objects."""
if contrast_tokens is None:
contrast_tokens = tokens
Expand All @@ -116,10 +115,10 @@ def join_token_ids(
def extract_args(
attribution_method: "FeatureAttribution",
attributed_fn: Callable[..., SingleScorePerStepTensor],
step_scores: List[str],
default_args: List[str],
step_scores: list[str],
default_args: list[str],
**kwargs,
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
attribution_args = kwargs.pop("attribution_args", {})
attributed_fn_args = kwargs.pop("attributed_fn_args", {})
step_scores_args = kwargs.pop("step_scores_args", {})
Expand All @@ -143,17 +142,13 @@ def extract_args(


def get_source_target_attributions(
attr: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]],
is_encoder_decoder: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
if is_encoder_decoder:
if isinstance(attr, tuple) and len(attr) > 1:
return attr[0], attr[1]
elif isinstance(attr, tuple) and len(attr) == 1:
return attr[0], None
) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]:
if isinstance(attr, tuple):
if is_encoder_decoder:
return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None)
else:
return attr, None
elif isinstance(attr, tuple):
return None, attr[0]
return (None, attr[0])
else:
return None, attr
return (attr, None) if is_encoder_decoder else (None, attr)
46 changes: 23 additions & 23 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""
import logging
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import torch
from jaxtyping import Int
Expand Down Expand Up @@ -173,12 +173,12 @@ def prepare_and_attribute(
pretty_progress: bool = True,
output_step_attributions: bool = False,
attribute_target: bool = False,
step_scores: List[str] = [],
step_scores: list[str] = [],
include_eos_baseline: bool = False,
attributed_fn: Union[str, Callable[..., SingleScorePerStepTensor], None] = None,
attribution_args: Dict[str, Any] = {},
attributed_fn_args: Dict[str, Any] = {},
step_scores_args: Dict[str, Any] = {},
attribution_args: dict[str, Any] = {},
attributed_fn_args: dict[str, Any] = {},
step_scores_args: dict[str, Any] = {},
) -> FeatureAttributionOutput:
r"""Prepares inputs and performs attribution.
Expand Down Expand Up @@ -276,11 +276,11 @@ def format_contrastive_targets(
self,
target_sequences: TextSequences,
target_tokens: OneOrMoreTokenSequences,
attributed_fn_args: Dict[str, Any],
step_scores_args: Dict[str, Any],
attributed_fn_args: dict[str, Any],
step_scores_args: dict[str, Any],
attr_pos_start: int,
attr_pos_end: int,
) -> Tuple[Optional[DecoderOnlyBatch], Optional[List[List[Tuple[int, int]]]], Dict[str, Any], Dict[str, Any]]:
) -> tuple[Optional[DecoderOnlyBatch], Optional[list[list[tuple[int, int]]]], dict[str, Any], dict[str, Any]]:
contrast_batch, contrast_targets_alignments = None, None
contrast_targets = attributed_fn_args.get("contrast_targets", None)
if contrast_targets is None:
Expand Down Expand Up @@ -327,10 +327,10 @@ def attribute(
pretty_progress: bool = True,
output_step_attributions: bool = False,
attribute_target: bool = False,
step_scores: List[str] = [],
attribution_args: Dict[str, Any] = {},
attributed_fn_args: Dict[str, Any] = {},
step_scores_args: Dict[str, Any] = {},
step_scores: list[str] = [],
attribution_args: dict[str, Any] = {},
attributed_fn_args: dict[str, Any] = {},
step_scores_args: dict[str, Any] = {},
) -> FeatureAttributionOutput:
r"""Performs the feature attribution procedure using the specified attribution method.
Expand Down Expand Up @@ -501,10 +501,10 @@ def filtered_attribute_step(
attributed_fn: Callable[..., SingleScorePerStepTensor],
target_attention_mask: Optional[Int[torch.Tensor, "batch_size 1"]] = None,
attribute_target: bool = False,
step_scores: List[str] = [],
attribution_args: Dict[str, Any] = {},
attributed_fn_args: Dict[str, Any] = {},
step_scores_args: Dict[str, Any] = {},
step_scores: list[str] = [],
attribution_args: dict[str, Any] = {},
attributed_fn_args: dict[str, Any] = {},
step_scores_args: dict[str, Any] = {},
) -> FeatureAttributionStepOutput:
r"""Performs a single attribution step for all the sequences in the batch that
still have valid target_ids, as identified by the target_attention_mask.
Expand Down Expand Up @@ -581,31 +581,31 @@ def filtered_attribute_step(
attribution_args,
)
# Format step scores arguments and calculate step scores
for step_score in step_scores:
for score in step_scores:
step_fn_args = self.attribution_model.formatter.format_step_function_args(
attribution_model=self.attribution_model,
forward_output=output,
target_ids=target_ids,
is_attributed_fn=False,
batch=batch,
)
step_fn_extra_args = get_step_scores_args([step_score], step_scores_args)
step_output.step_scores[step_score] = get_step_scores(step_score, step_fn_args, step_fn_extra_args)
step_fn_extra_args = get_step_scores_args([score], step_scores_args)
step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args).to("cpu")
# Reinsert finished sentences
if target_attention_mask is not None and is_filtered:
step_output.remap_from_filtered(target_attention_mask, orig_batch)
step_output = step_output.detach().to("cpu")
return step_output

def get_attribution_args(self, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def get_attribution_args(self, **kwargs) -> tuple[dict[str, Any], dict[str, Any]]:
if hasattr(self, "method") and hasattr(self.method, "attribute"):
return extract_signature_args(kwargs, self.method.attribute, self.ignore_extra_args, return_remaining=True)
return {}, {}

def attribute_step(
self,
attribute_fn_main_args: Dict[str, Any],
attribution_args: Dict[str, Any] = {},
attribute_fn_main_args: dict[str, Any],
attribution_args: dict[str, Any] = {},
) -> FeatureAttributionStepOutput:
r"""Performs a single attribution step for the specified attribution arguments.
Expand Down Expand Up @@ -663,7 +663,7 @@ class DummyAttribution(FeatureAttribution):
method_name = "dummy"

def attribute_step(
self, attribute_fn_main_args: Dict[str, Any], attribution_args: Dict[str, Any] = {}
self, attribute_fn_main_args: dict[str, Any], attribution_args: dict[str, Any] = {}
) -> FeatureAttributionStepOutput:
return FeatureAttributionStepOutput(
source_attributions=None,
Expand Down
12 changes: 6 additions & 6 deletions inseq/attr/feat/gradient_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Gradient-based feature attribution methods."""

import logging
from typing import Any, Dict
from typing import Any

from captum.attr import (
DeepLift,
Expand Down Expand Up @@ -65,8 +65,8 @@ def unhook(self, **kwargs):

def attribute_step(
self,
attribute_fn_main_args: Dict[str, Any],
attribution_args: Dict[str, Any] = {},
attribute_fn_main_args: dict[str, Any],
attribution_args: dict[str, Any] = {},
) -> GranularFeatureAttributionStepOutput:
r"""Performs a single attribution step for the specified attribution arguments.
Expand Down Expand Up @@ -95,9 +95,9 @@ def attribute_step(
attr, self.attribution_model.is_encoder_decoder
)
return GranularFeatureAttributionStepOutput(
source_attributions=source_attributions,
target_attributions=target_attributions,
step_scores={"deltas": deltas} if deltas is not None else None,
source_attributions=source_attributions.to("cpu") if source_attributions is not None else None,
target_attributions=target_attributions.to("cpu") if target_attributions is not None else None,
step_scores={"deltas": deltas.to("cpu")} if deltas is not None else None,
)


Expand Down
14 changes: 8 additions & 6 deletions inseq/attr/feat/internals_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Attention-based feature attribution methods."""

import logging
from typing import Any, Dict, Optional
from typing import Any, Optional

from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import Attribution
Expand Down Expand Up @@ -76,17 +76,19 @@ def attribute(
"""
# We adopt the format [batch_size, sequence_length, num_layers, num_heads]
# for consistency with other multi-unit methods (e.g. gradient attribution)
decoder_self_attentions = decoder_self_attentions[..., -1, :].clone().permute(0, 3, 1, 2)
decoder_self_attentions = decoder_self_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2)
if self.forward_func.is_encoder_decoder:
sequence_scores = {}
if len(inputs) > 1:
target_attributions = decoder_self_attentions
else:
target_attributions = None
sequence_scores["decoder_self_attentions"] = decoder_self_attentions
sequence_scores["encoder_self_attentions"] = encoder_self_attentions.clone().permute(0, 3, 4, 1, 2)
sequence_scores["encoder_self_attentions"] = (
encoder_self_attentions.to("cpu").clone().permute(0, 3, 4, 1, 2)
)
return MultiDimensionalFeatureAttributionStepOutput(
source_attributions=cross_attentions[..., -1, :].clone().permute(0, 3, 1, 2),
source_attributions=cross_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2),
target_attributions=target_attributions,
sequence_scores=sequence_scores,
_num_dimensions=2, # num_layers, num_heads
Expand All @@ -108,7 +110,7 @@ def __init__(self, attribution_model, **kwargs):

def attribute_step(
self,
attribute_fn_main_args: Dict[str, Any],
attribution_args: Dict[str, Any],
attribute_fn_main_args: dict[str, Any],
attribution_args: dict[str, Any],
) -> MultiDimensionalFeatureAttributionStepOutput:
return self.method.attribute(**attribute_fn_main_args, **attribution_args)
12 changes: 6 additions & 6 deletions inseq/attr/feat/ops/discretized_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

from pathlib import Path
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -51,7 +51,7 @@ def load_monotonic_path_builder(
self,
model_name: str,
vocabulary_embeddings: VocabularyEmbeddingsTensor,
special_tokens: List[int],
special_tokens: list[int],
cache_dir: Path = INSEQ_ARTIFACTS_CACHE / "dig_knn",
embedding_scaling: int = 1,
**kwargs,
Expand All @@ -67,7 +67,7 @@ def load_monotonic_path_builder(
)

@staticmethod
def get_inputs_baselines(scaled_features_tpl: Tuple[Tensor, ...], n_steps: int) -> Tuple[Tensor, ...]:
def get_inputs_baselines(scaled_features_tpl: tuple[Tensor, ...], n_steps: int) -> tuple[Tensor, ...]:
# Baseline and inputs are reversed in the path builder
# For every element in the batch, the first embedding of the sub-tensor
# of shape (n_steps x embedding_dim) is the baseline, the last is the input.
Expand Down Expand Up @@ -96,7 +96,7 @@ def attribute( # type: ignore
method: str = "greedy",
internal_batch_size: Union[None, int] = None,
return_convergence_delta: bool = False,
) -> Union[TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]]:
) -> Union[TensorOrTupleOfTensorsGeneric, tuple[TensorOrTupleOfTensorsGeneric, Tensor]]:
n_examples = inputs[0].shape[0]
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
Expand Down Expand Up @@ -146,11 +146,11 @@ def attribute( # type: ignore

def _attribute(
self,
scaled_features_tpl: Tuple[Tensor, ...],
scaled_features_tpl: tuple[Tensor, ...],
target: TargetType = None,
additional_forward_args: Any = None,
n_steps: int = 50,
) -> Tuple[Tensor, ...]:
) -> tuple[Tensor, ...]:
additional_forward_args = _format_additional_forward_args(additional_forward_args)
input_additional_args = (
_expand_additional_forward_args(additional_forward_args, n_steps)
Expand Down
Loading

0 comments on commit f434192

Please sign in to comment.