diff --git a/inseq/attr/feat/ops/__init__.py b/inseq/attr/feat/ops/__init__.py index 48011294..abe53d29 100644 --- a/inseq/attr/feat/ops/__init__.py +++ b/inseq/attr/feat/ops/__init__.py @@ -1,13 +1,13 @@ from .discretized_integrated_gradients import DiscretetizedIntegratedGradients from .lime import Lime from .monotonic_path_builder import MonotonicPathBuilder -from .reagent import ReAGent +from .reagent import Reagent from .sequential_integrated_gradients import SequentialIntegratedGradients __all__ = [ "DiscretetizedIntegratedGradients", "MonotonicPathBuilder", "Lime", - "ReAGent", + "Reagent", "SequentialIntegratedGradients", ] diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py index 8e3e9765..95fd8988 100644 --- a/inseq/attr/feat/ops/reagent.py +++ b/inseq/attr/feat/ops/reagent.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import TYPE_CHECKING, Any, Union import torch from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric @@ -10,12 +10,21 @@ from .reagent_core.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator from .reagent_core.stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator from .reagent_core.token_replacement.token_replacer.uniform import UniformTokenReplacer -from .reagent_core.token_replacement.token_sampler.postag import POSTagTokenSampler +from .reagent_core.token_sampler import POSTagTokenSampler +if TYPE_CHECKING: + from ....models import HuggingfaceModel -class ReAGent(PerturbationAttribution): - r""" - ReAGent + +class Reagent(PerturbationAttribution): + r"""Recursive attribution generator (ReAGent) method. + + Measures importance as the drop in prediction probability produced by replacing a token with a plausible + alternative predicted by a LM. + + Reference implementation: + `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models + `__ Args: forward_func (callable): The forward function of the model or any @@ -28,10 +37,6 @@ class ReAGent(PerturbationAttribution): max_probe_steps (int): max_probe_steps num_probes (int): number of probes in parallel - References: - `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models - `_ - Examples: ``` import inseq @@ -52,7 +57,7 @@ class ReAGent(PerturbationAttribution): def __init__( self, - attribution_model: Callable, + attribution_model: "HuggingfaceModel", rational_size: int = 5, rational_size_ratio: float = None, stopping_condition_top_k: int = 3, @@ -65,7 +70,9 @@ def __init__( model = attribution_model.model tokenizer = attribution_model.tokenizer - token_sampler = POSTagTokenSampler(tokenizer=tokenizer, device=model.device) + token_sampler = POSTagTokenSampler( + tokenizer=tokenizer, identifier=attribution_model.model_name, device=attribution_model.device + ) stopping_condition_evaluator = TopKStoppingConditionEvaluator( model=model, diff --git a/inseq/attr/feat/ops/reagent_core/__init__.py b/inseq/attr/feat/ops/reagent_core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py index 56541e98..fe842652 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/ranking.py @@ -63,7 +63,7 @@ def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: """ super().sample(input) - token_sampled = self.token_sampler.sample(input) + token_sampled = self.token_sampler(input) input_replaced = input * ~self.mask_replacing + token_sampled * self.mask_replacing diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py index 5fac259a..4c663bf1 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_replacer/uniform.py @@ -40,7 +40,7 @@ def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]: sample_uniform = torch.rand(input.shape, device=input.device) mask_replacing = sample_uniform < self.ratio - token_sampled = self.token_sampler.sample(input) + token_sampled = self.token_sampler(input) input_replaced = input * ~mask_replacing + token_sampled * mask_replacing diff --git a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py index 4526fd38..954b43ae 100644 --- a/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py +++ b/inseq/attr/feat/ops/reagent_core/token_replacement/token_sampler/inferential_m.py @@ -2,8 +2,6 @@ from transformers import AutoModelWithLMHead, AutoTokenizer from typing_extensions import override -from .base import TokenSampler - class InferentialMTokenSampler(TokenSampler): """Sample tokens from a seq-2-seq model""" @@ -37,8 +35,6 @@ def sample(self, inputs: torch.Tensor) -> torch.Tensor: token_inferences: sampled (placement) tokens by inference """ - super().sample(inputs) - batch_li = [] for seq_i in torch.arange(inputs.shape[0]): seq_li = [] diff --git a/inseq/attr/feat/ops/reagent_core/token_sampler.py b/inseq/attr/feat/ops/reagent_core/token_sampler.py new file mode 100644 index 00000000..73cfb7e7 --- /dev/null +++ b/inseq/attr/feat/ops/reagent_core/token_sampler.py @@ -0,0 +1,90 @@ +import logging +from abc import ABC, abstractmethod +from collections import defaultdict +from pathlib import Path +from typing import Any, Optional, Union + +import torch +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from .....utils import INSEQ_ARTIFACTS_CACHE, cache_results, is_nltk_available +from .....utils.typing import IdsTensor + +logger = logging.getLogger(__name__) + + +class TokenSampler(ABC): + """Base class for token samplers""" + + @abstractmethod + def __call__(self, input: IdsTensor, **kwargs) -> IdsTensor: + """Sample tokens according to the specified strategy.""" + pass + + +class POSTagTokenSampler(TokenSampler): + """Sample tokens from Uniform distribution on a set of words with the same POS tag.""" + + def __init__( + self, + tokenizer: Union[str, PreTrainedTokenizerBase], + identifier: str = "pos_tag_sampler", + save_cache: bool = True, + overwrite_cache: bool = False, + cache_dir: Path = INSEQ_ARTIFACTS_CACHE / "pos_tag_sampler_cache", + device: Optional[str] = None, + tokenizer_kwargs: Optional[dict[str, Any]] = {}, + ) -> None: + if isinstance(tokenizer, PreTrainedTokenizerBase): + self.tokenizer = tokenizer + else: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs) + cache_filename = cache_dir / f"{identifier.split('/')[-1]}.pkl" + self.pos2ids = self.build_pos_mapping_from_vocab( + cache_dir, + cache_filename, + save_cache, + overwrite_cache, + tokenizer=self.tokenizer, + ) + num_postags = len(self.pos2ids) + self.id2pos = torch.zeros([self.tokenizer.vocab_size], dtype=torch.long, device=device) + for pos_idx, ids in enumerate(self.pos2ids.values()): + self.id2pos[ids] = pos_idx + self.num_ids_per_pos = torch.tensor( + [len(ids) for ids in self.pos2ids.values()], dtype=torch.long, device=device + ) + self.offsets = torch.sum( + torch.tril(torch.ones([num_postags, num_postags], device=device), diagonal=-1) * self.num_ids_per_pos, + dim=-1, + ) + self.compact_idx = torch.cat( + tuple(torch.tensor(v, dtype=torch.long, device=device) for v in self.pos2ids.values()) + ) + + @staticmethod + @cache_results + def build_pos_mapping_from_vocab( + tokenizer: PreTrainedTokenizerBase, + log_every: int = 5000, + ) -> dict[str, list[int]]: + """Build mapping from POS tags to list of token ids from tokenizer's vocabulary.""" + if not is_nltk_available(): + raise ImportError("nltk is required to build POS tag mapping. Please install nltk.") + import nltk + + nltk.download("averaged_perceptron_tagger") + pos2ids = defaultdict(list) + for i in range(tokenizer.vocab_size): + word = tokenizer.decode([i]) + _, tag = nltk.pos_tag([word.strip()])[0] + pos2ids[tag].append(i) + if i % log_every == 0: + logger.info(f"Loading vocab from tokenizer - {i / tokenizer.vocab_size * 100:.2f}%") + return pos2ids + + def __call__(self, input_ids: IdsTensor) -> IdsTensor: + input_ids_pos = self.id2pos[input_ids] + sample_uniform = torch.rand(input_ids.shape, device=input_ids.device) + compact_group_idx = (sample_uniform * self.num_ids_per_pos[input_ids_pos] + self.offsets[input_ids_pos]).long() + return self.compact_idx[compact_group_idx] diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py index 138d4bfb..24a4aa68 100644 --- a/inseq/attr/feat/perturbation_attribution.py +++ b/inseq/attr/feat/perturbation_attribution.py @@ -10,7 +10,7 @@ from ...utils import Registry from .attribution_utils import get_source_target_attributions from .gradient_attribution import FeatureAttribution -from .ops import Lime, ReAGent +from .ops import Lime, Reagent logger = logging.getLogger(__name__) @@ -119,16 +119,22 @@ def attribute_step( ) -class ReAGentAttribution(PerturbationAttributionRegistry): - """ReAGent-based attribution method. - The main part of the code is in ops/reagent.py. +class ReagentAttribution(PerturbationAttributionRegistry): + """Recursive attribution generator (ReAGent) method. + + Measures importance as the drop in prediction probability produced by replacing a token with a plausible + alternative predicted by a LM. + + Reference implementation: + `ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models + `__ """ - method_name = "ReAGent" + method_name = "reagent" def __init__(self, attribution_model, **kwargs): super().__init__(attribution_model) - self.method = ReAGent(attribution_model=self.attribution_model, **kwargs) + self.method = Reagent(attribution_model=self.attribution_model, **kwargs) def attribute_step( self, diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index 29f81615..7dabecbb 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -13,6 +13,7 @@ is_datasets_available, is_ipywidgets_available, is_joblib_available, + is_nltk_available, is_scikitlearn_available, is_sentencepiece_available, is_transformers_available, @@ -94,6 +95,7 @@ "is_datasets_available", "is_captum_available", "is_joblib_available", + "is_nltk_available", "check_device", "get_default_device", "ndarray_to_bin_str", diff --git a/inseq/utils/import_utils.py b/inseq/utils/import_utils.py index cbd03420..2a1ccc2d 100644 --- a/inseq/utils/import_utils.py +++ b/inseq/utils/import_utils.py @@ -7,6 +7,7 @@ _datasets_available = find_spec("datasets") is not None _captum_available = find_spec("captum") is not None _joblib_available = find_spec("joblib") is not None +_nltk_available = find_spec("nltk") is not None def is_ipywidgets_available(): @@ -35,3 +36,7 @@ def is_captum_available(): def is_joblib_available(): return _joblib_available + + +def is_nltk_available(): + return _nltk_available