Skip to content

Commit

Permalink
Finished show_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Aug 8, 2024
1 parent 451af7c commit 20f963c
Show file tree
Hide file tree
Showing 7 changed files with 309 additions and 55 deletions.
1 change: 1 addition & 0 deletions inseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
merge_attributions,
show_attributions,
show_granular_attributions,
show_token_attributions,
)
from .models import AttributionModel, list_supported_frameworks, load_model, register_model_config
from .utils.id_utils import explain
Expand Down
3 changes: 2 additions & 1 deletion inseq/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
EncoderDecoderBatch,
slice_batch_from_position,
)
from .viz import show_attributions, show_granular_attributions
from .viz import show_attributions, show_granular_attributions, show_token_attributions

__all__ = [
"Aggregator",
Expand Down Expand Up @@ -59,6 +59,7 @@
"TextInput",
"show_attributions",
"show_granular_attributions",
"show_token_attributions",
"list_aggregation_functions",
"MultiDimensionalFeatureAttributionStepOutput",
"get_batch_from_inputs",
Expand Down
126 changes: 114 additions & 12 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from .aggregator import AggregableMixin, Aggregator, AggregatorPipeline
from .batch import Batch, BatchEmbedding, BatchEncoding, DecoderOnlyBatch, EncoderDecoderBatch
from .data_utils import TensorWrapper
from .viz import get_saliency_heatmap_treescope
from .viz import get_saliency_heatmap_treescope, get_tokens_heatmap_treescope

if TYPE_CHECKING:
from ..models import AttributionModel
Expand Down Expand Up @@ -235,12 +235,23 @@ def granular_attribution_visualizer(
else:
return treescope.IPythonVisualization(
treescope.figures.inline(
adapter.get_array_summary(value, fast=False),
treescope.render_array(
value,
axis_labels={0: f"Generated Tokens: {value.shape[0]}"},
axis_item_labels={0: column_labels},
adapter.get_array_summary(value, fast=False) + "\n\n",
treescope.figures.figure_from_treescope_rendering_part(
treescope.rendering_parts.indented_children(
[
get_tokens_heatmap_treescope(
tokens=column_labels,
scores=value.numpy(),
max_val=value.max().item(),
)
]
)
),
# treescope.render_array(
# value,
# axis_labels={0: f"Generated Tokens: {value.shape[0]}"},
# axis_item_labels={0: column_labels},
# ),
),
replace=True,
)
Expand Down Expand Up @@ -431,6 +442,7 @@ def show(
slice_dims: dict[int | str, tuple[int, int]] | None = None,
display: bool = True,
return_html: bool | None = False,
return_figure: bool = False,
aggregator: AggregatorPipeline | type[Aggregator] = None,
do_aggregation: bool = True,
**kwargs,
Expand Down Expand Up @@ -460,6 +472,8 @@ def show(
for later use.
return_html (:obj:`bool`, *optional*, defaults to False):
Whether to return the HTML code of the visualization.
return_figure (:obj:`bool`, *optional*, defaults to False):
For granular visualization, whether to return the Treescope figure object for further manipulation.
aggregator (:obj:`AggregatorPipeline`, *optional*, defaults to None):
Aggregates attributions before visualizing them. If not specified, the default aggregator for the class
is used.
Expand Down Expand Up @@ -496,6 +510,7 @@ def show(
show_dim=show_dim,
display=display,
return_html=return_html,
return_figure=return_figure,
slice_dims=slice_dims,
)

Expand All @@ -508,6 +523,7 @@ def show_granular(
slice_dims: dict[int | str, tuple[int, int]] | None = None,
display: bool = True,
return_html: bool | None = False,
return_figure: bool = False,
) -> str | None:
from inseq import show_granular_attributions

Expand All @@ -520,6 +536,36 @@ def show_granular(
slice_dims=slice_dims,
display=display,
return_html=return_html,
return_figure=return_figure,
)

def show_tokens(
self,
min_val: int | None = None,
max_val: int | None = None,
display: bool = True,
return_html: bool | None = False,
return_figure: bool = False,
replace_char: dict[str, str] | None = None,
wrap_after: int | str | list[str] | tuple[str] | None = None,
step_score_highlight: str | None = None,
aggregator: AggregatorPipeline | type[Aggregator] = None,
do_aggregation: bool = True,
**kwargs,
) -> str | None:
from inseq import show_token_attributions

aggregated = self.aggregate(aggregator, **kwargs) if do_aggregation else self
return show_token_attributions(
attributions=aggregated,
min_val=min_val,
max_val=max_val,
display=display,
return_html=return_html,
return_figure=return_figure,
replace_char=replace_char,
wrap_after=wrap_after,
step_score_highlight=step_score_highlight,
)

@property
Expand Down Expand Up @@ -859,10 +905,11 @@ def show(
slice_dims: dict[int | str, tuple[int, int]] | None = None,
display: bool = True,
return_html: bool | None = False,
return_figure: bool = False,
aggregator: AggregatorPipeline | type[Aggregator] = None,
do_aggregation: bool = True,
**kwargs,
) -> str | None:
) -> str | list | None:
"""Visualize the sequence attributions.
Args:
Expand All @@ -873,6 +920,7 @@ def show(
slice_dims (dict[int or str, tuple[int, int]], optional): Dimensions to slice.
display (bool, optional): If True, display the attribution visualization.
return_html (bool, optional): If True, return the attribution visualization as HTML.
return_figure (bool, optional): If True, return the Treescope figure object for further manipulation.
aggregator (:obj:`AggregatorPipeline` or :obj:`Type[Aggregator]`, optional): Aggregator
or pipeline to use. If not provided, the default aggregator for every sequence attribution
is used.
Expand All @@ -881,26 +929,35 @@ def show(
attributions are already aggregated.
Returns:
str: Attribution visualization as HTML if `return_html=True`, None otherwise.
str: Attribution visualization as HTML if `return_html=True`
list: List of Treescope figure objects if `return_figure=True`
None if `return_html=False` and `return_figure=False`
"""
out_str = ""
out_figs = []
for attr in self.sequence_attributions:
curr_out_str = attr.show(
curr_out = attr.show(
min_val=min_val,
max_val=max_val,
max_show_size=max_show_size,
show_dim=show_dim,
slice_dims=slice_dims,
display=display,
return_html=return_html,
return_figure=return_figure,
aggregator=aggregator,
do_aggregation=do_aggregation,
**kwargs,
)
if return_html:
out_str += curr_out_str
out_str += curr_out
if return_figure:
out_figs.append(curr_out)
if return_html:
return out_str
if return_figure:
return out_figs

def show_granular(
self,
Expand All @@ -911,10 +968,12 @@ def show_granular(
slice_dims: dict[int | str, tuple[int, int]] | None = None,
display: bool = True,
return_html: bool = False,
return_figure: bool = False,
) -> str | None:
out_str = ""
out_figs = []
for attr in self.sequence_attributions:
curr_out_str = attr.show_granular(
curr_out = attr.show_granular(
min_val=min_val,
max_val=max_val,
max_show_size=max_show_size,
Expand All @@ -924,9 +983,52 @@ def show_granular(
return_html=return_html,
)
if return_html:
out_str += curr_out_str
out_str += curr_out
if return_figure:
out_figs.append(curr_out)
if return_html:
return out_str
if return_figure:
return out_figs

def show_tokens(
self,
min_val: int | None = None,
max_val: int | None = None,
display: bool = True,
return_html: bool = False,
return_figure: bool = False,
replace_char: dict[str, str] | None = None,
wrap_after: int | str | list[str] | tuple[str] | None = None,
step_score_highlight: str | None = None,
aggregator: AggregatorPipeline | type[Aggregator] = None,
do_aggregation: bool = True,
**kwargs,
) -> str | None:
out_str = ""
out_figs = []
for attr in self.sequence_attributions:
curr_out = attr.show_tokens(
min_val=min_val,
max_val=max_val,
display=display,
return_html=return_html,
return_figure=return_figure,
replace_char=replace_char,
wrap_after=wrap_after,
step_score_highlight=step_score_highlight,
aggregator=aggregator,
do_aggregation=do_aggregation,
**kwargs,
)
if return_html:
out_str += curr_out
if return_figure:
out_figs.append(curr_out)
if return_html:
return out_str
if return_figure:
return out_figs

def weight_attributions(self, step_score_id: str):
for i, attr in enumerate(self.sequence_attributions):
Expand Down
Loading

0 comments on commit 20f963c

Please sign in to comment.