From 0059c3daa4d60a28a00b4ff87f7dc1a3e3b24403 Mon Sep 17 00:00:00 2001 From: Alan <41682961+alan-cooney@users.noreply.github.com> Date: Wed, 15 Nov 2023 23:46:51 -0800 Subject: [PATCH] Move towards abstract final pattern (#74) --- .vscode/cspell.json | 3 + poetry.lock | 19 +- pyproject.toml | 2 + sparse_autoencoder/__init__.py | 20 +- .../abstract_activation_resampler.py | 47 ++++ .../activation_store/__init__.py | 4 +- .../activation_store/base_store.py | 32 +-- .../activation_store/disk_store.py | 14 +- .../activation_store/list_store.py | 20 +- .../activation_store/tensor_store.py | 24 +- .../activation_store/utils/extend_resize.py | 19 +- .../utils/tests/test_extend_resize.py | 10 +- .../components/tests/test_tied_bias.py | 5 +- .../components/tests/test_unit_norm_linear.py | 59 ++++- .../autoencoder/components/tied_bias.py | 15 +- .../components/unit_norm_linear.py | 113 +++++----- sparse_autoencoder/autoencoder/loss.py | 90 -------- sparse_autoencoder/autoencoder/model.py | 19 +- .../autoencoder/tests/test_loss.py | 37 ---- sparse_autoencoder/loss/__init__.py | 23 ++ sparse_autoencoder/loss/abstract_loss.py | 163 ++++++++++++++ .../loss/learned_activations_l1.py | 68 ++++++ .../loss/mse_reconstruction_loss.py | 55 +++++ sparse_autoencoder/loss/reducer.py | 102 +++++++++ .../loss/tests/test_abstract_loss.py | 75 +++++++ .../loss/tests/test_learned_activations_l1.py | 63 ++++++ .../tests/test_mse_reconstruction_loss.py | 54 +++++ .../test_towards_monosemanticity_loss.py | 43 ++++ sparse_autoencoder/metrics/abstract_metric.py | 93 ++++++++ .../optimizer/abstract_optimizer.py | 15 ++ .../optimizer/adam_with_reset.py | 10 +- .../source_data/abstract_dataset.py | 6 +- .../src_model/store_activations_hook.py | 9 +- .../tests/test_store_activations_hook.py | 3 +- sparse_autoencoder/tensor_types.py | 205 ++++++++++++++++++ sparse_autoencoder/train/abstract_pipeline.py | 134 ++++++++++++ .../train/generate_activations.py | 9 +- sparse_autoencoder/train/metrics/capacity.py | 8 +- .../train/metrics/feature_density.py | 12 +- .../train/metrics/tests/test_capacities.py | 5 +- sparse_autoencoder/train/pipeline.py | 12 +- sparse_autoencoder/train/resample_neurons.py | 88 ++++---- .../train/tests/test_resample_neurons.py | 14 +- sparse_autoencoder/train/train_autoencoder.py | 42 ++-- 44 files changed, 1476 insertions(+), 387 deletions(-) create mode 100644 sparse_autoencoder/activation_resampler/abstract_activation_resampler.py delete mode 100644 sparse_autoencoder/autoencoder/loss.py delete mode 100644 sparse_autoencoder/autoencoder/tests/test_loss.py create mode 100644 sparse_autoencoder/loss/__init__.py create mode 100644 sparse_autoencoder/loss/abstract_loss.py create mode 100644 sparse_autoencoder/loss/learned_activations_l1.py create mode 100644 sparse_autoencoder/loss/mse_reconstruction_loss.py create mode 100644 sparse_autoencoder/loss/reducer.py create mode 100644 sparse_autoencoder/loss/tests/test_abstract_loss.py create mode 100644 sparse_autoencoder/loss/tests/test_learned_activations_l1.py create mode 100644 sparse_autoencoder/loss/tests/test_mse_reconstruction_loss.py create mode 100644 sparse_autoencoder/loss/tests/test_towards_monosemanticity_loss.py create mode 100644 sparse_autoencoder/metrics/abstract_metric.py create mode 100644 sparse_autoencoder/optimizer/abstract_optimizer.py create mode 100644 sparse_autoencoder/tensor_types.py create mode 100644 sparse_autoencoder/train/abstract_pipeline.py diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 455dfa89..311054e5 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -21,6 +21,7 @@ "docstrings", "donjayamanne", "dtype", + "dtypes", "dunder", "earlyterminate", "einops", @@ -36,6 +37,7 @@ "intuniform", "invloguniform", "invloguniformvalues", + "itemwise", "jaxtyping", "kaiming", "keepdim", @@ -71,6 +73,7 @@ "randn", "randperm", "relu", + "resampler", "resid", "rtol", "runcap", diff --git a/poetry.lock b/poetry.lock index acacd583..c2f0ddf1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2119,7 +2119,6 @@ optional = false python-versions = ">=3" files = [ {file = "nvidia_cudnn_cu12-8.9.6.50-py3-none-manylinux1_x86_64.whl", hash = "sha256:02fbdf6a9f00ba88da68b275d1f175111cc01bbe2294fb688cd309fd61af8844"}, - {file = "nvidia_cudnn_cu12-8.9.6.50-py3-none-win_amd64.whl", hash = "sha256:acfc4447a9345e8ba525e3b0641ee64bdfd35189ab9904241814ff991792f77a"}, ] [package.dependencies] @@ -3612,6 +3611,22 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "strenum" +version = "0.4.15" +description = "An Enum that inherits from str." +optional = false +python-versions = "*" +files = [ + {file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"}, + {file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"}, +] + +[package.extras] +docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"] +release = ["twine"] +test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"] + [[package]] name = "sympy" version = "1.12" @@ -4453,4 +4468,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10, <3.12" -content-hash = "fe9961f7d226ecb7bcde8ad4b3cb6ce4d06635138b842841d3fbc4c9aaf6dd0b" +content-hash = "4d4c458f649e9618b43a794543a00cfafddafc089a4bd68287bc9a769484149e" diff --git a/pyproject.toml b/pyproject.toml index 47a3dcfe..b5db06c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ [tool.poetry.dependencies] einops=">=0.6" python=">=3.10, <3.12" + strenum="^0.4.15" torch=">=2.1" wandb=">=0.15.12" zstandard="^0.22.0" # Required for downloading datasets such as The Pile @@ -125,6 +126,7 @@ "ARG001", # Fixtures often have unused arguments "PT004", # Fixtures don't return anything "S101", # Assert is needed in PyTest + "TCH001", # Don't need to mark type-only imports ] [tool.ruff.lint.pydocstyle] diff --git a/sparse_autoencoder/__init__.py b/sparse_autoencoder/__init__.py index f70eebae..d3ef4e92 100644 --- a/sparse_autoencoder/__init__.py +++ b/sparse_autoencoder/__init__.py @@ -1,23 +1,33 @@ """Sparse Autoencoder Library.""" from sparse_autoencoder.activation_store import ( ActivationStore, - ActivationStoreBatch, - ActivationStoreItem, DiskActivationStore, ListActivationStore, TensorActivationStore, ) from sparse_autoencoder.autoencoder.model import SparseAutoencoder +from sparse_autoencoder.loss import ( + AbstractLoss, + LearnedActivationsL1Loss, + LossLogType, + LossReducer, + LossReductionType, + MSEReconstructionLoss, +) from sparse_autoencoder.train.pipeline import pipeline __all__ = [ + "AbstractLoss", "ActivationStore", - "ActivationStoreBatch", - "ActivationStoreItem", "DiskActivationStore", + "LearnedActivationsL1Loss", "ListActivationStore", - "TensorActivationStore", + "LossLogType", + "LossReducer", + "LossReductionType", + "MSEReconstructionLoss", "SparseAutoencoder", + "TensorActivationStore", "pipeline", ] diff --git a/sparse_autoencoder/activation_resampler/abstract_activation_resampler.py b/sparse_autoencoder/activation_resampler/abstract_activation_resampler.py new file mode 100644 index 00000000..ef6fc57f --- /dev/null +++ b/sparse_autoencoder/activation_resampler/abstract_activation_resampler.py @@ -0,0 +1,47 @@ +"""Abstract activation resampler.""" + +from abc import ABC, abstractmethod + +from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore +from sparse_autoencoder.tensor_types import ( + DeadDecoderNeuronWeightUpdates, + DeadEncoderNeuronBiasUpdates, + DeadEncoderNeuronWeightUpdates, + NeuronActivity, +) + + +class AbstractActivationResampler(ABC): + """Abstract activation resampler.""" + + @abstractmethod + def resample_dead_neurons( + self, + neuron_activity: NeuronActivity, + store: TensorActivationStore, + num_input_activations: int = 819_200, + ) -> tuple[ + DeadEncoderNeuronWeightUpdates, DeadEncoderNeuronBiasUpdates, DeadDecoderNeuronWeightUpdates + ]: + """Resample dead neurons. + + Over the course of training, a subset of autoencoder neurons will have zero activity across + a large number of datapoints. The authors of *Towards Monosemanticity: Decomposing Language + Models With Dictionary Learning* found that “resampling” these dead neurons during training + improves the number of likely-interpretable features (i.e., those in the high density + cluster) and reduces total loss. This resampling may be compatible with the Lottery Ticket + Hypothesis and increase the number of chances the network has to find promising feature + directions. + + Warning: + The optimizer should be reset after applying this function, as the Adam state will be + incorrect for the modified weights and biases. + + Args: + neuron_activity: Number of times each neuron fired. store: Activation store. + store: TODO change. + num_input_activations: Number of input activations to use when resampling. Will be + rounded down to be divisible by the batch size, and cannot be larger than the number + of items currently in the store. + """ + raise NotImplementedError diff --git a/sparse_autoencoder/activation_store/__init__.py b/sparse_autoencoder/activation_store/__init__.py index 0869d051..f48b385a 100644 --- a/sparse_autoencoder/activation_store/__init__.py +++ b/sparse_autoencoder/activation_store/__init__.py @@ -1,5 +1,5 @@ """Activation Stores.""" -from .base_store import ActivationStore, ActivationStoreBatch, ActivationStoreItem +from .base_store import ActivationStore from .disk_store import DiskActivationStore from .list_store import ListActivationStore from .tensor_store import TensorActivationStore @@ -7,8 +7,6 @@ _all__ = [ ActivationStore, - ActivationStoreBatch, - ActivationStoreItem, DiskActivationStore, ListActivationStore, TensorActivationStore, diff --git a/sparse_autoencoder/activation_store/base_store.py b/sparse_autoencoder/activation_store/base_store.py index 5061fd02..652d52f0 100644 --- a/sparse_autoencoder/activation_store/base_store.py +++ b/sparse_autoencoder/activation_store/base_store.py @@ -3,29 +3,13 @@ from concurrent.futures import Future from typing import final -from jaxtyping import Float import torch -from torch import Tensor from torch.utils.data import Dataset +from sparse_autoencoder.tensor_types import InputOutputActivationBatch, InputOutputActivationVector -ActivationStoreItem = Float[Tensor, "neuron"] -"""Activation Store Dataset Item Type. -A single vector containing activations. For example this could be the activations from a specific -MLP layer, for a specific position and batch item. -""" - -ActivationStoreBatch = Float[Tensor, "*any neuron"] -"""Activation Store Dataset Batch Type. - -This can be e.g. a [batch, pos, neurons] tensor, containing activations from a specific MLP layer -in a transformer. Alternatively, it could be e.g. a [batch, pos, head_idx, neurons] tensor from an -attention layer. -""" - - -class ActivationStore(Dataset[ActivationStoreItem], ABC): +class ActivationStore(Dataset[InputOutputActivationVector], ABC): """Activation Store Abstract Class. Extends the `torch.utils.data.Dataset` class to provide an activation store, with additional @@ -43,16 +27,16 @@ class ActivationStore(Dataset[ActivationStoreItem], ABC): ... super().__init__() ... self._data = [] # In this example, we just store in a list ... - ... def append(self, item: ActivationStoreItem) -> None: + ... def append(self, item) -> None: ... self._data.append(item) ... - ... def extend(self, batch: ActivationStoreBatch): + ... def extend(self, batch): ... self._data.extend(batch) ... ... def empty(self): ... self._data = [] ... - ... def __getitem__(self, index: int) -> ActivationStoreItem: + ... def __getitem__(self, index: int): ... return self._data[index] ... ... def __len__(self) -> int: @@ -65,12 +49,12 @@ class ActivationStore(Dataset[ActivationStoreItem], ABC): """ @abstractmethod - def append(self, item: ActivationStoreItem) -> Future | None: + def append(self, item: InputOutputActivationVector) -> Future | None: """Add a Single Item to the Store.""" raise NotImplementedError @abstractmethod - def extend(self, batch: ActivationStoreBatch) -> Future | None: + def extend(self, batch: InputOutputActivationBatch) -> Future | None: """Add a Batch to the Store.""" raise NotImplementedError @@ -85,7 +69,7 @@ def __len__(self) -> int: raise NotImplementedError @abstractmethod - def __getitem__(self, index: int) -> ActivationStoreItem: + def __getitem__(self, index: int) -> InputOutputActivationVector: """Get an Item from the Store.""" raise NotImplementedError diff --git a/sparse_autoencoder/activation_store/disk_store.py b/sparse_autoencoder/activation_store/disk_store.py index 05d36073..f052bf13 100644 --- a/sparse_autoencoder/activation_store/disk_store.py +++ b/sparse_autoencoder/activation_store/disk_store.py @@ -10,12 +10,14 @@ from sparse_autoencoder.activation_store.base_store import ( ActivationStore, - ActivationStoreBatch, - ActivationStoreItem, ) from sparse_autoencoder.activation_store.utils.extend_resize import ( resize_to_list_vectors, ) +from sparse_autoencoder.tensor_types import ( + InputOutputActivationVector, + SourceModelActivations, +) DEFAULT_DISK_ACTIVATION_STORE_PATH = Path(tempfile.gettempdir()) / "activation_store" @@ -132,7 +134,7 @@ def _write_to_disk(self, *, wait_for_max: bool = False) -> None: filename = f"{self.__len__}.pt" torch.save(stacked_activations, self._storage_path / filename) - def append(self, item: ActivationStoreItem) -> Future | None: + def append(self, item: InputOutputActivationVector) -> Future | None: """Add a Single Item to the Store. Example: @@ -158,7 +160,7 @@ def append(self, item: ActivationStoreItem) -> Future | None: return None # Keep mypy happy - def extend(self, batch: ActivationStoreBatch) -> Future | None: + def extend(self, batch: SourceModelActivations) -> Future | None: """Add a Batch to the Store. Example: @@ -175,7 +177,7 @@ def extend(self, batch: ActivationStoreBatch) -> Future | None: Future that completes when the activation vectors have queued to be written to disk, and if needed, written to disk. """ - items: list[ActivationStoreItem] = resize_to_list_vectors(batch) + items: list[InputOutputActivationVector] = resize_to_list_vectors(batch) with self._cache_lock: self._cache.extend(items) @@ -228,7 +230,7 @@ def empty(self) -> None: file.unlink() self._disk_n_activation_vectors.value = 0 - def __getitem__(self, index: int) -> ActivationStoreItem: + def __getitem__(self, index: int) -> InputOutputActivationVector: """Get Item Dunder Method. Args: diff --git a/sparse_autoencoder/activation_store/list_store.py b/sparse_autoencoder/activation_store/list_store.py index 6226da3d..52562940 100644 --- a/sparse_autoencoder/activation_store/list_store.py +++ b/sparse_autoencoder/activation_store/list_store.py @@ -9,12 +9,14 @@ from sparse_autoencoder.activation_store.base_store import ( ActivationStore, - ActivationStoreBatch, - ActivationStoreItem, ) from sparse_autoencoder.activation_store.utils.extend_resize import ( resize_to_list_vectors, ) +from sparse_autoencoder.tensor_types import ( + InputOutputActivationVector, + SourceModelActivations, +) class ListActivationStore(ActivationStore): @@ -69,7 +71,7 @@ class ListActivationStore(ActivationStore): torch.Size([2, 100]) """ - _data: list[ActivationStoreItem] | ListProxy + _data: list[InputOutputActivationVector] | ListProxy """Underlying List Data Store.""" _device: torch.device | None @@ -92,7 +94,7 @@ class ListActivationStore(ActivationStore): def __init__( self, - data: list[ActivationStoreItem] | None = None, + data: list[InputOutputActivationVector] | None = None, device: torch.device | None = None, max_workers: int | None = None, *, @@ -166,7 +168,7 @@ def __sizeof__(self) -> int: return total_tensors_size + list_of_pointers_size - def __getitem__(self, index: int) -> ActivationStoreItem: + def __getitem__(self, index: int) -> InputOutputActivationVector: """Get Item Dunder Method. Example: @@ -205,7 +207,7 @@ def shuffle(self) -> None: self.wait_for_writes_to_complete() random.shuffle(self._data) - def append(self, item: ActivationStoreItem) -> Future | None: + def append(self, item: InputOutputActivationVector) -> Future | None: """Append a single item to the dataset. Note **append is blocking**. For better performance use extend instead with batches. @@ -223,7 +225,7 @@ def append(self, item: ActivationStoreItem) -> Future | None: """ self._data.append(item.to(self._device)) - def _extend(self, batch: ActivationStoreBatch) -> None: + def _extend(self, batch: SourceModelActivations) -> None: """Extend threadpool method. To be called by :meth:`extend`. @@ -233,13 +235,13 @@ def _extend(self, batch: ActivationStoreBatch) -> None: """ try: # Unstack to a list of tensors - items: list[ActivationStoreItem] = resize_to_list_vectors(batch) + items: list[InputOutputActivationVector] = resize_to_list_vectors(batch) self._data.extend(items) except Exception as e: # noqa: BLE001 self._pool_exceptions.append(e) - def extend(self, batch: ActivationStoreBatch) -> Future | None: + def extend(self, batch: SourceModelActivations) -> Future | None: """Extend the dataset with multiple items (non-blocking). Example: diff --git a/sparse_autoencoder/activation_store/tensor_store.py b/sparse_autoencoder/activation_store/tensor_store.py index 09885d07..3b91c79f 100644 --- a/sparse_autoencoder/activation_store/tensor_store.py +++ b/sparse_autoencoder/activation_store/tensor_store.py @@ -1,21 +1,19 @@ """Tensor Activation Store.""" -from jaxtyping import Float import torch -from torch import Tensor from sparse_autoencoder.activation_store.base_store import ( ActivationStore, - ActivationStoreBatch, - ActivationStoreItem, StoreFullError, ) from sparse_autoencoder.activation_store.utils.extend_resize import ( resize_to_single_item_dimension, ) - - -TensorActivationStoreData = Float[Tensor, "item neuron"] -"""Tensor Activation Store Dataset Item Type.""" +from sparse_autoencoder.tensor_types import ( + InputOutputActivationBatch, + InputOutputActivationVector, + SourceModelActivations, + StoreActivations, +) class TensorActivationStore(ActivationStore): @@ -59,7 +57,7 @@ class TensorActivationStore(ActivationStore): torch.Size([2, 100]) """ - _data: TensorActivationStoreData + _data: StoreActivations """Underlying Tensor Data Store.""" items_stored: int = 0 @@ -112,7 +110,7 @@ def __sizeof__(self) -> int: """ return self._data.element_size() * self._data.nelement() - def __getitem__(self, index: int) -> ActivationStoreItem: + def __getitem__(self, index: int) -> InputOutputActivationVector: """Get Item Dunder Method. Example: @@ -161,7 +159,7 @@ def shuffle(self) -> None: # Use this permutation to shuffle the active data in-place self._data[: self.items_stored] = self._data[perm] - def append(self, item: ActivationStoreItem) -> None: + def append(self, item: InputOutputActivationVector) -> None: """Add a single item to the store. Example: @@ -187,7 +185,7 @@ def append(self, item: ActivationStoreItem) -> None: ) self.items_stored += 1 - def extend(self, batch: ActivationStoreBatch) -> None: + def extend(self, batch: SourceModelActivations) -> None: """Add a batch to the store. Examples: @@ -208,7 +206,7 @@ def extend(self, batch: ActivationStoreBatch) -> None: Raises: IndexError: If there is no space remaining. """ - reshaped: Float[Tensor, "subset_item neuron"] = resize_to_single_item_dimension( + reshaped: InputOutputActivationBatch = resize_to_single_item_dimension( batch, ) diff --git a/sparse_autoencoder/activation_store/utils/extend_resize.py b/sparse_autoencoder/activation_store/utils/extend_resize.py index d079e845..0127adfa 100644 --- a/sparse_autoencoder/activation_store/utils/extend_resize.py +++ b/sparse_autoencoder/activation_store/utils/extend_resize.py @@ -1,17 +1,16 @@ """Resize Tensors for Extend Methods.""" from einops import rearrange -from jaxtyping import Float -from torch import Tensor -from sparse_autoencoder.activation_store.base_store import ( - ActivationStoreBatch, - ActivationStoreItem, +from sparse_autoencoder.tensor_types import ( + InputOutputActivationBatch, + InputOutputActivationVector, + SourceModelActivations, ) def resize_to_list_vectors( - batched_tensor: ActivationStoreBatch, -) -> list[ActivationStoreItem]: + batched_tensor: SourceModelActivations, +) -> list[InputOutputActivationVector]: """Resize Extend List Vectors. Takes a tensor of activation vectors, with arbitrary numbers of dimensions (the last of which is @@ -46,7 +45,7 @@ def resize_to_list_vectors( Returns: List of Activation Store Item Vectors """ - rearranged: Float[Tensor, "batch neuron"] = rearrange( + rearranged: InputOutputActivationBatch = rearrange( batched_tensor, "... neurons -> (...) neurons", ) @@ -55,8 +54,8 @@ def resize_to_list_vectors( def resize_to_single_item_dimension( - batch_activations: ActivationStoreBatch, -) -> Float[Tensor, "item neuron"]: + batch_activations: SourceModelActivations, +) -> InputOutputActivationBatch: """Resize Extend Single Item Dimension. Takes a tensor of activation vectors, with arbitrary numbers of dimensions (the last of which is diff --git a/sparse_autoencoder/activation_store/utils/tests/test_extend_resize.py b/sparse_autoencoder/activation_store/utils/tests/test_extend_resize.py index 7bc2755a..14842d8b 100644 --- a/sparse_autoencoder/activation_store/utils/tests/test_extend_resize.py +++ b/sparse_autoencoder/activation_store/utils/tests/test_extend_resize.py @@ -2,11 +2,11 @@ import pytest import torch -from sparse_autoencoder.activation_store.base_store import ActivationStoreBatch from sparse_autoencoder.activation_store.utils.extend_resize import ( resize_to_list_vectors, resize_to_single_item_dimension, ) +from sparse_autoencoder.tensor_types import InputOutputActivationBatch class TestResizeListVectors: @@ -28,7 +28,7 @@ def test_resize_to_list_vectors( ) -> None: """Check each item's shape in the resulting list.""" input_tensor = torch.rand(input_shape) - result = resize_to_list_vectors(ActivationStoreBatch(input_tensor)) + result = resize_to_list_vectors(InputOutputActivationBatch(input_tensor)) assert len(result) == expected_len, f"Expected list of length {expected_len}" assert all( @@ -44,7 +44,7 @@ def test_resize_to_list_vectors_values(self) -> None: torch.tensor([5.0, 6]), torch.tensor([7.0, 8]), ] - result = resize_to_list_vectors(ActivationStoreBatch(input_tensor)) + result = resize_to_list_vectors(InputOutputActivationBatch(input_tensor)) for expected, output in zip(expected_output, result, strict=True): assert torch.all( @@ -70,7 +70,7 @@ def test_resize_to_single_item_dimension( ) -> None: """Check the resulting tensor shape.""" input_tensor = torch.randn(input_shape) - result = resize_to_single_item_dimension(ActivationStoreBatch(input_tensor)) + result = resize_to_single_item_dimension(InputOutputActivationBatch(input_tensor)) assert result.shape == expected_shape, f"Expected tensor shape {expected_shape}" @@ -78,7 +78,7 @@ def test_resize_to_single_item_dimension_values(self) -> None: """Check the resulting tensor values.""" input_tensor = torch.tensor([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]]) expected_output = torch.tensor([[1.0, 2], [3, 4], [5, 6], [7, 8]]) - result = resize_to_single_item_dimension(ActivationStoreBatch(input_tensor)) + result = resize_to_single_item_dimension(InputOutputActivationBatch(input_tensor)) assert torch.all( torch.eq(expected_output, result), diff --git a/sparse_autoencoder/autoencoder/components/tests/test_tied_bias.py b/sparse_autoencoder/autoencoder/components/tests/test_tied_bias.py index 4d43019f..9a537544 100644 --- a/sparse_autoencoder/autoencoder/components/tests/test_tied_bias.py +++ b/sparse_autoencoder/autoencoder/components/tests/test_tied_bias.py @@ -2,11 +2,12 @@ import torch from sparse_autoencoder.autoencoder.components.tied_bias import TiedBias, TiedBiasPosition +from sparse_autoencoder.tensor_types import InputOutputActivationBatch def test_pre_encoder_subtracts_bias() -> None: """Check that the pre-encoder bias subtracts the bias.""" - encoder_input = torch.tensor([5.0, 3.0, 1.0]) + encoder_input: InputOutputActivationBatch = torch.tensor([[5.0, 3.0, 1.0]]) bias = torch.tensor([2.0, 4.0, 6.0]) expected = encoder_input - bias @@ -18,7 +19,7 @@ def test_pre_encoder_subtracts_bias() -> None: def test_post_encoder_adds_bias() -> None: """Check that the post-encoder bias adds the bias.""" - decoder_output = torch.tensor([5.0, 3.0, 1.0]) + decoder_output: InputOutputActivationBatch = torch.tensor([[5.0, 3.0, 1.0]]) bias = torch.tensor([2.0, 4.0, 6.0]) expected = decoder_output + bias diff --git a/sparse_autoencoder/autoencoder/components/tests/test_unit_norm_linear.py b/sparse_autoencoder/autoencoder/components/tests/test_unit_norm_linear.py index 06d9e09a..a52182a0 100644 --- a/sparse_autoencoder/autoencoder/components/tests/test_unit_norm_linear.py +++ b/sparse_autoencoder/autoencoder/components/tests/test_unit_norm_linear.py @@ -4,20 +4,67 @@ from sparse_autoencoder.autoencoder.components.unit_norm_linear import ConstrainedUnitNormLinear +def test_initialization() -> None: + """Test that the weights are initialized with unit norm.""" + layer = ConstrainedUnitNormLinear(learnt_features=3, decoded_features=4) + weight_norms = torch.norm(layer.weight, dim=1) + assert torch.allclose(weight_norms, torch.ones_like(weight_norms)) + + +def test_forward_pass() -> None: + """Test the forward pass of the layer.""" + layer = ConstrainedUnitNormLinear(learnt_features=3, decoded_features=4) + input_tensor = torch.randn(5, 3) # Batch size of 5, learnt features of 3 + output = layer(input_tensor) + assert output.shape == (5, 4) # Batch size of 5, decoded features of 4 + + +def test_bias_initialization_and_usage() -> None: + """Test the bias is initialized and used correctly.""" + layer = ConstrainedUnitNormLinear(learnt_features=3, decoded_features=4, bias=True) + assert layer.bias is not None + # Check the bias is used in the forward pass + input_tensor = torch.zeros(5, 3) + output = layer(input_tensor) + assert torch.allclose(output, layer.bias.unsqueeze(0).expand(5, -1)) + + +def test_multiple_training_steps() -> None: + """Test the unit norm constraint over multiple training iterations.""" + torch.random.manual_seed(0) + layer = ConstrainedUnitNormLinear(learnt_features=3, decoded_features=4) + optimizer = torch.optim.Adam(layer.parameters(), lr=0.01) + for _ in range(10): + data = torch.randn(5, 3) + optimizer.zero_grad() + logits = layer(data) + + weight_norms = torch.norm(layer.weight, dim=1) + assert torch.allclose(weight_norms, torch.ones_like(weight_norms), atol=2e-3) + + loss = torch.mean(logits**2) + loss.backward() + optimizer.step() + + def test_unit_norm_applied_backward() -> None: """Check that the unit norm is applied after each gradient step.""" - layer = ConstrainedUnitNormLinear(3, 4) + torch.random.manual_seed(0) + layer = ConstrainedUnitNormLinear(learnt_features=3, decoded_features=4) + optimizer = torch.optim.SGD(layer.parameters(), lr=0.1, momentum=0) data = torch.randn((3), requires_grad=True) logits = layer(data) - loss = torch.sum(logits**2) + loss = torch.mean(logits**2) loss.backward() - weight_norms = torch.sum(layer.weight**2, dim=1) - - # Check that the weights still have unit norm - assert torch.allclose(weight_norms, torch.ones_like(weight_norms)) # Check that the gradient is not zero (as that would be a trivial way the weights could be kept # unit norm) grad = layer.weight.grad assert grad is not None assert not torch.allclose(grad, torch.zeros_like(grad)) + + optimizer.step() + + # Check that the weights still have unit norm + weight_norms = torch.sum(layer.weight**2, dim=1) + assert torch.allclose(weight_norms, torch.ones_like(weight_norms), atol=2e-3) diff --git a/sparse_autoencoder/autoencoder/components/tied_bias.py b/sparse_autoencoder/autoencoder/components/tied_bias.py index 3047e982..e1ffaa24 100644 --- a/sparse_autoencoder/autoencoder/components/tied_bias.py +++ b/sparse_autoencoder/autoencoder/components/tied_bias.py @@ -1,10 +1,13 @@ """Tied Biases (Pre-Encoder and Post-Decoder).""" from enum import Enum -from jaxtyping import Float -from torch import Tensor from torch.nn import Module +from sparse_autoencoder.tensor_types import ( + InputOutputActivationBatch, + InputOutputActivationVector, +) + class TiedBiasPosition(str, Enum): """Tied Bias Position.""" @@ -24,13 +27,13 @@ class TiedBias(Module): https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-bias """ - _bias_reference: Float[Tensor, " input_activations"] + _bias_reference: InputOutputActivationVector _bias_position: TiedBiasPosition def __init__( self, - bias: Float[Tensor, " input_activations"], + bias: InputOutputActivationVector, position: TiedBiasPosition, ) -> None: """Initialize the bias layer. @@ -50,8 +53,8 @@ def __init__( def forward( self, - x: Float[Tensor, "*batch input_activations"], - ) -> Float[Tensor, "*batch input_activations"]: + x: InputOutputActivationBatch, + ) -> InputOutputActivationBatch: """Forward Pass. Args: diff --git a/sparse_autoencoder/autoencoder/components/unit_norm_linear.py b/sparse_autoencoder/autoencoder/components/unit_norm_linear.py index 1e814361..17ce0d7b 100644 --- a/sparse_autoencoder/autoencoder/components/unit_norm_linear.py +++ b/sparse_autoencoder/autoencoder/components/unit_norm_linear.py @@ -2,23 +2,29 @@ import math import einops -from jaxtyping import Float import torch from torch import Tensor from torch.nn import Module, init from torch.nn.parameter import Parameter +from sparse_autoencoder.tensor_types import ( + Axis, + DecoderWeights, + EncoderWeights, + InputOutputActivationVector, +) + class ConstrainedUnitNormLinear(Module): - """Constrained unit norm linear layer. + """Constrained unit norm linear decoder layer. - Linear layer for autoencoders, where the dictionary vectors (columns of the weight matrix) are + Linear layer for autoencoders, where the dictionary vectors (rows of the weight matrix) are constrained to have unit norm. This is done by removing the gradient information parallel to the dictionary vectors before applying the gradient step, using a backward hook. Motivation: - Unit norming the dictionary vectors, which are essentially the columns of the encoding and - decoding matrices, serves a few purposes: + Unit norming the dictionary vectors, which are essentially the rows of the decoding + matrices, serves a few purposes: 1. It helps with numerical stability, by preventing the dictionary vectors from growing too large. @@ -36,25 +42,22 @@ class ConstrainedUnitNormLinear(Module): loss](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-optimization). """ - in_features: int - """Number of input features.""" - - out_features: int - """Number of output features.""" + learnt_features: int + """Number of learnt features (inputs to this layer).""" - DIMENSION_CONSTRAIN_UNIT_NORM: int = -1 - """Dimension to constrain to unit norm.""" + decoded_features: int + """Number of decoded features (outputs from this layer).""" - weight: Float[Tensor, " out_features in_features"] + weight: DecoderWeights """Weight parameter.""" - bias: Float[Tensor, " out_features"] | None + bias: InputOutputActivationVector | None """Bias parameter.""" def __init__( self, - in_features: int, - out_features: int, + learnt_features: int, + decoded_features: int, *, bias: bool = True, device: torch.device | None = None, @@ -63,27 +66,27 @@ def __init__( """Initialize the constrained unit norm linear layer. Args: - in_features: Number of input features. - out_features: Number of output features. + learnt_features: Number of learnt features in the autoencoder. + decoded_features: Number of decoded (output) features in the autoencoder. bias: Whether to include a bias term. device: Device to use. dtype: Data type to use. """ # Create the linear layer as per the standard PyTorch linear layer super().__init__() - self.in_features = in_features - self.out_features = out_features + self.learnt_features = learnt_features + self.decoded_features = decoded_features self.weight = Parameter( - torch.empty((out_features, in_features), device=device, dtype=dtype) + torch.empty((decoded_features, learnt_features), device=device, dtype=dtype) ) if bias: - self.bias = Parameter(torch.empty(out_features, device=device, dtype=dtype)) + self.bias = Parameter(torch.empty(decoded_features, device=device, dtype=dtype)) else: self.register_parameter("bias", None) self.reset_parameters() # Register backward hook to remove any gradient information parallel to the dictionary - # vectors (columns of the weight matrix) before applying the gradient step. + # vectors (rows of the weight matrix) before applying the gradient step. self.weight.register_hook(self._weight_backward_hook) def reset_parameters(self) -> None: @@ -91,25 +94,23 @@ def reset_parameters(self) -> None: Example: >>> import torch - >>> layer = ConstrainedUnitNormLinear(3, 3) + >>> # Create a layer with 4 columns (learnt features) and 3 rows (decoded features) + >>> layer = ConstrainedUnitNormLinear(learnt_features=4, decoded_features=3) >>> layer.reset_parameters() - >>> column_norms = torch.sum(layer.weight ** 2, dim=1) - >>> column_norms.round(decimals=3).tolist() + >>> # Get the norm across the rows (by summing across the columns) + >>> row_norms = torch.sum(layer.weight ** 2, dim=1) + >>> row_norms.round(decimals=3).tolist() [1.0, 1.0, 1.0] """ # Initialize the weights with a normal distribution. Note we don't use e.g. kaiming # normalisation here, since we immediately scale the weights to have unit norm (so the # initial standard deviation doesn't matter). Note also that `init.normal_` is in place. - self.weight: Float[Tensor, "out_features in_features"] = init.normal_( - self.weight, mean=0, std=1 - ) + self.weight: EncoderWeights = init.normal_(self.weight, mean=0, std=1) - # Scale so that each column has unit norm + # Scale so that each row has unit norm with torch.no_grad(): - torch.nn.functional.normalize( - self.weight, dim=self.DIMENSION_CONSTRAIN_UNIT_NORM, out=self.weight - ) + torch.nn.functional.normalize(self.weight, dim=-1, out=self.weight) # Initialise the bias # This is the standard approach used in `torch.nn.Linear.reset_parameters` @@ -120,8 +121,8 @@ def reset_parameters(self) -> None: def _weight_backward_hook( self, - grad: Float[Tensor, "out_features in_features"], - ) -> Float[Tensor, "out_features in_features"]: + grad: EncoderWeights, + ) -> EncoderWeights: """Unit norm backward hook. By subtracting the projection of the gradient onto the dictionary vectors, we remove the @@ -129,17 +130,6 @@ def _weight_backward_hook( component that is orthogonal to the dictionary vectors (i.e. moving around the hypersphere). The result is that the backward pass does not change the norm of the dictionary vectors. - Example: - >>> import torch - >>> layer = ConstrainedUnitNormLinear(3, 4) - >>> data = torch.randn((3), requires_grad=True) - >>> logits = layer(data) - >>> loss = torch.sum(logits ** 2) - >>> loss.backward() # The hook is applied here - >>> weight_norms = torch.sum(layer.weight.data ** 2, dim=1) - >>> weight_norms.round(decimals=3).tolist() - [1.0, 1.0, 1.0, 1.0] - Args: grad: Gradient with respect to the weights. """ @@ -148,18 +138,29 @@ def _weight_backward_hook( # the gradient onto the dictionary vectors is the component of the gradient that is parallel # to the dictionary vectors, i.e. the component that moves to or from the center of the # hypersphere. - dot_product: Float[Tensor, " out_features"] = einops.einsum( - grad, self.weight, "out_features in_features, out_features in_features -> out_features" + normalized_weight: EncoderWeights = self.weight / torch.norm( + self.weight, dim=-1, keepdim=True ) - normalized_weight: Float[Tensor, "out_features in_features"] = self.weight / torch.norm( - self.weight, dim=self.DIMENSION_CONSTRAIN_UNIT_NORM, keepdim=True + # Calculate the dot product of the gradients with the dictionary vectors. + # This represents the component of the gradient parallel to each dictionary vector. + # The result will be a tensor of shape [decoded_features]. + dot_product = einops.einsum( + grad, + normalized_weight, + f"{Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE}, \ + {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE} \ + -> {Axis.LEARNT_FEATURE}", ) + # Scale the normalized weights by the dot product to get the projection. + # The result will be of the same shape as 'grad' and 'self.weight'. projection = einops.einsum( dot_product, normalized_weight, - "out_features, out_features in_features -> out_features in_features", + f"{Axis.LEARNT_FEATURE}, \ + {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE} \ + -> {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE}", ) # Subtracting the parallel component from the gradient leaves only the component that is @@ -184,15 +185,13 @@ def constrain_weights_unit_norm(self) -> None: >>> layer = ConstrainedUnitNormLinear(3, 3) >>> layer.weight.data = torch.ones((3, 3)) * 10 >>> layer.constrain_weights_unit_norm() - >>> column_norms = torch.sum(layer.weight ** 2, dim=1) - >>> column_norms.round(decimals=3).tolist() + >>> row_norms = torch.sum(layer.weight ** 2, dim=1) + >>> row_norms.round(decimals=3).tolist() [1.0, 1.0, 1.0] """ with torch.no_grad(): - torch.nn.functional.normalize( - self.weight, dim=self.DIMENSION_CONSTRAIN_UNIT_NORM, out=self.weight - ) + torch.nn.functional.normalize(self.weight, dim=-1, out=self.weight) def forward(self, x: Tensor) -> Tensor: """Forward pass. @@ -214,6 +213,6 @@ def forward(self, x: Tensor) -> Tensor: def extra_repr(self) -> str: """String extra representation of the module.""" return ( - f"in_features={self.in_features}, out_features={self.out_features}, " + f"in_features={self.learnt_features}, out_features={self.decoded_features}, " f"bias={self.bias is not None}" ) diff --git a/sparse_autoencoder/autoencoder/loss.py b/sparse_autoencoder/autoencoder/loss.py deleted file mode 100644 index 90aefb2c..00000000 --- a/sparse_autoencoder/autoencoder/loss.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Loss function for the Sparse Autoencoder.""" -from jaxtyping import Float -import torch -from torch import Tensor -from torch.nn.functional import mse_loss - - -def reconstruction_loss( - input_activations: Float[Tensor, "item input_features"], - output_activations: Float[Tensor, "item input_features"], -) -> Float[Tensor, " item"]: - """Reconstruction Loss (MSE). - - MSE reconstruction loss is calculated as the mean squared error between each each input vector - and it's corresponding decoded vector. The original paper found that models trained with some - loss functions such as cross-entropy loss generally prefer to represent features - polysemantically, whereas models trained with MSE may achieve the same loss for both - polysemantic and monosemantic representations of true features. - - Examples: - >>> input_activations = torch.tensor([[5.0, 4], [3.0, 4]]) - >>> output_activations = torch.tensor([[1.0, 5], [1.0, 5]]) - >>> reconstruction_loss(input_activations, output_activations) - tensor([8.5000, 2.5000]) - - Args: - input_activations: Input activations. - output_activations: Reconstructed activations. - - Returns: - Mean Squared Error reconstruction loss, over the features dimension. - """ - return mse_loss(input_activations, output_activations, reduction="none").mean(dim=-1) - - -def l1_loss(learned_activations: Float[Tensor, "item learned_features"]) -> Float[Tensor, " item"]: - """L1 Loss on Learned Activations. - - L1 loss penalty is the absolute sum of the learned activations. The L1 penalty is this - multiplied by the l1_coefficient (designed to encourage sparsity). - - Examples: - >>> learned_activations = torch.tensor([[2.0, -3], [2.0, -3]]) - >>> l1_loss(learned_activations) - tensor([5., 5.]) - - Args: - learned_activations: Activations from the hidden layer. - - Returns: - L1 loss on learned activations, summed over the features dimension. - """ - return torch.abs(learned_activations).sum(dim=-1) - - -def sae_training_loss( - reconstruction_loss_mse: Float[Tensor, " item"], - l1_loss_learned_activations: Float[Tensor, " item"], - l1_coefficient: float, -) -> Float[Tensor, " item"]: - """Loss Function for the Sparse Autoencoder. - - The original paper used L2 reconstruction loss, plus l1 loss on the hidden (learned) - activations. - - https://transformer-circuits.pub/2023/monosemantic-features/index.html#setup-autoencoder-motivation - - Warning: - It isn't meaningful to compare training loss across hyperparameters that change the loss - function, such as L1 coefficients. - - Examples: - >>> reconstruction_loss_mse = torch.tensor([8.5000, 2.5000]) - >>> l1_loss_learned_activations = torch.tensor([1., 1.]) - >>> l1_coefficient = 0.5 - >>> sae_training_loss(reconstruction_loss_mse, l1_loss_learned_activations, l1_coefficient) - tensor([9., 3.]) - - Args: - reconstruction_loss_mse: MSE reconstruction loss. - l1_loss_learned_activations: L1 loss on learned activations. - l1_coefficient: L1 coefficient. The original paper experimented with L1 coefficients of - [0.01, 0.008, 0.006, 0.004, 0.001]. They used 250 tokens per prompt, so as an - approximate guide if you use e.g. 2x this number of tokens you might consider using 0.5x - the l1 coefficient. - - Returns: - Overall training loss. - """ - return reconstruction_loss_mse + l1_loss_learned_activations * l1_coefficient diff --git a/sparse_autoencoder/autoencoder/model.py b/sparse_autoencoder/autoencoder/model.py index 37a2401d..7013576f 100644 --- a/sparse_autoencoder/autoencoder/model.py +++ b/sparse_autoencoder/autoencoder/model.py @@ -1,26 +1,29 @@ """The Sparse Autoencoder Model.""" from collections import OrderedDict -from jaxtyping import Float import torch -from torch import Tensor from torch.nn import Linear, Module, ReLU, Sequential from torch.nn.parameter import Parameter from sparse_autoencoder.autoencoder.components.tied_bias import TiedBias, TiedBiasPosition from sparse_autoencoder.autoencoder.components.unit_norm_linear import ConstrainedUnitNormLinear +from sparse_autoencoder.tensor_types import ( + InputOutputActivationBatch, + InputOutputActivationVector, + LearnedActivationBatch, +) class SparseAutoencoder(Module): """Sparse Autoencoder Model.""" - geometric_median_dataset: Float[Tensor, " input_activations"] + geometric_median_dataset: InputOutputActivationVector """Estimated Geometric Median of the Dataset. Used for initialising :attr:`tied_bias`. """ - tied_bias: Float[Parameter, " input_activations"] + tied_bias: InputOutputActivationBatch """Tied Bias Parameter. The same bias is used pre-encoder and post-decoder. @@ -48,7 +51,7 @@ def __init__( self, n_input_features: int, n_learned_features: int, - geometric_median_dataset: Float[Tensor, " input_activations"], + geometric_median_dataset: InputOutputActivationVector, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: @@ -104,10 +107,10 @@ def __init__( def forward( self, - x: Float[Tensor, "batch input_activations"], + x: InputOutputActivationBatch, ) -> tuple[ - Float[Tensor, "batch learned_activations"], - Float[Tensor, "batch input_activations"], + LearnedActivationBatch, + InputOutputActivationBatch, ]: """Forward Pass. diff --git a/sparse_autoencoder/autoencoder/tests/test_loss.py b/sparse_autoencoder/autoencoder/tests/test_loss.py deleted file mode 100644 index 7fad42a9..00000000 --- a/sparse_autoencoder/autoencoder/tests/test_loss.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Loss Function Tests.""" -import torch - -from sparse_autoencoder.autoencoder.loss import ( - l1_loss, - reconstruction_loss, - sae_training_loss, -) - - -def test_loss() -> None: - """Test loss against a non-vectorised approach.""" - input_activations: list[float] = [3.0, 4] - learned_activations: list[float] = [2.0, -3] - output_activations: list[float] = [1.0, 5] - l1_coefficient = 0.5 - - squared_errors: float = 0.0 - for i, o in zip(input_activations, output_activations, strict=True): - squared_errors += (i - o) ** 2 - mse = squared_errors / len(input_activations) - - l1_penalty: float = 0.0 - for neuron in learned_activations: - l1_penalty += abs(neuron) * l1_coefficient - - expected: float = mse + l1_penalty - - # Compute the reconstruction_loss, l1_loss, and sae_training_loss - mse_tensor = reconstruction_loss( - torch.tensor(input_activations).unsqueeze(0), - torch.tensor(output_activations).unsqueeze(0), - ) - l1_tensor = l1_loss(torch.tensor(learned_activations).unsqueeze(0)) - result = sae_training_loss(mse_tensor, l1_tensor, l1_coefficient) - - assert torch.allclose(result, torch.tensor([expected])) diff --git a/sparse_autoencoder/loss/__init__.py b/sparse_autoencoder/loss/__init__.py new file mode 100644 index 00000000..2ca18417 --- /dev/null +++ b/sparse_autoencoder/loss/__init__.py @@ -0,0 +1,23 @@ +"""Loss Modules. + +Loss modules are specialised PyTorch modules that calculate the loss for a Sparse Autoencoder. They +all inherit from AbstractLoss, which defines the interface for loss modules and some common methods. + +If you want to create your own loss function, see :class:`AbstractLoss`. + +For combining multiple loss modules into a single loss module, see :class:`LossReducer`. +""" +from sparse_autoencoder.loss.abstract_loss import AbstractLoss, LossLogType, LossReductionType +from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss +from sparse_autoencoder.loss.mse_reconstruction_loss import MSEReconstructionLoss +from sparse_autoencoder.loss.reducer import LossReducer + + +__all__ = [ + "AbstractLoss", + "LearnedActivationsL1Loss", + "LossLogType", + "LossReducer", + "LossReductionType", + "MSEReconstructionLoss", +] diff --git a/sparse_autoencoder/loss/abstract_loss.py b/sparse_autoencoder/loss/abstract_loss.py new file mode 100644 index 00000000..d0934a2e --- /dev/null +++ b/sparse_autoencoder/loss/abstract_loss.py @@ -0,0 +1,163 @@ +"""Abstract loss.""" +from abc import ABC, abstractmethod +from typing import TypeAlias, final + +from strenum import LowercaseStrEnum +import torch +from torch.nn import Module + +from sparse_autoencoder.tensor_types import ( + InputOutputActivationBatch, + ItemTensor, + LearnedActivationBatch, + TrainBatchStatistic, +) + + +class LossReductionType(LowercaseStrEnum): + """Loss reduction type (across batch items).""" + + MEAN = "mean" + """Mean loss across batch items.""" + + SUM = "sum" + """Sum the loss from all batch items.""" + + +LossLogType: TypeAlias = dict[str, int | float | str] +"""Loss log dict.""" + + +class AbstractLoss(Module, ABC): + """Abstract loss interface. + + Interface for implementing batch itemwise loss functions. + """ + + _modules: dict[str, "AbstractLoss"] # type: ignore[assignment] (narrowing) + """Children loss modules.""" + + @abstractmethod + def forward( + self, + source_activations: InputOutputActivationBatch, + learned_activations: LearnedActivationBatch, + decoded_activations: InputOutputActivationBatch, + ) -> TrainBatchStatistic: + """Batch itemwise loss. + + Args: + source_activations: Source activations (input activations to the autoencoder from the + source model). + learned_activations: Learned activations (intermediate activations in the autoencoder). + decoded_activations: Decoded activations. + + Returns: + Loss per batch item. + """ + raise NotImplementedError + + @final + def batch_scalar_loss( + self, + source_activations: InputOutputActivationBatch, + learned_activations: LearnedActivationBatch, + decoded_activations: InputOutputActivationBatch, + reduction: LossReductionType = LossReductionType.MEAN, + ) -> ItemTensor: + """Batch scalar loss. + + Args: + source_activations: Source activations (input activations to the autoencoder from the + source model). + learned_activations: Learned activations (intermediate activations in the autoencoder). + decoded_activations: Decoded activations. + reduction: Loss reduction type. Typically you would choose LossReductionType.MEAN to + make the loss independent of the batch size. + + Returns: + Loss for the batch. + """ + itemwise_loss = self.forward(source_activations, learned_activations, decoded_activations) + + match reduction: + case LossReductionType.MEAN: + return itemwise_loss.mean().squeeze() + case LossReductionType.SUM: + return itemwise_loss.sum().squeeze() + + @final + def batch_scalar_loss_with_log( + self, + source_activations: InputOutputActivationBatch, + learned_activations: LearnedActivationBatch, + decoded_activations: InputOutputActivationBatch, + reduction: LossReductionType = LossReductionType.MEAN, + ) -> tuple[ItemTensor, LossLogType]: + """Batch scalar loss. + + Args: + source_activations: Source activations (input activations to the autoencoder from the + source model). + learned_activations: Learned activations (intermediate activations in the autoencoder). + decoded_activations: Decoded activations. + reduction: Loss reduction type. Typically you would choose LossReductionType.MEAN to + make the loss independent of the batch size. + + Returns: + Tuple of the batch scalar loss and a dict of any properties to log. + """ + children_loss_scalars: list[ItemTensor] = [] + metrics: LossLogType = {} + + # If the loss module has children (e.g. it is a reducer): + if len(self._modules) > 0: + for loss_module in self._modules.values(): + child_loss, child_metrics = loss_module.batch_scalar_loss_with_log( + source_activations, + learned_activations, + decoded_activations, + reduction=reduction, + ) + children_loss_scalars.append(child_loss) + metrics.update(child_metrics) + + # Get the total loss & metric + current_module_loss = torch.stack(children_loss_scalars).sum() + + # Otherwise if it is a leaf loss module: + else: + current_module_loss = self.batch_scalar_loss( + source_activations, learned_activations, decoded_activations, reduction + ) + + # Add in the current loss module's metric + class_name = self.__class__.__name__ + metrics[class_name] = current_module_loss.detach().cpu().item() + + return current_module_loss, metrics + + @final + def __call__( + self, + source_activations: InputOutputActivationBatch, + learned_activations: LearnedActivationBatch, + decoded_activations: InputOutputActivationBatch, + reduction: LossReductionType = LossReductionType.MEAN, + ) -> tuple[ItemTensor, LossLogType]: + """Batch scalar loss. + + Args: + source_activations: Source activations (input activations to the autoencoder from the + source model). + learned_activations: Learned activations (intermediate activations in the autoencoder). + decoded_activations: Decoded activations. + reduction: Loss reduction type. Typically you would choose LossReductionType.MEAN to + make the loss independent of the batch size. + + Returns: + Tuple of the batch scalar loss and a dict of any properties to log. + """ + return self.batch_scalar_loss_with_log( + source_activations, learned_activations, decoded_activations, reduction + ) diff --git a/sparse_autoencoder/loss/learned_activations_l1.py b/sparse_autoencoder/loss/learned_activations_l1.py new file mode 100644 index 00000000..13ff98fd --- /dev/null +++ b/sparse_autoencoder/loss/learned_activations_l1.py @@ -0,0 +1,68 @@ +"""Learned activations L1 (absolute error) loss.""" +from typing import final + +import torch + +from sparse_autoencoder.loss.abstract_loss import AbstractLoss +from sparse_autoencoder.tensor_types import ( + InputOutputActivationBatch, + LearnedActivationBatch, + TrainBatchStatistic, +) + + +@final +class LearnedActivationsL1Loss(AbstractLoss): + """Learned activations L1 (absolute error) loss. + + L1 loss penalty is the absolute sum of the learned activations. The L1 penalty is this + multiplied by the l1_coefficient (designed to encourage sparsity). + + Example: + >>> l1_loss = LearnedActivationsL1Loss(0.1) + >>> learned_activations = torch.tensor([[2.0, -3], [2.0, -3]]) + >>> unused_activations = torch.zeros_like(learned_activations) + >>> # Returns loss and metrics to log + >>> l1_loss(unused_activations, learned_activations, unused_activations) + (tensor(0.5000), {'LearnedActivationsL1Loss': 0.5}) + """ + + l1_coefficient: float + """L1 coefficient.""" + + def __init__(self, l1_coefficient: float) -> None: + """Initialize the absolute error loss. + + Args: + l1_coefficient: L1 coefficient. The original paper experimented with L1 coefficients of + [0.01, 0.008, 0.006, 0.004, 0.001]. They used 250 tokens per prompt, so as an + approximate guide if you use e.g. 2x this number of tokens you might consider using + 0.5x the l1 coefficient. + """ + self.l1_coefficient = l1_coefficient + super().__init__() + + def forward( + self, + source_activations: InputOutputActivationBatch, # noqa: ARG002 + learned_activations: LearnedActivationBatch, + decoded_activations: InputOutputActivationBatch, # noqa: ARG002 + ) -> TrainBatchStatistic: + """Learned activations L1 (absolute error) loss. + + Args: + source_activations: Source activations (input activations to the autoencoder from the + source model). + learned_activations: Learned activations (intermediate activations in the autoencoder). + decoded_activations: Decoded activations. + + Returns: + Loss per batch item. + """ + absolute_loss = torch.abs(learned_activations) + + return absolute_loss.sum(dim=-1) * self.l1_coefficient + + def extra_repr(self) -> str: + """Extra representation string.""" + return f"l1_coefficient={self.l1_coefficient}" diff --git a/sparse_autoencoder/loss/mse_reconstruction_loss.py b/sparse_autoencoder/loss/mse_reconstruction_loss.py new file mode 100644 index 00000000..1255dbcf --- /dev/null +++ b/sparse_autoencoder/loss/mse_reconstruction_loss.py @@ -0,0 +1,55 @@ +"""MSE Reconstruction loss.""" +from typing import final + +from torch.nn.functional import mse_loss + +from sparse_autoencoder.loss.abstract_loss import AbstractLoss +from sparse_autoencoder.tensor_types import ( + InputOutputActivationBatch, + LearnedActivationBatch, + TrainBatchStatistic, +) + + +@final +class MSEReconstructionLoss(AbstractLoss): + """MSE Reconstruction loss. + + MSE reconstruction loss is calculated as the mean squared error between each each input vector + and it's corresponding decoded vector. The original paper found that models trained with some + loss functions such as cross-entropy loss generally prefer to represent features + polysemantically, whereas models trained with MSE may achieve the same loss for both + polysemantic and monosemantic representations of true features. + + Example: + >>> import torch + >>> loss = MSEReconstructionLoss() + >>> input_activations = torch.tensor([[5.0, 4], [3.0, 4]]) + >>> output_activations = torch.tensor([[1.0, 5], [1.0, 5]]) + >>> unused_activations = torch.zeros_like(input_activations) + >>> # Outputs both loss and metrics to log + >>> loss(input_activations, unused_activations, output_activations) + (tensor(5.5000), {'MSEReconstructionLoss': 5.5}) + """ + + def forward( + self, + source_activations: InputOutputActivationBatch, + learned_activations: LearnedActivationBatch, # noqa: ARG002 + decoded_activations: InputOutputActivationBatch, + ) -> TrainBatchStatistic: + """MSE Reconstruction loss (mean across features dimension). + + Args: + source_activations: Source activations (input activations to the autoencoder from the + source model). + learned_activations: Learned activations (intermediate activations in the autoencoder). + decoded_activations: Decoded activations. + + Returns: + Loss per batch item. + """ + square_error_loss = mse_loss(source_activations, decoded_activations, reduction="none") + + # Mean over just the features dimension (i.e. batch itemwise loss) + return square_error_loss.mean(dim=-1) diff --git a/sparse_autoencoder/loss/reducer.py b/sparse_autoencoder/loss/reducer.py new file mode 100644 index 00000000..8b3aec8d --- /dev/null +++ b/sparse_autoencoder/loss/reducer.py @@ -0,0 +1,102 @@ +"""Loss reducer.""" +from collections.abc import Iterator +from typing import final + +from jaxtyping import Float +import torch +from torch import Tensor + +from sparse_autoencoder.loss.abstract_loss import AbstractLoss +from sparse_autoencoder.tensor_types import ( + InputOutputActivationBatch, + ItemTensor, + LearnedActivationBatch, +) + + +@final +class LossReducer(AbstractLoss): + """Loss reducer. + + Reduces multiple loss algorithms into a single loss algorithm (by summing). Analogous to + nn.Sequential. + + Example: + >>> from sparse_autoencoder.loss.mse_reconstruction_loss import MSEReconstructionLoss + >>> from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss + >>> LossReducer( + ... MSEReconstructionLoss(), + ... LearnedActivationsL1Loss(0.001), + ... ) + LossReducer( + (0): MSEReconstructionLoss() + (1): LearnedActivationsL1Loss(l1_coefficient=0.001) + ) + + """ + + _modules: dict[str, "AbstractLoss"] + """Children loss modules.""" + + def __init__( + self, + *loss_modules: AbstractLoss, + ): + """Initialize the loss reducer. + + Args: + loss_modules: Loss modules to reduce. + + Raises: + ValueError: If the loss reducer has no loss modules. + """ + super().__init__() + + for idx, loss_module in enumerate(loss_modules): + self._modules[str(idx)] = loss_module + + if len(self) == 0: + error_message = "Loss reducer must have at least one loss module." + raise ValueError(error_message) + + def forward( + self, + source_activations: InputOutputActivationBatch, + learned_activations: LearnedActivationBatch, + decoded_activations: InputOutputActivationBatch, + ) -> ItemTensor: + """Reduce loss. + + Args: + source_activations: Source activations (input activations to the autoencoder from the + source model). + learned_activations: Learned activations (intermediate activations in the autoencoder). + decoded_activations: Decoded activations. + + Returns: + Mean loss across the batch, summed across the loss modules. + """ + all_modules_loss: Float[Tensor, "module train_batch"] = torch.stack( + [ + loss_module.forward(source_activations, learned_activations, decoded_activations) + for loss_module in self._modules.values() + ] + ) + + return all_modules_loss.sum(dim=0) + + def __dir__(self) -> list[str]: + """Dir dunder method.""" + return list(self._modules.__dir__()) + + def __getitem__(self, idx: int) -> AbstractLoss: + """Get item dunder method.""" + return self._modules[str(idx)] + + def __iter__(self) -> Iterator[AbstractLoss]: + """Iterator dunder method.""" + return iter(self._modules.values()) + + def __len__(self) -> int: + """Length dunder method.""" + return len(self._modules) diff --git a/sparse_autoencoder/loss/tests/test_abstract_loss.py b/sparse_autoencoder/loss/tests/test_abstract_loss.py new file mode 100644 index 00000000..b7c04333 --- /dev/null +++ b/sparse_autoencoder/loss/tests/test_abstract_loss.py @@ -0,0 +1,75 @@ +"""Tests for the AbstractLoss class.""" +import pytest +import torch + +from sparse_autoencoder.loss.abstract_loss import AbstractLoss, LossReductionType +from sparse_autoencoder.tensor_types import ( + InputOutputActivationBatch, + LearnedActivationBatch, + TrainBatchStatistic, +) + + +class DummyLoss(AbstractLoss): + """Dummy loss for testing.""" + + def forward( + self, + source_activations: InputOutputActivationBatch, # noqa: ARG002 + learned_activations: LearnedActivationBatch, # noqa: ARG002 + decoded_activations: InputOutputActivationBatch, # noqa: ARG002 + ) -> TrainBatchStatistic: + """Batch itemwise loss.""" + # Simple dummy implementation for testing + return torch.tensor([1.0, 2.0, 3.0]) + + +@pytest.fixture() +def dummy_loss() -> DummyLoss: + """Dummy loss for testing.""" + return DummyLoss() + + +def test_abstract_class_enforced() -> None: + """Test that initializing the abstract class raises an error.""" + with pytest.raises(TypeError): + AbstractLoss() # type: ignore + + +@pytest.mark.parametrize( + ("loss_reduction", "expected"), + [ + (LossReductionType.MEAN, 2.0), # Mean of [1.0, 2.0, 3.0] + (LossReductionType.SUM, 6.0), # Sum of [1.0, 2.0, 3.0] + ], +) +def test_batch_scalar_loss( + dummy_loss: DummyLoss, loss_reduction: LossReductionType, expected: float +) -> None: + """Test the batch scalar loss.""" + source_activations = learned_activations = decoded_activations = torch.ones((1, 3)) + + loss_mean = dummy_loss.batch_scalar_loss( + source_activations, learned_activations, decoded_activations, loss_reduction + ) + assert loss_mean.item() == expected + + +def test_batch_scalar_loss_with_log(dummy_loss: DummyLoss) -> None: + """Test the batch scalar loss with log.""" + source_activations = learned_activations = decoded_activations = torch.ones((1, 3)) + _loss, log = dummy_loss.batch_scalar_loss_with_log( + source_activations, learned_activations, decoded_activations + ) + assert "DummyLoss" in log + expected = 2.0 # Mean of [1.0, 2.0, 3.0] + assert log["DummyLoss"] == expected + + +def test_call_method(dummy_loss: DummyLoss) -> None: + """Test the call method.""" + source_activations = learned_activations = decoded_activations = torch.ones((1, 3)) + _loss, log = dummy_loss(source_activations, learned_activations, decoded_activations) + assert "DummyLoss" in log + expected = 2.0 # Mean of [1.0, 2.0, 3.0] + assert log["DummyLoss"] == expected diff --git a/sparse_autoencoder/loss/tests/test_learned_activations_l1.py b/sparse_autoencoder/loss/tests/test_learned_activations_l1.py new file mode 100644 index 00000000..6cee0bdc --- /dev/null +++ b/sparse_autoencoder/loss/tests/test_learned_activations_l1.py @@ -0,0 +1,63 @@ +"""Tests for LearnedActivationsL1Loss.""" +import pytest +import torch + +from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss + + +@pytest.fixture() +def l1_loss() -> LearnedActivationsL1Loss: + """Fixture for LearnedActivationsL1Loss with a default L1 coefficient.""" + return LearnedActivationsL1Loss(l1_coefficient=0.1) + + +def test_l1_loss_forward(l1_loss: LearnedActivationsL1Loss) -> None: + """Test the forward method of LearnedActivationsL1Loss.""" + learned_activations = torch.tensor([[2.0, -3.0], [2.0, -3.0]]) + source_activations = decoded_activations = torch.zeros_like(learned_activations) + + expected_loss = torch.tensor([0.5, 0.5]) # (|2| + |-3|) * 0.1 for each row + calculated_loss = l1_loss.forward(source_activations, learned_activations, decoded_activations) + + assert torch.allclose(calculated_loss, expected_loss), "L1 loss calculation is incorrect." + + +def test_l1_loss_with_different_l1_coefficients() -> None: + """Test LearnedActivationsL1Loss with different L1 coefficients.""" + learned_activations = torch.tensor([[2.0, -3.0], [2.0, -3.0]]) + source_activations = decoded_activations = torch.zeros_like(learned_activations) + + for coefficient in [0.01, 0.1, 0.5]: + l1_loss = LearnedActivationsL1Loss(l1_coefficient=coefficient) + expected_loss = torch.abs(learned_activations).sum(dim=-1) * coefficient + calculated_loss = l1_loss.forward( + source_activations, learned_activations, decoded_activations + ) + + assert torch.allclose( + calculated_loss, expected_loss + ), f"L1 loss calculation is incorrect for coefficient {coefficient}." + + +def test_l1_loss_with_zero_input(l1_loss: LearnedActivationsL1Loss) -> None: + """Test the L1 loss function with zero inputs.""" + learned_activations = torch.zeros((2, 3)) + source_activations = decoded_activations = torch.zeros_like(learned_activations) + + expected_loss = torch.zeros(2) + calculated_loss = l1_loss.forward(source_activations, learned_activations, decoded_activations) + + assert torch.all(calculated_loss == expected_loss), "L1 loss should be zero for zero inputs." + + +def test_l1_loss_with_negative_input(l1_loss: LearnedActivationsL1Loss) -> None: + """Test the L1 loss function with negative inputs.""" + learned_activations = torch.tensor([[-2.0, -3.0], [-1.0, -4.0]]) + source_activations = decoded_activations = torch.zeros_like(learned_activations) + + expected_loss = torch.tensor([0.5, 0.5]) # (|2| + |-3|) * 0.1 for each row + calculated_loss = l1_loss.forward(source_activations, learned_activations, decoded_activations) + + assert torch.allclose( + calculated_loss, expected_loss + ), "L1 loss calculation is incorrect with negative inputs." diff --git a/sparse_autoencoder/loss/tests/test_mse_reconstruction_loss.py b/sparse_autoencoder/loss/tests/test_mse_reconstruction_loss.py new file mode 100644 index 00000000..d2582f81 --- /dev/null +++ b/sparse_autoencoder/loss/tests/test_mse_reconstruction_loss.py @@ -0,0 +1,54 @@ +"""Tests for the MSEReconstructionLoss class.""" +import pytest +import torch + +from sparse_autoencoder.loss.mse_reconstruction_loss import MSEReconstructionLoss + + +@pytest.fixture() +def mse_loss() -> MSEReconstructionLoss: + """Fixture for MSEReconstructionLoss.""" + return MSEReconstructionLoss() + + +def test_mse_loss_forward(mse_loss: MSEReconstructionLoss) -> None: + """Test the forward method of MSEReconstructionLoss.""" + input_activations = torch.tensor([[5.0, 4.0], [3.0, 4.0]]) + output_activations = torch.tensor([[1.0, 5.0], [1.0, 5.0]]) + learned_activations = torch.zeros_like(input_activations) + + expected_loss = torch.tensor([8.5, 2.5]) + calculated_loss = mse_loss.forward(input_activations, learned_activations, output_activations) + + assert torch.allclose(calculated_loss, expected_loss), "MSE loss calculation is incorrect." + + +def test_mse_loss_with_zero_input(mse_loss: MSEReconstructionLoss) -> None: + """Test the MSE loss function with zero inputs.""" + input_activations = torch.zeros((2, 3)) + output_activations = torch.zeros_like(input_activations) + learned_activations = torch.zeros_like(input_activations) + + expected_loss = torch.zeros(2) + calculated_loss = mse_loss.forward(input_activations, learned_activations, output_activations) + + assert torch.all( + calculated_loss == expected_loss + ), "MSE loss should be zero for identical zero inputs." + + +def test_mse_loss_with_varying_input_shapes(mse_loss: MSEReconstructionLoss) -> None: + """Test the MSE loss function with varying input shapes.""" + for shape in [(1, 3), (5, 3), (10, 5)]: + input_activations = torch.rand(shape) + output_activations = torch.rand(shape) + learned_activations = torch.zeros_like(input_activations) + + calculated_loss = mse_loss.forward( + input_activations, learned_activations, output_activations + ) + + # Just checking if the loss calculation completes without error for different shapes + assert ( + calculated_loss.shape[0] == shape[0] + ), f"MSE loss calculation failed for shape {shape}." diff --git a/sparse_autoencoder/loss/tests/test_towards_monosemanticity_loss.py b/sparse_autoencoder/loss/tests/test_towards_monosemanticity_loss.py new file mode 100644 index 00000000..98ae2f8d --- /dev/null +++ b/sparse_autoencoder/loss/tests/test_towards_monosemanticity_loss.py @@ -0,0 +1,43 @@ +"""Test the loss function from the Towards Monosemanticity paper.""" +import torch + +from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss +from sparse_autoencoder.loss.mse_reconstruction_loss import MSEReconstructionLoss +from sparse_autoencoder.loss.reducer import LossReducer + + +class TestTowardsMonosemanticityLoss: + """Test the loss function from the Towards Monosemanticity paper.""" + + def test_loss(self) -> None: + """Test loss against a non-vectorised approach.""" + # Calculate the expected loss with a non-vectorised approach + input_activations: list[float] = [3.0, 4] + learned_activations: list[float] = [2.0, -3] + output_activations: list[float] = [1.0, 5] + l1_coefficient = 0.5 + + squared_errors: float = 0.0 + for i, o in zip(input_activations, output_activations, strict=True): + squared_errors += (i - o) ** 2 + mse = squared_errors / len(input_activations) + + l1_penalty: float = 0.0 + for neuron in learned_activations: + l1_penalty += abs(neuron) * l1_coefficient + + expected: float = mse + l1_penalty + + # Compare against the actual loss function + loss = LossReducer( + MSEReconstructionLoss(), + LearnedActivationsL1Loss(l1_coefficient), + ) + + result, _metrics = loss( + torch.tensor(input_activations).unsqueeze(0), + torch.tensor(learned_activations).unsqueeze(0), + torch.tensor(output_activations).unsqueeze(0), + ) + + assert torch.allclose(result, torch.tensor(expected)) diff --git a/sparse_autoencoder/metrics/abstract_metric.py b/sparse_autoencoder/metrics/abstract_metric.py new file mode 100644 index 00000000..b9dacd82 --- /dev/null +++ b/sparse_autoencoder/metrics/abstract_metric.py @@ -0,0 +1,93 @@ +"""Abstract metric classes.""" +from abc import ABC, abstractmethod +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, final + +from sparse_autoencoder.tensor_types import ( + InputOutputActivationBatch, + LearnedActivationBatch, +) + + +@dataclass +class GenerateMetricData: + """Generate metric data.""" + + generated_activations: InputOutputActivationBatch + + +@dataclass +class TrainMetricData: + """Train metric data.""" + + input_activations: InputOutputActivationBatch + + learned_activations: LearnedActivationBatch + + decoded_activations: InputOutputActivationBatch + + +@dataclass +class ValidationMetricData: + """Validation metric data.""" + + source_model_loss: float + + autoencoder_loss: float + + +class AbstractMetric(ABC): + """Abstract metric.""" + + _should_log_progress_bar: bool + + _should_log_weights_and_biases: bool + + @final + def __init__(self, *, log_progress_bar: bool = False, log_weights_and_biases: bool = True): + """Initialise the train metric.""" + self._should_log_progress_bar = log_progress_bar + self._should_log_weights_and_biases = log_weights_and_biases + + +class AbstractGenerateMetric(AbstractMetric, ABC): + """Abstract generate metric.""" + + @abstractmethod + def create_progress_bar_postfix(self, data: GenerateMetricData) -> OrderedDict[str, Any]: + """Create a progress bar postfix.""" + raise NotImplementedError + + @abstractmethod + def create_weights_and_biases_log(self, data: GenerateMetricData) -> OrderedDict[str, Any]: + """Create a log item for Weights and Biases.""" + raise NotImplementedError + + +class AbstractTrainMetric(AbstractMetric, ABC): + """Abstract train metric.""" + + @abstractmethod + def create_progress_bar_postfix(self, data: TrainMetricData) -> OrderedDict[str, Any]: + """Create a progress bar postfix.""" + raise NotImplementedError + + @abstractmethod + def create_weights_and_biases_log(self, data: TrainMetricData) -> OrderedDict[str, Any]: + """Create a log item for Weights and Biases.""" + raise NotImplementedError + + +class AbstractValidationMetric(AbstractMetric, ABC): + """Abstract validation metric.""" + + @abstractmethod + def create_progress_bar_postfix(self, data: ValidationMetricData) -> OrderedDict[str, Any]: + """Create a progress bar postfix.""" + raise NotImplementedError + + @abstractmethod + def create_weights_and_biases_log(self, data: ValidationMetricData) -> OrderedDict[str, Any]: + """Create a log item for Weights and Biases.""" + raise NotImplementedError diff --git a/sparse_autoencoder/optimizer/abstract_optimizer.py b/sparse_autoencoder/optimizer/abstract_optimizer.py new file mode 100644 index 00000000..da0d3b5a --- /dev/null +++ b/sparse_autoencoder/optimizer/abstract_optimizer.py @@ -0,0 +1,15 @@ +"""Abstract optimizer with reset.""" +from abc import ABC, abstractmethod + + +class AbstractOptimizerWithReset(ABC): + """Abstract optimizer with reset.""" + + @abstractmethod + def reset_state_all_parameters(self) -> None: + """Reset the state for all parameters. + + Resets any optimizer state (e.g. momentum). This is for use after manually editing model + parameters (e.g. with activation resampling). + """ + raise NotImplementedError diff --git a/sparse_autoencoder/optimizer/adam_with_reset.py b/sparse_autoencoder/optimizer/adam_with_reset.py index e39a2584..bfdfa001 100644 --- a/sparse_autoencoder/optimizer/adam_with_reset.py +++ b/sparse_autoencoder/optimizer/adam_with_reset.py @@ -3,15 +3,19 @@ This reset method is useful when resampling dead neurons during training. """ from collections.abc import Iterator +from typing import final -from jaxtyping import Int from torch import Tensor from torch.nn.parameter import Parameter from torch.optim import Adam from torch.optim.optimizer import params_t +from sparse_autoencoder.optimizer.abstract_optimizer import AbstractOptimizerWithReset +from sparse_autoencoder.tensor_types import DeadNeuronIndices -class AdamWithReset(Adam): + +@final +class AdamWithReset(Adam, AbstractOptimizerWithReset): """Adam Optimizer with a reset method. The :meth:`reset_state_all_parameters` and :meth:`reset_neurons_state` methods are useful when @@ -139,7 +143,7 @@ def _get_parameter_name_idx(self, parameter_name: str) -> int: def reset_neurons_state( self, parameter_name: str, - neuron_indices: Int[Tensor, " neuron_idx"], + neuron_indices: DeadNeuronIndices, axis: int, parameter_group: int = 0, ) -> None: diff --git a/sparse_autoencoder/source_data/abstract_dataset.py b/sparse_autoencoder/source_data/abstract_dataset.py index 83787c9d..86825f19 100644 --- a/sparse_autoencoder/source_data/abstract_dataset.py +++ b/sparse_autoencoder/source_data/abstract_dataset.py @@ -3,11 +3,11 @@ from typing import Any, Generic, TypedDict, TypeVar, final from datasets import IterableDataset, load_dataset -from jaxtyping import Int -from torch import Tensor from torch.utils.data import DataLoader from torch.utils.data import Dataset as TorchDataset +from sparse_autoencoder.tensor_types import BatchTokenizedPrompts + TokenizedPrompt = list[int] """A tokenized prompt.""" @@ -22,7 +22,7 @@ class TokenizedPrompts(TypedDict): class TorchTokenizedPrompts(TypedDict): """Tokenized prompts prepared for PyTorch.""" - input_ids: Int[Tensor, "batch pos"] + input_ids: BatchTokenizedPrompts HuggingFaceDatasetItem = TypeVar("HuggingFaceDatasetItem", bound=Any) diff --git a/sparse_autoencoder/src_model/store_activations_hook.py b/sparse_autoencoder/src_model/store_activations_hook.py index 9100408d..119ddc57 100644 --- a/sparse_autoencoder/src_model/store_activations_hook.py +++ b/sparse_autoencoder/src_model/store_activations_hook.py @@ -1,16 +1,15 @@ """TransformerLens Hook for storing activations.""" -from jaxtyping import Float -from torch import Tensor from transformer_lens.hook_points import HookPoint from sparse_autoencoder.activation_store.base_store import ActivationStore +from sparse_autoencoder.tensor_types import SourceModelActivations def store_activations_hook( - value: Float[Tensor, "*any neuron"], - hook: HookPoint, # noqa: ARG001 as needed by TransformerLens + value: SourceModelActivations, + hook: HookPoint, # noqa: ARG001 store: ActivationStore, -) -> Float[Tensor, "*any neuron"]: +) -> SourceModelActivations: """Store Activations Hook. Useful for getting just the specific activations wanted, rather than the full cache. diff --git a/sparse_autoencoder/src_model/tests/test_store_activations_hook.py b/sparse_autoencoder/src_model/tests/test_store_activations_hook.py index f9dddfa8..00ef2f49 100644 --- a/sparse_autoencoder/src_model/tests/test_store_activations_hook.py +++ b/sparse_autoencoder/src_model/tests/test_store_activations_hook.py @@ -6,6 +6,7 @@ from sparse_autoencoder.activation_store.list_store import ListActivationStore from sparse_autoencoder.src_model.store_activations_hook import store_activations_hook +from sparse_autoencoder.tensor_types import BatchTokenizedPrompts def test_hook_stores_activations() -> None: @@ -18,7 +19,7 @@ def test_hook_stores_activations() -> None: partial(store_activations_hook, store=store), ) - tokens = model.to_tokens("Hello world") + tokens: BatchTokenizedPrompts = model.to_tokens("Hello world") logits = model.forward(tokens, stop_at_layer=2) # type: ignore number_of_tokens = tokens.numel() diff --git a/sparse_autoencoder/tensor_types.py b/sparse_autoencoder/tensor_types.py new file mode 100644 index 00000000..d3b7aa70 --- /dev/null +++ b/sparse_autoencoder/tensor_types.py @@ -0,0 +1,205 @@ +"""Tensor Types. + +Tensor types with axis labels. Note that this uses the `jaxtyping` library, which works with PyTorch +tensors as well. +""" +from enum import auto +from typing import TypeAlias + +from jaxtyping import Float, Int +from strenum import LowercaseStrEnum +from torch import Tensor + + +class Axis(LowercaseStrEnum): + """Tensor axis names. + + Used to annotate tensor types. + + Example: + When used directly it prints a string: + + >>> print(Axis.INPUT_OUTPUT_FEATURE) + input_output_feature + + The primary use is to annotate tensor types: + + >>> from jaxtyping import Float + >>> from torch import Tensor + >>> from typing import TypeAlias + >>> batch: TypeAlias = Float[Tensor, Axis.dims(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] + >>> print(batch) + + + You can also join multiple axis together to represent the dimensions of a tensor: + + >>> print(Axis.dims(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)) + batch input_output_feature + """ + + # Batches + SOURCE_DATA_BATCH = auto() + """Batch of prompts used to generate source model activations.""" + + BATCH = auto() + """Batch of items that the SAE is being trained on.""" + + ITEMS = auto() + """Arbitrary number of items.""" + + # Features + INPUT_OUTPUT_FEATURE = auto() + """Input or output feature (e.g. feature in activation vector from source model).""" + + LEARNT_FEATURE = auto() + """Learn feature (e.g. feature in learnt activation vector).""" + + DEAD_FEATURE = auto() + """Dead feature.""" + + ALIVE_FEATURE = auto() + """Alive feature.""" + + # Feature indices + LEARNT_FEATURE_IDX = auto() + + # Other + POSITION = auto() + """Token position.""" + + SINGLE_ITEM = "" + """Single item axis.""" + + ANY = "*any" + """Any number of axis.""" + + @staticmethod + def dims(*axis: "Axis") -> str: + """Join multiple axis together, to represent the dimensions of a tensor. + + Example: + >>> print(Axis.dims(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)) + batch input_output_feature + + Args: + *axis: Axis to join. + + Returns: + Joined axis string. + """ + return " ".join(axis) + + +# Activation vectors +InputOutputActivationVector: TypeAlias = Float[Tensor, Axis.INPUT_OUTPUT_FEATURE] +"""Input/output activation vector. + +This is either a input activation vector from the source model, or a decoded activation vector +from the autoencoder. +""" + +LearntActivationVector: TypeAlias = Float[Tensor, Axis.LEARNT_FEATURE] +"""Learned activation vector. + +Activation vector from the hidden (learnt) layer of the autoencoder. Typically this is larger than +the input/output activation vector. +""" + +# Activation batches/stores +StoreActivations: TypeAlias = Float[Tensor, Axis.dims(Axis.ITEMS, Axis.INPUT_OUTPUT_FEATURE)] +"""Store of activation vectors. + +This is used to store large numbers of activation vectors from the source model. +""" + +SourceModelActivations: TypeAlias = Float[Tensor, Axis.dims(Axis.ANY, Axis.INPUT_OUTPUT_FEATURE)] +"""Source model activations. + +Can have any number of proceeding dimensions (e.g. an attention head may generate activations of +shape (batch_size, num_heads, seq_len, feature_dim). +""" + +InputOutputActivationBatch: TypeAlias = Float[ + Tensor, Axis.dims(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE) +] +"""Input/output activation batch. + +This is either a batch of input activation vectors from the source model, or a batch of decoded +activation vectors from the autoencoder. +""" + +LearnedActivationBatch: TypeAlias = Float[Tensor, Axis.dims(Axis.BATCH, Axis.LEARNT_FEATURE)] +"""Learned activation batch. + +This is a batch of activation vectors from the hidden (learnt) layer of the autoencoder. Typically +the feature dimension is larger than the input/output activation vector. +""" + +# Statistics +TrainBatchStatistic: TypeAlias = Float[Tensor, Axis.BATCH] +"""Train batch statistic. + +Contains one scalar value per item in the batch. +""" + +# Weights and biases +EncoderWeights: TypeAlias = Float[Tensor, Axis.dims(Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)] +"""Encoder weights. + +These weights are part of the encoder module of the autoencoder, responsible for decompressing the +input data (activations from a source model) into a higher-dimensional representation. + +The dictionary vectors (basis vectors in the learnt feature space), they can be thought of as +columns of this weight matrix, where each column corresponds to a particular feature in the +lower-dimensional space. The sparsity constraint (hopefully) enforces that they respond relatively +strongly to only a small portion of possible input vectors. +""" + +DecoderWeights: TypeAlias = Float[Tensor, Axis.dims(Axis.INPUT_OUTPUT_FEATURE, Axis.LEARNT_FEATURE)] +"""Decoder weights. + +These weights form the decoder part of the autoencoder, which aims to reconstruct the original input +data from the decompressed representation created by the encoder. + +The tensor's shape aligns with the training features and the learnt features. In this case, if we +view the dictionary vectors in the context of reconstruction, they can be thought of as rows in this +weight matrix. +""" + +# Weights and biases updated +NeuronActivity: TypeAlias = Int[Tensor, Axis.LEARNT_FEATURE] +"""Neuron activity. + +Number of times each neuron has fired (since the last reset). +""" + +DeadNeuronIndices: TypeAlias = Int[Tensor, Axis.LEARNT_FEATURE_IDX] +"""Dead neuron indices.""" + +SampledDeadNeuronInputs: TypeAlias = Float[ + Tensor, Axis.dims(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE) +] +"""Sampled dead neuron inputs.""" + +AliveEncoderWeights: TypeAlias = Float[Tensor, Axis.dims(Axis.LEARNT_FEATURE, Axis.ALIVE_FEATURE)] +"""Alive encoder weights.""" + +DeadEncoderNeuronWeightUpdates: TypeAlias = Float[ + Tensor, Axis.dims(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE) +] +"""Dead encoder neuron weight updates.""" + +DeadEncoderNeuronBiasUpdates: TypeAlias = Float[Tensor, Axis.DEAD_FEATURE] +"""Dead encoder neuron bias updates.""" + +DeadDecoderNeuronWeightUpdates: TypeAlias = Float[ + Tensor, Axis.dims(Axis.INPUT_OUTPUT_FEATURE, Axis.DEAD_FEATURE) +] +"""Dead decoder neuron weight updates.""" + +# Other +BatchTokenizedPrompts: TypeAlias = Int[Tensor, Axis.dims(Axis.SOURCE_DATA_BATCH, Axis.POSITION)] +"""Batch of tokenized prompts.""" + +ItemTensor: TypeAlias = Float[Tensor, Axis.SINGLE_ITEM] +"""Single element item tensor.""" diff --git a/sparse_autoencoder/train/abstract_pipeline.py b/sparse_autoencoder/train/abstract_pipeline.py new file mode 100644 index 00000000..1372264c --- /dev/null +++ b/sparse_autoencoder/train/abstract_pipeline.py @@ -0,0 +1,134 @@ +"""Abstract pipeline.""" +from abc import ABC, abstractmethod +from typing import final + +from tqdm.auto import tqdm +from transformer_lens import HookedTransformer + +from sparse_autoencoder.activation_resampler.abstract_activation_resampler import ( + AbstractActivationResampler, +) +from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore +from sparse_autoencoder.autoencoder.model import SparseAutoencoder +from sparse_autoencoder.loss.abstract_loss import AbstractLoss +from sparse_autoencoder.metrics.abstract_metric import ( + AbstractGenerateMetric, + AbstractTrainMetric, + AbstractValidationMetric, +) +from sparse_autoencoder.optimizer.abstract_optimizer import AbstractOptimizerWithReset +from sparse_autoencoder.source_data.abstract_dataset import SourceDataset +from sparse_autoencoder.tensor_types import NeuronActivity + + +class AbstractPipeline(ABC): + """Pipeline for training a Sparse Autoencoder on TransformerLens activations. + + Includes all the key functionality to train a sparse autoencoder, with a specific set of + hyperparameters. + """ + + generate_metrics: list[AbstractGenerateMetric] + + train_metrics: list[AbstractTrainMetric] + + validation_metric: list[AbstractValidationMetric] + + source_model: HookedTransformer + + source_dataset: SourceDataset + + autoencoder: SparseAutoencoder + + loss: AbstractLoss + + optimizer: AbstractOptimizerWithReset + + activation_resampler: AbstractActivationResampler | None + + progress_bar: tqdm | None + + @final + def __init__( + self, + generate_metrics: list[AbstractGenerateMetric], + train_metrics: list[AbstractTrainMetric], + validation_metric: list[AbstractValidationMetric], + source_model: HookedTransformer, + autoencoder: SparseAutoencoder, + source_dataset: SourceDataset, + activation_resampler: AbstractActivationResampler | None, + optimizer: AbstractOptimizerWithReset, + loss: AbstractLoss, + ): + """Initialize the pipeline.""" + self.generate_metrics = generate_metrics + self.train_metrics = train_metrics + self.validation_metric = validation_metric + self.source_model = source_model + self.autoencoder = autoencoder + self.source_dataset = source_dataset + self.activation_resampler = activation_resampler + self.optimizer = optimizer + self.loss = loss + + @abstractmethod + def generate_activations(self) -> TensorActivationStore: + """Generate activations.""" + raise NotImplementedError + + @abstractmethod + def train_autoencoder(self, activations: TensorActivationStore) -> NeuronActivity: + """Train the sparse autoencoder.""" + raise NotImplementedError + + @abstractmethod + def resample_neurons(self, neuron_activity: NeuronActivity) -> None: + """Resample dead neurons.""" + raise NotImplementedError + + def validate_sae(self) -> None: + """Get validation metrics.""" + raise NotImplementedError + + @final + def run_pipeline( + self, + source_batch_size: int, + resample_frequency: int, + validate_frequency: int, + checkpoint_frequency: int, + max_activations: int, + ) -> None: + """Run the full training pipeline.""" + last_resampled: int = 0 + last_validated: int = 0 + last_checkpoint: int = 0 + neuron_activity: NeuronActivity | None = None + + for _ in tqdm(range(0, max_activations, source_batch_size), title="Activations trained on"): + # Generate + activations: TensorActivationStore = self.generate_activations() + + # Train + batch_neuron_activity: NeuronActivity = self.train_autoencoder(activations) + detached_neuron_activity = batch_neuron_activity.detach().cpu() + if neuron_activity: + neuron_activity.add_(detached_neuron_activity) + else: + neuron_activity = detached_neuron_activity + + # Resample dead neurons (if needed) + if last_resampled > resample_frequency: + self.resample_neurons(neuron_activity) + self.last_resampled = 0 + + # Get validation metrics (if needed) + if last_validated > validate_frequency: + self.validate_sae() + self.last_validated = 0 + + # Checkpoint (if needed) + if last_checkpoint > checkpoint_frequency: + self.autoencoder.save_to_hf() + self.last_checkpoint = 0 diff --git a/sparse_autoencoder/train/generate_activations.py b/sparse_autoencoder/train/generate_activations.py index dacdfd9d..72740ee8 100644 --- a/sparse_autoencoder/train/generate_activations.py +++ b/sparse_autoencoder/train/generate_activations.py @@ -1,10 +1,9 @@ """Generate activations for training a Sparse Autoencoder.""" from collections.abc import Iterable from functools import partial +from typing import TYPE_CHECKING -from jaxtyping import Int import torch -from torch import Tensor from transformer_lens import HookedTransformer from sparse_autoencoder.activation_store.base_store import ( @@ -14,6 +13,10 @@ from sparse_autoencoder.src_model.store_activations_hook import store_activations_hook +if TYPE_CHECKING: + from sparse_autoencoder.tensor_types import BatchTokenizedPrompts + + def generate_activations( model: HookedTransformer, layer: int, @@ -77,5 +80,5 @@ def generate_activations( if len(store) + activations_per_batch > total: break - input_ids: Int[Tensor, "batch pos"] = batch["input_ids"].to(device) + input_ids: BatchTokenizedPrompts = batch["input_ids"].to(device) model.forward(input_ids, stop_at_layer=layer + 1) # type: ignore (TLens is typed incorrectly) diff --git a/sparse_autoencoder/train/metrics/capacity.py b/sparse_autoencoder/train/metrics/capacity.py index 7234eb40..e6d94d25 100644 --- a/sparse_autoencoder/train/metrics/capacity.py +++ b/sparse_autoencoder/train/metrics/capacity.py @@ -1,15 +1,15 @@ """Capacity metrics for sets of learned features.""" import einops -from jaxtyping import Float from numpy import histogram import numpy as np from numpy.typing import NDArray import torch -from torch import Tensor import wandb +from sparse_autoencoder.tensor_types import LearnedActivationBatch, TrainBatchStatistic -def calc_capacities(features: Float[Tensor, "n_feats feat_dim"]) -> Float[Tensor, " n_feats"]: + +def calc_capacities(features: LearnedActivationBatch) -> TrainBatchStatistic: """Calculate capacities. Measure the capacity of a set of features as defined in [Polysemanticity and Capacity in Neural Networks](https://arxiv.org/pdf/2210.01892.pdf). @@ -45,7 +45,7 @@ def calc_capacities(features: Float[Tensor, "n_feats feat_dim"]) -> Float[Tensor def wandb_capacities_histogram( - capacities: Float[Tensor, " n_feats"], + capacities: TrainBatchStatistic, ) -> wandb.Histogram: """Create a W&B histogram of the capacities. diff --git a/sparse_autoencoder/train/metrics/feature_density.py b/sparse_autoencoder/train/metrics/feature_density.py index c496f478..a3fc35e4 100644 --- a/sparse_autoencoder/train/metrics/feature_density.py +++ b/sparse_autoencoder/train/metrics/feature_density.py @@ -1,18 +1,18 @@ """Feature density metrics & histogram.""" import einops -from jaxtyping import Float from numpy import histogram import numpy as np from numpy.typing import NDArray import torch -from torch import Tensor import wandb +from sparse_autoencoder.tensor_types import LearnedActivationBatch, LearntActivationVector + def calc_feature_density( - activations: Float[Tensor, "sample activation"], threshold: float = 0.001 -) -> Float[Tensor, " activation"]: + activations: LearnedActivationBatch, threshold: float = 0.001 +) -> LearntActivationVector: """Count how many times each feature was active. Percentage of samples in which each feature was active (i.e. the neuron has "fired"). @@ -31,7 +31,7 @@ def calc_feature_density( Returns: Number of times each feature was active in a sample. """ - has_fired: Float[Tensor, "sample activation"] = torch.gt(activations, threshold).to( + has_fired: LearnedActivationBatch = torch.gt(activations, threshold).to( # Use float as einops requires this (64 as some features are very sparse) dtype=torch.float64 ) @@ -40,7 +40,7 @@ def calc_feature_density( def wandb_feature_density_histogram( - feature_density: Float[Tensor, " activation"], + feature_density: LearntActivationVector, ) -> wandb.Histogram: """Create a W&B histogram of the feature density. diff --git a/sparse_autoencoder/train/metrics/tests/test_capacities.py b/sparse_autoencoder/train/metrics/tests/test_capacities.py index 7bdce927..9b7c5d74 100644 --- a/sparse_autoencoder/train/metrics/tests/test_capacities.py +++ b/sparse_autoencoder/train/metrics/tests/test_capacities.py @@ -2,12 +2,11 @@ import math -from jaxtyping import Float import pytest from syrupy.session import SnapshotSession import torch -from torch import Tensor +from sparse_autoencoder.tensor_types import LearnedActivationBatch, TrainBatchStatistic from sparse_autoencoder.train.metrics.capacity import calc_capacities, wandb_capacities_histogram @@ -31,7 +30,7 @@ ], ) def test_calc_capacities( - features: Float[Tensor, "n_feats feat_dim"], expected_capacities: Float[Tensor, " n_feats"] + features: LearnedActivationBatch, expected_capacities: TrainBatchStatistic ) -> None: """Check that the capacity calculation is correct.""" capacities = calc_capacities(features) diff --git a/sparse_autoencoder/train/pipeline.py b/sparse_autoencoder/train/pipeline.py index f2e03dc5..b5f82c01 100644 --- a/sparse_autoencoder/train/pipeline.py +++ b/sparse_autoencoder/train/pipeline.py @@ -1,10 +1,9 @@ """Training Pipeline.""" from collections.abc import Iterable +from typing import TYPE_CHECKING import warnings -from jaxtyping import Int import torch -from torch import Tensor from torch.utils.data import DataLoader from tqdm.auto import tqdm from tqdm.contrib.logging import logging_redirect_tqdm @@ -20,6 +19,9 @@ from sparse_autoencoder.train.train_autoencoder import train_autoencoder +if TYPE_CHECKING: + from sparse_autoencoder.tensor_types import NeuronActivity + DEFAULT_RESAMPLE_N = 819_200 @@ -114,10 +116,8 @@ def pipeline( # noqa: PLR0913 total_steps: int = 0 activations_since_resampling: int = 0 - neuron_activity: Int[Tensor, " learned_features"] = torch.zeros( - autoencoder.n_learned_features, - dtype=torch.int32, - device=device, + neuron_activity: NeuronActivity = torch.zeros( + autoencoder.n_learned_features, dtype=torch.int32, device=device ) total_activations: int = 0 diff --git a/sparse_autoencoder/train/resample_neurons.py b/sparse_autoencoder/train/resample_neurons.py index b6e48e9e..77f2015d 100644 --- a/sparse_autoencoder/train/resample_neurons.py +++ b/sparse_autoencoder/train/resample_neurons.py @@ -3,14 +3,29 @@ from typing import TYPE_CHECKING from einops import rearrange -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool import torch from torch import Tensor from torch.utils.data import DataLoader from sparse_autoencoder.activation_store.base_store import ActivationStore -from sparse_autoencoder.autoencoder.loss import l1_loss, reconstruction_loss, sae_training_loss from sparse_autoencoder.autoencoder.model import SparseAutoencoder +from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss +from sparse_autoencoder.loss.mse_reconstruction_loss import MSEReconstructionLoss +from sparse_autoencoder.loss.reducer import LossReducer +from sparse_autoencoder.tensor_types import ( + AliveEncoderWeights, + DeadEncoderNeuronWeightUpdates, + DeadNeuronIndices, + DecoderWeights, + EncoderWeights, + InputOutputActivationBatch, + ItemTensor, + LearntActivationVector, + NeuronActivity, + SampledDeadNeuronInputs, + TrainBatchStatistic, +) from sparse_autoencoder.train.sweep_config import SweepParametersRuntime @@ -19,8 +34,8 @@ def get_dead_neuron_indices( - neuron_activity: Int[Tensor, " learned_features"], threshold: int = 0 -) -> Int[Tensor, " dead_neuron"]: + neuron_activity: NeuronActivity, threshold: int = 0 +) -> DeadNeuronIndices: """Identify the indices of neurons that have zero activity. Example: @@ -45,7 +60,7 @@ def compute_loss_and_get_activations( autoencoder: SparseAutoencoder, sweep_parameters: SweepParametersRuntime, num_inputs: int, -) -> tuple[Float[Tensor, " item"], Float[Tensor, "item input_feature"]]: +) -> tuple[TrainBatchStatistic, InputOutputActivationBatch]: """Compute the loss on a random subset of inputs. Computes the loss and also stores the input activations (for use in resampling neurons). @@ -60,8 +75,13 @@ def compute_loss_and_get_activations( A tuple containing the loss per item, and all input activations. """ with torch.no_grad(): - loss_batches: list[Float[Tensor, " batch_item"]] = [] - input_activations_batches: list[Float[Tensor, "batch_item input_feature"]] = [] + loss, _metrics = LossReducer( + MSEReconstructionLoss(), + LearnedActivationsL1Loss(sweep_parameters.l1_coefficient), + ) + + loss_batches: list[TrainBatchStatistic] = [] + input_activations_batches: list[InputOutputActivationBatch] = [] batch_size: int = sweep_parameters.batch_size dataloader = DataLoader(store, batch_size=batch_size) batches: int = num_inputs // batch_size @@ -69,13 +89,7 @@ def compute_loss_and_get_activations( for batch_idx, batch in enumerate(iter(dataloader)): input_activations_batches.append(batch) learned_activations, reconstructed_activations = autoencoder(batch) - loss_batches.append( - sae_training_loss( - reconstruction_loss(batch, reconstructed_activations), - l1_loss(learned_activations), - sweep_parameters.l1_coefficient, - ) - ) + loss_batches.append(loss.forward(batch, learned_activations, reconstructed_activations)) if batch_idx >= batches: break @@ -92,7 +106,7 @@ def compute_loss_and_get_activations( return loss, input_activations -def assign_sampling_probabilities(loss: Float[Tensor, " item"]) -> Tensor: +def assign_sampling_probabilities(loss: TrainBatchStatistic) -> Tensor: """Assign the sampling probabilities for each input activations vector. Assign each input vector a probability of being picked that is proportional to the square of @@ -114,10 +128,10 @@ def assign_sampling_probabilities(loss: Float[Tensor, " item"]) -> Tensor: def sample_input( - probabilities: Float[Tensor, " item"], - input_activations: Float[Tensor, "item input_feature"], + probabilities: TrainBatchStatistic, + input_activations: InputOutputActivationBatch, num_samples: int, -) -> Float[Tensor, "dead_neuron input_feature"]: +) -> SampledDeadNeuronInputs: """Sample an input vector based on the provided probabilities. Example: @@ -152,17 +166,15 @@ def sample_input( device=input_activations.device, ) - sample_indices: Int[Tensor, " dead_neuron"] = torch.multinomial( - probabilities, num_samples=num_samples - ) + sample_indices: DeadNeuronIndices = torch.multinomial(probabilities, num_samples=num_samples) return input_activations[sample_indices, :] def renormalize_and_scale( - sampled_input: Float[Tensor, "dead_neuron input_feature"], - neuron_activity: Int[Tensor, " learned_features"], - encoder_weight: Float[Tensor, "learned_feature input_feature"], -) -> Float[Tensor, "dead_neuron input_feature"]: + sampled_input: SampledDeadNeuronInputs, + neuron_activity: NeuronActivity, + encoder_weight: EncoderWeights, +) -> DeadEncoderNeuronWeightUpdates: """Renormalize and scale the resampled dictionary vectors. Renormalize the input vector to equal the average norm of the encoder weights for alive neurons @@ -202,21 +214,19 @@ def renormalize_and_scale( ) # Calculate the average norm of the encoder weights for alive neurons. - alive_encoder_weights: Float[Tensor, "learned_feature alive_input_features"] = encoder_weight[ - alive_neuron_mask, : - ] - average_alive_norm: Float[Tensor, 1] = alive_encoder_weights.norm(dim=-1).mean() + alive_encoder_weights: AliveEncoderWeights = encoder_weight[alive_neuron_mask, :] + average_alive_norm: ItemTensor = alive_encoder_weights.norm(dim=-1).mean() # Renormalize the input vector to equal the average norm of the encoder weights for alive # neurons times 0.2. - renormalized_input: Float[Tensor, "dead_neuron input_feature"] = torch.nn.functional.normalize( + renormalized_input: SampledDeadNeuronInputs = torch.nn.functional.normalize( sampled_input, dim=-1 ) return renormalized_input * (average_alive_norm * 0.2) def resample_dead_neurons( - neuron_activity: Int[Tensor, " learned_features"], + neuron_activity: NeuronActivity, store: ActivationStore, autoencoder: SparseAutoencoder, sweep_parameters: SweepParametersRuntime, @@ -269,27 +279,27 @@ def resample_dead_neurons( # Assign each input vector a probability of being picked that is proportional to the square # of the autoencoder's loss on that input. - sample_probabilities: Float[Tensor, " item"] = assign_sampling_probabilities(loss) + sample_probabilities: TrainBatchStatistic = assign_sampling_probabilities(loss) # Get references to the encoder and decoder parameters encoder_linear: torch.nn.Linear = autoencoder.encoder.get_submodule("Linear") # type: ignore decoder_linear: ConstrainedUnitNormLinear = autoencoder.decoder.get_submodule( "ConstrainedUnitNormLinear" ) # type: ignore - encoder_weight: Float[Tensor, "learned_feature input_feature"] = encoder_linear.weight - encoder_bias: Float[Tensor, " learned_feature"] = encoder_linear.bias - decoder_weight: Float[Tensor, "input_feature learned_feature"] = decoder_linear.weight + encoder_weight: EncoderWeights = encoder_linear.weight + encoder_bias: LearntActivationVector = encoder_linear.bias + decoder_weight: DecoderWeights = decoder_linear.weight # For each dead neuron sample an input according to these probabilities. - sampled_input: Float[Tensor, "dead_neuron input_feature"] = sample_input( + sampled_input: SampledDeadNeuronInputs = sample_input( sample_probabilities, input_activations, len(dead_neuron_indices) ) # Renormalize the input vector to have unit L2 norm and set this to be the dictionary # vector for the dead autoencoder neuron. - renormalized_input: Float[ - Tensor, "dead_neuron input_feature" - ] = torch.nn.functional.normalize(sampled_input, dim=-1) + renormalized_input: SampledDeadNeuronInputs = torch.nn.functional.normalize( + sampled_input, dim=-1 + ) decoder_weight[:, dead_neuron_indices] = rearrange( renormalized_input, "dead_neuron input_feature -> input_feature dead_neuron" diff --git a/sparse_autoencoder/train/tests/test_resample_neurons.py b/sparse_autoencoder/train/tests/test_resample_neurons.py index afa042de..0ff6d6f1 100644 --- a/sparse_autoencoder/train/tests/test_resample_neurons.py +++ b/sparse_autoencoder/train/tests/test_resample_neurons.py @@ -1,7 +1,6 @@ """Tests for the resample_neurons module.""" import copy -from jaxtyping import Float, Int import pytest import torch from torch import Tensor @@ -9,6 +8,11 @@ from sparse_autoencoder.activation_store.base_store import ActivationStore from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore from sparse_autoencoder.autoencoder.model import SparseAutoencoder +from sparse_autoencoder.tensor_types import ( + AliveEncoderWeights, + NeuronActivity, + SampledDeadNeuronInputs, +) from sparse_autoencoder.train.resample_neurons import ( assign_sampling_probabilities, compute_loss_and_get_activations, @@ -205,10 +209,10 @@ class TestRenormalizeAndScale: @staticmethod def calculate_expected_output( - sampled_input: Float[Tensor, "dead_neuron input_feature"], - neuron_activity: Int[Tensor, " learned_features"], - encoder_weight: Float[Tensor, "learned_feature input_feature"], - ) -> Float[Tensor, "dead_neuron input_feature"]: + sampled_input: SampledDeadNeuronInputs, + neuron_activity: NeuronActivity, + encoder_weight: AliveEncoderWeights, + ) -> SampledDeadNeuronInputs: """Non-vectorized approach to compare against.""" # Initialize variables total_norm = 0 diff --git a/sparse_autoencoder/train/train_autoencoder.py b/sparse_autoencoder/train/train_autoencoder.py index cffa3097..cca5e063 100644 --- a/sparse_autoencoder/train/train_autoencoder.py +++ b/sparse_autoencoder/train/train_autoencoder.py @@ -1,18 +1,16 @@ """Training Pipeline.""" -from jaxtyping import Float, Int import torch -from torch import Tensor, device +from torch import device from torch.optim import Optimizer from torch.utils.data import DataLoader import wandb from sparse_autoencoder.activation_store.base_store import ActivationStore -from sparse_autoencoder.autoencoder.loss import ( - l1_loss, - reconstruction_loss, - sae_training_loss, -) from sparse_autoencoder.autoencoder.model import SparseAutoencoder +from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss +from sparse_autoencoder.loss.mse_reconstruction_loss import MSEReconstructionLoss +from sparse_autoencoder.loss.reducer import LossReducer +from sparse_autoencoder.tensor_types import LearntActivationVector, NeuronActivity from sparse_autoencoder.train.sweep_config import SweepParametersRuntime @@ -24,7 +22,7 @@ def train_autoencoder( previous_steps: int, log_interval: int = 10, device: device | None = None, -) -> tuple[int, Float[Tensor, " learned_feature"]]: +) -> tuple[int, LearntActivationVector]: """Sparse Autoencoder Training Loop. Args: @@ -45,10 +43,15 @@ def train_autoencoder( batch_size=sweep_parameters.batch_size, ) - learned_activations_fired_count: Int[Tensor, " learned_feature"] = torch.zeros( + learned_activations_fired_count: NeuronActivity = torch.zeros( autoencoder.n_learned_features, dtype=torch.int32, device=device ) + loss = LossReducer( + MSEReconstructionLoss(), + LearnedActivationsL1Loss(sweep_parameters.l1_coefficient), + ) + step: int = 0 # Initialize step for step, store_batch in enumerate(activations_dataloader): # Zero the gradients @@ -61,15 +64,8 @@ def train_autoencoder( learned_activations, reconstructed_activations = autoencoder(batch) # Get metrics - reconstruction_loss_mse: Float[Tensor, " item"] = reconstruction_loss( - batch, - reconstructed_activations, - ) - l1_loss_learned_activations: Float[Tensor, " item"] = l1_loss(learned_activations) - total_loss: Float[Tensor, " item"] = sae_training_loss( - reconstruction_loss_mse, - l1_loss_learned_activations, - sweep_parameters.l1_coefficient, + total_loss, metrics = loss.batch_scalar_loss_with_log( + batch, learned_activations, reconstructed_activations ) # Store count of how many neurons have fired @@ -78,18 +74,12 @@ def train_autoencoder( learned_activations_fired_count.add_(fired.sum(dim=0)) # Backwards pass - total_loss.mean().backward() + total_loss.backward() optimizer.step() # Log if step % log_interval == 0 and wandb.run is not None: - wandb.log( - { - "reconstruction_loss": reconstruction_loss_mse.mean().item(), - "l1_loss": l1_loss_learned_activations.mean().item(), - "loss": total_loss.mean().item(), - }, - ) + wandb.log(metrics) current_step = previous_steps + step + 1