Skip to content

Commit

Permalink
Added caching for POSTagTokenSampler, minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Feb 19, 2024
1 parent 06d5f60 commit 82d6e93
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 25 deletions.
4 changes: 2 additions & 2 deletions inseq/attr/feat/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .discretized_integrated_gradients import DiscretetizedIntegratedGradients
from .lime import Lime
from .monotonic_path_builder import MonotonicPathBuilder
from .reagent import ReAGent
from .reagent import Reagent
from .sequential_integrated_gradients import SequentialIntegratedGradients

__all__ = [
"DiscretetizedIntegratedGradients",
"MonotonicPathBuilder",
"Lime",
"ReAGent",
"Reagent",
"SequentialIntegratedGradients",
]
29 changes: 18 additions & 11 deletions inseq/attr/feat/ops/reagent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Union
from typing import TYPE_CHECKING, Any, Union

import torch
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
Expand All @@ -10,12 +10,21 @@
from .reagent_core.importance_score_evaluator.delta_prob import DeltaProbImportanceScoreEvaluator
from .reagent_core.stopping_condition_evaluator.top_k import TopKStoppingConditionEvaluator
from .reagent_core.token_replacement.token_replacer.uniform import UniformTokenReplacer
from .reagent_core.token_replacement.token_sampler.postag import POSTagTokenSampler
from .reagent_core.token_sampler import POSTagTokenSampler

if TYPE_CHECKING:
from ....models import HuggingfaceModel

class ReAGent(PerturbationAttribution):
r"""
ReAGent

class Reagent(PerturbationAttribution):
r"""Recursive attribution generator (ReAGent) method.
Measures importance as the drop in prediction probability produced by replacing a token with a plausible
alternative predicted by a LM.
Reference implementation:
`ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models
<https://arxiv.org/abs/2402.00794>`__
Args:
forward_func (callable): The forward function of the model or any
Expand All @@ -28,10 +37,6 @@ class ReAGent(PerturbationAttribution):
max_probe_steps (int): max_probe_steps
num_probes (int): number of probes in parallel
References:
`ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models
<https://arxiv.org/abs/2402.00794>`_
Examples:
```
import inseq
Expand All @@ -52,7 +57,7 @@ class ReAGent(PerturbationAttribution):

def __init__(
self,
attribution_model: Callable,
attribution_model: "HuggingfaceModel",
rational_size: int = 5,
rational_size_ratio: float = None,
stopping_condition_top_k: int = 3,
Expand All @@ -65,7 +70,9 @@ def __init__(
model = attribution_model.model
tokenizer = attribution_model.tokenizer

token_sampler = POSTagTokenSampler(tokenizer=tokenizer, device=model.device)
token_sampler = POSTagTokenSampler(
tokenizer=tokenizer, identifier=attribution_model.model_name, device=attribution_model.device
)

stopping_condition_evaluator = TopKStoppingConditionEvaluator(
model=model,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]:
"""
super().sample(input)

token_sampled = self.token_sampler.sample(input)
token_sampled = self.token_sampler(input)

input_replaced = input * ~self.mask_replacing + token_sampled * self.mask_replacing

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def sample(self, input: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]:
sample_uniform = torch.rand(input.shape, device=input.device)
mask_replacing = sample_uniform < self.ratio

token_sampled = self.token_sampler.sample(input)
token_sampled = self.token_sampler(input)

input_replaced = input * ~mask_replacing + token_sampled * mask_replacing

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from transformers import AutoModelWithLMHead, AutoTokenizer
from typing_extensions import override

from .base import TokenSampler


class InferentialMTokenSampler(TokenSampler):
"""Sample tokens from a seq-2-seq model"""
Expand Down Expand Up @@ -37,8 +35,6 @@ def sample(self, inputs: torch.Tensor) -> torch.Tensor:
token_inferences: sampled (placement) tokens by inference
"""
super().sample(inputs)

batch_li = []
for seq_i in torch.arange(inputs.shape[0]):
seq_li = []
Expand Down
90 changes: 90 additions & 0 deletions inseq/attr/feat/ops/reagent_core/token_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from pathlib import Path
from typing import Any, Optional, Union

import torch
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from .....utils import INSEQ_ARTIFACTS_CACHE, cache_results, is_nltk_available
from .....utils.typing import IdsTensor

logger = logging.getLogger(__name__)


class TokenSampler(ABC):
"""Base class for token samplers"""

@abstractmethod
def __call__(self, input: IdsTensor, **kwargs) -> IdsTensor:
"""Sample tokens according to the specified strategy."""
pass


class POSTagTokenSampler(TokenSampler):
"""Sample tokens from Uniform distribution on a set of words with the same POS tag."""

def __init__(
self,
tokenizer: Union[str, PreTrainedTokenizerBase],
identifier: str = "pos_tag_sampler",
save_cache: bool = True,
overwrite_cache: bool = False,
cache_dir: Path = INSEQ_ARTIFACTS_CACHE / "pos_tag_sampler_cache",
device: Optional[str] = None,
tokenizer_kwargs: Optional[dict[str, Any]] = {},
) -> None:
if isinstance(tokenizer, PreTrainedTokenizerBase):
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)
cache_filename = cache_dir / f"{identifier.split('/')[-1]}.pkl"
self.pos2ids = self.build_pos_mapping_from_vocab(
cache_dir,
cache_filename,
save_cache,
overwrite_cache,
tokenizer=self.tokenizer,
)
num_postags = len(self.pos2ids)
self.id2pos = torch.zeros([self.tokenizer.vocab_size], dtype=torch.long, device=device)
for pos_idx, ids in enumerate(self.pos2ids.values()):
self.id2pos[ids] = pos_idx
self.num_ids_per_pos = torch.tensor(
[len(ids) for ids in self.pos2ids.values()], dtype=torch.long, device=device
)
self.offsets = torch.sum(
torch.tril(torch.ones([num_postags, num_postags], device=device), diagonal=-1) * self.num_ids_per_pos,
dim=-1,
)
self.compact_idx = torch.cat(
tuple(torch.tensor(v, dtype=torch.long, device=device) for v in self.pos2ids.values())
)

@staticmethod
@cache_results
def build_pos_mapping_from_vocab(
tokenizer: PreTrainedTokenizerBase,
log_every: int = 5000,
) -> dict[str, list[int]]:
"""Build mapping from POS tags to list of token ids from tokenizer's vocabulary."""
if not is_nltk_available():
raise ImportError("nltk is required to build POS tag mapping. Please install nltk.")
import nltk

nltk.download("averaged_perceptron_tagger")
pos2ids = defaultdict(list)
for i in range(tokenizer.vocab_size):
word = tokenizer.decode([i])
_, tag = nltk.pos_tag([word.strip()])[0]
pos2ids[tag].append(i)
if i % log_every == 0:
logger.info(f"Loading vocab from tokenizer - {i / tokenizer.vocab_size * 100:.2f}%")
return pos2ids

def __call__(self, input_ids: IdsTensor) -> IdsTensor:
input_ids_pos = self.id2pos[input_ids]
sample_uniform = torch.rand(input_ids.shape, device=input_ids.device)
compact_group_idx = (sample_uniform * self.num_ids_per_pos[input_ids_pos] + self.offsets[input_ids_pos]).long()
return self.compact_idx[compact_group_idx]
18 changes: 12 additions & 6 deletions inseq/attr/feat/perturbation_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ...utils import Registry
from .attribution_utils import get_source_target_attributions
from .gradient_attribution import FeatureAttribution
from .ops import Lime, ReAGent
from .ops import Lime, Reagent

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -119,16 +119,22 @@ def attribute_step(
)


class ReAGentAttribution(PerturbationAttributionRegistry):
"""ReAGent-based attribution method.
The main part of the code is in ops/reagent.py.
class ReagentAttribution(PerturbationAttributionRegistry):
"""Recursive attribution generator (ReAGent) method.
Measures importance as the drop in prediction probability produced by replacing a token with a plausible
alternative predicted by a LM.
Reference implementation:
`ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models
<https://arxiv.org/abs/2402.00794>`__
"""

method_name = "ReAGent"
method_name = "reagent"

def __init__(self, attribution_model, **kwargs):
super().__init__(attribution_model)
self.method = ReAGent(attribution_model=self.attribution_model, **kwargs)
self.method = Reagent(attribution_model=self.attribution_model, **kwargs)

def attribute_step(
self,
Expand Down
2 changes: 2 additions & 0 deletions inseq/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
is_datasets_available,
is_ipywidgets_available,
is_joblib_available,
is_nltk_available,
is_scikitlearn_available,
is_sentencepiece_available,
is_transformers_available,
Expand Down Expand Up @@ -94,6 +95,7 @@
"is_datasets_available",
"is_captum_available",
"is_joblib_available",
"is_nltk_available",
"check_device",
"get_default_device",
"ndarray_to_bin_str",
Expand Down
5 changes: 5 additions & 0 deletions inseq/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
_datasets_available = find_spec("datasets") is not None
_captum_available = find_spec("captum") is not None
_joblib_available = find_spec("joblib") is not None
_nltk_available = find_spec("nltk") is not None


def is_ipywidgets_available():
Expand Down Expand Up @@ -35,3 +36,7 @@ def is_captum_available():

def is_joblib_available():
return _joblib_available


def is_nltk_available():
return _nltk_available

0 comments on commit 82d6e93

Please sign in to comment.