Skip to content

Commit

Permalink
Drop Python 3.9 support
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Aug 6, 2024
1 parent e03ba89 commit 9c3ab45
Show file tree
Hide file tree
Showing 55 changed files with 535 additions and 546 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
if: github.actor != 'dependabot[bot]' && github.actor != 'dependabot-preview[bot]'
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ sphinx:
build:
os: ubuntu-20.04
tools:
python: "3.9"
python: "3.10"

python:
install:
Expand Down
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

## 🚀 Features

- Added [treescope](https://github.com/google-deepmind/treescope) for model and tensor visualization.
- Added [treescope](https://github.com/google-deepmind/treescope) for interactive model and tensor visualization. ([#283](https://github.com/inseq-team/inseq/pull/283))

- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM`, `Gemma2ForCausalLM` to model config.

Expand Down Expand Up @@ -90,4 +90,4 @@ out_female = attrib_model.attribute(

## 💥 Breaking Changes

*No changes*
- Dropped support for Python 3.9. Please use Python >= 3.10. ([#283](https://github.com/inseq-team/inseq/pull/283))
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
[![Downloads](https://static.pepy.tech/badge/inseq)](https://pepy.tech/project/inseq)
[![License](https://img.shields.io/github/license/inseq-team/inseq)](https://github.com/inseq-team/inseq/blob/main/LICENSE)
[![Demo Paper](https://img.shields.io/badge/ACL%20Anthology%20-%20?logo=data%3Aimage%2Fx-icon%3Bbase64%2CAAABAAEAIBIAAAEAIABwCQAAFgAAACgAAAAgAAAAJAAAAAEAIAAAAAAAAAkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACEa7k0mH%2B%2F5JBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQd7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQd7f8kHO3%2FJB3t%2FyQd7f8kHO3%2FJB3t%2FyQd7f8kHe3%2FJB3t%2FyMc79EkGP8VAAAAAAAAAAAAAAAAIRruTSYf7%2FkkHO3%2FJBzt%2FyQd7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJB3t%2FyQc7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJB3t%2FyQd7f8kHe3%2FIxzv0SQY%2FxUAAAAAAAAAAAAAAAAhIe5NJh%2Fv%2BSQd7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJBzt%2FyQd7f8kHe3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHe3%2FJB3t%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHe3%2FJBzt%2FyQd7f8jHO%2FRJBj%2FFQAAAAAAAAAAAAAAACEa7k0mH%2B%2F5JB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJBzt%2FyQd7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJBzt%2FyQd7f8kHe3%2FJB3t%2FyQc7f8kHe3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyMc79EkGP8VAAAAAAAAAAAAAAAAIRruTSYf7%2FkkHO3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJB3t%2FyQd7f8kHO3%2FJB3t%2FyQc7f8kHe3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHe3%2FJBzt%2FyQd7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FIxzv0SQY%2FxUAAAAAAAAAAAAAAAAhGu5NJh%2Fv%2BSQd7f8kHO3%2FJBzt%2FyQd7f8jIOzYIxvtgiQc8X8kHPF%2FJBzxfyQc8X8kHPF%2FJBzxfyQc8X8kHPF%2FIx%2FuiiMf7OgkHe3%2FJBzt%2FyQc7f8kHO3%2FJhzs9CUg7JYkHPF%2FJBzxfyQc8X8iHfBoMzP%2FCgAAAAAAAAAAAAAAACEa7k0mHu%2F5JBzt%2FyQc7f8kHO3%2FJBzt%2FyQb7LEAAP8FAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAkGP8VIxzv0SQc7f8kHe3%2FJB3t%2FyQc7f8jHOzpIhzuLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIRruTSYf7%2FkkHe3%2FJB3t%2FyQd7f8kHO3%2FJBvssQAA%2FwUAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACQY%2FxUjHO%2FRJB3t%2FyQc7f8kHO3%2FJBzt%2FyMb7OkiHO4tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAhGu5NJh%2Fv%2BSQd7f8kHe3%2FJBzt%2FyQc7f8kHuyxAAD%2FBQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAJBj%2FFSMc79EkHe3%2FJBzt%2FyQc7f8kHO3%2FIxvs6SIc7i0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACEa7k0mH%2B%2F5JBzt%2FyQc7f8kHO3%2FJBzt%2FyQb7LEAAP8FAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAkGP8VIx3v0SQd7f8kHe3%2FJBzt%2FyQc7f8jHOzpIhzuLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIRruTSYf7%2FkkHO3%2FJBzt%2FyQc7f8kHe3%2FJBzuxSgi81MhGu5NIRruTSEa7k0hGu5NISHuTSEh7k0hGu5NIRruTSIa72EjHe3aJBzt%2FyQd7f8kHO3%2FJBzt%2FyMc7OkiHO4tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAhGu5NJh7v%2BSQc7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJh%2Fv%2BSYf7%2FkmHu%2F5Jh%2Fv%2BSYf7%2FkmH%2B%2F5Jh%2Fv%2BSYf7%2FkmH%2B%2F5Jh7v%2BSQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FIxzs6SIc7i0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACEa7k0mHu%2F5JBzt%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJBzt%2FyQd7f8jHOzpIhzuLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIRruTSYe7%2FkkHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQd7f8kHe3%2FJB3t%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyMc7OkiHO4tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAhGu5NJh%2Fv%2BSQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FIxzs6SIc7i0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACEa7k0mH%2B%2F5JB3t%2FyQc7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8jHOzpIhzuLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAKB%2FtOSUc7askHuyxJBvssSQe7LEkHuyxJB7ssSQe7LEkHuyxJBvssSQb7LEkHuyxJB7ssSQe7LEkHuyxJB7ssSUc7LMjHe31JB3t%2FyQd7f8kHe3%2FJBzt%2FyMc7OkiHO4tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD%2FBQAA%2FwUAAP8FAAD%2FBQAA%2FwUAAP8FAAD%2FBQAA%2FwUAAP8FAAD%2FBQAA%2FwUAAP8FAAD%2FBQAA%2FwUAAP8FHBzsGyMd7qYjHO%2FRIxzv0SMd79EjHO%2FRIx7tux4Y%2BSoAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAABwAAAAcAAAAHAAAABwAAAAcAAAAHAP8A%2FwD%2FAP8A%2FwD%2FAP8A%2FwAAAP8AAAD%2FAAAA%2FwAAAP8AAAD%2FAAAA%2FwAAAP%2BAAAD8%3D&labelColor=white&color=red&link=https%3A%2F%2Faclanthology.org%2F2023.acl-demo.40%2F
)](http://arxiv.org/abs/2302.13942)
)](https://aclanthology.org/2023.acl-demo.40)

</div>
<div align="center">

[![Follow Inseq on Twitter]( https://img.shields.io/badge/Twitter-1DA1F2?style=for-the-badge&logo=twitter&logoColor=white)](https://twitter.com/InseqLib)
[![Follow Inseq on Twitter](https://img.shields.io/badge/Twitter-1DA1F2?style=for-the-badge&logo=twitter&logoColor=white)](https://twitter.com/InseqLib)
[![Join the Inseq Discord server](https://img.shields.io/badge/Discord-7289DA?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/V5VgwwFPbu)
[![Read the Docs](https://img.shields.io/badge/-Docs-blue?style=for-the-badge&logo=Read-the-Docs&logoColor=white&link=https://inseq.org)](https://inseq.org)
[![Tutorial](https://img.shields.io/badge/-Tutorial-orange?style=for-the-badge&logo=Jupyter&logoColor=white&link=https://github.com/inseq-team/inseq/blob/main/examples/inseq_tutorial.ipynb)](https://github.com/inseq-team/inseq/blob/main/examples/inseq_tutorial.ipynb)
Expand All @@ -30,7 +30,7 @@ Inseq is a Pytorch-based hackable toolkit to democratize access to common post-h

## Installation

Inseq is available on PyPI and can be installed with `pip` for Python >= 3.9, <= 3.12:
Inseq is available on PyPI and can be installed with `pip` for Python >= 3.10, <= 3.12:

```bash
# Install latest stable version
Expand Down
14 changes: 8 additions & 6 deletions inseq/attr/attribution_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
"""Decorators for attribution methods."""

import logging
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import wraps
from typing import Any, Callable, Optional
from typing import Any

from ..data.data_utils import TensorWrapper

Expand Down Expand Up @@ -55,14 +55,14 @@ def batched(f: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator that enables batching of the args."""

@wraps(f)
def batched_wrapper(self, *args, batch_size: Optional[int] = None, **kwargs):
def get_batched(bs: Optional[int], seq: Sequence[Any]) -> list[list[Any]]:
def batched_wrapper(self, *args, batch_size: int | None = None, **kwargs):
def get_batched(bs: int | None, seq: Sequence[Any]) -> list[list[Any]]:
if isinstance(seq, str):
seq = [seq]
if isinstance(seq, list):
return [seq[i : i + bs] for i in range(0, len(seq), bs)] # noqa
if isinstance(seq, tuple):
return list(zip(*[get_batched(bs, s) for s in seq]))
return list(zip(*[get_batched(bs, s) for s in seq], strict=False))
elif isinstance(seq, TensorWrapper):
return [seq.slice_batch(slice(i, i + bs)) for i in range(0, len(seq), bs)] # noqa
else:
Expand All @@ -75,7 +75,9 @@ def get_batched(bs: Optional[int], seq: Sequence[Any]) -> list[list[Any]]:
len_batches = len(batched_args[0])
assert all(len(batch) == len_batches for batch in batched_args)
output = []
zipped_batched_args = zip(*batched_args) if len(batched_args) > 1 else [(x,) for x in batched_args[0]]
zipped_batched_args = (
zip(*batched_args, strict=False) if len(batched_args) > 1 else [(x,) for x in batched_args[0]]
)
for i, batch in enumerate(zipped_batched_args):
logger.debug(f"Batching enabled: processing batch {i + 1} of {len_batches}...")
out = f(self, *batch, **kwargs)
Expand Down
25 changes: 13 additions & 12 deletions inseq/attr/feat/attribution_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import math
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from ...utils import extract_signature_args, get_aligned_idx
from ...utils.typing import (
Expand All @@ -24,8 +25,8 @@
def tok2string(
attribution_model: "AttributionModel",
token_lists: OneOrMoreTokenSequences,
start: Optional[int] = None,
end: Optional[int] = None,
start: int | None = None,
end: int | None = None,
as_targets: bool = True,
) -> TextInput:
"""Enables bounded tokenization of a list of lists of tokens with start and end positions."""
Expand All @@ -42,14 +43,14 @@ def rescale_attributions_to_tokens(
) -> OneOrMoreAttributionSequences:
return [
attr[: len(tokens)] if not all(math.isnan(x) for x in attr) else []
for attr, tokens in zip(attributions, tokens)
for attr, tokens in zip(attributions, tokens, strict=False)
]


def check_attribute_positions(
max_length: int,
attr_pos_start: Optional[int] = None,
attr_pos_end: Optional[int] = None,
attr_pos_start: int | None = None,
attr_pos_end: int | None = None,
) -> tuple[int, int]:
r"""Checks whether the combination of start/end positions for attribution is valid.
Expand Down Expand Up @@ -88,8 +89,8 @@ def check_attribute_positions(
def join_token_ids(
tokens: OneOrMoreTokenSequences,
ids: OneOrMoreIdSequences,
contrast_tokens: Optional[OneOrMoreTokenSequences] = None,
contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None,
contrast_tokens: OneOrMoreTokenSequences | None = None,
contrast_targets_alignments: list[list[tuple[int, int]]] | None = None,
) -> list[TokenWithId]:
"""Joins tokens and ids into a list of TokenWithId objects."""
if contrast_tokens is None:
Expand All @@ -99,10 +100,10 @@ def join_token_ids(
contrast_targets_alignments = [[(idx, idx) for idx, _ in enumerate(seq)] for seq in tokens]
sequences = []
for target_tokens_seq, contrast_target_tokens_seq, input_ids_seq, alignments_seq in zip(
tokens, contrast_tokens, ids, contrast_targets_alignments
tokens, contrast_tokens, ids, contrast_targets_alignments, strict=False
):
curr_seq = []
for pos_idx, (token, token_idx) in enumerate(zip(target_tokens_seq, input_ids_seq)):
for pos_idx, (token, token_idx) in enumerate(zip(target_tokens_seq, input_ids_seq, strict=False)):
contrast_pos_idx = get_aligned_idx(pos_idx, alignments_seq)
if contrast_pos_idx != -1 and token != contrast_target_tokens_seq[contrast_pos_idx]:
curr_seq.append(TokenWithId(f"{contrast_target_tokens_seq[contrast_pos_idx]}{token}", -1))
Expand Down Expand Up @@ -142,10 +143,10 @@ def extract_args(


def get_source_target_attributions(
attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]],
attr: StepAttributionTensor | tuple[StepAttributionTensor, StepAttributionTensor],
is_encoder_decoder: bool,
has_sequence_scores: bool = False,
) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]:
) -> tuple[StepAttributionTensor | None, StepAttributionTensor | None]:
if isinstance(attr, tuple):
if is_encoder_decoder:
if has_sequence_scores:
Expand Down
23 changes: 12 additions & 11 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
* 🟡: Allow custom arguments for model loading in the :class:`FeatureAttribution` :meth:`load` method.
"""
import logging
from collections.abc import Callable
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Optional

import torch
from jaxtyping import Int
Expand Down Expand Up @@ -123,7 +124,7 @@ def load(
cls,
method_name: str,
attribution_model: Optional["AttributionModel"] = None,
model_name_or_path: Optional[ModelIdentifier] = None,
model_name_or_path: ModelIdentifier | None = None,
**kwargs,
) -> "FeatureAttribution":
r"""Load the selected method and hook it to an existing or available
Expand Down Expand Up @@ -168,16 +169,16 @@ def prepare_and_attribute(
self,
sources: FeatureAttributionInput,
targets: FeatureAttributionInput,
attr_pos_start: Optional[int] = None,
attr_pos_end: Optional[int] = None,
attr_pos_start: int | None = None,
attr_pos_end: int | None = None,
show_progress: bool = True,
pretty_progress: bool = True,
output_step_attributions: bool = False,
attribute_target: bool = False,
step_scores: list[str] = [],
include_eos_baseline: bool = False,
skip_special_tokens: bool = False,
attributed_fn: Union[str, Callable[..., SingleScorePerStepTensor], None] = None,
attributed_fn: str | Callable[..., SingleScorePerStepTensor] | None = None,
attribution_args: dict[str, Any] = {},
attributed_fn_args: dict[str, Any] = {},
step_scores_args: dict[str, Any] = {},
Expand Down Expand Up @@ -317,7 +318,7 @@ def format_contrastive_targets(
attr_pos_start: int,
attr_pos_end: int,
skip_special_tokens: bool = False,
) -> tuple[Optional[DecoderOnlyBatch], Optional[list[list[tuple[int, int]]]], dict[str, Any], dict[str, Any]]:
) -> tuple[DecoderOnlyBatch | None, list[list[tuple[int, int]]] | None, dict[str, Any], dict[str, Any]]:
contrast_batch, contrast_targets_alignments = None, None
contrast_targets = attributed_fn_args.get("contrast_targets", None)
if contrast_targets is None:
Expand Down Expand Up @@ -357,10 +358,10 @@ def format_contrastive_targets(

def attribute(
self,
batch: Union[DecoderOnlyBatch, EncoderDecoderBatch],
batch: DecoderOnlyBatch | EncoderDecoderBatch,
attributed_fn: Callable[..., SingleScorePerStepTensor],
attr_pos_start: Optional[int] = None,
attr_pos_end: Optional[int] = None,
attr_pos_start: int | None = None,
attr_pos_end: int | None = None,
show_progress: bool = True,
pretty_progress: bool = True,
output_step_attributions: bool = False,
Expand Down Expand Up @@ -545,10 +546,10 @@ def attribute(

def filtered_attribute_step(
self,
batch: Union[DecoderOnlyBatch, EncoderDecoderBatch],
batch: DecoderOnlyBatch | EncoderDecoderBatch,
target_ids: Int[torch.Tensor, "batch_size 1"],
attributed_fn: Callable[..., SingleScorePerStepTensor],
target_attention_mask: Optional[Int[torch.Tensor, "batch_size 1"]] = None,
target_attention_mask: Int[torch.Tensor, "batch_size 1"] | None = None,
attribute_target: bool = False,
step_scores: list[str] = [],
attribution_args: dict[str, Any] = {},
Expand Down
8 changes: 4 additions & 4 deletions inseq/attr/feat/internals_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Attention-based feature attribution methods."""

import logging
from typing import Any, Optional
from typing import Any

from captum._utils.typing import TensorOrTupleOfTensorsGeneric

Expand Down Expand Up @@ -46,9 +46,9 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
additional_forward_args: TensorOrTupleOfTensorsGeneric,
encoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None,
decoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None,
cross_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None,
encoder_self_attentions: MultiLayerMultiUnitScoreTensor | None = None,
decoder_self_attentions: MultiLayerMultiUnitScoreTensor | None = None,
cross_attentions: MultiLayerMultiUnitScoreTensor | None = None,
) -> MultiDimensionalFeatureAttributionStepOutput:
"""Extracts the attention weights from the model.
Expand Down
13 changes: 7 additions & 6 deletions inseq/attr/feat/ops/discretized_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
# OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

from collections.abc import Callable
from pathlib import Path
from typing import Any, Callable, Union
from typing import Any

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -94,9 +95,9 @@ def attribute( # type: ignore
additional_forward_args: Any = None,
n_steps: int = 50,
method: str = "greedy",
internal_batch_size: Union[None, int] = None,
internal_batch_size: None | int = None,
return_convergence_delta: bool = False,
) -> Union[TensorOrTupleOfTensorsGeneric, tuple[TensorOrTupleOfTensorsGeneric, Tensor]]:
) -> TensorOrTupleOfTensorsGeneric | tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
n_examples = inputs[0].shape[0]
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
Expand All @@ -112,7 +113,7 @@ def attribute( # type: ignore
n_steps=n_steps,
scale_strategy=method,
)
for input_tensor, baseline_tensor in zip(inputs, baselines)
for input_tensor, baseline_tensor in zip(inputs, baselines, strict=False)
)
if internal_batch_size is not None:
attributions = _batch_attribution(
Expand Down Expand Up @@ -181,7 +182,7 @@ def _attribute(
# total_grads has the same dimensionality as the original inputs
total_grads = tuple(
_reshape_and_sum(scaled_grad, n_steps, grad.shape[0] // n_steps, grad.shape[1:])
for (scaled_grad, grad) in zip(scaled_grads, grads)
for (scaled_grad, grad) in zip(scaled_grads, grads, strict=False)
)
# computes attribution for each tensor in input_tuple
# attributions has the same dimensionality as the original inputs
Expand All @@ -191,5 +192,5 @@ def _attribute(
inputs, baselines = self.get_inputs_baselines(scaled_features_tpl, n_steps)
return tuple(
total_grad * (input - baseline)
for (total_grad, input, baseline) in zip(total_grads, inputs, baselines)
for (total_grad, input, baseline) in zip(total_grads, inputs, baselines, strict=False)
)
14 changes: 10 additions & 4 deletions inseq/attr/feat/ops/lime.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import inspect
import logging
import math
from collections.abc import Callable
from functools import partial
from typing import Any, Callable, Optional, cast
from typing import Any, cast

import torch
from captum._utils.common import _expand_additional_forward_args, _expand_target
Expand All @@ -25,8 +26,8 @@ def __init__(
similarity_func: Callable = None,
perturb_func: Callable = None,
perturb_interpretable_space: bool = False,
from_interp_rep_transform: Optional[Callable] = None,
to_interp_rep_transform: Optional[Callable] = None,
from_interp_rep_transform: Callable | None = None,
to_interp_rep_transform: Callable | None = None,
mask_prob: float = 0.3,
) -> None:
if interpretable_model is None:
Expand Down Expand Up @@ -271,7 +272,12 @@ def detach_to_list(t):

# Merge the binary mask with the special_token_ids mask
mask = (
torch.tensor([m + s if s == 0 else s for m, s in zip(mask_multinomial_binary, mask_special_token_ids)])
torch.tensor(
[
m + s if s == 0 else s
for m, s in zip(mask_multinomial_binary, mask_special_token_ids, strict=False)
]
)
.to(self.attribution_model.device)
.unsqueeze(-1) # 1D -> 2D
)
Expand Down
Loading

0 comments on commit 9c3ab45

Please sign in to comment.