Skip to content

Commit

Permalink
Merge pull request #44 from chanind/pytest-setup
Browse files Browse the repository at this point in the history
chore: setting up pytest
  • Loading branch information
callummcdougall authored Apr 24, 2024
2 parents d759ef0 + 2079d00 commit 034eefa
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ jobs:
run: poetry run ruff format --check
- name: check types
run: poetry run pyright
- name: test
run: poetry run pytest
- name: build
run: poetry build

Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ lint:
poetry run ruff format --check .
poetry run pyright .

test:
poetry run pytest

check-all:
make format
make lint
make test
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ To cite this work, you can use this bibtex citation:

This project is uses [Poetry](https://python-poetry.org/) for dependency management. After cloning the repo, install dependencies with `poetry install`.

This project uses [Ruff](https://docs.astral.sh/ruff/) for formatting and linting, and [Pyright](https://github.com/microsoft/pyright) for type-checking. If you submit a PR, make sure that your code passes all checks. You can run all check with `make check-all`.
This project uses [Ruff](https://docs.astral.sh/ruff/) for formatting and linting, [Pyright](https://github.com/microsoft/pyright) for type-checking, and [Pytest](https://docs.pytest.org/) for tests. If you submit a PR, make sure that your code passes all checks. You can run all checks with `make check-all`.

# Version history (recording started at `0.2.9`)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ reportConstantRedefinition = "none"
reportUnknownLambdaType = "none"
reportUnknownParameterType = "none"
reportPrivateUsage = "none"
reportPrivateImportUsage = "none"

[build-system]
requires = ["poetry-core"]
Expand Down
33 changes: 33 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
import torch
from transformer_lens import HookedTransformer

from sae_vis.model_fns import AutoEncoder, AutoEncoderConfig


@pytest.fixture
def model() -> HookedTransformer:
model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu")
model.eval()
return model


@pytest.fixture
def autoencoder() -> AutoEncoder:
cfg = AutoEncoderConfig(d_in=64, dict_mult=2)
autoencoder = AutoEncoder(cfg)
# set weights and biases to hardcoded values so tests are consistent
seed1 = torch.tensor([0.1, -0.2, 0.3, -0.4] * 16) # 64
seed2 = torch.tensor([0.2, -0.1, 0.4, -0.2] * 32) # 64 x 2
seed3 = torch.tensor([0.3, -0.3, 0.6, -0.6] * 16) # 64
seed4 = torch.tensor([-0.4, 0.4, 0.8, -0.8] * 32) # 64 x 2
autoencoder.load_state_dict(
{
"W_enc": torch.outer(seed1, seed2),
"W_dec": torch.outer(seed4, seed3),
"b_enc": torch.zeros_like(autoencoder.b_enc) + 0.5,
"b_dec": torch.zeros_like(autoencoder.b_dec) + 0.3,
}
)

return AutoEncoder(cfg)
89 changes: 89 additions & 0 deletions tests/test_data_storing_fns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import json
from pathlib import Path

from transformer_lens import HookedTransformer

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData
from sae_vis.model_fns import AutoEncoder

ROOT_DIR = Path(__file__).parent.parent


def test_SaeVisData_create_results_look_reasonable(
model: HookedTransformer, autoencoder: AutoEncoder
):
cfg = SaeVisConfig(hook_point="blocks.2.hook_resid_pre", minibatch_size_tokens=2)
tokens = model.to_tokens(
[
"But what about second breakfast?" * 3,
"Nothing is cheesier than cheese." * 3,
]
)
data = SaeVisData.create(encoder=autoencoder, model=model, tokens=tokens, cfg=cfg)

assert data.encoder == autoencoder
assert data.model == model
assert data.cfg == cfg
# kurtosis and skew are both empty, is this itentional?
assert len(data.feature_stats.max) == 128
assert len(data.feature_stats.frac_nonzero) == 128
assert len(data.feature_stats.quantile_data) == 128
assert len(data.feature_stats.quantiles) > 1000
for val in data.feature_stats.max:
assert val >= 0
for val in data.feature_stats.frac_nonzero:
assert 0 <= val <= 1
for prev_val, next_val in zip(
data.feature_stats.quantiles[:-1], data.feature_stats.quantiles[1:]
):
assert prev_val <= next_val
for bounds, prec in data.feature_stats.ranges_and_precisions:
assert len(bounds) == 2
assert bounds[0] <= bounds[1]
assert prec > 0
# each feature should get its own key
assert set(data.feature_data_dict.keys()) == set(range(128))


def test_SaeVisData_create_and_save_feature_centric_vis(
model: HookedTransformer,
autoencoder: AutoEncoder,
tmp_path: Path,
):
cfg = SaeVisConfig(hook_point="blocks.2.hook_resid_pre", minibatch_size_tokens=2)
tokens = model.to_tokens(
[
"But what about second breakfast?" * 3,
"Nothing is cheesier than cheese." * 3,
]
)
data = SaeVisData.create(encoder=autoencoder, model=model, tokens=tokens, cfg=cfg)
save_path = tmp_path / "feature_centric_vis.html"
data.save_feature_centric_vis(save_path)
assert (save_path).exists()
with open(save_path) as f:
html_contents = f.read()

# all the CSS should be in the HTML
css_files = (ROOT_DIR / "sae_vis" / "css").glob("*.css")
assert len(list(css_files)) > 0
for css_file in css_files:
with open(css_file) as f:
assert f.read() in html_contents

# all the JS should be in the HTML
js_files = (ROOT_DIR / "sae_vis" / "js").glob("*.js")
assert len(list(js_files)) > 0
for js_file in js_files:
with open(js_file) as f:
assert f.read() in html_contents

# all the HTML templates should be in the HTML
html_files = (ROOT_DIR / "sae_vis" / "html").glob("*.html")
assert len(list(html_files)) > 0
for html_file in html_files:
with open(html_file) as f:
assert f.read() in html_contents

assert json.dumps(data.feature_stats.aggdata) in html_contents

0 comments on commit 034eefa

Please sign in to comment.