Skip to content

Commit

Permalink
Bump ruff style to py39
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Jan 14, 2024
1 parent 8d89b11 commit 4872d5b
Show file tree
Hide file tree
Showing 40 changed files with 346 additions and 341 deletions.
5 changes: 3 additions & 2 deletions inseq/attr/attribution_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
"""Decorators for attribution methods."""

import logging
from collections.abc import Sequence
from functools import wraps
from typing import Any, Callable, List, Optional, Sequence
from typing import Any, Callable, Optional

from ..data.data_utils import TensorWrapper

Expand Down Expand Up @@ -55,7 +56,7 @@ def batched(f: Callable[..., Any]) -> Callable[..., Any]:

@wraps(f)
def batched_wrapper(self, *args, batch_size: Optional[int] = None, **kwargs):
def get_batched(bs: Optional[int], seq: Sequence[Any]) -> List[List[Any]]:
def get_batched(bs: Optional[int], seq: Sequence[Any]) -> list[list[Any]]:
if isinstance(seq, str):
seq = [seq]
if isinstance(seq, list):
Expand Down
18 changes: 9 additions & 9 deletions inseq/attr/feat/attribution_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import math
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from ...utils import extract_signature_args, get_aligned_idx
from ...utils.typing import (
Expand Down Expand Up @@ -50,7 +50,7 @@ def check_attribute_positions(
max_length: int,
attr_pos_start: Optional[int] = None,
attr_pos_end: Optional[int] = None,
) -> Tuple[int, int]:
) -> tuple[int, int]:
r"""Checks whether the combination of start/end positions for attribution is valid.
Args:
Expand Down Expand Up @@ -89,8 +89,8 @@ def join_token_ids(
tokens: OneOrMoreTokenSequences,
ids: OneOrMoreIdSequences,
contrast_tokens: Optional[OneOrMoreTokenSequences] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
) -> List[TokenWithId]:
contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None,
) -> list[TokenWithId]:
"""Joins tokens and ids into a list of TokenWithId objects."""
if contrast_tokens is None:
contrast_tokens = tokens
Expand All @@ -115,10 +115,10 @@ def join_token_ids(
def extract_args(
attribution_method: "FeatureAttribution",
attributed_fn: Callable[..., SingleScorePerStepTensor],
step_scores: List[str],
default_args: List[str],
step_scores: list[str],
default_args: list[str],
**kwargs,
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
attribution_args = kwargs.pop("attribution_args", {})
attributed_fn_args = kwargs.pop("attributed_fn_args", {})
step_scores_args = kwargs.pop("step_scores_args", {})
Expand All @@ -142,9 +142,9 @@ def extract_args(


def get_source_target_attributions(
attr: Union[StepAttributionTensor, Tuple[StepAttributionTensor, StepAttributionTensor]],
attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]],
is_encoder_decoder: bool,
) -> Tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]:
) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]:
if isinstance(attr, tuple):
if is_encoder_decoder:
return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None)
Expand Down
40 changes: 20 additions & 20 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""
import logging
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import torch
from jaxtyping import Int
Expand Down Expand Up @@ -173,12 +173,12 @@ def prepare_and_attribute(
pretty_progress: bool = True,
output_step_attributions: bool = False,
attribute_target: bool = False,
step_scores: List[str] = [],
step_scores: list[str] = [],
include_eos_baseline: bool = False,
attributed_fn: Union[str, Callable[..., SingleScorePerStepTensor], None] = None,
attribution_args: Dict[str, Any] = {},
attributed_fn_args: Dict[str, Any] = {},
step_scores_args: Dict[str, Any] = {},
attribution_args: dict[str, Any] = {},
attributed_fn_args: dict[str, Any] = {},
step_scores_args: dict[str, Any] = {},
) -> FeatureAttributionOutput:
r"""Prepares inputs and performs attribution.
Expand Down Expand Up @@ -276,11 +276,11 @@ def format_contrastive_targets(
self,
target_sequences: TextSequences,
target_tokens: OneOrMoreTokenSequences,
attributed_fn_args: Dict[str, Any],
step_scores_args: Dict[str, Any],
attributed_fn_args: dict[str, Any],
step_scores_args: dict[str, Any],
attr_pos_start: int,
attr_pos_end: int,
) -> Tuple[Optional[DecoderOnlyBatch], Optional[List[List[Tuple[int, int]]]], Dict[str, Any], Dict[str, Any]]:
) -> tuple[Optional[DecoderOnlyBatch], Optional[list[list[tuple[int, int]]]], dict[str, Any], dict[str, Any]]:
contrast_batch, contrast_targets_alignments = None, None
contrast_targets = attributed_fn_args.get("contrast_targets", None)
if contrast_targets is None:
Expand Down Expand Up @@ -327,10 +327,10 @@ def attribute(
pretty_progress: bool = True,
output_step_attributions: bool = False,
attribute_target: bool = False,
step_scores: List[str] = [],
attribution_args: Dict[str, Any] = {},
attributed_fn_args: Dict[str, Any] = {},
step_scores_args: Dict[str, Any] = {},
step_scores: list[str] = [],
attribution_args: dict[str, Any] = {},
attributed_fn_args: dict[str, Any] = {},
step_scores_args: dict[str, Any] = {},
) -> FeatureAttributionOutput:
r"""Performs the feature attribution procedure using the specified attribution method.
Expand Down Expand Up @@ -501,10 +501,10 @@ def filtered_attribute_step(
attributed_fn: Callable[..., SingleScorePerStepTensor],
target_attention_mask: Optional[Int[torch.Tensor, "batch_size 1"]] = None,
attribute_target: bool = False,
step_scores: List[str] = [],
attribution_args: Dict[str, Any] = {},
attributed_fn_args: Dict[str, Any] = {},
step_scores_args: Dict[str, Any] = {},
step_scores: list[str] = [],
attribution_args: dict[str, Any] = {},
attributed_fn_args: dict[str, Any] = {},
step_scores_args: dict[str, Any] = {},
) -> FeatureAttributionStepOutput:
r"""Performs a single attribution step for all the sequences in the batch that
still have valid target_ids, as identified by the target_attention_mask.
Expand Down Expand Up @@ -597,15 +597,15 @@ def filtered_attribute_step(
step_output = step_output.detach().to("cpu")
return step_output

def get_attribution_args(self, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def get_attribution_args(self, **kwargs) -> tuple[dict[str, Any], dict[str, Any]]:
if hasattr(self, "method") and hasattr(self.method, "attribute"):
return extract_signature_args(kwargs, self.method.attribute, self.ignore_extra_args, return_remaining=True)
return {}, {}

def attribute_step(
self,
attribute_fn_main_args: Dict[str, Any],
attribution_args: Dict[str, Any] = {},
attribute_fn_main_args: dict[str, Any],
attribution_args: dict[str, Any] = {},
) -> FeatureAttributionStepOutput:
r"""Performs a single attribution step for the specified attribution arguments.
Expand Down Expand Up @@ -663,7 +663,7 @@ class DummyAttribution(FeatureAttribution):
method_name = "dummy"

def attribute_step(
self, attribute_fn_main_args: Dict[str, Any], attribution_args: Dict[str, Any] = {}
self, attribute_fn_main_args: dict[str, Any], attribution_args: dict[str, Any] = {}
) -> FeatureAttributionStepOutput:
return FeatureAttributionStepOutput(
source_attributions=None,
Expand Down
6 changes: 3 additions & 3 deletions inseq/attr/feat/gradient_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Gradient-based feature attribution methods."""

import logging
from typing import Any, Dict
from typing import Any

from captum.attr import (
DeepLift,
Expand Down Expand Up @@ -65,8 +65,8 @@ def unhook(self, **kwargs):

def attribute_step(
self,
attribute_fn_main_args: Dict[str, Any],
attribution_args: Dict[str, Any] = {},
attribute_fn_main_args: dict[str, Any],
attribution_args: dict[str, Any] = {},
) -> GranularFeatureAttributionStepOutput:
r"""Performs a single attribution step for the specified attribution arguments.
Expand Down
6 changes: 3 additions & 3 deletions inseq/attr/feat/internals_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Attention-based feature attribution methods."""

import logging
from typing import Any, Dict, Optional
from typing import Any, Optional

from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import Attribution
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(self, attribution_model, **kwargs):

def attribute_step(
self,
attribute_fn_main_args: Dict[str, Any],
attribution_args: Dict[str, Any],
attribute_fn_main_args: dict[str, Any],
attribution_args: dict[str, Any],
) -> MultiDimensionalFeatureAttributionStepOutput:
return self.method.attribute(**attribute_fn_main_args, **attribution_args)
12 changes: 6 additions & 6 deletions inseq/attr/feat/ops/discretized_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

from pathlib import Path
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -51,7 +51,7 @@ def load_monotonic_path_builder(
self,
model_name: str,
vocabulary_embeddings: VocabularyEmbeddingsTensor,
special_tokens: List[int],
special_tokens: list[int],
cache_dir: Path = INSEQ_ARTIFACTS_CACHE / "dig_knn",
embedding_scaling: int = 1,
**kwargs,
Expand All @@ -67,7 +67,7 @@ def load_monotonic_path_builder(
)

@staticmethod
def get_inputs_baselines(scaled_features_tpl: Tuple[Tensor, ...], n_steps: int) -> Tuple[Tensor, ...]:
def get_inputs_baselines(scaled_features_tpl: tuple[Tensor, ...], n_steps: int) -> tuple[Tensor, ...]:
# Baseline and inputs are reversed in the path builder
# For every element in the batch, the first embedding of the sub-tensor
# of shape (n_steps x embedding_dim) is the baseline, the last is the input.
Expand Down Expand Up @@ -96,7 +96,7 @@ def attribute( # type: ignore
method: str = "greedy",
internal_batch_size: Union[None, int] = None,
return_convergence_delta: bool = False,
) -> Union[TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]]:
) -> Union[TensorOrTupleOfTensorsGeneric, tuple[TensorOrTupleOfTensorsGeneric, Tensor]]:
n_examples = inputs[0].shape[0]
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
Expand Down Expand Up @@ -146,11 +146,11 @@ def attribute( # type: ignore

def _attribute(
self,
scaled_features_tpl: Tuple[Tensor, ...],
scaled_features_tpl: tuple[Tensor, ...],
target: TargetType = None,
additional_forward_args: Any = None,
n_steps: int = 50,
) -> Tuple[Tensor, ...]:
) -> tuple[Tensor, ...]:
additional_forward_args = _format_additional_forward_args(additional_forward_args)
input_additional_args = (
_expand_additional_forward_args(additional_forward_args, n_steps)
Expand Down
14 changes: 7 additions & 7 deletions inseq/attr/feat/ops/monotonic_path_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from enum import Enum
from itertools import islice
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Union

import torch
from jaxtyping import Float, Int
Expand Down Expand Up @@ -56,7 +56,7 @@ class UnknownPathBuildingStrategy(Exception):
def __init__(
self,
strategy: str,
*args: Tuple[Any],
*args: tuple[Any],
) -> None:
"""Initialize the exception."""
super().__init__(
Expand All @@ -75,7 +75,7 @@ def __init__(
self,
vocabulary_embeddings: VocabularyEmbeddingsTensor,
knn_graph: "csr_matrix",
special_tokens: List[int] = [],
special_tokens: list[int] = [],
) -> None:
"""Initialize the monotonic path builder."""
self.vocabulary_embeddings = vocabulary_embeddings
Expand Down Expand Up @@ -112,7 +112,7 @@ def load(
overwrite_cache: bool = False,
cache_dir: Path = INSEQ_ARTIFACTS_CACHE / "path_knn",
vocabulary_embeddings: Optional[VocabularyEmbeddingsTensor] = None,
special_tokens: List[int] = [],
special_tokens: list[int] = [],
embedding_scaling: int = 1,
) -> "MonotonicPathBuilder":
"""Load a cached monotonic path builder from a model name, or compute it if it does not exist."""
Expand Down Expand Up @@ -188,7 +188,7 @@ def find_path(
baseline_idx: int,
n_steps: Optional[int] = 30,
strategy: Optional[str] = "greedy",
) -> List[int]:
) -> list[int]:
"""Find a monotonic path from a word to a baseline."""
# if word_idx is a special token copy it and return
if word_idx in self.special_tokens:
Expand All @@ -207,7 +207,7 @@ def find_path(
return word_path

def build_monotonic_path_embedding(
self, word_path: List[int], baseline_idx: int, n_steps: int = 30
self, word_path: list[int], baseline_idx: int, n_steps: int = 30
) -> Float[torch.Tensor, "n_steps embed_size"]:
"""Build a monotonic path embedding from a word path."""
baseline_vec = self.vocabulary_embeddings[baseline_idx]
Expand All @@ -231,7 +231,7 @@ def get_closest_word(
self,
word_idx: int,
baseline_idx: int,
word_path: List[int],
word_path: list[int],
strategy: str = "greedy",
n_steps: int = 30,
) -> int:
Expand Down
14 changes: 7 additions & 7 deletions inseq/attr/feat/ops/sequential_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import typing
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, Union

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -138,7 +138,7 @@ def attribute(
internal_batch_size: Union[None, int] = None,
*,
return_convergence_delta: Literal[True],
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
) -> tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
...

def attribute( # type: ignore
Expand All @@ -153,7 +153,7 @@ def attribute( # type: ignore
return_convergence_delta: bool = False,
) -> Union[
TensorOrTupleOfTensorsGeneric,
Tuple[TensorOrTupleOfTensorsGeneric, Tensor],
tuple[TensorOrTupleOfTensorsGeneric, Tensor],
]:
r"""
This method attributes the output of the model with given target index
Expand Down Expand Up @@ -367,15 +367,15 @@ def attribute( # type: ignore

def _attribute(
self,
inputs: Tuple[Tensor, ...],
baselines: Tuple[Union[Tensor, int, float], ...],
inputs: tuple[Tensor, ...],
baselines: tuple[Union[Tensor, int, float], ...],
target: TargetType = None,
additional_forward_args: Any = None,
n_steps: int = 50,
method: str = "gausslegendre",
idx: int = None,
step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None,
) -> Tuple[Tensor, ...]:
step_sizes_and_alphas: Union[None, tuple[list[float], list[float]]] = None,
) -> tuple[Tensor, ...]:
if step_sizes_and_alphas is None:
# retrieve step size and scaling factor for specified
# approximation method
Expand Down
Loading

0 comments on commit 4872d5b

Please sign in to comment.