Skip to content

Commit

Permalink
Add treescope requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Aug 1, 2024
1 parent e5b835b commit e03ba89
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

## 🚀 Features

- Added [treescope](https://github.com/google-deepmind/treescope) for model and tensor visualization.

- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM`, `Gemma2ForCausalLM` to model config.

- Add `rescale_attributions` to Inseq CLI commands for `rescale=True` ([#280](https://github.com/inseq-team/inseq/pull/280)).
Expand Down
4 changes: 2 additions & 2 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ def from_step_attributions(
pos_start.append(curr_pos_start)
source = tokenized_target_sentences[seq_idx][:curr_pos_start] if not sources else sources[seq_idx]
curr_seq_attribution: FeatureAttributionSequenceOutput = attr.get_sequence_cls(
source=source,
target=tokenized_target_sentences[seq_idx],
source=deepcopy(source),
target=deepcopy(tokenized_target_sentences[seq_idx]),
attr_pos_start=pos_start[seq_idx],
attr_pos_end=attr_pos_end,
)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ dependencies = [
"torch>=2.0",
"matplotlib>=3.5.3",
"tqdm>=4.64.0",
"treescope>=0.1.0",
"nvidia-cublas-cu11>=11.10.3.66; sys_platform=='Linux'",
"nvidia-cuda-cupti-cu11>=11.7.101; sys_platform=='Linux'",
"nvidia-cuda-nvrtc-cu11>=11.7.99; sys_platform=='Linux'",
Expand Down
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ numpy==1.26.4
# scikit-learn
# scipy
# transformers
# treescope
packaging==23.2
# via
# datasets
Expand Down Expand Up @@ -402,6 +403,8 @@ traitlets==5.14.1
# matplotlib-inline
transformers==4.38.1
# via inseq (pyproject.toml)
treescope==0.1.0
# via inseq (pyproject.toml)
typeguard==2.13.3
# via
# inseq (pyproject.toml)
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ numpy==1.26.4
# jaxtyping
# matplotlib
# transformers
# treescope
packaging==23.2
# via
# huggingface-hub
Expand Down Expand Up @@ -106,6 +107,8 @@ tqdm==4.66.4
# transformers
transformers==4.38.1
# via inseq (pyproject.toml)
treescope==0.1.0
# via inseq (pyproject.toml)
typeguard==2.13.3
# via
# inseq (pyproject.toml)
Expand Down

0 comments on commit e03ba89

Please sign in to comment.