Skip to content

Commit

Permalink
Move towards abstract final pattern (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Nov 16, 2023
1 parent 75c9b9e commit 0059c3d
Show file tree
Hide file tree
Showing 44 changed files with 1,476 additions and 387 deletions.
3 changes: 3 additions & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"docstrings",
"donjayamanne",
"dtype",
"dtypes",
"dunder",
"earlyterminate",
"einops",
Expand All @@ -36,6 +37,7 @@
"intuniform",
"invloguniform",
"invloguniformvalues",
"itemwise",
"jaxtyping",
"kaiming",
"keepdim",
Expand Down Expand Up @@ -71,6 +73,7 @@
"randn",
"randperm",
"relu",
"resampler",
"resid",
"rtol",
"runcap",
Expand Down
19 changes: 17 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
[tool.poetry.dependencies]
einops=">=0.6"
python=">=3.10, <3.12"
strenum="^0.4.15"
torch=">=2.1"
wandb=">=0.15.12"
zstandard="^0.22.0" # Required for downloading datasets such as The Pile
Expand Down Expand Up @@ -125,6 +126,7 @@
"ARG001", # Fixtures often have unused arguments
"PT004", # Fixtures don't return anything
"S101", # Assert is needed in PyTest
"TCH001", # Don't need to mark type-only imports
]

[tool.ruff.lint.pydocstyle]
Expand Down
20 changes: 15 additions & 5 deletions sparse_autoencoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
"""Sparse Autoencoder Library."""
from sparse_autoencoder.activation_store import (
ActivationStore,
ActivationStoreBatch,
ActivationStoreItem,
DiskActivationStore,
ListActivationStore,
TensorActivationStore,
)
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.loss import (
AbstractLoss,
LearnedActivationsL1Loss,
LossLogType,
LossReducer,
LossReductionType,
MSEReconstructionLoss,
)
from sparse_autoencoder.train.pipeline import pipeline


__all__ = [
"AbstractLoss",
"ActivationStore",
"ActivationStoreBatch",
"ActivationStoreItem",
"DiskActivationStore",
"LearnedActivationsL1Loss",
"ListActivationStore",
"TensorActivationStore",
"LossLogType",
"LossReducer",
"LossReductionType",
"MSEReconstructionLoss",
"SparseAutoencoder",
"TensorActivationStore",
"pipeline",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Abstract activation resampler."""

from abc import ABC, abstractmethod

from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
from sparse_autoencoder.tensor_types import (
DeadDecoderNeuronWeightUpdates,
DeadEncoderNeuronBiasUpdates,
DeadEncoderNeuronWeightUpdates,
NeuronActivity,
)


class AbstractActivationResampler(ABC):
"""Abstract activation resampler."""

@abstractmethod
def resample_dead_neurons(
self,
neuron_activity: NeuronActivity,
store: TensorActivationStore,
num_input_activations: int = 819_200,
) -> tuple[
DeadEncoderNeuronWeightUpdates, DeadEncoderNeuronBiasUpdates, DeadDecoderNeuronWeightUpdates
]:
"""Resample dead neurons.
Over the course of training, a subset of autoencoder neurons will have zero activity across
a large number of datapoints. The authors of *Towards Monosemanticity: Decomposing Language
Models With Dictionary Learning* found that “resampling” these dead neurons during training
improves the number of likely-interpretable features (i.e., those in the high density
cluster) and reduces total loss. This resampling may be compatible with the Lottery Ticket
Hypothesis and increase the number of chances the network has to find promising feature
directions.
Warning:
The optimizer should be reset after applying this function, as the Adam state will be
incorrect for the modified weights and biases.
Args:
neuron_activity: Number of times each neuron fired. store: Activation store.
store: TODO change.
num_input_activations: Number of input activations to use when resampling. Will be
rounded down to be divisible by the batch size, and cannot be larger than the number
of items currently in the store.
"""
raise NotImplementedError
4 changes: 1 addition & 3 deletions sparse_autoencoder/activation_store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""Activation Stores."""
from .base_store import ActivationStore, ActivationStoreBatch, ActivationStoreItem
from .base_store import ActivationStore
from .disk_store import DiskActivationStore
from .list_store import ListActivationStore
from .tensor_store import TensorActivationStore


_all__ = [
ActivationStore,
ActivationStoreBatch,
ActivationStoreItem,
DiskActivationStore,
ListActivationStore,
TensorActivationStore,
Expand Down
32 changes: 8 additions & 24 deletions sparse_autoencoder/activation_store/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,13 @@
from concurrent.futures import Future
from typing import final

from jaxtyping import Float
import torch
from torch import Tensor
from torch.utils.data import Dataset

from sparse_autoencoder.tensor_types import InputOutputActivationBatch, InputOutputActivationVector

ActivationStoreItem = Float[Tensor, "neuron"]
"""Activation Store Dataset Item Type.

A single vector containing activations. For example this could be the activations from a specific
MLP layer, for a specific position and batch item.
"""

ActivationStoreBatch = Float[Tensor, "*any neuron"]
"""Activation Store Dataset Batch Type.
This can be e.g. a [batch, pos, neurons] tensor, containing activations from a specific MLP layer
in a transformer. Alternatively, it could be e.g. a [batch, pos, head_idx, neurons] tensor from an
attention layer.
"""


class ActivationStore(Dataset[ActivationStoreItem], ABC):
class ActivationStore(Dataset[InputOutputActivationVector], ABC):
"""Activation Store Abstract Class.
Extends the `torch.utils.data.Dataset` class to provide an activation store, with additional
Expand All @@ -43,16 +27,16 @@ class ActivationStore(Dataset[ActivationStoreItem], ABC):
... super().__init__()
... self._data = [] # In this example, we just store in a list
...
... def append(self, item: ActivationStoreItem) -> None:
... def append(self, item) -> None:
... self._data.append(item)
...
... def extend(self, batch: ActivationStoreBatch):
... def extend(self, batch):
... self._data.extend(batch)
...
... def empty(self):
... self._data = []
...
... def __getitem__(self, index: int) -> ActivationStoreItem:
... def __getitem__(self, index: int):
... return self._data[index]
...
... def __len__(self) -> int:
Expand All @@ -65,12 +49,12 @@ class ActivationStore(Dataset[ActivationStoreItem], ABC):
"""

@abstractmethod
def append(self, item: ActivationStoreItem) -> Future | None:
def append(self, item: InputOutputActivationVector) -> Future | None:
"""Add a Single Item to the Store."""
raise NotImplementedError

@abstractmethod
def extend(self, batch: ActivationStoreBatch) -> Future | None:
def extend(self, batch: InputOutputActivationBatch) -> Future | None:
"""Add a Batch to the Store."""
raise NotImplementedError

Expand All @@ -85,7 +69,7 @@ def __len__(self) -> int:
raise NotImplementedError

@abstractmethod
def __getitem__(self, index: int) -> ActivationStoreItem:
def __getitem__(self, index: int) -> InputOutputActivationVector:
"""Get an Item from the Store."""
raise NotImplementedError

Expand Down
14 changes: 8 additions & 6 deletions sparse_autoencoder/activation_store/disk_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@

from sparse_autoencoder.activation_store.base_store import (
ActivationStore,
ActivationStoreBatch,
ActivationStoreItem,
)
from sparse_autoencoder.activation_store.utils.extend_resize import (
resize_to_list_vectors,
)
from sparse_autoencoder.tensor_types import (
InputOutputActivationVector,
SourceModelActivations,
)


DEFAULT_DISK_ACTIVATION_STORE_PATH = Path(tempfile.gettempdir()) / "activation_store"
Expand Down Expand Up @@ -132,7 +134,7 @@ def _write_to_disk(self, *, wait_for_max: bool = False) -> None:
filename = f"{self.__len__}.pt"
torch.save(stacked_activations, self._storage_path / filename)

def append(self, item: ActivationStoreItem) -> Future | None:
def append(self, item: InputOutputActivationVector) -> Future | None:
"""Add a Single Item to the Store.
Example:
Expand All @@ -158,7 +160,7 @@ def append(self, item: ActivationStoreItem) -> Future | None:

return None # Keep mypy happy

def extend(self, batch: ActivationStoreBatch) -> Future | None:
def extend(self, batch: SourceModelActivations) -> Future | None:
"""Add a Batch to the Store.
Example:
Expand All @@ -175,7 +177,7 @@ def extend(self, batch: ActivationStoreBatch) -> Future | None:
Future that completes when the activation vectors have queued to be written to disk, and
if needed, written to disk.
"""
items: list[ActivationStoreItem] = resize_to_list_vectors(batch)
items: list[InputOutputActivationVector] = resize_to_list_vectors(batch)

with self._cache_lock:
self._cache.extend(items)
Expand Down Expand Up @@ -228,7 +230,7 @@ def empty(self) -> None:
file.unlink()
self._disk_n_activation_vectors.value = 0

def __getitem__(self, index: int) -> ActivationStoreItem:
def __getitem__(self, index: int) -> InputOutputActivationVector:
"""Get Item Dunder Method.
Args:
Expand Down
20 changes: 11 additions & 9 deletions sparse_autoencoder/activation_store/list_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@

from sparse_autoencoder.activation_store.base_store import (
ActivationStore,
ActivationStoreBatch,
ActivationStoreItem,
)
from sparse_autoencoder.activation_store.utils.extend_resize import (
resize_to_list_vectors,
)
from sparse_autoencoder.tensor_types import (
InputOutputActivationVector,
SourceModelActivations,
)


class ListActivationStore(ActivationStore):
Expand Down Expand Up @@ -69,7 +71,7 @@ class ListActivationStore(ActivationStore):
torch.Size([2, 100])
"""

_data: list[ActivationStoreItem] | ListProxy
_data: list[InputOutputActivationVector] | ListProxy
"""Underlying List Data Store."""

_device: torch.device | None
Expand All @@ -92,7 +94,7 @@ class ListActivationStore(ActivationStore):

def __init__(
self,
data: list[ActivationStoreItem] | None = None,
data: list[InputOutputActivationVector] | None = None,
device: torch.device | None = None,
max_workers: int | None = None,
*,
Expand Down Expand Up @@ -166,7 +168,7 @@ def __sizeof__(self) -> int:

return total_tensors_size + list_of_pointers_size

def __getitem__(self, index: int) -> ActivationStoreItem:
def __getitem__(self, index: int) -> InputOutputActivationVector:
"""Get Item Dunder Method.
Example:
Expand Down Expand Up @@ -205,7 +207,7 @@ def shuffle(self) -> None:
self.wait_for_writes_to_complete()
random.shuffle(self._data)

def append(self, item: ActivationStoreItem) -> Future | None:
def append(self, item: InputOutputActivationVector) -> Future | None:
"""Append a single item to the dataset.
Note **append is blocking**. For better performance use extend instead with batches.
Expand All @@ -223,7 +225,7 @@ def append(self, item: ActivationStoreItem) -> Future | None:
"""
self._data.append(item.to(self._device))

def _extend(self, batch: ActivationStoreBatch) -> None:
def _extend(self, batch: SourceModelActivations) -> None:
"""Extend threadpool method.
To be called by :meth:`extend`.
Expand All @@ -233,13 +235,13 @@ def _extend(self, batch: ActivationStoreBatch) -> None:
"""
try:
# Unstack to a list of tensors
items: list[ActivationStoreItem] = resize_to_list_vectors(batch)
items: list[InputOutputActivationVector] = resize_to_list_vectors(batch)

self._data.extend(items)
except Exception as e: # noqa: BLE001
self._pool_exceptions.append(e)

def extend(self, batch: ActivationStoreBatch) -> Future | None:
def extend(self, batch: SourceModelActivations) -> Future | None:
"""Extend the dataset with multiple items (non-blocking).
Example:
Expand Down
Loading

0 comments on commit 0059c3d

Please sign in to comment.