Skip to content

Commit

Permalink
Option to spare computations
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Feb 27, 2024
1 parent def38cf commit de4b304
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
46 changes: 28 additions & 18 deletions inseq/attr/feat/ops/value_zeroing.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def attribute(
cross_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
encoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None,
decoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None,
output_decoder_self_scores: bool = True,
output_encoder_self_scores: bool = True,
) -> TensorOrTupleOfTensorsGeneric:
"""Perform attribution using the Value Zeroing method.
Expand All @@ -334,6 +336,11 @@ def attribute(
encoder-decoders models. Default: None.
decoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1,
target_seq_len, hidden_size]`` containing hidden states of the decoder.
output_decoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the
decoder self-attention value vectors in encoder-decoder models. Cannot be false for decoder-only, or
if target-side attribution is requested using `attribute_target=True`. Default: True.
output_encoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the
encoder self-attention value vectors in encoder-decoder models. Default: True.
Returns:
`TensorOrTupleOfTensorsGeneric`: Attribution outputs for source-only or source + target feature attribution
Expand All @@ -343,29 +350,32 @@ def attribute(
f"Similarity metric {similarity_metric} not available."
f"Available metrics: {','.join(self.SIMILARITY_METRICS.keys())}"
)

decoder_scores = self.compute_modules_post_zeroing_similarity(
inputs=inputs,
additional_forward_args=additional_forward_args,
hidden_states=decoder_hidden_states,
attention_module_name=self.forward_func.config.self_attention_module,
similarity_metric=similarity_metric,
mode=ValueZeroingModule.DECODER.value,
zeroed_units_indices=decoder_zeroed_units_indices,
use_causal_mask=True,
)
# Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values
# Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py
if self.forward_func.is_encoder_decoder:
encoder_scores = self.compute_modules_post_zeroing_similarity(
decoder_scores = None
if not self.forward_func.is_encoder_decoder or output_decoder_self_scores or len(inputs) > 1:
decoder_scores = self.compute_modules_post_zeroing_similarity(
inputs=inputs,
additional_forward_args=additional_forward_args,
hidden_states=encoder_hidden_states,
hidden_states=decoder_hidden_states,
attention_module_name=self.forward_func.config.self_attention_module,
similarity_metric=similarity_metric,
mode=ValueZeroingModule.ENCODER.value,
zeroed_units_indices=encoder_zeroed_units_indices,
mode=ValueZeroingModule.DECODER.value,
zeroed_units_indices=decoder_zeroed_units_indices,
use_causal_mask=True,
)
# Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values
# Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py
if self.forward_func.is_encoder_decoder:
encoder_scores = None
if output_encoder_self_scores:
encoder_scores = self.compute_modules_post_zeroing_similarity(
inputs=inputs,
additional_forward_args=additional_forward_args,
hidden_states=encoder_hidden_states,
attention_module_name=self.forward_func.config.self_attention_module,
similarity_metric=similarity_metric,
mode=ValueZeroingModule.ENCODER.value,
zeroed_units_indices=encoder_zeroed_units_indices,
)
cross_scores = self.compute_modules_post_zeroing_similarity(
inputs=inputs,
additional_forward_args=additional_forward_args,
Expand Down
6 changes: 4 additions & 2 deletions inseq/attr/feat/perturbation_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,10 @@ def attribute_step(
target_attributions = decoder_self_scores.to("cpu")
else:
target_attributions = None
sequence_scores["decoder_self_scores"] = decoder_self_scores.to("cpu")
sequence_scores["encoder_self_scores"] = encoder_self_scores.to("cpu")
if decoder_self_scores is not None:
sequence_scores["decoder_self_scores"] = decoder_self_scores.to("cpu")
if encoder_self_scores is not None:
sequence_scores["encoder_self_scores"] = encoder_self_scores.to("cpu")
return MultiDimensionalFeatureAttributionStepOutput(
source_attributions=decoder_cross_scores.to("cpu"),
target_attributions=target_attributions,
Expand Down

0 comments on commit de4b304

Please sign in to comment.