generated from alan-cooney/transformer-lens-starter-template
-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move towards abstract final pattern (#74)
- Loading branch information
1 parent
75c9b9e
commit 0059c3d
Showing
44 changed files
with
1,476 additions
and
387 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,33 @@ | ||
"""Sparse Autoencoder Library.""" | ||
from sparse_autoencoder.activation_store import ( | ||
ActivationStore, | ||
ActivationStoreBatch, | ||
ActivationStoreItem, | ||
DiskActivationStore, | ||
ListActivationStore, | ||
TensorActivationStore, | ||
) | ||
from sparse_autoencoder.autoencoder.model import SparseAutoencoder | ||
from sparse_autoencoder.loss import ( | ||
AbstractLoss, | ||
LearnedActivationsL1Loss, | ||
LossLogType, | ||
LossReducer, | ||
LossReductionType, | ||
MSEReconstructionLoss, | ||
) | ||
from sparse_autoencoder.train.pipeline import pipeline | ||
|
||
|
||
__all__ = [ | ||
"AbstractLoss", | ||
"ActivationStore", | ||
"ActivationStoreBatch", | ||
"ActivationStoreItem", | ||
"DiskActivationStore", | ||
"LearnedActivationsL1Loss", | ||
"ListActivationStore", | ||
"TensorActivationStore", | ||
"LossLogType", | ||
"LossReducer", | ||
"LossReductionType", | ||
"MSEReconstructionLoss", | ||
"SparseAutoencoder", | ||
"TensorActivationStore", | ||
"pipeline", | ||
] |
47 changes: 47 additions & 0 deletions
47
sparse_autoencoder/activation_resampler/abstract_activation_resampler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
"""Abstract activation resampler.""" | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore | ||
from sparse_autoencoder.tensor_types import ( | ||
DeadDecoderNeuronWeightUpdates, | ||
DeadEncoderNeuronBiasUpdates, | ||
DeadEncoderNeuronWeightUpdates, | ||
NeuronActivity, | ||
) | ||
|
||
|
||
class AbstractActivationResampler(ABC): | ||
"""Abstract activation resampler.""" | ||
|
||
@abstractmethod | ||
def resample_dead_neurons( | ||
self, | ||
neuron_activity: NeuronActivity, | ||
store: TensorActivationStore, | ||
num_input_activations: int = 819_200, | ||
) -> tuple[ | ||
DeadEncoderNeuronWeightUpdates, DeadEncoderNeuronBiasUpdates, DeadDecoderNeuronWeightUpdates | ||
]: | ||
"""Resample dead neurons. | ||
Over the course of training, a subset of autoencoder neurons will have zero activity across | ||
a large number of datapoints. The authors of *Towards Monosemanticity: Decomposing Language | ||
Models With Dictionary Learning* found that “resampling” these dead neurons during training | ||
improves the number of likely-interpretable features (i.e., those in the high density | ||
cluster) and reduces total loss. This resampling may be compatible with the Lottery Ticket | ||
Hypothesis and increase the number of chances the network has to find promising feature | ||
directions. | ||
Warning: | ||
The optimizer should be reset after applying this function, as the Adam state will be | ||
incorrect for the modified weights and biases. | ||
Args: | ||
neuron_activity: Number of times each neuron fired. store: Activation store. | ||
store: TODO change. | ||
num_input_activations: Number of input activations to use when resampling. Will be | ||
rounded down to be divisible by the batch size, and cannot be larger than the number | ||
of items currently in the store. | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.