Skip to content

Commit

Permalink
Attributed behavior for contrastive step functions (#228)
Browse files Browse the repository at this point in the history
* Add contrast utils and is_attributed_fn step function arg

* Fix imports
  • Loading branch information
gsarti authored Oct 30, 2023
1 parent 8183185 commit 6feda95
Show file tree
Hide file tree
Showing 9 changed files with 277 additions and 109 deletions.
1 change: 1 addition & 0 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def filtered_attribute_step(
attribution_model=self.attribution_model,
forward_output=output,
target_ids=target_ids,
is_attributed_fn=False,
batch=batch,
)
step_fn_extra_args = get_step_scores_args([step_score], step_scores_args)
Expand Down
133 changes: 29 additions & 104 deletions inseq/attr/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import torch.nn.functional as F
from transformers.modeling_outputs import ModelOutput

from ..data import DecoderOnlyBatch, FeatureAttributionInput, get_batch_from_inputs, slice_batch_from_position
from ..data import FeatureAttributionInput
from ..data.aggregation_functions import DEFAULT_ATTRIBUTION_AGGREGATE_DICT
from ..utils import extract_signature_args, filter_logits, top_p_logits_mask
from ..utils.contrast_utils import _get_contrast_output, _setup_contrast_args, contrast_fn_docstring
from ..utils.typing import EmbeddingsTensor, IdsTensor, SingleScorePerStepTensor, TargetIdsTensor

if TYPE_CHECKING:
Expand All @@ -27,6 +28,9 @@ class StepFunctionBaseArgs:
forward_output (:class:`~inseq.models.ModelOutput`): The output of the model's forward pass.
target_ids (:obj:`torch.Tensor`): Tensor of target token ids of size :obj:`(batch_size,)` corresponding to
the target predicted tokens for the next generation step.
is_attributed_fn (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether the step function is being used
as attribution target. Defaults to :obj:`False`. Enables custom behavior that is different whether the fn
is used as target or not.
encoder_input_ids (:obj:`torch.Tensor`): Tensor of ids of encoder input tokens of size
:obj:`(batch_size, source_seq_len)`, representing encoder inputs at the present step. Available only for
encoder-decoder models.
Expand All @@ -50,6 +54,7 @@ class StepFunctionBaseArgs:
decoder_input_ids: IdsTensor
decoder_input_embeds: EmbeddingsTensor
decoder_attention_mask: IdsTensor
is_attributed_fn: bool


@dataclass
Expand All @@ -76,36 +81,6 @@ def __call__(
...


CONTRAST_FN_ARGS_DOCSTRING = """Args:
contrast_sources (:obj:`str` or :obj:`list(str)`): Source text(s) used as contrastive inputs to compute
the contrastive step function for encoder-decoder models. If not specified, the source text is assumed to
match the original source text. Defaults to :obj:`None`.
contrast_targets (:obj:`str` or :obj:`list(str)`): Contrastive target text(s) to be compared to the original
target text. If not specified, the original target text is used as contrastive target (will result in same
output unless ``contrast_sources`` are specified). Defaults to :obj:`None`.
contrast_targets_alignments (:obj:`list(tuple(int, int))`, `optional`): A list of tuples of indices, where the
first element is the index of the original target token and the second element is the index of the
contrastive target token, used only if :obj:`contrast_targets` is specified. If an explicit alignment is
not specified, the alignment of the original and contrastive target texts is assumed to be 1:1 for all
available tokens. Defaults to :obj:`None`.
"""


def contrast_fn_docstring():
def docstring_decorator(fn: StepFunction):
"""Returns the docstring for the contrastive step functions."""
if fn.__doc__ is not None:
if "Args:\n" in fn.__doc__:
fn.__doc__ = fn.__doc__.replace("Args:\n", CONTRAST_FN_ARGS_DOCSTRING)
else:
fn.__doc__ = fn.__doc__ + "\n " + CONTRAST_FN_ARGS_DOCSTRING
else:
fn.__doc__ = CONTRAST_FN_ARGS_DOCSTRING
return fn

return docstring_decorator


def logit_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor:
"""Compute the logit of the target_ids from the model's output logits."""
logits = args.attribution_model.output2logits(args.forward_output)
Expand Down Expand Up @@ -149,87 +124,27 @@ def perplexity_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor:
return 2 ** crossentropy_fn(args)


@contrast_fn_docstring()
def _get_contrast_output(
args: StepFunctionArgs,
contrast_sources: Optional[FeatureAttributionInput] = None,
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
return_contrastive_target_ids: bool = False,
**forward_kwargs,
) -> ModelOutput:
"""Utility function to return the output of the model for given contrastive inputs.
Args:
return_contrastive_target_ids (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to return the
contrastive target ids as well as the model output. Defaults to :obj:`False`.
**forward_kwargs: Additional keyword arguments to be passed to the model's forward pass.
"""
c_tgt_ids = None
is_enc_dec = args.attribution_model.is_encoder_decoder
if contrast_targets:
c_batch = DecoderOnlyBatch.from_batch(
get_batch_from_inputs(
attribution_model=args.attribution_model,
inputs=contrast_targets,
as_targets=is_enc_dec,
)
)
curr_prefix_len = args.decoder_input_ids.size(1)
if len(contrast_targets_alignments) > 0 and isinstance(contrast_targets_alignments[0], list):
contrast_targets_alignments = contrast_targets_alignments[0]
c_batch, c_tgt_ids = slice_batch_from_position(c_batch, curr_prefix_len, contrast_targets_alignments)

if args.decoder_input_ids.size(0) != c_batch.target_ids.size(0):
raise ValueError(
f"Contrastive batch size ({c_batch.target_ids.size(0)}) must match candidate batch size"
f" ({args.decoder_input_ids.size(0)}). Multi-sentence attribution and methods expanding inputs to"
" multiple steps (e.g. Integrated Gradients) are not yet supported for contrastive attribution."
)

args.decoder_input_ids = c_batch.target_ids
args.decoder_input_embeds = c_batch.target_embeds
args.decoder_attention_mask = c_batch.target_mask
if contrast_sources:
if not (is_enc_dec and isinstance(args, StepFunctionEncoderDecoderArgs)):
raise ValueError(
"Contrastive source inputs can only be used with encoder-decoder models. "
"Use `contrast_targets` to set a contrastive target containing a prefix for decoder-only models."
)
c_enc_in = args.attribution_model.encode(contrast_sources)
args.encoder_input_ids = c_enc_in.input_ids
args.encoder_attention_mask = c_enc_in.attention_mask
args.encoder_input_embeds = args.attribution_model.embed(args.encoder_input_ids, as_targets=False)
c_batch = args.attribution_model.formatter.convert_args_to_batch(args)
c_out = args.attribution_model.get_forward_output(c_batch, use_embeddings=is_enc_dec, **forward_kwargs)
if return_contrastive_target_ids:
return c_out, c_tgt_ids
return c_out


@contrast_fn_docstring()
def contrast_logits_fn(
args: StepFunctionArgs,
contrast_sources: Optional[FeatureAttributionInput] = None,
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
contrast_force_inputs: bool = False,
):
"""Returns the logit of a generation target given contrastive context or target prediction alternative.
If only ``contrast_targets`` are specified, the logit of the contrastive prediction is computed given same
context. The logit for the same token given contrastive source/target preceding context can also be computed
using ``contrast_sources`` without specifying ``contrast_targets``.
"""
c_output, c_tgt_ids = _get_contrast_output(
c_args = _setup_contrast_args(
args,
contrast_sources=contrast_sources,
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
return_contrastive_target_ids=True,
contrast_force_inputs=contrast_force_inputs,
)
if c_tgt_ids:
args.target_ids = c_tgt_ids
args.forward_output = c_output
return logit_fn(args)
return logit_fn(c_args)


@contrast_fn_docstring()
Expand All @@ -239,23 +154,21 @@ def contrast_prob_fn(
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
logprob: bool = False,
contrast_force_inputs: bool = False,
):
"""Returns the probability of a generation target given contrastive context or target prediction alternative.
If only ``contrast_targets`` are specified, the probability of the contrastive prediction is computed given same
context. The probability for the same token given contrastive source/target preceding context can also be computed
using ``contrast_sources`` without specifying ``contrast_targets``.
"""
c_output, c_tgt_ids = _get_contrast_output(
c_args = _setup_contrast_args(
args,
contrast_sources=contrast_sources,
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
return_contrastive_target_ids=True,
contrast_force_inputs=contrast_force_inputs,
)
if c_tgt_ids:
args.target_ids = c_tgt_ids
args.forward_output = c_output
return probability_fn(args, logprob=logprob)
return probability_fn(c_args, logprob=logprob)


@contrast_fn_docstring()
Expand All @@ -264,7 +177,7 @@ def pcxmi_fn(
contrast_sources: Optional[FeatureAttributionInput] = None,
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
**kwargs,
contrast_force_inputs: bool = False,
) -> SingleScorePerStepTensor:
"""Compute the pointwise conditional cross-mutual information (P-CXMI) of target ids given original and contrastive
input options. The P-CXMI is defined as the negative log-ratio between the conditional probability of the target
Expand All @@ -277,6 +190,7 @@ def pcxmi_fn(
contrast_sources=contrast_sources,
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
contrast_force_inputs=contrast_force_inputs,
)
return -torch.log2(torch.div(original_probs, contrast_probs))

Expand All @@ -290,6 +204,7 @@ def kl_divergence_fn(
top_k: int = 0,
top_p: float = 1.0,
min_tokens_to_keep: int = 1,
contrast_force_inputs: bool = False,
) -> SingleScorePerStepTensor:
"""Compute the pointwise Kullback-Leibler divergence of target ids given original and contrastive input options.
The KL divergence is the expectation of the log difference between the probabilities of regular (P) and contrastive
Expand All @@ -304,7 +219,11 @@ def kl_divergence_fn(
min_tokens_to_keep (:obj:`int`): Minimum number of tokens to keep with :obj:`top_p` filtering. Defaults to
:obj:`1`.
"""

if not contrast_force_inputs and args.is_attributed_fn:
raise RuntimeError(
"Using KL divergence as attribution target might lead to unexpected results, depending on the attribution"
"method used. Use --contrast_force_inputs in the model.attribute call to proceed."
)
original_logits: torch.Tensor = args.attribution_model.output2logits(args.forward_output)
contrast_output = _get_contrast_output(
args=args,
Expand All @@ -313,7 +232,7 @@ def kl_divergence_fn(
contrast_targets_alignments=contrast_targets_alignments,
return_contrastive_target_ids=False,
)
contrast_logits: torch.Tensor = args.attribution_model.output2logits(contrast_output)
contrast_logits: torch.Tensor = args.attribution_model.output2logits(contrast_output.forward_output)
filtered_original_logits, filtered_contrast_logits = filter_logits(
original_logits=original_logits,
contrast_logits=contrast_logits,
Expand All @@ -338,6 +257,7 @@ def contrast_prob_diff_fn(
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
logprob: bool = False,
contrast_force_inputs: bool = False,
):
"""Returns the difference between next step probability for a candidate generation target vs. a contrastive
alternative. Can be used as attribution target to answer the question: "Which features were salient in the
Expand All @@ -353,6 +273,7 @@ def contrast_prob_diff_fn(
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
logprob=logprob,
contrast_force_inputs=contrast_force_inputs,
)
return model_probs - contrast_probs

Expand All @@ -363,6 +284,7 @@ def contrast_logits_diff_fn(
contrast_sources: Optional[FeatureAttributionInput] = None,
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
contrast_force_inputs: bool = False,
):
"""Equivalent to ``contrast_prob_diff_fn`` but for logits. The original target function used in
`Yin and Neubig (2022) <https://aclanthology.org/2022.emnlp-main.14>`__
Expand All @@ -373,6 +295,7 @@ def contrast_logits_diff_fn(
contrast_sources=contrast_sources,
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
contrast_force_inputs=contrast_force_inputs,
)
return model_logits - contrast_logits

Expand All @@ -383,6 +306,7 @@ def in_context_pvi_fn(
contrast_sources: Optional[FeatureAttributionInput] = None,
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
contrast_force_inputs: bool = False,
):
"""Returns the in-context pointwise V-usable information as defined by `Lu et al. (2023)
<https://arxiv.org/abs/2310.12300>`__. In-context PVI is a variant of P-CXMI that captures the amount of usable
Expand All @@ -400,6 +324,7 @@ def in_context_pvi_fn(
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
logprob=True,
contrast_force_inputs=contrast_force_inputs,
)
return -orig_logprob + contrast_logprob

Expand Down
2 changes: 2 additions & 0 deletions inseq/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def from_batch(self, batch: Batch) -> "DecoderOnlyBatch":
def slice_batch_from_position(
batch: DecoderOnlyBatch, curr_idx: int, alignments: Optional[List[Tuple[int, int]]] = None
) -> Tuple[DecoderOnlyBatch, IdsTensor]:
if len(alignments) > 0 and isinstance(alignments[0], list):
alignments = alignments[0]
truncate_idx = get_aligned_idx(curr_idx, alignments)
tgt_ids = batch.target_ids[:, truncate_idx]
return batch[:truncate_idx], tgt_ids
3 changes: 2 additions & 1 deletion inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def format_step_function_args(
forward_output: ModelOutput,
target_ids: ExpandedTargetIdsTensor,
batch: DecoderOnlyBatch,
is_attributed_fn: bool = False,
) -> StepFunctionArgs:
raise NotImplementedError()

Expand Down Expand Up @@ -650,7 +651,7 @@ def _forward(
output = self.get_forward_output(batch, use_embeddings=use_embeddings, **kwargs)
logger.debug(f"logits: {pretty_tensor(output.logits)}")
step_fn_args = self.formatter.format_step_function_args(
attribution_model=self, forward_output=output, target_ids=target_ids, batch=batch
attribution_model=self, forward_output=output, target_ids=target_ids, is_attributed_fn=True, batch=batch
)
step_fn_extra_args = {k: v for k, v in zip(attributed_fn_argnames, args) if v is not None}
return attributed_fn(step_fn_args, **step_fn_extra_args)
Expand Down
2 changes: 2 additions & 0 deletions inseq/models/decoder_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,13 @@ def format_step_function_args(
forward_output: ModelOutput,
target_ids: ExpandedTargetIdsTensor,
batch: DecoderOnlyBatch,
is_attributed_fn: bool = False,
) -> StepFunctionDecoderOnlyArgs:
return StepFunctionDecoderOnlyArgs(
attribution_model=attribution_model,
forward_output=forward_output,
target_ids=target_ids,
is_attributed_fn=is_attributed_fn,
decoder_input_ids=batch.target_ids,
decoder_attention_mask=batch.target_mask,
decoder_input_embeds=batch.target_embeds,
Expand Down
2 changes: 2 additions & 0 deletions inseq/models/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,13 @@ def format_step_function_args(
forward_output: ModelOutput,
target_ids: ExpandedTargetIdsTensor,
batch: EncoderDecoderBatch,
is_attributed_fn: bool = False,
) -> StepFunctionEncoderDecoderArgs:
return StepFunctionEncoderDecoderArgs(
attribution_model=attribution_model,
forward_output=forward_output,
target_ids=target_ids,
is_attributed_fn=is_attributed_fn,
encoder_input_ids=batch.source_ids,
decoder_input_ids=batch.target_ids,
encoder_input_embeds=batch.source_embeds,
Expand Down
9 changes: 5 additions & 4 deletions inseq/utils/alignment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ def get_adjusted_alignments(
# Default behavior: fill missing alignments with 1:1 position alignments starting from the bottom of the
# two sequences
if not match_pairs:
if len(contrast_tokens) < step_idx:
filled_alignments.append((pair_idx, 0))
if (len(contrast_tokens) - step_idx) < start_pos:
filled_alignments.append((pair_idx, len(contrast_tokens) - 1))
else:
filled_alignments.append((pair_idx, len(contrast_tokens) - step_idx))
else:
Expand All @@ -329,11 +329,12 @@ def get_adjusted_alignments(
valid_match = match_pairs_unaligned[0] if match_pairs_unaligned else match_pairs[0]
filled_alignments.append(valid_match)
if alignments != filled_alignments:
alignments = filled_alignments
logger.warning(
f"Provided alignments do not cover all {end_pos - start_pos} tokens from the original"
" sequence.\nFilling missing position with right-aligned 1:1 position alignments."
" sequence.\nFilling missing position with right-aligned 1:1 position alignments.\n"
f"Generated alignments: {alignments}"
)
alignments = filled_alignments
if is_auto_aligned:
logger.warning(
f"Using {ALIGN_MODEL_ID} for automatic alignments. Provide custom alignments for non-linguistic "
Expand Down
Loading

0 comments on commit 6feda95

Please sign in to comment.