From d09e827df35690b6e1e30731c5601864552face2 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Wed, 28 Feb 2024 17:15:44 +0100 Subject: [PATCH] Value Zeroing attribution method (#173) --- CHANGELOG.md | 5 +- README.md | 2 + .../main_classes/feature_attribution.rst | 19 +- inseq/attr/feat/__init__.py | 4 + inseq/attr/feat/attribution_utils.py | 8 +- inseq/attr/feat/feature_attribution.py | 49 ++- inseq/attr/feat/internals_attribution.py | 22 +- inseq/attr/feat/ops/__init__.py | 2 + inseq/attr/feat/ops/value_zeroing.py | 394 ++++++++++++++++++ inseq/attr/feat/perturbation_attribution.py | 101 ++++- inseq/attr/step_functions.py | 3 +- inseq/commands/commands_utils.py | 1 - inseq/data/aggregator.py | 53 +-- inseq/data/attribution.py | 106 +++-- inseq/data/data_utils.py | 6 +- inseq/models/attribution_model.py | 42 +- inseq/models/model_config.py | 20 +- inseq/models/model_config.yaml | 111 ++++- inseq/utils/__init__.py | 10 + inseq/utils/hooks.py | 110 +++++ inseq/utils/misc.py | 10 +- inseq/utils/torch_utils.py | 118 ++++++ inseq/utils/typing.py | 36 +- requirements-dev.txt | 2 +- requirements.txt | 2 +- tests/attr/feat/test_feature_attribution.py | 127 +++++- tests/data/test_aggregator.py | 34 +- tests/fixtures/aggregator.json | 2 + tests/inference_commons.py | 8 + tests/models/test_huggingface_model.py | 19 +- 30 files changed, 1252 insertions(+), 174 deletions(-) create mode 100644 inseq/attr/feat/ops/value_zeroing.py create mode 100644 inseq/utils/hooks.py diff --git a/CHANGELOG.md b/CHANGELOG.md index faeca60e..3fade279 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ - Support for multi-GPU attribution ([#238](https://github.com/inseq-team/inseq/pull/238)) - Added `inseq attribute-context` CLI command to support the [PECoRe framework] for detecting and attributing context reliance in generative LMs ([#237](https://github.com/inseq-team/inseq/pull/237)) +- Added `value_zeroing` (`inseq.attr.feat.perturbation_attribution.ValueZeroingAttribution`) attribution method ([#173](https://github.com/inseq-team/inseq/pull/173)) +- `value_zeroing` and `attention` use scores from the last generation step to produce outputs more efficiently (`is_final_step_method = True`) ([#173](https://github.com/inseq-team/inseq/pull/173)). ## šŸ”§ Fixes & Refactoring @@ -26,4 +28,5 @@ ## šŸ’„ Breaking Changes -*No changes* +- If `attention` is used as attribution method in `model.attribute`, `step_scores` cannot be extracted at the same time since the method does not require iterating over the full sequence anymore. ([#173](https://github.com/inseq-team/inseq/pull/173)) As an alternative, step scores can be extracted separately using the `dummy` attribution method (i.e. no attribution). +- BOS is always included in target-side attribution and generated sequences if present. ([#173](https://github.com/inseq-team/inseq/pull/173)) diff --git a/README.md b/README.md index 79aa0563..5e09d734 100644 --- a/README.md +++ b/README.md @@ -147,6 +147,8 @@ Use the `inseq.list_feature_attribution_methods` function to list all available - `lime`: ["Why Should I Trust You?": Explaining the Predictions of Any Classifier](https://arxiv.org/abs/1602.04938) (Ribeiro et al., 2016) +- `value_zeroing`: [Quantifying Context Mixing in Transformers](https://aclanthology.org/2023.eacl-main.245/) (Mohebbi et al. 2023) + #### Step functions Step functions are used to extract custom scores from the model at each step of the attribution process with the `step_scores` argument in `model.attribute`. They can also be used as targets for attribution methods relying on model outputs (e.g. gradient-based methods) by passing them as the `attributed_fn` argument. The following step functions are currently supported: diff --git a/docs/source/main_classes/feature_attribution.rst b/docs/source/main_classes/feature_attribution.rst index 174b405c..1f282626 100644 --- a/docs/source/main_classes/feature_attribution.rst +++ b/docs/source/main_classes/feature_attribution.rst @@ -17,7 +17,7 @@ Attribution Methods .. autoclass:: inseq.attr.FeatureAttribution :members: -Gradient Attribution Methods +Gradient-based Attribution Methods ----------------------------------------------------------------------------------------------------------------------- .. autoclass:: inseq.attr.feat.GradientAttributionRegistry @@ -67,7 +67,7 @@ Layer Attribution Methods :members: -Attention Attribution Methods +Internals-based Attribution Methods ----------------------------------------------------------------------------------------------------------------------- .. autoclass:: inseq.attr.feat.InternalsAttributionRegistry @@ -76,3 +76,18 @@ Attention Attribution Methods .. autoclass:: inseq.attr.feat.AttentionWeightsAttribution :members: + +Perturbation-based Attribution Methods +----------------------------------------------------------------------------------------------------------------------- + +.. autoclass:: inseq.attr.feat.PerturbationAttributionRegistry + :members: + +.. autoclass:: inseq.attr.feat.OcclusionAttribution + :members: + +.. autoclass:: inseq.attr.feat.LimeAttribution + :members: + +.. autoclass:: inseq.attr.feat.ValueZeroingAttribution + :members: \ No newline at end of file diff --git a/inseq/attr/feat/__init__.py b/inseq/attr/feat/__init__.py index cc07f530..2b25778a 100644 --- a/inseq/attr/feat/__init__.py +++ b/inseq/attr/feat/__init__.py @@ -17,6 +17,8 @@ from .perturbation_attribution import ( LimeAttribution, OcclusionAttribution, + PerturbationAttributionRegistry, + ValueZeroingAttribution, ) __all__ = [ @@ -39,4 +41,6 @@ "OcclusionAttribution", "LimeAttribution", "SequentialIntegratedGradientsAttribution", + "ValueZeroingAttribution", + "PerturbationAttributionRegistry", ] diff --git a/inseq/attr/feat/attribution_utils.py b/inseq/attr/feat/attribution_utils.py index 8da4f899..a9679845 100644 --- a/inseq/attr/feat/attribution_utils.py +++ b/inseq/attr/feat/attribution_utils.py @@ -144,11 +144,15 @@ def extract_args( def get_source_target_attributions( attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]], is_encoder_decoder: bool, + has_sequence_scores: bool = False, ) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]: if isinstance(attr, tuple): if is_encoder_decoder: - return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None) + if has_sequence_scores: + return (attr[0], attr[1], attr[2]) + else: + return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None) else: - return (None, attr[0]) + return (None, None, attr[0]) if has_sequence_scores else (None, attr[0]) else: return (attr, None) if is_encoder_decoder else (None, attr) diff --git a/inseq/attr/feat/feature_attribution.py b/inseq/attr/feat/feature_attribution.py index 250cd700..ce0e9300 100644 --- a/inseq/attr/feat/feature_attribution.py +++ b/inseq/attr/feat/feature_attribution.py @@ -114,6 +114,7 @@ def __init__(self, attribution_model: "AttributionModel", hook_to_model: bool = self.use_hidden_states: bool = False self.use_predicted_target: bool = True self.use_model_config: bool = False + self.is_final_step_method: bool = False if hook_to_model: self.hook(**kwargs) @@ -272,6 +273,35 @@ def _run_compatibility_checks(self, attributed_fn) -> None: " method." ) + @staticmethod + def _build_multistep_output_from_single_step( + single_step_output: FeatureAttributionStepOutput, + attr_pos_start: int, + attr_pos_end: int, + ) -> list[FeatureAttributionStepOutput]: + if single_step_output.step_scores: + raise ValueError("step_scores are not supported for final step attribution methods.") + num_seq = len(single_step_output.prefix) + steps = [] + for pos_idx in range(attr_pos_start, attr_pos_end): + step_output = single_step_output.clone_empty() + step_output.source = single_step_output.source + step_output.prefix = [single_step_output.prefix[seq_idx][:pos_idx] for seq_idx in range(num_seq)] + step_output.target = ( + single_step_output.target + if pos_idx == attr_pos_end - 1 + else [[single_step_output.prefix[seq_idx][pos_idx]] for seq_idx in range(num_seq)] + ) + if single_step_output.source_attributions is not None: + step_output.source_attributions = single_step_output.source_attributions[:, :, pos_idx - 1] + if single_step_output.target_attributions is not None: + step_output.target_attributions = single_step_output.target_attributions[:, :pos_idx, pos_idx - 1] + single_step_output.step_scores = {} + if single_step_output.sequence_scores is not None: + step_output.sequence_scores = single_step_output.sequence_scores + steps.append(step_output) + return steps + def format_contrastive_targets( self, target_sequences: TextSequences, @@ -416,9 +446,9 @@ def attribute( target_lengths=targets_lengths, method_name=self.method_name, show=show_progress, - pretty=pretty_progress, + pretty=False if self.is_final_step_method else pretty_progress, attr_pos_start=attr_pos_start, - attr_pos_end=attr_pos_end, + attr_pos_end=1 if self.is_final_step_method else attr_pos_end, ) whitespace_indexes = find_char_indexes(sequences.targets, " ") attribution_outputs = [] @@ -427,6 +457,8 @@ def attribute( # Attribution loop for generation for step in range(attr_pos_start, iter_pos_end): + if self.is_final_step_method and step != iter_pos_end - 1: + continue tgt_ids, tgt_mask = batch.get_step_target(step, with_attention=True) step_output = self.filtered_attribute_step( batch[:step], @@ -450,7 +482,7 @@ def attribute( contrast_targets_alignments=contrast_targets_alignments, ) attribution_outputs.append(step_output) - if pretty_progress: + if pretty_progress and not self.is_final_step_method: tgt_tokens = batch.target_tokens skipped_prefixes = tok2string(self.attribution_model, tgt_tokens, end=attr_pos_start) attributed_sentences = tok2string(self.attribution_model, tgt_tokens, attr_pos_start, step + 1) @@ -471,12 +503,17 @@ def attribute( end = datetime.now() close_progress_bar(pbar, show=show_progress, pretty=pretty_progress) batch.detach().to("cpu") + if self.is_final_step_method: + attribution_outputs = self._build_multistep_output_from_single_step( + attribution_outputs[0], + attr_pos_start=attr_pos_start, + attr_pos_end=iter_pos_end, + ) out = FeatureAttributionOutput( sequence_attributions=FeatureAttributionSequenceOutput.from_step_attributions( attributions=attribution_outputs, tokenized_target_sentences=target_tokens_with_ids, - pad_id=self.attribution_model.pad_token, - has_bos_token=self.attribution_model.is_encoder_decoder, + pad_token=self.attribution_model.pad_token, attr_pos_end=attr_pos_end, ), step_attributions=attribution_outputs if output_step_attributions else None, @@ -593,7 +630,7 @@ def filtered_attribute_step( step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args).to("cpu") # Reinsert finished sentences if target_attention_mask is not None and is_filtered: - step_output.remap_from_filtered(target_attention_mask, orig_batch) + step_output.remap_from_filtered(target_attention_mask, orig_batch, self.is_final_step_method) step_output = step_output.detach().to("cpu") return step_output diff --git a/inseq/attr/feat/internals_attribution.py b/inseq/attr/feat/internals_attribution.py index 9c6e8923..003c1869 100644 --- a/inseq/attr/feat/internals_attribution.py +++ b/inseq/attr/feat/internals_attribution.py @@ -16,12 +16,12 @@ import logging from typing import Any, Optional +import torch from captum._utils.typing import TensorOrTupleOfTensorsGeneric -from captum.attr._utils.attribution import Attribution from ...data import MultiDimensionalFeatureAttributionStepOutput from ...utils import Registry -from ...utils.typing import MultiLayerMultiUnitScoreTensor +from ...utils.typing import InseqAttribution, MultiLayerMultiUnitScoreTensor from .feature_attribution import FeatureAttribution logger = logging.getLogger(__name__) @@ -38,7 +38,7 @@ class AttentionWeightsAttribution(InternalsAttributionRegistry): method_name = "attention" - class AttentionWeights(Attribution): + class AttentionWeights(InseqAttribution): @staticmethod def has_convergence_delta() -> bool: return False @@ -74,9 +74,14 @@ def attribute( :class:`~inseq.data.MultiDimensionalFeatureAttributionStepOutput`: A step output containing attention weights for each layer and head, with shape :obj:`(batch_size, seq_len, n_layers, n_heads)`. """ - # We adopt the format [batch_size, sequence_length, num_layers, num_heads] + # We adopt the format [batch_size, sequence_length, sequence_length, num_layers, num_heads] # for consistency with other multi-unit methods (e.g. gradient attribution) - decoder_self_attentions = decoder_self_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2) + decoder_self_attentions = decoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2) + decoder_self_attentions = torch.where( + decoder_self_attentions == 0, + (torch.ones_like(decoder_self_attentions) * float("nan")), + decoder_self_attentions, + ) if self.forward_func.is_encoder_decoder: sequence_scores = {} if len(inputs) > 1: @@ -85,10 +90,11 @@ def attribute( target_attributions = None sequence_scores["decoder_self_attentions"] = decoder_self_attentions sequence_scores["encoder_self_attentions"] = ( - encoder_self_attentions.to("cpu").clone().permute(0, 3, 4, 1, 2) + encoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2) ) + cross_attentions = cross_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2) return MultiDimensionalFeatureAttributionStepOutput( - source_attributions=cross_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2), + source_attributions=cross_attentions, target_attributions=target_attributions, sequence_scores=sequence_scores, _num_dimensions=2, # num_layers, num_heads @@ -106,6 +112,8 @@ def __init__(self, attribution_model, **kwargs): self.use_attention_weights = True # Does not rely on predicted output (i.e. decoding strategy agnostic) self.use_predicted_target = False + # Needs only the final generation step to extract scores + self.is_final_step_method = True self.method = self.AttentionWeights(attribution_model) def attribute_step( diff --git a/inseq/attr/feat/ops/__init__.py b/inseq/attr/feat/ops/__init__.py index 388ab042..7d86167a 100644 --- a/inseq/attr/feat/ops/__init__.py +++ b/inseq/attr/feat/ops/__init__.py @@ -2,10 +2,12 @@ from .lime import Lime from .monotonic_path_builder import MonotonicPathBuilder from .sequential_integrated_gradients import SequentialIntegratedGradients +from .value_zeroing import ValueZeroing __all__ = [ "DiscretetizedIntegratedGradients", "MonotonicPathBuilder", + "ValueZeroing", "Lime", "SequentialIntegratedGradients", ] diff --git a/inseq/attr/feat/ops/value_zeroing.py b/inseq/attr/feat/ops/value_zeroing.py new file mode 100644 index 00000000..c2afdbc7 --- /dev/null +++ b/inseq/attr/feat/ops/value_zeroing.py @@ -0,0 +1,394 @@ +# Copyright 2023 The Inseq Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from enum import Enum +from typing import TYPE_CHECKING, Callable, Optional + +import torch +from captum._utils.typing import TensorOrTupleOfTensorsGeneric +from torch import nn +from torch.utils.hooks import RemovableHandle + +from ....utils import ( + StackFrame, + find_block_stack, + get_post_variable_assignment_hook, + recursive_get_submodule, + validate_indices, +) +from ....utils.typing import ( + EmbeddingsTensor, + InseqAttribution, + MultiLayerEmbeddingsTensor, + MultiLayerScoreTensor, + OneOrMoreIndices, + OneOrMoreIndicesDict, +) + +if TYPE_CHECKING: + from ....models import HuggingfaceModel + +logger = logging.getLogger(__name__) + + +class ValueZeroingSimilarityMetric(Enum): + COSINE = "cosine" + EUCLIDEAN = "euclidean" + + +class ValueZeroingModule(Enum): + DECODER = "decoder" + ENCODER = "encoder" + + +class ValueZeroing(InseqAttribution): + """Value Zeroing method for feature attribution. + + Introduced by `Mohebbi et al. (2023) `__ to quantify context mixing inside + Transformer models. The method is based on the observation that context mixing is regulated by the value vectors + of the attention mechanism. The method consists of two steps: + + 1. Zeroing the value vectors of the attention mechanism for a given token index at a given layer of the model. + 2. Computing the similarity between hidden states produced with and without the zeroing operation, and using it + as a measure of context mixing for the given token at the given layer. + + The method is converted into a feature attribution method by allowing for extraction of value zeroing scores at + specific layers, or by aggregating them across layers. + + Attributes: + SIMILARITY_METRICS (:obj:`Dict[str, Callable]`): + Dictionary of available similarity metrics to be used forvcomputing the distance between hidden states + produced with and without the zeroing operation. Converted to distances as 1 - produced values. + forward_func (:obj:`AttributionModel`): + The attribution model to be used for value zeroing. + clean_block_output_states (:obj:`Dict[int, torch.Tensor]`): + Dictionary to store the hidden states produced by the model without the zeroing operation. + corrupted_block_output_states (:obj:`Dict[int, torch.Tensor]`): + Dictionary to store the hidden states produced by the model with the zeroing operation. + """ + + SIMILARITY_METRICS = { + "cosine": nn.CosineSimilarity(dim=-1), + "euclidean": lambda x, y: torch.cdist(x, y, p=2), + } + + def __init__(self, forward_func: "HuggingfaceModel") -> None: + super().__init__(forward_func) + self.clean_block_output_states: dict[int, EmbeddingsTensor] = {} + self.corrupted_block_output_states: dict[int, EmbeddingsTensor] = {} + + @staticmethod + def get_value_zeroing_hook(varname: str = "value") -> Callable[..., None]: + """Returns a hook to zero the value vectors of the attention mechanism. + + Args: + varname (:obj:`str`, optional): The name of the variable containing the value vectors. The variable + is expected to be a 3D tensor of shape (batch_size, num_heads, seq_len) and is retrieved from the + local variables of the execution frame during the forward pass. + """ + + def value_zeroing_forward_mid_hook( + frame: StackFrame, + zeroed_token_index: Optional[int] = None, + zeroed_units_indices: Optional[OneOrMoreIndices] = None, + batch_size: int = 1, + ) -> None: + if varname not in frame.f_locals: + raise ValueError( + f"Variable {varname} not found in the local frame." + f"Other variable names: {', '.join(frame.f_locals.keys())}" + ) + # Zeroing value vectors corresponding to the given token index + if zeroed_token_index is not None: + values_size = frame.f_locals[varname].size() + if len(values_size) == 3: # Assume merged shape (bsz * num_heads, seq_len, hidden_size) e.g. Whisper + values = frame.f_locals[varname].view(batch_size, -1, *values_size[1:]) + elif len(values_size) == 4: # Assume per-head shape (bsz, num_heads, seq_len, hidden_size) e.g. GPT-2 + values = frame.f_locals[varname].clone() + else: + raise ValueError( + f"Value vector shape {frame.f_locals[varname].size()} not supported. " + "Supported shapes: (batch_size, num_heads, seq_len, hidden_size) or " + "(batch_size * num_heads, seq_len, hidden_size)" + ) + zeroed_units_indices = validate_indices(values, 1, zeroed_units_indices).to(values.device) + zeroed_token_index = torch.tensor(zeroed_token_index, device=values.device) + # Mask heads corresponding to zeroed units and tokens corresponding to zeroed tokens + values[:, zeroed_units_indices, zeroed_token_index] = 0 + if len(values_size) == 3: + frame.f_locals[varname] = values.view(-1, *values_size[1:]) + elif len(values_size) == 4: + frame.f_locals[varname] = values + + return value_zeroing_forward_mid_hook + + def get_states_extract_and_patch_hook(self, block_idx: int, hidden_state_idx: int = 0) -> Callable[..., None]: + """Returns a hook to extract the produced hidden states (corrupted by value zeroing) + and patch them with pre-computed clean states that will be passed onwards in the model forward. + + Args: + block_idx (:obj:`int`): The idx of the block at which the hook is applied, used to store extracted states. + hidden_state_idx (:obj:`int`, optional): The index of the hidden state in the model output tuple. + """ + + def states_extract_and_patch_forward_hook(module, args, output) -> None: + self.corrupted_block_output_states[block_idx] = output[hidden_state_idx].clone().float().detach().cpu() + + # Rebuild the output tuple patching the clean states at the place of the corrupted ones + output = ( + output[:hidden_state_idx] + + (self.clean_block_output_states[block_idx].to(output[hidden_state_idx].device),) + + output[hidden_state_idx + 1 :] + ) + return output + + return states_extract_and_patch_forward_hook + + @staticmethod + def has_convergence_delta() -> bool: + return False + + def compute_modules_post_zeroing_similarity( + self, + inputs: TensorOrTupleOfTensorsGeneric, + additional_forward_args: TensorOrTupleOfTensorsGeneric, + hidden_states: MultiLayerEmbeddingsTensor, + attention_module_name: str, + attributed_seq_len: Optional[int] = None, + similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value, + mode: str = ValueZeroingModule.DECODER.value, + zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, + min_score_threshold: float = 1e-5, + use_causal_mask: bool = False, + ) -> MultiLayerScoreTensor: + """Given a ``nn.ModuleList``, computes the similarity between the clean and corrupted states for each block. + + Args: + modules (:obj:`nn.ModuleList`): The list of modules to compute the similarity for. + hidden_states (:obj:`MultiLayerEmbeddingsTensor`): The cached hidden states of the modules to use as clean + counterparts when computing the similarity. + attention_module_name (:obj:`str`): The name of the attention module to zero the values for. + attributed_seq_len (:obj:`int`): The length of the sequence to attribute. If not specified, it is assumed + to be the same as the length of the hidden states. + similarity_metric (:obj:`str`): The name of the similarity metric used. Default: "cosine". + mode (:obj:`str`): The mode of the model to compute the similarity for. Default: "decoder". + zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int]]` or :obj:`dict` with :obj:`int` keys and + `Union[int, tuple[int, int], list[int]]` values, optional): The indices of the attention heads + that should be zeroed to compute corrupted states. + - If None, all attention heads across all layers are zeroed. + - If an integer, the same attention head is zeroed across all layers. + - If a tuple of two integers, the attention heads in the range are zeroed across all layers. + - If a list of integers, the attention heads in the list are zeroed across all layers. + - If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for + the corresponding layer. Any missing layer will not be zeroed. + Default: None. + min_score_threshold (:obj:`float`, optional): The minimum score threshold to consider when computing the + similarity. Default: 1e-5. + use_causal_mask (:obj:`bool`, optional): Whether a causal mask is applied to zeroing scores Default: False. + + Returns: + :obj:`MultiLayerScoreTensor`: A tensor of shape ``[batch_size, seq_len, num_layer]`` containing distances + (1 - similarity score) between original and corrupted states for each layer. + """ + if mode == ValueZeroingModule.DECODER.value: + modules: nn.ModuleList = find_block_stack(self.forward_func.get_decoder()) + elif mode == ValueZeroingModule.ENCODER.value: + modules: nn.ModuleList = find_block_stack(self.forward_func.get_encoder()) + else: + raise NotImplementedError(f"Mode {mode} not implemented for value zeroing.") + if attributed_seq_len is None: + attributed_seq_len = hidden_states.size(2) + batch_size = hidden_states.size(0) + generated_seq_len = hidden_states.size(2) + num_layers = len(modules) + + # Store clean hidden states for later use. Starts at 1 since the first element of the modules stack is the + # embedding layer, and we are only interested in the transformer blocks outputs. + self.clean_block_output_states = { + block_idx: hidden_states[:, block_idx + 1, ...].clone().detach().cpu() for block_idx in range(len(modules)) + } + # Scores for every layer of the model + all_scores = torch.ones( + batch_size, num_layers, generated_seq_len, attributed_seq_len, device=hidden_states.device + ) * float("nan") + + # Hooks: + # 1. states_extract_and_patch_hook on the transformer block stores corrupted states and force clean states + # as the output of the block forward pass, i.e. the zeroing is done independently across layers. + # 2. value_zeroing_hook on the attention module performs the value zeroing by replacing the "value" tensor + # during the forward (name is config-dependent) with a zeroed version for the specified token index. + # + # State extraction hooks can be registered only once since they are token-independent + # Skip last block since its states are not used raw, but may have further transformations applied to them + # (e.g. LayerNorm, Dropout). These are extracted separately from the model outputs. + states_extraction_hook_handles: list[RemovableHandle] = [] + for block_idx in range(len(modules) - 1): + states_extract_and_patch_hook = self.get_states_extract_and_patch_hook(block_idx, hidden_state_idx=0) + states_extraction_hook_handles.append( + modules[block_idx].register_forward_hook(states_extract_and_patch_hook) + ) + # Zeroing is done for every token in the sequence separately (O(n) complexity) + for token_idx in range(attributed_seq_len): + value_zeroing_hook_handles: list[RemovableHandle] = [] + # Value zeroing hooks are registered for every token separately since they are token-dependent + for block_idx, block in enumerate(modules): + attention_module = recursive_get_submodule(block, attention_module_name) + if attention_module is None: + raise ValueError(f"Attention module {attention_module_name} not found in block {block_idx}.") + if isinstance(zeroed_units_indices, dict): + if block_idx not in zeroed_units_indices: + continue + zeroed_units_indices_block = zeroed_units_indices[block_idx] + else: + zeroed_units_indices_block = zeroed_units_indices + value_zeroing_hook = get_post_variable_assignment_hook( + module=attention_module, + varname=self.forward_func.config.value_vector, + hook_fn=self.get_value_zeroing_hook(self.forward_func.config.value_vector), + zeroed_token_index=token_idx, + zeroed_units_indices=zeroed_units_indices_block, + batch_size=batch_size, + ) + value_zeroing_hook_handle = attention_module.register_forward_pre_hook(value_zeroing_hook) + value_zeroing_hook_handles.append(value_zeroing_hook_handle) + + # Run forward pass with hooks. Fills self.corrupted_hidden_states with corrupted states across layers + # when zeroing the specified token index. + with torch.no_grad(): + output = self.forward_func.forward_with_output( + *inputs, *additional_forward_args, output_hidden_states=True + ) + # Extract last layer states directly from the model outputs + # This allows us to handle the presence of additional transformations (e.g. LayerNorm, Dropout) + # in the last layer automatically. + corrupted_states_dict = self.forward_func.get_hidden_states_dict(output) + corrupted_decoder_last_hidden_state = ( + corrupted_states_dict[f"{mode}_hidden_states"][:, -1, ...].clone().detach().cpu() + ) + self.corrupted_block_output_states[len(modules) - 1] = corrupted_decoder_last_hidden_state + for handle in value_zeroing_hook_handles: + handle.remove() + for block_idx in range(len(modules)): + similarity_scores = self.SIMILARITY_METRICS[similarity_metric]( + self.clean_block_output_states[block_idx].float(), self.corrupted_block_output_states[block_idx] + ) + if use_causal_mask: + all_scores[:, block_idx, token_idx:, token_idx] = 1 - similarity_scores[:, token_idx:] + else: + all_scores[:, block_idx, :, token_idx] = 1 - similarity_scores + self.corrupted_block_output_states = {} + for handle in states_extraction_hook_handles: + handle.remove() + self.clean_block_output_states = {} + all_scores = torch.where(all_scores < min_score_threshold, torch.zeros_like(all_scores), all_scores) + # Normalize scores to sum to 1 + per_token_sum_score = all_scores.nansum(dim=-1, keepdim=True) + per_token_sum_score[per_token_sum_score == 0] = 1 + all_scores = all_scores / per_token_sum_score + + # Final shape: [batch_size, attributed_seq_len, generated_seq_len, num_layers] + return all_scores.permute(0, 3, 2, 1) + + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + additional_forward_args: TensorOrTupleOfTensorsGeneric, + similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value, + encoder_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, + decoder_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, + cross_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, + encoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None, + decoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None, + output_decoder_self_scores: bool = True, + output_encoder_self_scores: bool = True, + ) -> TensorOrTupleOfTensorsGeneric: + """Perform attribution using the Value Zeroing method. + + Args: + similarity_metric (:obj:`str`, optional): The similarity metric to use for computing the distance between + hidden states produced with and without the zeroing operation. Default: cosine similarity. + zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int]]` or :obj:`dict` with :obj:`int` keys and + `Union[int, tuple[int, int], list[int]]` values, optional): The indices of the attention heads + that should be zeroed to compute corrupted states. + - If None, all attention heads across all layers are zeroed. + - If an integer, the same attention head is zeroed across all layers. + - If a tuple of two integers, the attention heads in the range are zeroed across all layers. + - If a list of integers, the attention heads in the list are zeroed across all layers. + - If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for + the corresponding layer. + + Default: None (all heads are zeroed for every layer). + encoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1, + source_seq_len, hidden_size]`` containing hidden states of the encoder. Available only for + encoder-decoders models. Default: None. + decoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1, + target_seq_len, hidden_size]`` containing hidden states of the decoder. + output_decoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the + decoder self-attention value vectors in encoder-decoder models. Cannot be false for decoder-only, or + if target-side attribution is requested using `attribute_target=True`. Default: True. + output_encoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the + encoder self-attention value vectors in encoder-decoder models. Default: True. + + Returns: + `TensorOrTupleOfTensorsGeneric`: Attribution outputs for source-only or source + target feature attribution + """ + if similarity_metric not in self.SIMILARITY_METRICS: + raise ValueError( + f"Similarity metric {similarity_metric} not available." + f"Available metrics: {','.join(self.SIMILARITY_METRICS.keys())}" + ) + decoder_scores = None + if not self.forward_func.is_encoder_decoder or output_decoder_self_scores or len(inputs) > 1: + decoder_scores = self.compute_modules_post_zeroing_similarity( + inputs=inputs, + additional_forward_args=additional_forward_args, + hidden_states=decoder_hidden_states, + attention_module_name=self.forward_func.config.self_attention_module, + similarity_metric=similarity_metric, + mode=ValueZeroingModule.DECODER.value, + zeroed_units_indices=decoder_zeroed_units_indices, + use_causal_mask=True, + ) + # Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values + # Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py + if self.forward_func.is_encoder_decoder: + encoder_scores = None + if output_encoder_self_scores: + encoder_scores = self.compute_modules_post_zeroing_similarity( + inputs=inputs, + additional_forward_args=additional_forward_args, + hidden_states=encoder_hidden_states, + attention_module_name=self.forward_func.config.self_attention_module, + similarity_metric=similarity_metric, + mode=ValueZeroingModule.ENCODER.value, + zeroed_units_indices=encoder_zeroed_units_indices, + ) + cross_scores = self.compute_modules_post_zeroing_similarity( + inputs=inputs, + additional_forward_args=additional_forward_args, + hidden_states=decoder_hidden_states, + attributed_seq_len=encoder_hidden_states.size(2), + attention_module_name=self.forward_func.config.cross_attention_module, + similarity_metric=similarity_metric, + mode=ValueZeroingModule.DECODER.value, + zeroed_units_indices=cross_zeroed_units_indices, + ) + return encoder_scores, cross_scores, decoder_scores + elif encoder_zeroed_units_indices is not None or cross_zeroed_units_indices is not None: + logger.warning( + "Zeroing indices for encoder and cross-attentions were specified, but the model is not an " + "encoder-decoder. Use `decoder_zeroed_units_indices` to parametrize zeroing for the decoder module." + ) + return (decoder_scores,) diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py index fbebb780..c3eb0211 100644 --- a/inseq/attr/feat/perturbation_attribution.py +++ b/inseq/attr/feat/perturbation_attribution.py @@ -6,11 +6,12 @@ from ...data import ( CoarseFeatureAttributionStepOutput, GranularFeatureAttributionStepOutput, + MultiDimensionalFeatureAttributionStepOutput, ) from ...utils import Registry from .attribution_utils import get_source_target_attributions from .gradient_attribution import FeatureAttribution -from .ops import Lime +from .ops import Lime, ValueZeroing logger = logging.getLogger(__name__) @@ -117,3 +118,101 @@ def attribute_step( target_attributions=out.target_attributions, sequence_scores=out.sequence_scores, ) + + +class ValueZeroingAttribution(PerturbationAttributionRegistry): + """Value Zeroing method for feature attribution. + + Introduced by `Mohebbi et al. (2023) `__ to quantify context mixing + in Transformer models. The method is based on the observation that context mixing is regulated by the value vectors + of the attention mechanism. The method consists of two steps: + + 1. Zeroing the value vectors of the attention mechanism for a given token index at a given layer of the model. + 2. Computing the similarity between hidden states produced with and without the zeroing operation, and using it + as a measure of context mixing for the given token at the given layer. + + The method is converted into a feature attribution method by allowing for extraction of value zeroing scores at + specific layers, or by aggregating them across layers. + + Reference implementations: + - Original implementation: `hmohebbi/ValueZeroing `__ + - Encoder-decoder implementation: `hmohebbi/ContextMixingASR `__ + + Args: + similarity_metric (:obj:`str`, optional): The similarity metric to use for computing the distance between + hidden states produced with and without the zeroing operation. Options: cosine, euclidean. Default: cosine. + encoder_zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int], dict]`, optional): The indices of + the attention heads that should be zeroed to compute corrupted states in the encoder self-attention module. + Not used for decoder-only models, or if ``output_encoder_self_scores`` is False. Format + + - None: all attention heads across all layers are zeroed. + - int: the same attention head is zeroed across all layers. + - tuple of two integers: the attention heads in the range are zeroed across all layers. + - list of integers: the attention heads in the list are zeroed across all layers. + - dictionary: the keys are the layer indices and the values are the zeroed attention heads for the corresponding layer. + + Default: None (all heads are zeroed for every encoder layer). + decoder_zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int], dict]`, optional): Same as + ``encoder_zeroed_units_indices`` but for the decoder self-attention module. Not used for encoder-decoder + models or if ``output_decoder_self_scores`` is False. Default: None (all heads are zeroed for every decoder layer). + cross_zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int], dict]`, optional): Same as + ``encoder_zeroed_units_indices`` but for the cross-attention module in encoder-decoder models. Not used + if the model is decoder-only. Default: None (all heads are zeroed for every layer). + output_decoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the + decoder self-attention value vectors in encoder-decoder models. Cannot be false for decoder-only, or + if target-side attribution is requested using `attribute_target=True`. Default: True. + output_encoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the + encoder self-attention value vectors in encoder-decoder models. Default: True. + + Returns: + :class:`~inseq.data.MultiDimensionalFeatureAttributionStepOutput`: The final dimension returned by the method + is ``[attributed_seq_len, generated_seq_len, num_layers]``. If ``output_decoder_self_scores`` and + ``output_encoder_self_scores`` are True, the respective scores are returned in the ``sequence_scores`` + output dictionary. + """ + + method_name = "value_zeroing" + + def __init__(self, attribution_model, **kwargs): + super().__init__(attribution_model, hook_to_model=False) + # Hidden states will be passed to the attribute_step method + self.use_hidden_states = True + # Does not rely on predicted output (i.e. decoding strategy agnostic) + self.use_predicted_target = False + # Uses model configuration to access attention module and value vector variable + self.use_model_config = True + # Needs only the final generation step to extract scores + self.is_final_step_method = True + self.method = ValueZeroing(attribution_model) + self.hook(**kwargs) + + def attribute_step( + self, + attribute_fn_main_args: dict[str, Any], + attribution_args: dict[str, Any] = {}, + ) -> MultiDimensionalFeatureAttributionStepOutput: + attr = self.method.attribute(**attribute_fn_main_args, **attribution_args) + encoder_self_scores, decoder_cross_scores, decoder_self_scores = get_source_target_attributions( + attr, self.attribution_model.is_encoder_decoder, has_sequence_scores=True + ) + sequence_scores = {} + if self.attribution_model.is_encoder_decoder: + if len(attribute_fn_main_args["inputs"]) > 1: + target_attributions = decoder_self_scores.to("cpu") + else: + target_attributions = None + if decoder_self_scores is not None: + sequence_scores["decoder_self_scores"] = decoder_self_scores.to("cpu") + if encoder_self_scores is not None: + sequence_scores["encoder_self_scores"] = encoder_self_scores.to("cpu") + return MultiDimensionalFeatureAttributionStepOutput( + source_attributions=decoder_cross_scores.to("cpu"), + target_attributions=target_attributions, + sequence_scores=sequence_scores, + _num_dimensions=1, # num_layers + ) + return MultiDimensionalFeatureAttributionStepOutput( + source_attributions=None, + target_attributions=decoder_self_scores, + _num_dimensions=1, # num_layers + ) diff --git a/inseq/attr/step_functions.py b/inseq/attr/step_functions.py index d9cd6092..83aa8d6e 100644 --- a/inseq/attr/step_functions.py +++ b/inseq/attr/step_functions.py @@ -462,8 +462,7 @@ def register_step_function( attribution targets by gradient-based feature attribution methods. Args: - fn (:obj:`callable`): The function to be used to compute step scores. Default parameters (use kwargs to capture - unused ones when defining your function): + fn (:obj:`callable`): The function to be used to compute step scores. Default parameters (use kwargs to capture unused ones when defining your function): - :obj:`attribution_model`: an :class:`~inseq.models.AttributionModel` instance, corresponding to the model used for computing the score. diff --git a/inseq/commands/commands_utils.py b/inseq/commands/commands_utils.py index 7701409a..dbfb8ac4 100644 --- a/inseq/commands/commands_utils.py +++ b/inseq/commands/commands_utils.py @@ -18,5 +18,4 @@ def command_args_docstring(cls): field_help = field.metadata.get("help", "") docstring += textwrap.dedent(f"\n**{field.name}** (``{field_type}``): {field_help}\n") cls.__doc__ = docstring - print(docstring) return cls diff --git a/inseq/data/aggregator.py b/inseq/data/aggregator.py index bb475707..dbae5352 100644 --- a/inseq/data/aggregator.py +++ b/inseq/data/aggregator.py @@ -12,9 +12,10 @@ aggregate_token_sequence, available_classes, extract_signature_args, + validate_indices, ) from ..utils import normalize as normalize_fn -from ..utils.typing import IndexSpan, TokenWithId +from ..utils.typing import IndexSpan, OneOrMoreIndices, TokenWithId from .aggregation_functions import AggregationFunction from .data_utils import TensorWrapper @@ -305,7 +306,7 @@ def _process_attribution_scores( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, - select_idx: Union[int, tuple[int, int], list[int], None] = None, + select_idx: Optional[OneOrMoreIndices] = None, normalize: bool = True, **kwargs, ): @@ -366,7 +367,7 @@ def aggregate_source_attributions( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, - select_idx: Union[int, tuple[int, int], list[int], None] = None, + select_idx: Optional[OneOrMoreIndices] = None, normalize: bool = True, **kwargs, ): @@ -380,7 +381,7 @@ def aggregate_target_attributions( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, - select_idx: Union[int, tuple[int, int], list[int], None] = None, + select_idx: Optional[OneOrMoreIndices] = None, normalize: bool = True, **kwargs, ): @@ -398,7 +399,7 @@ def aggregate_sequence_scores( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, - select_idx: Union[int, tuple[int, int], list[int], None] = None, + select_idx: Optional[OneOrMoreIndices] = None, **kwargs, ): if aggregate_fn.takes_sequence_scores: @@ -439,46 +440,12 @@ def is_compatible(attr: "FeatureAttributionSequenceOutput"): def _filter_scores( scores: torch.Tensor, dim: int = -1, - indices: Union[int, tuple[int, int], list[int], None] = None, + indices: Optional[OneOrMoreIndices] = None, ) -> torch.Tensor: - n_units = scores.shape[dim] - - if hasattr(indices, "__iter__"): - if len(indices) == 0: - raise RuntimeError("At least two indices must be specified for aggregation.") - if len(indices) == 1: - indices = indices[0] - + indexed = scores.index_select(dim, validate_indices(scores, dim, indices).to(scores.device)) if isinstance(indices, int): - if indices not in range(-n_units, n_units): - raise IndexError(f"Index out of range. Scores only have {n_units} units.") - indices = indices if indices >= 0 else n_units + indices - return scores.select(dim, torch.tensor(indices, device=scores.device)) - else: - if indices is None: - indices = (0, n_units) - logger.info("No indices specified for extraction. Using all units by default.") - - # Convert negative indices to positive indices - if hasattr(indices, "__iter__"): - indices = type(indices)([h_idx if h_idx >= 0 else n_units + h_idx for h_idx in indices]) - if not hasattr(indices, "__iter__") or ( - len(indices) == 2 and isinstance(indices, tuple) and indices[0] >= indices[1] - ): - raise RuntimeError( - "A (start, end) tuple of indices representing a span, a list of individual indices" - " or a single index must be specified for select_idx." - ) - max_idx_val = n_units if isinstance(indices, list) else n_units + 1 - if not all(h in range(-n_units, max_idx_val) for h in indices): - raise IndexError("One or more index out of range. Scores only have {n_units} units.") - if len(set(indices)) != len(indices): - raise IndexError("Duplicate indices are not allowed.") - if isinstance(indices, tuple): - scores = scores.index_select(dim, torch.arange(indices[0], indices[1], device=scores.device)) - else: - scores = scores.index_select(dim, torch.tensor(indices, device=scores.device)) - return scores + return indexed.squeeze(dim) + return indexed @staticmethod def _aggregate_scores( diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py index f3671244..7841cf7c 100644 --- a/inseq/data/attribution.py +++ b/inseq/data/attribution.py @@ -12,6 +12,7 @@ get_sequences_from_batched_steps, json_advanced_dump, json_advanced_load, + pad_with_nan, pretty_dict, remap_from_filtered, ) @@ -178,9 +179,8 @@ def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callab def from_step_attributions( cls, attributions: list["FeatureAttributionStepOutput"], - tokenized_target_sentences: Optional[list[list[TokenWithId]]] = None, - pad_id: Optional[Any] = None, - has_bos_token: bool = True, + tokenized_target_sentences: list[list[TokenWithId]], + pad_token: Optional[Any] = None, attr_pos_end: Optional[int] = None, ) -> list["FeatureAttributionSequenceOutput"]: """Converts a list of :class:`~inseq.data.attribution.FeatureAttributionStepOutput` objects containing multiple @@ -198,36 +198,35 @@ def from_step_attributions( num_sequences = len(attr.prefix) if not all(len(attr.prefix) == num_sequences for attr in attributions): raise ValueError("All the attributions must include the same number of sequences.") - seq_attributions = [] - sources = None - if attr.source_attributions is not None: - sources = [drop_padding(attr.source[seq_id], pad_id) for seq_id in range(num_sequences)] - targets = [ - drop_padding([a.target[seq_id][0] for a in attributions], pad_id) for seq_id in range(num_sequences) - ] - if tokenized_target_sentences is None: - tokenized_target_sentences = targets - if has_bos_token: - tokenized_target_sentences = [tok_seq[1:] for tok_seq in tokenized_target_sentences] - tokenized_target_sentences = [ - drop_padding(tokenized_target_sentences[seq_id], pad_id) for seq_id in range(num_sequences) - ] + seq_attributions: list[FeatureAttributionSequenceOutput] = [] + sources = [] + targets = [] + pos_start = [] + for seq_idx in range(num_sequences): + if attr.source_attributions is not None: + sources.append(drop_padding(attr.source[seq_idx], pad_token)) + curr_target = [a.target[seq_idx][0] for a in attributions] + targets.append(drop_padding(curr_target, pad_token)) + if all(attr.prefix[seq_idx][0] == pad_token for seq_idx in range(num_sequences)): + tokenized_target_sentences[seq_idx] = tokenized_target_sentences[seq_idx][:1] + drop_padding( + tokenized_target_sentences[seq_idx][1:], pad_token + ) + else: + tokenized_target_sentences[seq_idx] = drop_padding(tokenized_target_sentences[seq_idx], pad_token) if attr_pos_end is None: attr_pos_end = max(len(t) for t in tokenized_target_sentences) - pos_start = [ - min(len(tokenized_target_sentences[seq_id]), attr_pos_end) - len(targets[seq_id]) - for seq_id in range(num_sequences) - ] - for seq_id in range(num_sequences): - source = tokenized_target_sentences[seq_id][: pos_start[seq_id]] if sources is None else sources[seq_id] - seq_attributions.append( - attr.get_sequence_cls( - source=source, - target=tokenized_target_sentences[seq_id], - attr_pos_start=pos_start[seq_id], - attr_pos_end=attr_pos_end, - ) + for seq_idx in range(num_sequences): + # If the model is decoder-only, the source is the input prefix + curr_pos_start = min(len(tokenized_target_sentences[seq_idx]), attr_pos_end) - len(targets[seq_idx]) + pos_start.append(curr_pos_start) + source = tokenized_target_sentences[seq_idx][:curr_pos_start] if not sources else sources[seq_idx] + curr_seq_attribution: FeatureAttributionSequenceOutput = attr.get_sequence_cls( + source=source, + target=tokenized_target_sentences[seq_idx], + attr_pos_start=pos_start[seq_idx], + attr_pos_end=attr_pos_end, ) + seq_attributions.append(curr_seq_attribution) if attr.source_attributions is not None: source_attributions = get_sequences_from_batched_steps([att.source_attributions for att in attributions]) for seq_id in range(num_sequences): @@ -241,18 +240,13 @@ def from_step_attributions( [att.target_attributions for att in attributions], padding_dims=[1] ) for seq_id in range(num_sequences): - if has_bos_token: - target_attributions[seq_id] = target_attributions[seq_id][1:, ...] start_idx = max(pos_start) - pos_start[seq_id] end_idx = start_idx + len(tokenized_target_sentences[seq_id]) target_attributions[seq_id] = target_attributions[seq_id][ start_idx:end_idx, : len(targets[seq_id]), ... # noqa: E203 ] if target_attributions[seq_id].shape[0] != len(tokenized_target_sentences[seq_id]): - empty_final_row = torch.ones( - 1, *target_attributions[seq_id].shape[1:], device=target_attributions[seq_id].device - ) * float("nan") - target_attributions[seq_id] = torch.cat([target_attributions[seq_id], empty_final_row], dim=0) + target_attributions[seq_id] = pad_with_nan(target_attributions[seq_id], dim=0, pad_size=1) seq_attributions[seq_id].target_attributions = target_attributions[seq_id] if attr.step_scores is not None: step_scores = [{} for _ in range(num_sequences)] @@ -427,47 +421,51 @@ def remap_from_filtered( self, target_attention_mask: TargetIdsTensor, batch: Union[DecoderOnlyBatch, EncoderDecoderBatch], + is_final_step_method: bool = False, ) -> None: """Remaps the attributions to the original shape of the input sequence.""" + batch_size = ( + len(batch.sources.input_tokens) if self.source_attributions is not None else len(batch.target_tokens) + ) + source_len = len(batch.sources.input_tokens[0]) + target_len = len(batch.target_tokens[0]) + # Normal per-step attribution outputs have shape (batch_size, seq_len, ...) + other_dims_start_idx = 2 + # Final step attribution outputs have shape (batch_size, seq_len, seq_len, ...) + if is_final_step_method: + other_dims_start_idx += 1 + other_dims = ( + self.source_attributions.shape[other_dims_start_idx:] + if self.source_attributions is not None + else self.target_attributions.shape[other_dims_start_idx:] + ) if self.source_attributions is not None: self.source_attributions = remap_from_filtered( - original_shape=(len(batch.sources.input_tokens), *self.source_attributions.shape[1:]), + original_shape=(batch_size, *self.source_attributions.shape[1:]), mask=target_attention_mask, filtered=self.source_attributions, ) if self.target_attributions is not None: self.target_attributions = remap_from_filtered( - original_shape=(len(batch.target_tokens), *self.target_attributions.shape[1:]), + original_shape=(batch_size, *self.target_attributions.shape[1:]), mask=target_attention_mask, filtered=self.target_attributions, ) if self.step_scores is not None: for score_name, score_tensor in self.step_scores.items(): self.step_scores[score_name] = remap_from_filtered( - original_shape=(len(batch.target_tokens), 1), + original_shape=(batch_size, 1), mask=target_attention_mask, filtered=score_tensor.unsqueeze(-1), ).squeeze(-1) if self.sequence_scores is not None: for score_name, score_tensor in self.sequence_scores.items(): if score_name.startswith("decoder"): - original_shape = ( - len(batch.target_tokens), - self.target_attributions.shape[1], - *self.target_attributions.shape[1:], - ) + original_shape = (batch_size, target_len, target_len, *other_dims) elif score_name.startswith("encoder"): - original_shape = ( - len(batch.sources.input_tokens), - self.source_attributions.shape[1], - *self.source_attributions.shape[1:], - ) + original_shape = (batch_size, source_len, source_len, *other_dims) else: # default case: cross-attention - original_shape = ( - len(batch.sources.input_tokens), - self.target_attributions.shape[1], - *self.source_attributions.shape[1:], - ) + original_shape = (batch_size, source_len, target_len, *other_dims) self.sequence_scores[score_name] = remap_from_filtered( original_shape=original_shape, mask=target_attention_mask, diff --git a/inseq/data/data_utils.py b/inseq/data/data_utils.py index b907d627..d0f90203 100644 --- a/inseq/data/data_utils.py +++ b/inseq/data/data_utils.py @@ -112,7 +112,7 @@ def _torch(attr): def _eq(self_attr: TensorClass, other_attr: TensorClass) -> bool: try: if isinstance(self_attr, torch.Tensor): - return torch.allclose(self_attr, other_attr, equal_nan=True) + return torch.allclose(self_attr, other_attr, equal_nan=True, atol=1e-5) elif isinstance(self_attr, dict): return all(TensorWrapper._eq(self_attr[k], other_attr[k]) for k in self_attr.keys()) else: @@ -175,6 +175,10 @@ def clone(self: TensorClass) -> TensorClass: out_params[field.name] = None return self.__class__(**out_params) + def clone_empty(self: TensorClass) -> TensorClass: + out_params = {k: v for k, v in self.__dict__.items() if k.startswith("_") and v is not None} + return self.__class__(**out_params) + def to_dict(self: TensorClass) -> dict[str, Any]: return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} diff --git a/inseq/models/attribution_model.py b/inseq/models/attribution_model.py index c74e82c8..b96db5c2 100644 --- a/inseq/models/attribution_model.py +++ b/inseq/models/attribution_model.py @@ -386,6 +386,24 @@ def attribute( original_device = self.device if device is not None: self.device = device + attribution_method = self.get_attribution_method(method, override_default_attribution) + attributed_fn = self.get_attributed_fn(attributed_fn) + attribution_args, attributed_fn_args, step_scores_args = extract_args( + attribution_method, + attributed_fn, + step_scores, + default_args=self.formatter.get_step_function_reserved_args(), + **kwargs, + ) + if isnotebook(): + logger.debug("Pretty progress currently not supported in notebooks, falling back to tqdm.") + pretty_progress = False + if attribution_method.is_final_step_method: + if step_scores: + raise ValueError( + "Step scores are not supported for final step methods since they do not iterate over the full" + " sequence. Please remove the step scores and compute them separatly passing method='dummy'." + ) input_texts, generated_texts = format_input_texts(input_texts, generated_texts) has_generated_texts = generated_texts is not None if not self.is_encoder_decoder: @@ -411,36 +429,30 @@ def attribute( f"Generation arguments {generation_args} are provided, but will be ignored (constrained decoding)." ) logger.debug(f"reference_texts={generated_texts}") - attribution_method = self.get_attribution_method(method, override_default_attribution) - attributed_fn = self.get_attributed_fn(attributed_fn) - attribution_args, attributed_fn_args, step_scores_args = extract_args( - attribution_method, - attributed_fn, - step_scores, - default_args=self.formatter.get_step_function_reserved_args(), - **kwargs, - ) - if isnotebook(): - logger.debug("Pretty progress currently not supported in notebooks, falling back to tqdm.") - pretty_progress = False if not self.is_encoder_decoder: assert all( generated_texts[idx].startswith(input_texts[idx]) for idx in range(len(input_texts)) ), "Forced generations with decoder-only models must start with the input texts." if has_generated_texts and len(input_texts) > 1: - logger.info( + logger.warning( "Batched constrained decoding is currently not supported for decoder-only models." " Using batch size of 1." ) batch_size = 1 if len(input_texts) > 1 and (attr_pos_start is not None or attr_pos_end is not None): - logger.info( + logger.warning( "Custom attribution positions are currently not supported when batching generations for" " decoder-only models. Using batch size of 1." ) batch_size = 1 + elif attribution_method.is_final_step_method and len(input_texts) > 1: + logger.warning( + "Batched attribution with encoder-decoder models currently not supported for final-step methods." + " Using batch size of 1." + ) + batch_size = 1 if attribution_method.method_name == "lime": - logger.info("Batched attribution currently not supported for LIME. Using batch size of 1.") + logger.warning("Batched attribution currently not supported for LIME. Using batch size of 1.") batch_size = 1 attribution_outputs = attribution_method.prepare_and_attribute( input_texts, diff --git a/inseq/models/model_config.py b/inseq/models/model_config.py index 05b8a468..52d9d47b 100644 --- a/inseq/models/model_config.py +++ b/inseq/models/model_config.py @@ -1,6 +1,7 @@ import logging from dataclasses import dataclass from pathlib import Path +from typing import Optional import yaml @@ -10,14 +11,25 @@ @dataclass class ModelConfig: """Configuration used by the methods for which the attribute ``use_model_config=True``. + Args: - attention_module (:obj:`str`): - The name of the module performing the attention computation (e.g.``attn`` for the GPT-2 model in - transformers). Can be identified by looking at the name of the attribute instantiating the attention module + self_attention_module (:obj:`str`): + The name of the module performing the self-attention computation (e.g.``attn`` for the GPT-2 model in + transformers). Can be identified by looking at the name of the self-attention module attribute in the model's transformer block class (e.g. :obj:`transformers.models.gpt2.GPT2Block` for GPT-2). + cross_attention_module (:obj:`str`): + The name of the module performing the cross-attention computation (e.g.``encoder_attn`` for MarianMT models + in transformers). Can be identified by looking at the name of the cross-attention module attribute + in the model's transformer block class (e.g. :obj:`transformers.models.marian.MarianDecoderLayer`). + value_vector (:obj:`str`): + The name of the variable in the forward pass of the attention module containing the value vector + (e.g. ``value`` for the GPT-2 model in transformers). Can be identified by looking at the forward pass of + the attention module (e.g. :obj:`transformers.models.gpt2.modeling_gpt2.GPT2Attention.forward` for GPT-2). """ - attention_module: str + self_attention_module: str + value_vector: str + cross_attention_module: Optional[str] = None MODEL_CONFIGS = { diff --git a/inseq/models/model_config.yaml b/inseq/models/model_config.yaml index b48ed209..bcc32e41 100644 --- a/inseq/models/model_config.yaml +++ b/inseq/models/model_config.yaml @@ -1,2 +1,111 @@ +# Decoder-only models +BioGptForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +BloomForCausalLM: + self_attention_module: "self_attention" + value_vector: "value_layer" +CodeGenForCausalLM: + self_attention_module: "attn" + value_vector: "value" +FalconForCausalLM: + self_attention_module: "self_attention" + value_vector: "value_layer" +GemmaForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +GPTBigCodeForCausalLM: + self_attention_module: "attn" + value_vector: "value" +GPTJForCausalLM: + self_attention_module: "attn" + value_vector: "value" GPT2LMHeadModel: - attention_module: "attn" \ No newline at end of file + self_attention_module: "attn" + value_vector: "value" +GPTNeoForCausalLM: + self_attention_module: "attn" + value_vector: "value" +GPTNeoXForCausalLM: + self_attention_module: "attention" + value_vector: "value" +LlamaForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +MistralForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +MixtralForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +MptForCausalLM: + self_attention_module: "attn" + value_vector: "value_states" +OpenAIGPTLMHeadModel: + self_attention_module: "attn" + value_vector: "value" +OPTForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +PhiForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +Qwen2ForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +StableLmForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" +XGLMForCausalLM: + self_attention_module: "self_attn" + value_vector: "value_states" + +# Encoder-decoder models +BartForConditionalGeneration: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "value_states" +MarianMTModel: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "value_states" +FSMTForConditionalGeneration: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "v" +M2M100ForConditionalGeneration: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "value_states" +MBartForConditionalGeneration: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "value_states" +MT5ForConditionalGeneration: + self_attention_module: "SelfAttention" + cross_attention_module: "EncDecAttention" + value_vector: "value_states" +NllbMoeForConditionalGeneration: + self_attention_module: "self_attn" + cross_attention_module: "cross_attention" + value_vector: "value_states" +PegasusForConditionalGeneration: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "value_states" +SeamlessM4TForTextToText: + self_attention_module: "self_attn" + cross_attention_module: "cross_attention" + value_vector: "value" +SeamlessM4Tv2ForTextToText: + self_attention_module: "self_attn" + cross_attention_module: "cross_attention" + value_vector: "value" +T5ForConditionalGeneration: + self_attention_module: "SelfAttention" + cross_attention_module: "EncDecAttention" + value_vector: "value_states" +UMT5ForConditionalGeneration: + self_attention_module: "SelfAttention" + cross_attention_module: "EncDecAttention" + value_vector: "value_states" \ No newline at end of file diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index 29f81615..69d9d1ad 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -8,6 +8,7 @@ MissingAttributionMethodError, UnknownAttributionMethodError, ) +from .hooks import StackFrame, get_post_variable_assignment_hook from .import_utils import ( is_captum_available, is_datasets_available, @@ -49,12 +50,16 @@ check_device, euclidean_distance, filter_logits, + find_block_stack, get_default_device, get_front_padding, get_sequences_from_batched_steps, normalize, + pad_with_nan, + recursive_get_submodule, remap_from_filtered, top_p_logits_mask, + validate_indices, ) __all__ = [ @@ -118,4 +123,9 @@ "top_p_logits_mask", "filter_logits", "cli_arg", + "get_post_variable_assignment_hook", + "StackFrame", + "validate_indices", + "pad_with_nan", + "recursive_get_submodule", ] diff --git a/inseq/utils/hooks.py b/inseq/utils/hooks.py new file mode 100644 index 00000000..02472f4e --- /dev/null +++ b/inseq/utils/hooks.py @@ -0,0 +1,110 @@ +import re +from inspect import getsourcelines +from sys import gettrace, settrace +from typing import Callable, Optional, TypeVar + +from torch import nn + +from .misc import get_left_padding + +StackFrame = TypeVar("StackFrame") + + +def get_last_variable_assignment_position( + module: nn.Module, + varname: str, + fname: str = "forward", +) -> Optional[int]: + """Extract the code line number of the last variable assignment for a variable of interest in the specified method + of a `nn.Module` object. + + Args: + module (`nn.Module`): + A PyTorch module containing a method with a variable assignment after which the hook should be executed. + varname (`str`): + The name of the variable to use as anchor for the hook. + fname (`str`, *optional*, defaults to "forward"): + The name of the method in which the variable assignment should be searched. + + Returns: + `Optional[int]`: Returns the line number in the file (not relative to the method) of the last variable + assignment. Returns None if no assignment to the variable was found. + """ + # Matches any assignment of variable varname + pattern = rf"^\s*(?:\w+\s*,\s*)*\b{varname}\b\s*(?:,.+\s*)*=\s*[^\W=]+.*$" + code, startline = getsourcelines(getattr(module, fname)) + line_numbers = [] + i = 0 + while i < len(code): + line = code[i] + # Handles multi-line assignments + if re.match(pattern, line): + parentheses_count = line.count("(") - line.count(")") + ends_with_newline = lambda l: l.strip().endswith("\\") + follow_indent = lambda l, i: len(code) > i + 1 and get_left_padding(code[i + 1]) > get_left_padding(l) + while (ends_with_newline(line) or follow_indent(line, i) or parentheses_count > 0) and len(code) > i + 1: + i += 1 + line = code[i] + parentheses_count += line.count("(") - line.count(")") + line_numbers.append(i) + i += 1 + if len(line_numbers) == 0: + return None + return line_numbers[-1] + startline + 1 + + +def get_post_variable_assignment_hook( + module: nn.Module, + varname: str, + fname: str = "forward", + hook_fn: Callable[[StackFrame], None] = lambda **kwargs: None, + **kwargs, +) -> Callable[[], None]: + """Creates a hook that is called after the last variable assignment in the specified method of a `nn.Module`. + + This is a hacky method using the ``sys.settrace()`` function to circumvent the limited hook points of Pytorch hooks + and set a custom hook point dynamically. This approach is preferred to ensure a broader compatibility with Hugging + Face transformers models that do not provide hook points in their architectures for the moment. + + Args: + module (`nn.Module`): + A PyTorch module containing a method with a variable assignment after which the hook should be executed. + varname (`str`): + The name of the variable to use as anchor for the hook. + fname (`str`, *optional*, defaults to "forward"): + The name of the method in which the variable assignment should be searched. + hook_fn (`Callable[[FrameType], None]`, *optional*, defaults to lambdaframe): + A custom hook function that is called after the last variable assignment in the specified method. The first + parameter is the current frame in the execution at the hook point, and any additional arguments can be + passed when creating the hook. ``frame.f_locals`` is a dictionary containing all local variables. + + Returns: + The hook function that can be registered with the module. If hooking the module's ``forward()`` method, the + hook can be registered with Pytorch native hook methods. + """ + hook_line_num = get_last_variable_assignment_position(module, varname, fname) + curr_trace_fn = gettrace() + if hook_line_num is None: + raise ValueError(f"Could not find assignment to {varname} in {module}'s {fname}() method") + + def var_tracer(frame, event, arg=None): + curr_line_num = frame.f_lineno + curr_func_name = frame.f_code.co_name + + # Matches the first executable line after hook_line_num in the same function of the same module + if ( + event == "line" + and curr_line_num >= hook_line_num + and curr_func_name == fname + and isinstance(frame.f_locals.get("self"), nn.Module) + and frame.f_locals.get("self")._get_name() == module._get_name() + ): + # Call the custom hook providing the current frame and any additional arguments as context + hook_fn(frame, **kwargs) + settrace(curr_trace_fn) + return var_tracer + + def hook(*args, **kwargs): + settrace(var_tracer) + + return hook diff --git a/inseq/utils/misc.py b/inseq/utils/misc.py index e09e5df7..628995bc 100644 --- a/inseq/utils/misc.py +++ b/inseq/utils/misc.py @@ -10,7 +10,6 @@ from functools import wraps from importlib import import_module from inspect import signature -from itertools import dropwhile from numbers import Number from os import PathLike, fsync from typing import Any, Callable, Optional, Union @@ -171,10 +170,10 @@ def pad(seq: Sequence[Sequence[Any]], pad_id: Any): return seq -def drop_padding(seq: Sequence[Any], pad_id: Any): +def drop_padding(seq: Sequence[TokenWithId], pad_id: str): if pad_id is None: return seq - return list(reversed(list(dropwhile(lambda x: x == pad_id, reversed(seq))))) + return [x for x in seq if x.token != pad_id] def isnotebook(): @@ -435,3 +434,8 @@ def clean_tokens(tokens: list[str], remove_tokens: list[str]) -> tuple[list[str] else: removed_token_idxs += [idx] return clean_tokens, removed_token_idxs + + +def get_left_padding(text: str): + """Returns the number of spaces at the beginning of a string.""" + return len(text) - len(text.lstrip()) diff --git a/inseq/utils/torch_utils.py b/inseq/utils/torch_utils.py index 88e807cc..86acd635 100644 --- a/inseq/utils/torch_utils.py +++ b/inseq/utils/torch_utils.py @@ -5,11 +5,14 @@ import torch import torch.nn.functional as F from jaxtyping import Int, Num +from torch import nn from torch.backends.cuda import is_built as is_cuda_built from torch.backends.mps import is_available as is_mps_available from torch.backends.mps import is_built as is_mps_built from torch.cuda import is_available as is_cuda_available +from .typing import OneOrMoreIndices + if TYPE_CHECKING: pass @@ -244,3 +247,118 @@ def get_default_device() -> str: return "cpu" else: return "cpu" + + +def find_block_stack(module): + """Recursively searches for the first instance of a `nn.ModuleList` submodule within a given `torch.nn.Module`. + + Args: + module (:obj:`torch.nn.Module`): A Pytorch :obj:`nn.Module` object. + + Returns: + :obj:`torch.nn.ModuleList`: The first instance of a :obj:`nn.Module` submodule found within the given object. + None: If no `nn.ModuleList` submodule is found within the given `nn.Module` object. + """ + # Check if the current module is an instance of nn.ModuleList + if isinstance(module, nn.ModuleList): + return module + + # Recursively search for nn.ModuleList in the submodules of the current module + for submodule in module.children(): + module_list = find_block_stack(submodule) + if module_list is not None: + return module_list + + # If nn.ModuleList is not found in any submodules, return None + return None + + +def validate_indices( + scores: torch.Tensor, + dim: int = -1, + indices: Optional[OneOrMoreIndices] = None, +) -> OneOrMoreIndices: + """Validates a set of indices for a given dimension of a tensor of scores. Supports single indices, spans and lists + of indices, including negative indices to specify positions relative to the end of the tensor. + + Args: + scores (torch.Tensor): The tensor of scores. + dim (int, optional): The dimension of the tensor that will be indexed. Defaults to -1. + indices (Union[int, tuple[int, int], list[int], None], optional): + - If an integer, it is interpreted as a single index for the dimension. + - If a tuple of two integers, it is interpreted as a span of indices for the dimension. + - If a list of integers, it is interpreted as a list of individual indices for the dimension. + + Returns: + ``Union[int, tuple[int, int], list[int]]``: The validated list of positive indices for indexing the dimension. + """ + if dim >= scores.ndim: + raise IndexError(f"Dimension {dim} is greater than tensor dimension {scores.ndim}") + n_units = scores.shape[dim] + if not isinstance(indices, (int, tuple, list)) and indices is not None: + raise TypeError( + "Indices must be an integer, a (start, end) tuple of indices representing a span, a list of individual" + " indices or a single index." + ) + if hasattr(indices, "__iter__"): + if len(indices) == 0: + raise RuntimeError("An empty sequence of indices is not allowed.") + if len(indices) == 1: + indices = indices[0] + + if isinstance(indices, int): + if indices not in range(-n_units, n_units): + raise IndexError(f"Index out of range. Scores only have {n_units} units.") + indices = indices if indices >= 0 else n_units + indices + return torch.tensor(indices) + else: + if indices is None: + indices = (0, n_units) + logger.info("No indices specified. Using all indices by default.") + + # Convert negative indices to positive indices + if hasattr(indices, "__iter__"): + indices = type(indices)([h_idx if h_idx >= 0 else n_units + h_idx for h_idx in indices]) + if not hasattr(indices, "__iter__") or ( + len(indices) == 2 and isinstance(indices, tuple) and indices[0] >= indices[1] + ): + raise RuntimeError( + "A (start, end) tuple of indices representing a span, a list of individual indices" + " or a single index must be specified." + ) + max_idx_val = n_units if isinstance(indices, list) else n_units + 1 + if not all(h in range(-n_units, max_idx_val) for h in indices): + raise IndexError(f"One or more index out of range. Scores only have {n_units} units.") + if len(set(indices)) != len(indices): + raise IndexError("Duplicate indices are not allowed.") + if isinstance(indices, tuple): + return torch.arange(indices[0], indices[1]) + else: + return torch.tensor(indices) + + +def pad_with_nan(t: torch.Tensor, dim: int, pad_size: int, front: bool = False) -> torch.Tensor: + """Utility to pad a tensor with nan values along a given dimension.""" + nan_tensor = torch.ones( + *t.shape[:dim], + pad_size, + *t.shape[dim + 1 :], + device=t.device, + ) * float("nan") + if front: + return torch.cat([nan_tensor, t], dim=dim) + return torch.cat([t, nan_tensor], dim=dim) + + +def recursive_get_submodule(parent: nn.Module, target: str) -> Optional[nn.Module]: + if target == "": + return parent + mod = None + if hasattr(parent, target): + mod = getattr(parent, target) + else: + for submodule in parent.children(): + mod = recursive_get_submodule(submodule, target) + if mod is not None: + break + return mod diff --git a/inseq/utils/typing.py b/inseq/utils/typing.py index 7599bbc7..4eec4a5b 100644 --- a/inseq/utils/typing.py +++ b/inseq/utils/typing.py @@ -1,13 +1,17 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import torch +from captum.attr._utils.attribution import Attribution from jaxtyping import Float, Float32, Int64 from transformers import PreTrainedModel TextInput = Union[str, Sequence[str]] +if TYPE_CHECKING: + from inseq.models import AttributionModel + @dataclass class TokenWithId: @@ -28,6 +32,34 @@ def __eq__(self, other: Union[str, int, "TokenWithId"]): return False +class InseqAttribution(Attribution): + """A wrapper class for the Captum library's Attribution class to type hint the ``forward_func`` attribute + as an :class:`~inseq.models.AttributionModel`. + """ + + def __init__(self, forward_func: "AttributionModel") -> None: + r""" + Args: + forward_func (:class:`~inseq.models.AttributionModel`): The model hooker to the attribution method. + """ + self.forward_func = forward_func + + attribute: Callable + + @property + def multiplies_by_inputs(self): + return False + + def has_convergence_delta(self) -> bool: + return False + + compute_convergence_delta: Callable + + @classmethod + def get_name(cls: type["InseqAttribution"]) -> str: + return "".join([char if char.islower() or idx == 0 else " " + char for idx, char in enumerate(cls.__name__)]) + + @dataclass class TextSequences: targets: TextInput @@ -40,6 +72,8 @@ class TextSequences: OneOrMoreAttributionSequences = Sequence[Sequence[float]] IndexSpan = Union[tuple[int, int], Sequence[tuple[int, int]]] +OneOrMoreIndices = Union[int, list[int], tuple[int, int]] +OneOrMoreIndicesDict = dict[int, OneOrMoreIndices] IdsTensor = Int64[torch.Tensor, "batch_size seq_len"] TargetIdsTensor = Int64[torch.Tensor, "batch_size"] diff --git a/requirements-dev.txt b/requirements-dev.txt index 91a4d3f2..92a9ca95 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -361,7 +361,7 @@ traitlets==5.14.1 # jupyter-client # jupyter-core # matplotlib-inline -transformers==4.37.2 +transformers==4.38.1 typeguard==2.13.3 # via jaxtyping typer==0.9.0 diff --git a/requirements.txt b/requirements.txt index 9f392d72..a0a99e61 100644 --- a/requirements.txt +++ b/requirements.txt @@ -93,7 +93,7 @@ tqdm==4.66.2 # captum # huggingface-hub # transformers -transformers==4.37.2 +transformers==4.38.1 typeguard==2.13.3 # via jaxtyping typing-extensions==4.9.0 diff --git a/tests/attr/feat/test_feature_attribution.py b/tests/attr/feat/test_feature_attribution.py index 07b2045c..80856176 100644 --- a/tests/attr/feat/test_feature_attribution.py +++ b/tests/attr/feat/test_feature_attribution.py @@ -1,8 +1,14 @@ +from typing import Any, Optional + import torch +from captum._utils.typing import TensorOrTupleOfTensorsGeneric from pytest import fixture import inseq +from inseq.attr.feat.internals_attribution import InternalsAttributionRegistry +from inseq.data import MultiDimensionalFeatureAttributionStepOutput from inseq.models import HuggingfaceDecoderOnlyModel, HuggingfaceEncoderDecoderModel +from inseq.utils.typing import InseqAttribution, MultiLayerMultiUnitScoreTensor @fixture(scope="session") @@ -69,7 +75,7 @@ def test_contrastive_attribution_seq2seq_alignments(saliency_mt_model_larger: Hu "orig_tgt": "I soldati della pace ONU", "contrast_tgt": "Le forze militari di pace delle Nazioni Unite", "alignments": [[(0, 0), (1, 1), (2, 2), (3, 4), (4, 5), (5, 7), (6, 9)]], - "aligned_tgts": ["ā–Le ā†’ ā–I", "ā–forze ā†’ ā–soldati", "ā–di ā†’ ā–della", "ā–pace", "ā–Nazioni ā†’ ā–ONU", ""], + "aligned_tgts": ["", "ā–Le ā†’ ā–I", "ā–forze ā†’ ā–soldati", "ā–di ā†’ ā–della", "ā–pace", "ā–Nazioni ā†’ ā–ONU", ""], } out = saliency_mt_model_larger.attribute( aligned["src"], @@ -129,3 +135,122 @@ def test_mcd_weighted_attribution_gpt(saliency_gpt_model): ) attribution_scores = out.sequence_attributions[0].target_attributions assert isinstance(attribution_scores, torch.Tensor) + + +class MultiStepAttentionWeights(InseqAttribution): + """Variant of the AttentionWeights class with is_final_step_method = False. + As a result, the attention matrix is computed and sliced at every generation step. + We define it here to test consistency with the final step method. + """ + + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + additional_forward_args: TensorOrTupleOfTensorsGeneric, + encoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None, + decoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None, + cross_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None, + ) -> MultiDimensionalFeatureAttributionStepOutput: + # We adopt the format [batch_size, sequence_length, num_layers, num_heads] + # for consistency with other multi-unit methods (e.g. gradient attribution) + decoder_self_attentions = decoder_self_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2) + if self.forward_func.is_encoder_decoder: + sequence_scores = {} + if len(inputs) > 1: + target_attributions = decoder_self_attentions + else: + target_attributions = None + sequence_scores["decoder_self_attentions"] = decoder_self_attentions + sequence_scores["encoder_self_attentions"] = ( + encoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2) + ) + return MultiDimensionalFeatureAttributionStepOutput( + source_attributions=cross_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2), + target_attributions=target_attributions, + sequence_scores=sequence_scores, + _num_dimensions=2, # num_layers, num_heads + ) + else: + return MultiDimensionalFeatureAttributionStepOutput( + source_attributions=None, + target_attributions=decoder_self_attentions, + _num_dimensions=2, # num_layers, num_heads + ) + + +class MultiStepAttentionWeightsAttribution(InternalsAttributionRegistry): + """Variant of the basic attention attribution method computing attention weights at every generation step.""" + + method_name = "per_step_attention" + + def __init__(self, attribution_model, **kwargs): + super().__init__(attribution_model) + # Attention weights will be passed to the attribute_step method + self.use_attention_weights = True + # Does not rely on predicted output (i.e. decoding strategy agnostic) + self.use_predicted_target = False + self.method = MultiStepAttentionWeights(attribution_model) + + def attribute_step( + self, + attribute_fn_main_args: dict[str, Any], + attribution_args: dict[str, Any], + ) -> MultiDimensionalFeatureAttributionStepOutput: + return self.method.attribute(**attribute_fn_main_args, **attribution_args) + + +def test_seq2seq_final_step_per_step_conformity(saliency_mt_model_larger: HuggingfaceEncoderDecoderModel): + out_per_step = saliency_mt_model_larger.attribute( + "Hello ladies and badgers!", + method="per_step_attention", + attribute_target=True, + show_progress=False, + output_step_attributions=True, + ) + out_final_step = saliency_mt_model_larger.attribute( + "Hello ladies and badgers!", + method="attention", + attribute_target=True, + show_progress=False, + output_step_attributions=True, + ) + assert out_per_step[0] == out_final_step[0] + + +def test_gpt_final_step_per_step_conformity(saliency_gpt_model_larger: HuggingfaceDecoderOnlyModel): + out_per_step = saliency_gpt_model_larger.attribute( + "Hello ladies and badgers!", + method="per_step_attention", + show_progress=False, + output_step_attributions=True, + ) + out_final_step = saliency_gpt_model_larger.attribute( + "Hello ladies and badgers!", + method="attention", + show_progress=False, + output_step_attributions=True, + ) + assert out_per_step[0] == out_final_step[0] + + +# Batching for Seq2Seq models is not supported when using is_final_step methods +# Passing several sentences will attributed them one by one under the hood +# def test_seq2seq_multi_step_attention_weights_batched_full_match(saliency_mt_model: HuggingfaceEncoderDecoderModel): + + +def test_gpt_multi_step_attention_weights_batched_full_match(saliency_gpt_model_larger: HuggingfaceDecoderOnlyModel): + out_per_step = saliency_gpt_model_larger.attribute( + ["Hello world!", "Colorless green ideas sleep furiously."], + method="per_step_attention", + show_progress=False, + ) + out_final_step = saliency_gpt_model_larger.attribute( + ["Hello world!", "Colorless green ideas sleep furiously."], + method="attention", + show_progress=False, + ) + for i in range(2): + assert out_per_step[i].target_attributions.shape == out_final_step[i].target_attributions.shape + assert torch.allclose( + out_per_step[i].target_attributions, out_final_step[i].target_attributions, equal_nan=True, atol=1e-5 + ) diff --git a/tests/data/test_aggregator.py b/tests/data/test_aggregator.py index eb5086ca..f7e7c3e5 100644 --- a/tests/data/test_aggregator.py +++ b/tests/data/test_aggregator.py @@ -39,14 +39,14 @@ def test_sequence_attribution_aggregator(saliency_mt_model: HuggingfaceEncoderDe ) seqattr = out.sequence_attributions[0] assert seqattr.source_attributions.shape == (6, 7, 512) - assert seqattr.target_attributions.shape == (7, 7, 512) + assert seqattr.target_attributions.shape == (8, 7, 512) assert seqattr.step_scores["probability"].shape == (7,) for i, step in enumerate(out.step_attributions): assert step.source_attributions.shape == (1, 6, 512) assert step.target_attributions.shape == (1, i + 1, 512) out_agg = seqattr.aggregate() assert out_agg.source_attributions.shape == (6, 7) - assert out_agg.target_attributions.shape == (7, 7) + assert out_agg.target_attributions.shape == (8, 7) assert out_agg.step_scores["probability"].shape == (7,) @@ -56,9 +56,9 @@ def test_continuous_span_aggregator(saliency_mt_model: HuggingfaceEncoderDecoder ) seqattr = out.sequence_attributions[0] out_agg = seqattr.aggregate(ContiguousSpanAggregator, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) - assert out_agg.source_attributions.shape == (5, 4, 512) - assert out_agg.target_attributions.shape == (4, 4, 512) - assert out_agg.step_scores["probability"].shape == (4,) + assert out_agg.source_attributions.shape == (5, 5, 512) + assert out_agg.target_attributions.shape == (5, 5, 512) + assert out_agg.step_scores["probability"].shape == (5,) def test_span_aggregator_with_prefix(saliency_gpt_model: HuggingfaceDecoderOnlyModel): @@ -76,14 +76,14 @@ def test_aggregator_pipeline(saliency_mt_model: HuggingfaceEncoderDecoderModel): seqattr = out.sequence_attributions[0] squeezesum = AggregatorPipeline([ContiguousSpanAggregator, SequenceAttributionAggregator]) out_agg_squeezesum = seqattr.aggregate(squeezesum, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) - assert out_agg_squeezesum.source_attributions.shape == (5, 4) - assert out_agg_squeezesum.target_attributions.shape == (4, 4) - assert out_agg_squeezesum.step_scores["probability"].shape == (4,) + assert out_agg_squeezesum.source_attributions.shape == (5, 5) + assert out_agg_squeezesum.target_attributions.shape == (5, 5) + assert out_agg_squeezesum.step_scores["probability"].shape == (5,) sumsqueeze = AggregatorPipeline([SequenceAttributionAggregator, ContiguousSpanAggregator]) out_agg_sumsqueeze = seqattr.aggregate(sumsqueeze, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) - assert out_agg_sumsqueeze.source_attributions.shape == (5, 4) - assert out_agg_sumsqueeze.target_attributions.shape == (4, 4) - assert out_agg_sumsqueeze.step_scores["probability"].shape == (4,) + assert out_agg_sumsqueeze.source_attributions.shape == (5, 5) + assert out_agg_sumsqueeze.target_attributions.shape == (5, 5) + assert out_agg_sumsqueeze.step_scores["probability"].shape == (5,) assert not torch.allclose(out_agg_squeezesum.source_attributions, out_agg_sumsqueeze.source_attributions) assert not torch.allclose(out_agg_squeezesum.target_attributions, out_agg_sumsqueeze.target_attributions) # Named indexing version @@ -91,12 +91,12 @@ def test_aggregator_pipeline(saliency_mt_model: HuggingfaceEncoderDecoderModel): named_sumsqueeze = ["scores", "spans"] out_agg_squeezesum_named = seqattr.aggregate(named_squeezesum, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) out_agg_sumsqueeze_named = seqattr.aggregate(named_sumsqueeze, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) - assert out_agg_squeezesum_named.source_attributions.shape == (5, 4) - assert out_agg_squeezesum_named.target_attributions.shape == (4, 4) - assert out_agg_squeezesum_named.step_scores["probability"].shape == (4,) - assert out_agg_sumsqueeze_named.source_attributions.shape == (5, 4) - assert out_agg_sumsqueeze_named.target_attributions.shape == (4, 4) - assert out_agg_sumsqueeze_named.step_scores["probability"].shape == (4,) + assert out_agg_squeezesum_named.source_attributions.shape == (5, 5) + assert out_agg_squeezesum_named.target_attributions.shape == (5, 5) + assert out_agg_squeezesum_named.step_scores["probability"].shape == (5,) + assert out_agg_sumsqueeze_named.source_attributions.shape == (5, 5) + assert out_agg_sumsqueeze_named.target_attributions.shape == (5, 5) + assert out_agg_sumsqueeze_named.step_scores["probability"].shape == (5,) assert not torch.allclose( out_agg_squeezesum_named.source_attributions, out_agg_sumsqueeze_named.source_attributions ) diff --git a/tests/fixtures/aggregator.json b/tests/fixtures/aggregator.json index fc029eec..53123526 100644 --- a/tests/fixtures/aggregator.json +++ b/tests/fixtures/aggregator.json @@ -36,6 +36,7 @@ ], "target": "Inseq \u00e8 un framework per l'attribuzione automatica di modelli sequenziali.", "target_subwords": [ + "", "\u2581In", "se", "q", @@ -58,6 +59,7 @@ "" ], "target_merged": [ + "", "\u2581Inseq", "\u2581\u00e8", "\u2581un", diff --git a/tests/inference_commons.py b/tests/inference_commons.py index 3da21068..19810018 100644 --- a/tests/inference_commons.py +++ b/tests/inference_commons.py @@ -1,3 +1,6 @@ +import json +import os + from inseq.data import EncoderDecoderBatch from inseq.utils import json_advanced_load @@ -9,3 +12,8 @@ def get_example_batches(): dict_batches["batches"] = [batch.torch() for batch in dict_batches["batches"]] assert all(isinstance(batch, EncoderDecoderBatch) for batch in dict_batches["batches"]) return dict_batches + + +def load_examples() -> dict: + file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/huggingface_model.json") + return json.load(open(file)) diff --git a/tests/models/test_huggingface_model.py b/tests/models/test_huggingface_model.py index 993c07ac..72da4a2f 100644 --- a/tests/models/test_huggingface_model.py +++ b/tests/models/test_huggingface_model.py @@ -2,8 +2,6 @@ since it is bugged is not very elegant, this will need to be refactored. """ -import json -import os import pytest import torch @@ -15,8 +13,9 @@ from inseq.data import FeatureAttributionOutput, FeatureAttributionSequenceOutput from inseq.utils import get_default_device -EXAMPLES_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/huggingface_model.json") -EXAMPLES = json.load(open(EXAMPLES_FILE)) +from ..inference_commons import load_examples + +EXAMPLES = load_examples() USE_REFERENCE_TEXT = [True, False] ATTRIBUTE_TARGET = [True, False] @@ -275,8 +274,8 @@ def test_attribute_slice_seq2seq(saliency_mt_model): assert ex2.attr_pos_start == len(ex2.target) assert ex2.attr_pos_end == len(ex2.target) assert ex2.source_attributions.shape[1] == 0 and ex2.target_attributions.shape[1] == 0 - assert ex3.attr_pos_start == 12 - assert ex3.attr_pos_end == 15 + assert ex3.attr_pos_start == 13 + assert ex3.attr_pos_end == 16 assert ex1.source_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start assert ex1.target_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start assert ex1.target_attributions.shape[0] == ex1.attr_pos_end @@ -303,12 +302,12 @@ def test_attribute_decoder(saliency_gpt2_model): assert ex1.target_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start assert ex1.target_attributions.shape[0] == ex1.attr_pos_end # Empty attributions outputs have start and end set to seq length - assert ex2.attr_pos_start == 17 - assert ex2.attr_pos_end == 22 + assert ex2.attr_pos_start == 9 + assert ex2.attr_pos_end == 14 assert ex2.target_attributions.shape[1] == ex2.attr_pos_end - ex2.attr_pos_start assert ex2.target_attributions.shape[0] == ex2.attr_pos_end - assert ex3.attr_pos_start == 17 - assert ex3.attr_pos_end == 22 + assert ex3.attr_pos_start == 12 + assert ex3.attr_pos_end == 17 assert ex3.target_attributions.shape[1] == ex3.attr_pos_end - ex3.attr_pos_start assert ex3.target_attributions.shape[0] == ex3.attr_pos_end assert out.info["attr_pos_start"] == 17