From de4b304bfaf0f83091b7febf799b8808ced471e5 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Tue, 27 Feb 2024 20:41:29 +0100 Subject: [PATCH] Option to spare computations --- inseq/attr/feat/ops/value_zeroing.py | 46 +++++++++++++-------- inseq/attr/feat/perturbation_attribution.py | 6 ++- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/inseq/attr/feat/ops/value_zeroing.py b/inseq/attr/feat/ops/value_zeroing.py index 430667c6..c50ca3f8 100644 --- a/inseq/attr/feat/ops/value_zeroing.py +++ b/inseq/attr/feat/ops/value_zeroing.py @@ -312,6 +312,8 @@ def attribute( cross_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, encoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None, decoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None, + output_decoder_self_scores: bool = True, + output_encoder_self_scores: bool = True, ) -> TensorOrTupleOfTensorsGeneric: """Perform attribution using the Value Zeroing method. @@ -334,6 +336,11 @@ def attribute( encoder-decoders models. Default: None. decoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1, target_seq_len, hidden_size]`` containing hidden states of the decoder. + output_decoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the + decoder self-attention value vectors in encoder-decoder models. Cannot be false for decoder-only, or + if target-side attribution is requested using `attribute_target=True`. Default: True. + output_encoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the + encoder self-attention value vectors in encoder-decoder models. Default: True. Returns: `TensorOrTupleOfTensorsGeneric`: Attribution outputs for source-only or source + target feature attribution @@ -343,29 +350,32 @@ def attribute( f"Similarity metric {similarity_metric} not available." f"Available metrics: {','.join(self.SIMILARITY_METRICS.keys())}" ) - - decoder_scores = self.compute_modules_post_zeroing_similarity( - inputs=inputs, - additional_forward_args=additional_forward_args, - hidden_states=decoder_hidden_states, - attention_module_name=self.forward_func.config.self_attention_module, - similarity_metric=similarity_metric, - mode=ValueZeroingModule.DECODER.value, - zeroed_units_indices=decoder_zeroed_units_indices, - use_causal_mask=True, - ) - # Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values - # Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py - if self.forward_func.is_encoder_decoder: - encoder_scores = self.compute_modules_post_zeroing_similarity( + decoder_scores = None + if not self.forward_func.is_encoder_decoder or output_decoder_self_scores or len(inputs) > 1: + decoder_scores = self.compute_modules_post_zeroing_similarity( inputs=inputs, additional_forward_args=additional_forward_args, - hidden_states=encoder_hidden_states, + hidden_states=decoder_hidden_states, attention_module_name=self.forward_func.config.self_attention_module, similarity_metric=similarity_metric, - mode=ValueZeroingModule.ENCODER.value, - zeroed_units_indices=encoder_zeroed_units_indices, + mode=ValueZeroingModule.DECODER.value, + zeroed_units_indices=decoder_zeroed_units_indices, + use_causal_mask=True, ) + # Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values + # Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py + if self.forward_func.is_encoder_decoder: + encoder_scores = None + if output_encoder_self_scores: + encoder_scores = self.compute_modules_post_zeroing_similarity( + inputs=inputs, + additional_forward_args=additional_forward_args, + hidden_states=encoder_hidden_states, + attention_module_name=self.forward_func.config.self_attention_module, + similarity_metric=similarity_metric, + mode=ValueZeroingModule.ENCODER.value, + zeroed_units_indices=encoder_zeroed_units_indices, + ) cross_scores = self.compute_modules_post_zeroing_similarity( inputs=inputs, additional_forward_args=additional_forward_args, diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py index 40da111c..8c3c7486 100644 --- a/inseq/attr/feat/perturbation_attribution.py +++ b/inseq/attr/feat/perturbation_attribution.py @@ -153,8 +153,10 @@ def attribute_step( target_attributions = decoder_self_scores.to("cpu") else: target_attributions = None - sequence_scores["decoder_self_scores"] = decoder_self_scores.to("cpu") - sequence_scores["encoder_self_scores"] = encoder_self_scores.to("cpu") + if decoder_self_scores is not None: + sequence_scores["decoder_self_scores"] = decoder_self_scores.to("cpu") + if encoder_self_scores is not None: + sequence_scores["encoder_self_scores"] = encoder_self_scores.to("cpu") return MultiDimensionalFeatureAttributionStepOutput( source_attributions=decoder_cross_scores.to("cpu"), target_attributions=target_attributions,