Skip to content

Commit

Permalink
Fix multi-gpu inference and duplicate BOS issue for decoder-only (#280)
Browse files Browse the repository at this point in the history
* Fix skip special tokens if pre-specified in prompt

* Move attentions to cpu preemptively

* Add rescale and readme changes
  • Loading branch information
gsarti authored Jul 23, 2024
1 parent 979d223 commit e5b835b
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 8 deletions.
13 changes: 8 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

## 🚀 Features

- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM` to model config.
- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM`, `Gemma2ForCausalLM` to model config.

- Add `rescale_attributions` to Inseq CLI commands for `rescale=True` ([#280](https://github.com/inseq-team/inseq/pull/280)).

- Rows and columns in the visualization now have indices alongside tokens to facilitate index-based slicing, aggregation and alignment [#282](https://github.com/inseq-team/inseq/pull/282)

Expand All @@ -23,7 +25,7 @@ out.save("output.json")
out.save("output_fp16.json", scores_precision="float16") # or "float8"

# Automatic conversion to float32
out_loaded = inseq.FeatureAttributionOutput.load("output_fp16.json")
out_loaded = inseq.FeatureAttributionOutput.load("output_fp16.json")
```

- - A new `SliceAggregator` (`"slices"`) is added to allow for slicing source (in encoder-decoder) or target (in decoder-only) tokens from a `FeatureAttributionSequenceOutput` object, using the same syntax of `ContiguousSpanAggregator`. The `__getitem__` method of the `FeatureAttributionSequenceOutput` is a shortcut for this, allowing slicing with `[start:stop]` syntax. [#282](https://github.com/inseq-team/inseq/pull/282)
Expand Down Expand Up @@ -71,18 +73,19 @@ out_female = attrib_model.attribute(
## 🔧 Fixes and Refactoring

- Fix the issue in the attention implementation from [#268](https://github.com/inseq-team/inseq/issues/268) where non-terminal position in the tensor were set to nan if they were 0s ([#269](https://github.com/inseq-team/inseq/pull/269)).

- Fix the pad token in cases where it is not specified by default in the loaded model (e.g. for Qwen models) ([#269](https://github.com/inseq-team/inseq/pull/269)).

- Fix bug reported in [#266](https://github.com/inseq-team/inseq/issues/266) making `value_zeroing` unusable for SDPA attention. This enables using the method on models using SDPA attention as default (e.g. `GemmaForCausalLM`) without passing `model_kwargs={'attn_implementation': 'eager'}` ([#267](https://github.com/inseq-team/inseq/pull/267)).

- The directions of generated/attributed tokens were clarified in the visualization using arrows instead of x/y [#282](https://github.com/inseq-team/inseq/pull/282)
- Fix multi-device support and duplicate BOS for chat template models ([#280](https://github.com/inseq-team/inseq/pull/280)).

- The directions of generated/attributed tokens were clarified in the visualization using arrows instead of x/y [#282](https://github.com/inseq-team/inseq/pull/282)

## 📝 Documentation and Tutorials

*No changes*

## 💥 Breaking Changes

*No changes*
*No changes*
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ Our vision for Inseq is to create a centralized, comprehensive and robust set of

## Citing Inseq

If you use Inseq in your research we suggest to include a mention to the specific release (e.g. v0.4.0) and we kindly ask you to cite our reference paper as:
If you use Inseq in your research we suggest to include a mention to the specific release (e.g. v0.6.0) and we kindly ask you to cite our reference paper as:

```bibtex
@inproceedings{sarti-etal-2023-inseq,
Expand Down Expand Up @@ -308,7 +308,7 @@ If you use Inseq in your research we suggest to include a mention to the specifi
Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below.

> [!TIP]
> Last update: May 2024. Please open a pull request to add your publication to the list.
> Last update: June 2024. Please open a pull request to add your publication to the list.

<details>
<summary><b>2023</b></summary>
Expand All @@ -331,6 +331,9 @@ Inseq has been used in various research projects. A list of known publications t
<li><a href="https://arxiv.org/abs/2402.00794">ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models</a> (Zhao et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2404.02421">Revisiting subword tokenization: A case study on affixal negation in large language models</a> (Truong et al., 2024)</li>
<li><a href="https://hal.science/hal-04581586">Exploring NMT Explainability for Translators Using NMT Visualising Tools</a> (Gonzalez-Saez et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2405.14899">DETAIL: Task DEmonsTration Attribution for Interpretable In-context Learning</a> (Zhou et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2406.06399">Should We Fine-Tune or RAG? Evaluating Different Techniques to Adapt LLMs for Dialogue</a> (Alghisi et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2406.13663">Model Internals-based Answer Attribution for Trustworthy Retrieval-Augmented Generation</a> (Qi, Sarti et al., 2024)</li>
</ol>

</details>
3 changes: 3 additions & 0 deletions inseq/commands/attribute/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ def aggregate_attribution_scores(
selectors: Optional[list[int]] = None,
aggregators: Optional[list[str]] = None,
normalize_attributions: bool = False,
rescale_attributions: bool = False,
) -> FeatureAttributionOutput:
if selectors is not None and aggregators is not None:
for select_idx, aggregator_fn in zip(selectors, aggregators):
out = out.aggregate(
aggregator=aggregator_fn,
normalize=normalize_attributions,
rescale=rescale_attributions,
select_idx=select_idx,
do_post_aggregation_checks=False,
)
Expand Down Expand Up @@ -79,6 +81,7 @@ def attribute(input_texts, generated_texts, args: AttributeExtendedArgs):
selectors=args.attribution_selectors,
aggregators=args.attribution_aggregators,
normalize_attributions=args.normalize_attributions,
rescale_attributions=args.rescale_attributions,
)
print(f"Saving {'aggregated ' if args.aggregate_output else ''}attributions to {args.save_path}")
out.save(args.save_path, overwrite=True)
Expand Down
8 changes: 8 additions & 0 deletions inseq/commands/attribute/attribute_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ class AttributeBaseArgs:
"for each context are normalized to sum up to 1, providing a relative notion of input salience."
),
)
rescale_attributions: bool = cli_arg(
default=False,
help=(
"Whether to rescale the attribution scores for each context. If ``True``, the attribution scores "
"for each context are rescaled to sum up to the number of tokens in the input, providing an absolute"
" notion of input salience."
),
)
model_kwargs: dict = cli_arg(
default_factory=dict,
help="Additional keyword arguments passed to the model constructor in JSON format.",
Expand Down
1 change: 1 addition & 0 deletions inseq/commands/attribute_context/attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
selectors=args.attribution_selectors,
aggregators=args.attribution_aggregators,
normalize_attributions=args.normalize_attributions,
rescale_attributions=args.rescale_attributions,
)[0]
if args.show_intermediate_outputs:
cci_attrib_out.show(do_aggregation=False)
Expand Down
4 changes: 3 additions & 1 deletion inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,9 @@ def attribute(
"Step scores are not supported for final step methods since they do not iterate over the full"
" sequence. Please remove the step scores and compute them separatly passing method='dummy'."
)
input_texts, generated_texts = format_input_texts(input_texts, generated_texts)
input_texts, generated_texts = format_input_texts(
input_texts, generated_texts, skip_special_tokens, self.special_tokens
)
has_generated_texts = generated_texts is not None
if not self.is_encoder_decoder:
for i in range(len(input_texts)):
Expand Down
8 changes: 8 additions & 0 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,12 @@ def get_attentions_dict(
) -> dict[str, MultiLayerMultiUnitScoreTensor]:
if output.encoder_attentions is None or output.decoder_attentions is None:
raise ValueError("Model does not support attribution relying on attention outputs.")
if output.encoder_attentions is not None:
output.encoder_attentions = tuple(att.to("cpu") for att in output.encoder_attentions)
if output.decoder_attentions is not None:
output.decoder_attentions = tuple(att.to("cpu") for att in output.decoder_attentions)
if output.cross_attentions is not None:
output.cross_attentions = tuple(att.to("cpu") for att in output.cross_attentions)
return {
"encoder_self_attentions": torch.stack(output.encoder_attentions, dim=1),
"decoder_self_attentions": torch.stack(output.decoder_attentions, dim=1),
Expand Down Expand Up @@ -506,6 +512,8 @@ def configure_embeddings_scale(self):
def get_attentions_dict(output: CausalLMOutput) -> dict[str, MultiLayerMultiUnitScoreTensor]:
if output.attentions is None:
raise ValueError("Model does not support attribution relying on attention outputs.")
else:
output.attentions = tuple(att.to("cpu") for att in output.attentions)
return {
"decoder_self_attentions": torch.stack(output.attentions, dim=1),
}
Expand Down
3 changes: 3 additions & 0 deletions inseq/models/model_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ FalconForCausalLM:
GemmaForCausalLM:
self_attention_module: "self_attn"
value_vector: "value_states"
Gemma2ForCausalLM:
self_attention_module: "self_attn"
value_vector: "value_states"
GPTBigCodeForCausalLM:
self_attention_module: "attn"
value_vector: "value"
Expand Down
7 changes: 7 additions & 0 deletions inseq/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def isnotebook():
def format_input_texts(
texts: TextInput,
ref_texts: Optional[TextInput] = None,
skip_special_tokens: bool = False,
special_tokens: list[str] = [],
) -> tuple[list[str], list[str]]:
texts = [texts] if isinstance(texts, str) else texts
reference_texts = [ref_texts] if isinstance(ref_texts, str) else ref_texts
Expand All @@ -211,6 +213,11 @@ def format_input_texts(
len(texts), len(reference_texts)
)
)
if skip_special_tokens:
for special_token in special_tokens:
texts = [text.replace(special_token, "") for text in texts]
if reference_texts is not None:
reference_texts = [text.replace(special_token, "") for text in reference_texts]
return texts, reference_texts


Expand Down

0 comments on commit e5b835b

Please sign in to comment.