diff --git a/.vscode/cspell.json b/.vscode/cspell.json index bab2b0c3..fee7316d 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -42,6 +42,7 @@ "intuniform", "invloguniform", "invloguniformvalues", + "ipynb", "itemwise", "jaxtyping", "kaiming", diff --git a/docs/content/SUMMARY.md b/docs/content/SUMMARY.md index 03e1fbb2..0c0ca3c2 100644 --- a/docs/content/SUMMARY.md +++ b/docs/content/SUMMARY.md @@ -1,6 +1,7 @@ * [Home](index.md) * [Demo](demo.ipynb) +* [Flexible Demo](flexible_demo.ipynb) * [Reference](reference/) * [Contributing](contributing.md) * [Citation](citation.md) \ No newline at end of file diff --git a/docs/content/demo.ipynb b/docs/content/demo.ipynb index 762bc330..599d553f 100644 --- a/docs/content/demo.ipynb +++ b/docs/content/demo.ipynb @@ -13,30 +13,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Sparse Autoencoder Training Demo\n", + "# Quick Start Training Demo\n", "\n", - "This demo shows you how to train a sparse autoencoder (SAE) on a\n", - "[TransformerLens](https://github.com/neelnanda-io/TransformerLens) model. It replicates Neel Nanda's\n", + "This is a quick start demo to get training a SAE right away. All you need to do is choose a few\n", + "hyperparameters (like the model to train on), and then set it off.\n", + "By default it replicates Neel Nanda's\n", "[comment on the Anthropic dictionary learning\n", - "paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html#comment-nanda).\n", - "\n", - "## Introduction\n", - "\n", - "The way this library works is that we provide all the components necessary to train a sparse\n", - "autoencoder. For the most part, these are just standard PyTorch modules. For example `AdamWithReset` is\n", - "just an extension of `torch.optim.Adam`, with a few extra bells and whistles that are needed for training a SAE\n", - "(e.g. a method to reset the optimizer state when you are also resampling dead neurons).\n", - "\n", - "This is very flexible - it's easy for you to extend and change just one component if you want, just\n", - "like you'd do with a standard PyTorch mode. It also means it's very easy to see what is going on\n", - "under the hood. However to get you started, the following demo sets up a\n", - "default SAE that uses the implementation that Neel Nanda used in his comment above.\n", - "\n", - "### Approach\n", - "\n", - "The approach is pretty simple - we run a training pipeline that alternates between generating\n", - "activations from a *source model*, and training the *sparse autoencoder* model on these generated\n", - "activations." + "paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html#comment-nanda)." ] }, { @@ -73,52 +56,34 @@ "\n", "# Otherwise enable hot reloading in dev mode\n", "if not in_colab:\n", - " from IPython import get_ipython # type: ignore\n", - "\n", - " ip = get_ipython()\n", - " if ip is not None and ip.extension_manager is not None and not ip.extension_manager.loaded:\n", - " ip.extension_manager.load(\"autoreload\") # type: ignore\n", - " %autoreload 2" + " %load_ext autoreload\n", + " %autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using device: mps\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", - "from pathlib import Path\n", - "\n", - "import torch\n", - "from transformer_lens import HookedTransformer\n", - "from transformer_lens.utils import get_device\n", "\n", "from sparse_autoencoder import (\n", - " ActivationResampler,\n", - " AdamWithReset,\n", - " L2ReconstructionLoss,\n", - " LearnedActivationsL1Loss,\n", - " LossReducer,\n", - " Pipeline,\n", - " PreTokenizedDataset,\n", - " SparseAutoencoder,\n", + " sweep,\n", + " SweepConfig,\n", + " Hyperparameters,\n", + " SourceModelHyperparameters,\n", + " Parameter,\n", + " SourceDataHyperparameters,\n", + " Method,\n", + " LossHyperparameters,\n", + " OptimizerHyperparameters,\n", ")\n", "import wandb\n", "\n", "\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "\n", - "device = get_device()\n", - "print(f\"Using device: {device}\") # You will need a GPU" + "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"demo.ipynb\"" ] }, { @@ -132,359 +97,272 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The way this library works is that you can define your own hyper-parameters and then setup the\n", - "underlying components with them. This is extremely flexible, but to help you get started we've\n", - "included some common ones below along with some sensible defaults. You can also easily sweep through\n", - "multiple hyperparameters with `wandb.sweep`." + "Customize any hyperparameters you want below (by default we're sweeping over l1 coefficient and\n", + "learning rate):" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [], - "source": [ - "torch.random.manual_seed(49)\n", - "\n", - "hyperparameters = {\n", - " # Expansion factor is the number of features in the sparse representation, relative to the\n", - " # number of features in the original MLP layer. The original paper experimented with 1x to 256x,\n", - " # and we have found that 4x is a good starting point.\n", - " \"expansion_factor\": 4,\n", - " # L1 coefficient is the coefficient of the L1 regularization term (used to encourage sparsity).\n", - " \"l1_coefficient\": 3e-4,\n", - " # Adam parameters (set to the default ones here)\n", - " \"lr\": 1e-4,\n", - " \"adam_beta_1\": 0.9,\n", - " \"adam_beta_2\": 0.999,\n", - " \"adam_epsilon\": 1e-8,\n", - " \"adam_weight_decay\": 0.0,\n", - " # Batch sizes\n", - " \"train_batch_size\": 4096,\n", - " \"context_size\": 128,\n", - " # Source model hook point\n", - " \"source_model_name\": \"gelu-2l\",\n", - " \"source_model_dtype\": \"float32\",\n", - " \"source_model_hook_point\": \"blocks.0.hook_mlp_out\",\n", - " \"source_model_hook_point_layer\": 0,\n", - " # Train pipeline parameters\n", - " \"max_store_size\": 384 * 4096 * 2,\n", - " \"max_activations\": 2_000_000_000,\n", - " \"resample_frequency\": 122_880_000,\n", - " \"checkpoint_frequency\": 100_000_000,\n", - " \"validation_frequency\": 384 * 4096 * 2 * 100, # Every 100 generations\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Source Model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The source model is just a [TransformerLens](https://github.com/neelnanda-io/TransformerLens) model\n", - "(see [here](https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", - "for a full list of supported models).\n", - "\n", - "In this example we're training a sparse autoencoder on the activations from the first MLP layer, so\n", - "we'll also get some details about that hook point." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded pretrained model gelu-2l into HookedTransformer\n" - ] - }, { "data": { "text/plain": [ - "'Source: gelu-2l, Hook: blocks.0.hook_mlp_out, Features: 512'" + "SweepConfig(parameters=Hyperparameters(\n", + " source_data=SourceDataHyperparameters(dataset_path=Parameter(value=NeelNanda/c4-code-tokenized-2b), context_size=Parameter(value=128))\n", + " source_model=SourceModelHyperparameters(name=Parameter(value=gelu-2l), hook_site=Parameter(value=mlp_out), hook_layer=Parameter(value=0), hook_dimension=Parameter(value=512), dtype=Parameter(value=float32))\n", + " activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_resamples=Parameter(value=4), n_steps_collate=Parameter(value=100000000), resample_dataset_size=Parameter(value=819200), dead_neuron_threshold=Parameter(value=0.0))\n", + " autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=4))\n", + " loss=LossHyperparameters(l1_coefficient=Parameter(values=[0.001, 0.0001, 1e-05]))\n", + " optimizer=OptimizerHyperparameters(lr=Parameter(values=[0.001, 0.0001, 1e-05]), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n", + " pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=12), train_batch_size=Parameter(value=4096), max_store_size=Parameter(value=3145728), max_activations=Parameter(value=2000000000), checkpoint_frequency=Parameter(value=100000000), validation_frequency=Parameter(value=314572800), validation_number_activations=Parameter(value=1024))\n", + " random_seed=Parameter(value=49)\n", + "), method=, metric=Metric(name='total_loss', goal=, impute=None, imputewhilerunning=None, target=None), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None, runcap=None)" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Source model setup with TransformerLens\n", - "src_model = HookedTransformer.from_pretrained(\n", - " str(hyperparameters[\"source_model_name\"]), dtype=str(hyperparameters[\"source_model_dtype\"])\n", + "sweep_config = SweepConfig(\n", + " parameters=Hyperparameters(\n", + " loss=LossHyperparameters(\n", + " l1_coefficient=Parameter(values=[1e-3, 1e-4, 1e-5]),\n", + " ),\n", + " optimizer=OptimizerHyperparameters(\n", + " lr=Parameter(values=[1e-3, 1e-4, 1e-5]),\n", + " ),\n", + " source_model=SourceModelHyperparameters(\n", + " name=Parameter(\"gelu-2l\"),\n", + " hook_site=Parameter(\"mlp_out\"),\n", + " hook_layer=Parameter(0),\n", + " hook_dimension=Parameter(512),\n", + " ),\n", + " source_data=SourceDataHyperparameters(\n", + " dataset_path=Parameter(\"NeelNanda/c4-code-tokenized-2b\"),\n", + " ),\n", + " ),\n", + " method=Method.RANDOM,\n", ")\n", - "\n", - "# Details about the activations we'll train the sparse autoencoder on\n", - "autoencoder_input_dim: int = src_model.cfg.d_model # type: ignore (TransformerLens typing is currently broken)\n", - "\n", - "f\"Source: {hyperparameters['source_model_name']}, \\\n", - " Hook: {hyperparameters['source_model_hook_point']}, \\\n", - " Features: {autoencoder_input_dim}\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Sparse Autoencoder" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can then setup the sparse autoencoder. The default model (`SparseAutoencoder`) is setup as per\n", - "the original Anthropic paper [Towards Monosemanticity: Decomposing Language Models With Dictionary\n", - "Learning ](https://transformer-circuits.pub/2023/monosemantic-features/index.html).\n", - "\n", - "However it's just a standard PyTorch model, so you can create your own model instead if you want to\n", - "use a different architecture. To do this you just need to extend the `AbstractAutoencoder`, and\n", - "optionally the underlying `AbstractEncoder`, `AbstractDecoder` and `AbstractOuterBias`. See these\n", - "classes (which are fully documented) for more details." + "sweep_config" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "SparseAutoencoder(\n", - " (_pre_encoder_bias): TiedBias(position=pre_encoder)\n", - " (_encoder): LinearEncoder(\n", - " in_features=512, out_features=2048\n", - " (activation_function): ReLU()\n", - " )\n", - " (_decoder): UnitNormDecoder(in_features=2048, out_features=512)\n", - " (_post_decoder_bias): TiedBias(position=post_decoder)\n", - ")" + "Parameter(value=gelu-2l)" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "expansion_factor = hyperparameters[\"expansion_factor\"]\n", - "autoencoder = SparseAutoencoder(\n", - " n_input_features=autoencoder_input_dim, # size of the activations we are autoencoding\n", - " n_learned_features=int(autoencoder_input_dim * expansion_factor), # size of SAE\n", - ").to(device)\n", - "autoencoder" + "Parameter(\"gelu-2l\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We'll also want to setup an Optimizer and Loss function. In this case we'll also use the standard\n", - "approach from the original Anthropic paper. However you can create your own loss functions and\n", - "optimizers by extending `AbstractLoss` and `AbstractOptimizerWithReset` respectively." + "### Run the sweep" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Create sweep with ID: f5e0gllf\n", + "Sweep URL: https://wandb.ai/alan-cooney/sparse-autoencoder/sweeps/f5e0gllf\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: sjlk7s82 with config:\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \tactivation_resampler: {'dead_neuron_threshold': 0, 'max_resamples': 4, 'n_steps_collate': 100000000, 'resample_dataset_size': 819200, 'resample_interval': 200000000}\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \tautoencoder: {'expansion_factor': 4}\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \tloss: {'l1_coefficient': 0.0001}\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \toptimizer: {'adam_beta_1': 0.9, 'adam_beta_2': 0.99, 'adam_weight_decay': 0, 'amsgrad': False, 'fused': False, 'lr': 0.0001}\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \tpipeline: {'checkpoint_frequency': 100000000, 'log_frequency': 100, 'max_activations': 2000000000, 'max_store_size': 3145728, 'source_data_batch_size': 12, 'train_batch_size': 4096, 'validation_frequency': 314572800, 'validation_number_activations': 1024}\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \trandom_seed: 49\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \tsource_data: {'context_size': 128, 'dataset_path': 'NeelNanda/c4-code-tokenized-2b'}\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \tsource_model: {'dtype': 'float32', 'hook_dimension': 512, 'hook_layer': 0, 'hook_site': 'mlp_out', 'name': 'gelu-2l'}\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33malan-cooney\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Ignored wandb.init() arg project when running a sweep.\n" + ] + }, { "data": { + "text/html": [ + "Tracking run with wandb version 0.16.0" + ], "text/plain": [ - "LossReducer(\n", - " (0): LearnedActivationsL1Loss(l1_coefficient=0.0003)\n", - " (1): L2ReconstructionLoss()\n", - ")" + "" ] }, - "execution_count": 6, "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# We use a loss reducer, which simply adds up the losses from the underlying loss functions.\n", - "loss = LossReducer(\n", - " LearnedActivationsL1Loss(\n", - " l1_coefficient=float(hyperparameters[\"l1_coefficient\"]),\n", - " ),\n", - " L2ReconstructionLoss(),\n", - ")\n", - "loss" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ + "output_type": "display_data" + }, { "data": { + "text/html": [ + "Run data is saved locally in /Users/alan/Documents/Repos/sparse_autoencoder/docs/content/wandb/run-20231203_173216-sjlk7s82" + ], "text/plain": [ - "AdamWithReset (\n", - "Parameter Group 0\n", - " amsgrad: False\n", - " betas: (0.9, 0.999)\n", - " capturable: False\n", - " differentiable: False\n", - " eps: 1e-08\n", - " foreach: None\n", - " fused: None\n", - " lr: 0.0001\n", - " maximize: False\n", - " weight_decay: 0.0\n", - ")" + "" ] }, - "execution_count": 7, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run cerulean-sweep-1 to Weights & Biases (docs)
Sweep page: https://wandb.ai/alan-cooney/sparse-autoencoder/sweeps/f5e0gllf" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/alan-cooney/sparse-autoencoder" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View sweep at https://wandb.ai/alan-cooney/sparse-autoencoder/sweeps/f5e0gllf" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/alan-cooney/sparse-autoencoder/runs/sjlk7s82" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model gelu-2l into HookedTransformer\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2d8e963dce5741b391d0dae1f558a5cd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/28 [00:00cerulean-sweep-1 at: https://wandb.ai/alan-cooney/sparse-autoencoder/runs/sjlk7s82
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20231203_173216-sjlk7s82/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "optimizer = AdamWithReset(\n", - " params=autoencoder.parameters(),\n", - " named_parameters=autoencoder.named_parameters(),\n", - " lr=float(hyperparameters[\"lr\"]),\n", - " betas=(float(hyperparameters[\"adam_beta_1\"]), float(hyperparameters[\"adam_beta_2\"])),\n", - " eps=float(hyperparameters[\"adam_epsilon\"]),\n", - " weight_decay=float(hyperparameters[\"adam_weight_decay\"]),\n", - ")\n", - "optimizer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally we'll initialise an activation resampler." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "activation_resampler = ActivationResampler(\n", - " resample_interval=10_000, n_steps_collate=10_000, max_resamples=5\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Source dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This is just a dataset of tokenized prompts, to be used in generating activations (which are in turn\n", - "used to train the SAE)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "source_data = PreTokenizedDataset(\n", - " dataset_path=\"NeelNanda/c4-code-tokenized-2b\", context_size=int(hyperparameters[\"context_size\"])\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Training" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If you initialise [wandb](https://wandb.ai/site), the pipeline will automatically log all metrics to\n", - "wandb. However, we should pass in a dictionary with all of our hyperaparameters so they're on \n", - "wandb. \n", + "sweep(sweep_config=sweep_config)\n", "\n", - "We strongly encourage users to make use of wandb in order to understand the training process." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "checkpoint_path = Path(\"../../.checkpoints\")\n", - "checkpoint_path.mkdir(exist_ok=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "Path(\".cache/\").mkdir(exist_ok=True)\n", - "wandb.init(\n", - " project=\"sparse-autoencoder\",\n", - " dir=\".cache\",\n", - " config=hyperparameters,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pipeline = Pipeline(\n", - " activation_resampler=activation_resampler,\n", - " autoencoder=autoencoder,\n", - " cache_name=str(hyperparameters[\"source_model_hook_point\"]),\n", - " checkpoint_directory=checkpoint_path,\n", - " layer=int(hyperparameters[\"source_model_hook_point_layer\"]),\n", - " loss=loss,\n", - " optimizer=optimizer,\n", - " source_data_batch_size=6,\n", - " source_dataset=source_data,\n", - " source_model=src_model,\n", - ")\n", - "\n", - "pipeline.run_pipeline(\n", - " train_batch_size=int(hyperparameters[\"train_batch_size\"]),\n", - " max_store_size=int(hyperparameters[\"max_store_size\"]),\n", - " max_activations=int(hyperparameters[\"max_activations\"]),\n", - " checkpoint_frequency=int(hyperparameters[\"checkpoint_frequency\"]),\n", - " validate_frequency=int(hyperparameters[\"validation_frequency\"]),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ "wandb.finish()" ] } diff --git a/docs/content/flexible_demo.ipynb b/docs/content/flexible_demo.ipynb new file mode 100644 index 00000000..88e538a5 --- /dev/null +++ b/docs/content/flexible_demo.ipynb @@ -0,0 +1,518 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Flexible Training Demo\n", + "\n", + "This demo shows you how to train a sparse autoencoder (SAE) on a\n", + "[TransformerLens](https://github.com/neelnanda-io/TransformerLens) model. It replicates Neel Nanda's\n", + "[comment on the Anthropic dictionary learning\n", + "paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html#comment-nanda).\n", + "\n", + "## Introduction\n", + "\n", + "The way this library works is that we provide all the components necessary to train a sparse\n", + "autoencoder. For the most part, these are just standard PyTorch modules. For example `AdamWithReset` is\n", + "just an extension of `torch.optim.Adam`, with a few extra bells and whistles that are needed for training a SAE\n", + "(e.g. a method to reset the optimizer state when you are also resampling dead neurons).\n", + "\n", + "This is very flexible - it's easy for you to extend and change just one component if you want, just\n", + "like you'd do with a standard PyTorch mode. It also means it's very easy to see what is going on\n", + "under the hood. However to get you started, the following demo sets up a\n", + "default SAE that uses the implementation that Neel Nanda used in his comment above.\n", + "\n", + "### Approach\n", + "\n", + "The approach is pretty simple - we run a training pipeline that alternates between generating\n", + "activations from a *source model*, and training the *sparse autoencoder* model on these generated\n", + "activations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Check if we're in Colab\n", + "try:\n", + " import google.colab # noqa: F401 # type: ignore\n", + "\n", + " in_colab = True\n", + "except ImportError:\n", + " in_colab = False\n", + "\n", + "# Install if in Colab\n", + "if in_colab:\n", + " %pip install sparse_autoencoder transformer_lens transformers wandb\n", + "\n", + "# Otherwise enable hot reloading in dev mode\n", + "if not in_colab:\n", + " from IPython import get_ipython # type: ignore\n", + "\n", + " ip = get_ipython()\n", + " if ip is not None and ip.extension_manager is not None and not ip.extension_manager.loaded:\n", + " ip.extension_manager.load(\"autoreload\") # type: ignore\n", + " %autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: mps\n" + ] + } + ], + "source": [ + "import os\n", + "from pathlib import Path\n", + "\n", + "import torch\n", + "from transformer_lens import HookedTransformer\n", + "from transformer_lens.utils import get_device\n", + "\n", + "from sparse_autoencoder import (\n", + " ActivationResampler,\n", + " AdamWithReset,\n", + " L2ReconstructionLoss,\n", + " LearnedActivationsL1Loss,\n", + " LossReducer,\n", + " Pipeline,\n", + " PreTokenizedDataset,\n", + " SparseAutoencoder,\n", + ")\n", + "import wandb\n", + "\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "\n", + "device = get_device()\n", + "print(f\"Using device: {device}\") # You will need a GPU" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Hyperparameters" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The way this library works is that you can define your own hyper-parameters and then setup the\n", + "underlying components with them. This is extremely flexible, but to help you get started we've\n", + "included some common ones below along with some sensible defaults. You can also easily sweep through\n", + "multiple hyperparameters with `wandb.sweep`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "torch.random.manual_seed(49)\n", + "\n", + "hyperparameters = {\n", + " # Expansion factor is the number of features in the sparse representation, relative to the\n", + " # number of features in the original MLP layer. The original paper experimented with 1x to 256x,\n", + " # and we have found that 4x is a good starting point.\n", + " \"expansion_factor\": 4,\n", + " # L1 coefficient is the coefficient of the L1 regularization term (used to encourage sparsity).\n", + " \"l1_coefficient\": 3e-4,\n", + " # Adam parameters (set to the default ones here)\n", + " \"lr\": 1e-4,\n", + " \"adam_beta_1\": 0.9,\n", + " \"adam_beta_2\": 0.999,\n", + " \"adam_epsilon\": 1e-8,\n", + " \"adam_weight_decay\": 0.0,\n", + " # Batch sizes\n", + " \"train_batch_size\": 4096,\n", + " \"context_size\": 128,\n", + " # Source model hook point\n", + " \"source_model_name\": \"gelu-2l\",\n", + " \"source_model_dtype\": \"float32\",\n", + " \"source_model_hook_point\": \"blocks.0.hook_mlp_out\",\n", + " \"source_model_hook_point_layer\": 0,\n", + " # Train pipeline parameters\n", + " \"max_store_size\": 384 * 4096 * 2,\n", + " \"max_activations\": 2_000_000_000,\n", + " \"resample_frequency\": 122_880_000,\n", + " \"checkpoint_frequency\": 100_000_000,\n", + " \"validation_frequency\": 384 * 4096 * 2 * 100, # Every 100 generations\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Source Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The source model is just a [TransformerLens](https://github.com/neelnanda-io/TransformerLens) model\n", + "(see [here](https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", + "for a full list of supported models).\n", + "\n", + "In this example we're training a sparse autoencoder on the activations from the first MLP layer, so\n", + "we'll also get some details about that hook point." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model gelu-2l into HookedTransformer\n" + ] + }, + { + "data": { + "text/plain": [ + "'Source: gelu-2l, Hook: blocks.0.hook_mlp_out, Features: 512'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Source model setup with TransformerLens\n", + "src_model = HookedTransformer.from_pretrained(\n", + " str(hyperparameters[\"source_model_name\"]), dtype=str(hyperparameters[\"source_model_dtype\"])\n", + ")\n", + "\n", + "# Details about the activations we'll train the sparse autoencoder on\n", + "autoencoder_input_dim: int = src_model.cfg.d_model # type: ignore (TransformerLens typing is currently broken)\n", + "\n", + "f\"Source: {hyperparameters['source_model_name']}, \\\n", + " Hook: {hyperparameters['source_model_hook_point']}, \\\n", + " Features: {autoencoder_input_dim}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sparse Autoencoder" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can then setup the sparse autoencoder. The default model (`SparseAutoencoder`) is setup as per\n", + "the original Anthropic paper [Towards Monosemanticity: Decomposing Language Models With Dictionary\n", + "Learning ](https://transformer-circuits.pub/2023/monosemantic-features/index.html).\n", + "\n", + "However it's just a standard PyTorch model, so you can create your own model instead if you want to\n", + "use a different architecture. To do this you just need to extend the `AbstractAutoencoder`, and\n", + "optionally the underlying `AbstractEncoder`, `AbstractDecoder` and `AbstractOuterBias`. See these\n", + "classes (which are fully documented) for more details." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "SparseAutoencoder(\n", + " (_pre_encoder_bias): TiedBias(position=pre_encoder)\n", + " (_encoder): LinearEncoder(\n", + " in_features=512, out_features=2048\n", + " (activation_function): ReLU()\n", + " )\n", + " (_decoder): UnitNormDecoder(in_features=2048, out_features=512)\n", + " (_post_decoder_bias): TiedBias(position=post_decoder)\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "expansion_factor = hyperparameters[\"expansion_factor\"]\n", + "autoencoder = SparseAutoencoder(\n", + " n_input_features=autoencoder_input_dim, # size of the activations we are autoencoding\n", + " n_learned_features=int(autoencoder_input_dim * expansion_factor), # size of SAE\n", + ").to(device)\n", + "autoencoder" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll also want to setup an Optimizer and Loss function. In this case we'll also use the standard\n", + "approach from the original Anthropic paper. However you can create your own loss functions and\n", + "optimizers by extending `AbstractLoss` and `AbstractOptimizerWithReset` respectively." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LossReducer(\n", + " (0): LearnedActivationsL1Loss(l1_coefficient=0.0003)\n", + " (1): L2ReconstructionLoss()\n", + ")" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# We use a loss reducer, which simply adds up the losses from the underlying loss functions.\n", + "loss = LossReducer(\n", + " LearnedActivationsL1Loss(\n", + " l1_coefficient=float(hyperparameters[\"l1_coefficient\"]),\n", + " ),\n", + " L2ReconstructionLoss(),\n", + ")\n", + "loss" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AdamWithReset (\n", + "Parameter Group 0\n", + " amsgrad: False\n", + " betas: (0.9, 0.999)\n", + " capturable: False\n", + " differentiable: False\n", + " eps: 1e-08\n", + " foreach: None\n", + " fused: None\n", + " lr: 0.0001\n", + " maximize: False\n", + " weight_decay: 0.0\n", + ")" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "optimizer = AdamWithReset(\n", + " params=autoencoder.parameters(),\n", + " named_parameters=autoencoder.named_parameters(),\n", + " lr=float(hyperparameters[\"lr\"]),\n", + " betas=(float(hyperparameters[\"adam_beta_1\"]), float(hyperparameters[\"adam_beta_2\"])),\n", + " eps=float(hyperparameters[\"adam_epsilon\"]),\n", + " weight_decay=float(hyperparameters[\"adam_weight_decay\"]),\n", + ")\n", + "optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally we'll initialise an activation resampler." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "activation_resampler = ActivationResampler(\n", + " resample_interval=10_000, n_steps_collate=10_000, max_resamples=5\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Source dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is just a dataset of tokenized prompts, to be used in generating activations (which are in turn\n", + "used to train the SAE)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source_data = PreTokenizedDataset(\n", + " dataset_path=\"NeelNanda/c4-code-tokenized-2b\", context_size=int(hyperparameters[\"context_size\"])\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you initialise [wandb](https://wandb.ai/site), the pipeline will automatically log all metrics to\n", + "wandb. However, we should pass in a dictionary with all of our hyperaparameters so they're on \n", + "wandb. \n", + "\n", + "We strongly encourage users to make use of wandb in order to understand the training process." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "checkpoint_path = Path(\"../../.checkpoints\")\n", + "checkpoint_path.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Path(\".cache/\").mkdir(exist_ok=True)\n", + "wandb.init(\n", + " project=\"sparse-autoencoder\",\n", + " dir=\".cache\",\n", + " config=hyperparameters,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline = Pipeline(\n", + " activation_resampler=activation_resampler,\n", + " autoencoder=autoencoder,\n", + " cache_name=str(hyperparameters[\"source_model_hook_point\"]),\n", + " checkpoint_directory=checkpoint_path,\n", + " layer=int(hyperparameters[\"source_model_hook_point_layer\"]),\n", + " loss=loss,\n", + " optimizer=optimizer,\n", + " source_data_batch_size=6,\n", + " source_dataset=source_data,\n", + " source_model=src_model,\n", + ")\n", + "\n", + "pipeline.run_pipeline(\n", + " train_batch_size=int(hyperparameters[\"train_batch_size\"]),\n", + " max_store_size=int(hyperparameters[\"max_store_size\"]),\n", + " max_activations=int(hyperparameters[\"max_activations\"]),\n", + " checkpoint_frequency=int(hyperparameters[\"checkpoint_frequency\"]),\n", + " validate_frequency=int(hyperparameters[\"validation_frequency\"]),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wandb.finish()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + }, + "vscode": { + "interpreter": { + "hash": "31186ba1239ad81afeb3c631b4833e71f34259d3b92eebb37a9091b916e08620" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 63a98afc..31bf45a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,11 +149,15 @@ select=["ALL"] [tool.ruff.lint] + [tool.ruff.lint.flake8-tidy-imports] + ban-relative-imports="all" + [tool.ruff.lint.flake8-annotations] mypy-init-return=true [tool.ruff.lint.isort] force-sort-within-sections=true + known-third-party=["wandb"] lines-after-imports=2 [tool.ruff.lint.per-file-ignores] diff --git a/sparse_autoencoder/__init__.py b/sparse_autoencoder/__init__.py index 7e0b91ca..2766e940 100644 --- a/sparse_autoencoder/__init__.py +++ b/sparse_autoencoder/__init__.py @@ -14,22 +14,76 @@ from sparse_autoencoder.source_data.pretokenized_dataset import PreTokenizedDataset from sparse_autoencoder.source_data.text_dataset import TextDataset from sparse_autoencoder.train.pipeline import Pipeline +from sparse_autoencoder.train.sweep import ( + sweep, +) +from sparse_autoencoder.train.sweep_config import ( + ActivationResamplerHyperparameters, + AutoencoderHyperparameters, + Hyperparameters, + LossHyperparameters, + OptimizerHyperparameters, + PipelineHyperparameters, + SourceDataHyperparameters, + SourceModelHyperparameters, + SourceModelRuntimeHyperparameters, + SweepConfig, +) +from sparse_autoencoder.train.utils.wandb_sweep_types import ( + Controller, + ControllerType, + Distribution, + Goal, + HyperbandStopping, + HyperbandStoppingType, + Impute, + ImputeWhileRunning, + Kind, + Method, + Metric, + NestedParameter, + Parameter, +) __all__ = [ "ActivationResampler", + "ActivationResamplerHyperparameters", "AdamWithReset", + "AutoencoderHyperparameters", "CapacityMetric", + "Controller", + "ControllerType", "DiskActivationStore", + "Distribution", + "Goal", + "HyperbandStopping", + "HyperbandStoppingType", + "Hyperparameters", + "Impute", + "ImputeWhileRunning", + "Kind", "L2ReconstructionLoss", "LearnedActivationsL1Loss", "ListActivationStore", + "LossHyperparameters", "LossLogType", "LossReducer", "LossReductionType", + "Method", + "Metric", + "NestedParameter", + "OptimizerHyperparameters", + "Parameter", "Pipeline", + "PipelineHyperparameters", "PreTokenizedDataset", + "SourceDataHyperparameters", + "SourceModelHyperparameters", + "SourceModelRuntimeHyperparameters", "SparseAutoencoder", + "sweep", + "SweepConfig", "TensorActivationStore", "TextDataset", "TrainBatchFeatureDensityMetric", diff --git a/sparse_autoencoder/activation_resampler/activation_resampler.py b/sparse_autoencoder/activation_resampler/activation_resampler.py index 8b66926c..75b1aeee 100644 --- a/sparse_autoencoder/activation_resampler/activation_resampler.py +++ b/sparse_autoencoder/activation_resampler/activation_resampler.py @@ -457,3 +457,14 @@ def resample_dead_neurons( dead_encoder_bias_updates=dead_encoder_bias_updates, dead_decoder_weight_updates=dead_decoder_weight_updates, ) + + def __str__(self) -> str: + """Return a string representation of the activation resampler.""" + return ( + f"ActivationResampler(" + f"resample_interval={self.resample_interval}, " + f"max_resamples={self._max_resamples}, " + f"n_steps_collate={self.n_steps_collate}, " + f"resample_dataset_size={self._resample_dataset_size}, " + f"dead_neuron_threshold={self._dead_neuron_threshold})" + ) diff --git a/sparse_autoencoder/train/sweep.py b/sparse_autoencoder/train/sweep.py new file mode 100644 index 00000000..05cb5cda --- /dev/null +++ b/sparse_autoencoder/train/sweep.py @@ -0,0 +1,236 @@ +"""Sweep.""" +from pathlib import Path + +import torch +from transformer_lens import HookedTransformer +from transformer_lens.utils import get_act_name, get_device +import wandb + +from sparse_autoencoder import ( + ActivationResampler, + AdamWithReset, + L2ReconstructionLoss, + LearnedActivationsL1Loss, + LossReducer, + Pipeline, + PreTokenizedDataset, + SparseAutoencoder, +) +from sparse_autoencoder.train.sweep_config import ( + RuntimeHyperparameters, + SweepConfig, +) + + +def setup_activation_resampler(hyperparameters: RuntimeHyperparameters) -> ActivationResampler: + """Setup the activation resampler for the autoencoder. + + Args: + hyperparameters: The hyperparameters dictionary. + + Returns: + ActivationResampler: The initialized activation resampler. + """ + return ActivationResampler( + resample_interval=hyperparameters["activation_resampler"]["resample_interval"], + max_resamples=hyperparameters["activation_resampler"]["max_resamples"], + n_steps_collate=hyperparameters["activation_resampler"]["n_steps_collate"], + resample_dataset_size=hyperparameters["activation_resampler"]["resample_dataset_size"], + dead_neuron_threshold=hyperparameters["activation_resampler"]["dead_neuron_threshold"], + ) + + +def setup_source_model(hyperparameters: RuntimeHyperparameters) -> HookedTransformer: + """Setup the source model using HookedTransformer. + + Args: + hyperparameters: The hyperparameters dictionary. + + Returns: + The initialized source model. + """ + return HookedTransformer.from_pretrained( + hyperparameters["source_model"]["name"], + dtype=hyperparameters["source_model"]["dtype"], + ) + + +def setup_autoencoder( + hyperparameters: RuntimeHyperparameters, device: torch.device +) -> SparseAutoencoder: + """Setup the sparse autoencoder. + + Args: + hyperparameters: The hyperparameters dictionary. + device: The computation device. + + Returns: + The initialized sparse autoencoder. + """ + autoencoder_input_dim: int = hyperparameters["source_model"]["hook_dimension"] + expansion_factor = hyperparameters["autoencoder"]["expansion_factor"] + return SparseAutoencoder( + n_input_features=autoencoder_input_dim, + n_learned_features=autoencoder_input_dim * expansion_factor, + ).to(device) + + +def setup_loss_function(hyperparameters: RuntimeHyperparameters) -> LossReducer: + """Setup the loss function for the autoencoder. + + Args: + hyperparameters: The hyperparameters dictionary. + + Returns: + The combined loss function. + """ + return LossReducer( + LearnedActivationsL1Loss( + l1_coefficient=hyperparameters["loss"]["l1_coefficient"], + ), + L2ReconstructionLoss(), + ) + + +def setup_optimizer( + autoencoder: SparseAutoencoder, hyperparameters: RuntimeHyperparameters +) -> AdamWithReset: + """Setup the optimizer for the autoencoder. + + Args: + autoencoder: The sparse autoencoder model. + hyperparameters: The hyperparameters dictionary. + + Returns: + The initialized optimizer. + """ + return AdamWithReset( + params=autoencoder.parameters(), + named_parameters=autoencoder.named_parameters(), + lr=hyperparameters["optimizer"]["lr"], + betas=( + hyperparameters["optimizer"]["adam_beta_1"], + hyperparameters["optimizer"]["adam_beta_2"], + ), + weight_decay=hyperparameters["optimizer"]["adam_weight_decay"], + amsgrad=hyperparameters["optimizer"]["amsgrad"], + fused=hyperparameters["optimizer"]["fused"], + ) + + +def setup_source_data(hyperparameters: RuntimeHyperparameters) -> PreTokenizedDataset: + """Setup the source data for training. + + Args: + hyperparameters: The hyperparameters dictionary. + + Returns: + PreTokenizedDataset: The initialized source data. + """ + return PreTokenizedDataset( + dataset_path=hyperparameters["source_data"]["dataset_path"], + context_size=hyperparameters["source_data"]["context_size"], + ) + + +def setup_wandb() -> RuntimeHyperparameters: + """Initialise wandb for experiment tracking.""" + wandb.init(project="sparse-autoencoder") + return dict(wandb.config) # type: ignore + + +def run_training_pipeline( + hyperparameters: RuntimeHyperparameters, + source_model: HookedTransformer, + autoencoder: SparseAutoencoder, + loss: LossReducer, + optimizer: AdamWithReset, + activation_resampler: ActivationResampler, + source_data: PreTokenizedDataset, +) -> None: + """Run the training pipeline for the sparse autoencoder. + + Args: + hyperparameters: The hyperparameters dictionary. + source_model: The source model. + autoencoder: The sparse autoencoder. + loss: The loss function. + optimizer: The optimizer. + activation_resampler: The activation resampler. + source_data: The source data. + """ + checkpoint_path = Path("../../.checkpoints") + checkpoint_path.mkdir(exist_ok=True) + + random_seed = hyperparameters["random_seed"] + torch.random.manual_seed(random_seed) + + hook_point = get_act_name( + hyperparameters["source_model"]["hook_site"], hyperparameters["source_model"]["hook_layer"] + ) + pipeline = Pipeline( + activation_resampler=activation_resampler, + autoencoder=autoencoder, + cache_name=hook_point, + checkpoint_directory=checkpoint_path, + layer=hyperparameters["source_model"]["hook_layer"], + loss=loss, + optimizer=optimizer, + source_data_batch_size=hyperparameters["pipeline"]["source_data_batch_size"], + source_dataset=source_data, + source_model=source_model, + log_frequency=hyperparameters["pipeline"]["log_frequency"], + ) + + pipeline.run_pipeline( + train_batch_size=hyperparameters["pipeline"]["train_batch_size"], + max_store_size=hyperparameters["pipeline"]["max_store_size"], + max_activations=hyperparameters["pipeline"]["max_activations"], + checkpoint_frequency=hyperparameters["pipeline"]["checkpoint_frequency"], + validate_frequency=hyperparameters["pipeline"]["validation_frequency"], + validation_number_activations=hyperparameters["pipeline"]["validation_number_activations"], + ) + + +def sweep(sweep_config: SweepConfig) -> None: + """Main function to run the training pipeline with wandb hyperparameter sweep.""" + sweep_id = wandb.sweep(sweep_config.to_dict(), project="sparse-autoencoder") + + def train() -> None: + """Train the sparse autoencoder using the hyperparameters from the WandB sweep.""" + # Set up WandB + hyperparameters = setup_wandb() + + # Setup the device for training + device = get_device() + + # Set up the source model + source_model = setup_source_model(hyperparameters) + + # Set up the autoencoder + autoencoder = setup_autoencoder(hyperparameters, device) + + # Set up the loss function + loss_function = setup_loss_function(hyperparameters) + + # Set up the optimizer + optimizer = setup_optimizer(autoencoder, hyperparameters) + + # Set up the activation resampler + activation_resampler = setup_activation_resampler(hyperparameters) + + # Set up the source data + source_data = setup_source_data(hyperparameters) + + # Run the training pipeline + run_training_pipeline( + hyperparameters=hyperparameters, + source_model=source_model, + autoencoder=autoencoder, + loss=loss_function, + optimizer=optimizer, + activation_resampler=activation_resampler, + source_data=source_data, + ) + + wandb.agent(sweep_id, train) diff --git a/sparse_autoencoder/train/sweep_config.py b/sparse_autoencoder/train/sweep_config.py index 46f14922..45f8173b 100644 --- a/sparse_autoencoder/train/sweep_config.py +++ b/sparse_autoencoder/train/sweep_config.py @@ -1,104 +1,323 @@ -"""Sweep Config.""" -from dataclasses import asdict, dataclass, field -from typing import Any +"""Sweep config. + +Default hyperparameter setup for quick tuning of a sparse autoencoder. +""" +from dataclasses import dataclass, field +from typing import TypedDict, final from sparse_autoencoder.train.utils.wandb_sweep_types import ( Method, Metric, + NestedParameter, Parameter, Parameters, WandbSweepConfig, ) -# NOTE: This must be kept in sync with SweepParametersRuntime -@dataclass(frozen=True) -class SweepParameterConfig(Parameters): - """Sweep Parameter Config.""" +# Warning: The runtime hyperparameter classes must be manually kept in sync with the hyperparameter +# classes, so that static type checking works. - lr: Parameter[float] | None - """Adam Learning Rate.""" - adam_beta_1: Parameter[float] | None - """Adam Beta 1. +@dataclass +class ActivationResamplerHyperparameters(NestedParameter): + """Activation resampler hyperparameters.""" - The exponential decay rate for the first moment estimates (mean) of the gradient. + resample_interval: Parameter[int] = field(default_factory=lambda: Parameter(200_000_000)) + """Resample interval.""" + + max_resamples: Parameter[int] = field(default_factory=lambda: Parameter(4)) + """Maximum number of resamples.""" + + n_steps_collate: Parameter[int] = field(default_factory=lambda: Parameter(100_000_000)) + """Number of steps to collate before resampling. + + Number of autoencoder learned activation vectors to collate before resampling. """ - adam_beta_2: Parameter[float] | None - """Adam Beta 2. + resample_dataset_size: Parameter[int] = field(default_factory=lambda: Parameter(819_200)) + """Resample dataset size. - The exponential decay rate for the second moment estimates (variance) of the gradient. + Number of autoencoder input activations to use for calculating the loss, as part of the + resampling process to create the reset neuron weights. """ - adam_epsilon: Parameter[float] | None - """Adam Epsilon. + dead_neuron_threshold: Parameter[float] = field(default_factory=lambda: Parameter(0.0)) + """Dead neuron threshold. - A small constant for numerical stability. + Threshold for determining if a neuron is dead (has "fired" in less than this portion of the + collated sample). """ - adam_weight_decay: Parameter[float] | None - """Adam Weight Decay. - Weight decay (L2 penalty). +class ActivationResamplerRuntimeHyperparameters(TypedDict): + """Activation resampler runtime hyperparameters.""" + + resample_interval: int + max_resamples: int + n_steps_collate: int + resample_dataset_size: int + dead_neuron_threshold: float + + +@dataclass +class AutoencoderHyperparameters(NestedParameter): + """Sparse autoencoder hyperparameters.""" + + expansion_factor: Parameter[int] = field(default_factory=lambda: Parameter(4)) + """Expansion Factor. + + Size of the learned features relative to the input features. """ - l1_coefficient: Parameter[float] | None + +class AutoencoderRuntimeHyperparameters(TypedDict): + """Autoencoder runtime hyperparameters.""" + + expansion_factor: int + + +@dataclass +class LossHyperparameters(NestedParameter): + """Loss hyperparameters.""" + + l1_coefficient: Parameter[float] = field(default_factory=lambda: Parameter(1e-4)) """L1 Penalty Coefficient. The L1 penalty is the absolute sum of learned (hidden) activations, multiplied by this constant. The penalty encourages sparsity in the learned activations. This loss penalty can be reduced by using more features, or using a lower L1 coefficient. + """ + + +class LossRuntimeHyperparameters(TypedDict): + """Loss runtime hyperparameters.""" + + l1_coefficient: float + + +@dataclass +class OptimizerHyperparameters(NestedParameter): + """Optimizer hyperparameters.""" + + lr: Parameter[float] = field(default_factory=lambda: Parameter(values=[1e-3, 1e-4, 1e-5])) + """Learning rate.""" + + adam_beta_1: Parameter[float] = field(default_factory=lambda: Parameter(0.9)) + """Adam Beta 1. + + The exponential decay rate for the first moment estimates (mean) of the gradient. + """ + + adam_beta_2: Parameter[float] = field(default_factory=lambda: Parameter(0.99)) + """Adam Beta 2. + + The exponential decay rate for the second moment estimates (variance) of the gradient. + """ + + adam_weight_decay: Parameter[float] = field(default_factory=lambda: Parameter(0.0)) + """Adam Weight Decay. + + Weight decay (L2 penalty). + """ + + amsgrad: Parameter[bool] = field(default_factory=lambda: Parameter(value=False)) + """AMSGrad. + + Whether to use the AMSGrad variant of this algorithm from the paper [On the Convergence of Adam + and Beyond](https://arxiv.org/abs/1904.09237). + """ + + fused: Parameter[bool] = field(default_factory=lambda: Parameter(value=False)) + """Fused. - Default values from the [original - paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html). + Whether to use a fused implementation of the optimizer (may be faster on CUDA). """ - batch_size: Parameter[int] | None - """Batch size. - Used in SAE Forward Pass.""" +class OptimizerRuntimeHyperparameters(TypedDict): + """Optimizer runtime hyperparameters.""" + + lr: float + adam_beta_1: float + adam_beta_2: float + adam_weight_decay: float + amsgrad: bool + fused: bool + + +@dataclass +class SourceDataHyperparameters(NestedParameter): + """Source data hyperparameters.""" + + dataset_path: Parameter[str] + """Dataset path.""" + + context_size: Parameter[int] = field(default_factory=lambda: Parameter(128)) + """Context size.""" + + +class SourceDataRuntimeHyperparameters(TypedDict): + """Source data runtime hyperparameters.""" + + dataset_path: str + context_size: int + + +@dataclass +class SourceModelHyperparameters(NestedParameter): + """Source model hyperparameters.""" + + name: Parameter[str] + """Source model name.""" + + hook_site: Parameter[str] + """Source model hook site.""" + + hook_layer: Parameter[int] + """Source model hook point layer.""" + + hook_dimension: Parameter[int] + """Source model hook point dimension.""" + + dtype: Parameter[str] = field(default_factory=lambda: Parameter("float32")) + """Source model dtype.""" + + +class SourceModelRuntimeHyperparameters(TypedDict): + """Source model runtime hyperparameters.""" + + name: str + hook_site: str + hook_layer: int + hook_dimension: int + dtype: str + + +@dataclass +class PipelineHyperparameters(NestedParameter): + """Pipeline hyperparameters.""" + + log_frequency: Parameter[int] = field(default_factory=lambda: Parameter(100)) + """Training log frequency.""" + + source_data_batch_size: Parameter[int] = field(default_factory=lambda: Parameter(12)) + """Source data batch size.""" + + train_batch_size: Parameter[int] = field(default_factory=lambda: Parameter(4096)) + """Train batch size.""" + + max_store_size: Parameter[int] = field(default_factory=lambda: Parameter(384 * 4096 * 2)) + """Max store size.""" + + max_activations: Parameter[int] = field(default_factory=lambda: Parameter(2_000_000_000)) + """Max activations.""" + + checkpoint_frequency: Parameter[int] = field(default_factory=lambda: Parameter(100_000_000)) + """Checkpoint frequency.""" + + validation_frequency: Parameter[int] = field( + default_factory=lambda: Parameter(384 * 4096 * 2 * 100) + ) + """Validation frequency.""" + + validation_number_activations: Parameter[int] = field(default_factory=lambda: Parameter(1024)) + """Number of activations to use for validation.""" + + +class PipelineRuntimeHyperparameters(TypedDict): + """Pipeline runtime hyperparameters.""" + + log_frequency: int + source_data_batch_size: int + train_batch_size: int + max_store_size: int + max_activations: int + checkpoint_frequency: int + validation_frequency: int + validation_number_activations: int + + +@dataclass +class Hyperparameters(Parameters): + """Sweep Hyperparameters.""" + + # Required parameters + source_data: SourceDataHyperparameters + + source_model: SourceModelHyperparameters + + # Optional parameters + activation_resampler: ActivationResamplerHyperparameters = field( + default_factory=lambda: ActivationResamplerHyperparameters() + ) + autoencoder: AutoencoderHyperparameters = field( + default_factory=lambda: AutoencoderHyperparameters() + ) -# NOTE: This must be kept in sync with SweepParameterConfig -@dataclass(frozen=True) -class SweepParametersRuntime(dict[str, Any]): - """Sweep parameter runtime values.""" + loss: LossHyperparameters = field(default_factory=lambda: LossHyperparameters()) - lr: float = 0.001 + optimizer: OptimizerHyperparameters = field(default_factory=lambda: OptimizerHyperparameters()) - adam_beta_1: float = 0.9 + pipeline: PipelineHyperparameters = field(default_factory=lambda: PipelineHyperparameters()) - adam_beta_2: float = 0.999 + random_seed: Parameter[int] = field(default_factory=lambda: Parameter(49)) + """Random seed.""" - adam_epsilon: float = 1e-8 + def __post_init__(self) -> None: + """Post initialisation checks.""" + # Check the resample dataset size <= the store size (currently only works if value is used + # for both). + if ( + self.activation_resampler.resample_dataset_size.value is not None + and self.pipeline.max_store_size.value is not None + and self.activation_resampler.resample_dataset_size.value + >= int(self.pipeline.max_store_size.value) + ): + error_message = ( + "Resample dataset size must be less than or equal to the pipeline max store size" + ) + raise ValueError(error_message) - adam_weight_decay: float = 0.0 + @final + def __str__(self) -> str: + """String representation of this object.""" + items_representation = [] + for key, value in self.__dict__.items(): + if value is not None: + items_representation.append(f"{key}={value}") + joined_items = "\n ".join(items_representation) - l1_coefficient: float = 0.001 + class_name = self.__class__.__name__ - batch_size: int = 4096 + return f"{class_name}(\n {joined_items}\n)" - def to_dict(self) -> dict[str, Any]: - """Return dict representation of this object.""" - return asdict(self) + @final + def __repr__(self) -> str: + """Representation of this object.""" + return self.__str__() -@dataclass(frozen=True) +@dataclass class SweepConfig(WandbSweepConfig): """Sweep Config.""" - parameters: SweepParameterConfig + parameters: Hyperparameters - method: Method = Method.grid + method: Method = Method.RANDOM - metric: Metric = field(default_factory=lambda: Metric(name="loss")) + metric: Metric = field(default_factory=lambda: Metric(name="total_loss")) - def to_dict(self) -> dict[str, Any]: - """Return dict representation of this object.""" - dict_representation = asdict(self) - # Convert StrEnums to strings - dict_representation["method"] = dict_representation["method"].value +class RuntimeHyperparameters(TypedDict): + """Runtime hyperparameters.""" - return dict_representation + source_data: SourceDataRuntimeHyperparameters + source_model: SourceModelRuntimeHyperparameters + activation_resampler: ActivationResamplerRuntimeHyperparameters + autoencoder: AutoencoderRuntimeHyperparameters + loss: LossRuntimeHyperparameters + optimizer: OptimizerRuntimeHyperparameters + pipeline: PipelineRuntimeHyperparameters + random_seed: int diff --git a/sparse_autoencoder/train/tests/__snapshots__/test_sweep.ambr b/sparse_autoencoder/train/tests/__snapshots__/test_sweep.ambr new file mode 100644 index 00000000..81897947 --- /dev/null +++ b/sparse_autoencoder/train/tests/__snapshots__/test_sweep.ambr @@ -0,0 +1,25 @@ +# serializer version: 1 +# name: test_setup_activation_resampler + 'ActivationResampler(resample_interval=200000000, max_resamples=4, n_steps_collate=100000000, resample_dataset_size=819200, dead_neuron_threshold=0.0)' +# --- +# name: test_setup_autoencoder + ''' + SparseAutoencoder( + (_pre_encoder_bias): TiedBias(position=pre_encoder) + (_encoder): LinearEncoder( + in_features=512, out_features=2048 + (activation_function): ReLU() + ) + (_decoder): UnitNormDecoder(in_features=2048, out_features=512) + (_post_decoder_bias): TiedBias(position=post_decoder) + ) + ''' +# --- +# name: test_setup_loss_function + ''' + LossReducer( + (0): LearnedActivationsL1Loss(l1_coefficient=0.0001) + (1): L2ReconstructionLoss() + ) + ''' +# --- diff --git a/sparse_autoencoder/train/tests/test_sweep.py b/sparse_autoencoder/train/tests/test_sweep.py new file mode 100644 index 00000000..d16423f4 --- /dev/null +++ b/sparse_autoencoder/train/tests/test_sweep.py @@ -0,0 +1,83 @@ +"""Tests for sweep functionality.""" + +import pytest +from syrupy.session import SnapshotSession +import torch + +from sparse_autoencoder.train.sweep import ( + setup_activation_resampler, + setup_autoencoder, + setup_loss_function, +) +from sparse_autoencoder.train.sweep_config import ( + RuntimeHyperparameters, +) + + +@pytest.fixture() +def dummy_hyperparameters() -> RuntimeHyperparameters: + """Sweep config dummy fixture.""" + return { + "activation_resampler": { + "dead_neuron_threshold": 0.0, + "max_resamples": 4, + "n_steps_collate": 100_000_000, + "resample_dataset_size": 819_200, + "resample_interval": 200_000_000, + }, + "autoencoder": {"expansion_factor": 4}, + "loss": {"l1_coefficient": 0.0001}, + "optimizer": { + "adam_beta_1": 0.9, + "adam_beta_2": 0.99, + "adam_weight_decay": 0, + "amsgrad": False, + "fused": False, + "lr": 1e-05, + }, + "pipeline": { + "checkpoint_frequency": 100000000, + "log_frequency": 100, + "max_activations": 2000000000, + "max_store_size": 3145728, + "source_data_batch_size": 12, + "train_batch_size": 4096, + "validation_frequency": 314572800, + "validation_number_activations": 1024, + }, + "random_seed": 49, + "source_data": {"context_size": 128, "dataset_path": "NeelNanda/c4-code-tokenized-2b"}, + "source_model": { + "dtype": "float32", + "hook_dimension": 512, + "hook_layer": 0, + "hook_site": "mlp_out", + "name": "gelu-2l", + }, + } + + +def test_setup_activation_resampler( + dummy_hyperparameters: RuntimeHyperparameters, snapshot: SnapshotSession +) -> None: + """Test the setup_activation_resampler function.""" + activation_resampler = setup_activation_resampler(dummy_hyperparameters) + assert snapshot == str( + activation_resampler + ), "Activation resampler string representation has changed." + + +def test_setup_autoencoder( + dummy_hyperparameters: RuntimeHyperparameters, snapshot: SnapshotSession +) -> None: + """Test the setup_autoencoder function.""" + autoencoder = setup_autoencoder(dummy_hyperparameters, device=torch.device("cpu")) + assert snapshot == str(autoencoder), "Autoencoder string representation has changed." + + +def test_setup_loss_function( + dummy_hyperparameters: RuntimeHyperparameters, snapshot: SnapshotSession +) -> None: + """Test the setup_loss_function function.""" + loss_function = setup_loss_function(dummy_hyperparameters) + assert snapshot == str(loss_function), "Loss function string representation has changed." diff --git a/sparse_autoencoder/train/utils/tests/test_wandb_sweep_types.py b/sparse_autoencoder/train/utils/tests/test_wandb_sweep_types.py new file mode 100644 index 00000000..56620bc0 --- /dev/null +++ b/sparse_autoencoder/train/utils/tests/test_wandb_sweep_types.py @@ -0,0 +1,58 @@ +"""Test wandb sweep types.""" +from dataclasses import dataclass, field + +from sparse_autoencoder.train.utils.wandb_sweep_types import ( + Method, + Metric, + NestedParameter, + Parameter, + Parameters, + WandbSweepConfig, +) + + +class TestNestedParameter: + """NestedParameter tests.""" + + def test_to_dict(self) -> None: + """Test to_dict method.""" + + @dataclass + class DummyNestedParameter(NestedParameter): + nested_property: Parameter[float] = field(default_factory=lambda: Parameter(1.0)) + + dummy = DummyNestedParameter() + + # It should be in the nested "parameters" key. + assert dummy.to_dict() == {"parameters": {"nested_property": {"value": 1.0}}} + + +class TestWandbSweepConfig: + """WandbSweepConfig tests.""" + + def test_to_dict(self) -> None: + """Test to_dict method.""" + + @dataclass + class DummyNestedParameter(NestedParameter): + nested_property: Parameter[float] = field(default_factory=lambda: Parameter(1.0)) + + @dataclass + class DummyParameters(Parameters): + nested: DummyNestedParameter = field(default_factory=lambda: DummyNestedParameter()) + top_level: Parameter[float] = field(default_factory=lambda: Parameter(1.0)) + + dummy = WandbSweepConfig( + parameters=DummyParameters(), method=Method.GRID, metric=Metric(name="total_loss") + ) + + assert dummy.to_dict() == { + "method": "grid", + "metric": {"goal": "minimize", "name": "total_loss"}, + "parameters": { + "nested": { + "parameters": {"nested_property": {"value": 1.0}}, + }, + "top_level": {"value": 1.0}, + }, + } diff --git a/sparse_autoencoder/train/utils/wandb_sweep_types.py b/sparse_autoencoder/train/utils/wandb_sweep_types.py index aeb26c9a..8dc1d98c 100644 --- a/sparse_autoencoder/train/utils/wandb_sweep_types.py +++ b/sparse_autoencoder/train/utils/wandb_sweep_types.py @@ -2,38 +2,182 @@ Weights & Biases just provide a JSON Schema, so we've converted here to dataclasses. """ -# ruff: noqa -from dataclasses import dataclass -from enum import Enum -from typing import Any, Generic, TypeVar +from abc import ABC +from dataclasses import asdict, dataclass, is_dataclass +from enum import Enum, auto +from typing import Any, Generic, TypeVar, final +from strenum import LowercaseStrEnum -class ControllerType(Enum): + +class ControllerType(LowercaseStrEnum): """Controller Type.""" - cloud = "cloud" - local = "local" + CLOUD = auto() + """Weights & Biases cloud controller. + Utilizes Weights & Biases as the sweep controller, enabling launching of multiple nodes that all + communicate with the Weights & Biases cloud service to coordinate the sweep. + """ -@dataclass(frozen=True) -class Controller: - """Controller.""" + LOCAL = auto() + """Local controller. - type: ControllerType # noqa: A003 + Manages the sweep operation locally, without the need for cloud-based coordination or external + services. + """ -class HyperbandStoppingType(Enum): +class HyperbandStoppingType(LowercaseStrEnum): """Hyperband Stopping Type.""" - hyperband = "hyperband" + HYPERBAND = auto() + """Hyperband algorithm. + + Implements the Hyperband stopping algorithm, an adaptive resource allocation and early-stopping + method to efficiently tune hyperparameters. + """ + + +class Kind(LowercaseStrEnum): + """Kind.""" + + SWEEP = auto() + + +class Method(LowercaseStrEnum): + """Method.""" + + BAYES = auto() + """Bayesian optimization. + + Employs Bayesian optimization for hyperparameter tuning, a probabilistic model-based approach + for finding the optimal set of parameters. + """ + + CUSTOM = auto() + """Custom method. + + Allows for a user-defined custom method for hyperparameter tuning, providing flexibility in the + sweep process. + """ + + GRID = auto() + """Grid search. + + Utilizes a grid search approach for hyperparameter tuning, systematically working through + multiple combinations of parameter values. + """ + + RANDOM = auto() + """Random search. + + Implements a random search strategy for hyperparameter tuning, exploring the parameter space + randomly. + """ + + +class Goal(LowercaseStrEnum): + """Goal.""" + + MAXIMIZE = auto() + """Maximization goal. + + Sets the objective of the hyperparameter tuning process to maximize a specified metric. + """ + + MINIMIZE = auto() + """Minimization goal. + + Aims to minimize a specified metric during the hyperparameter tuning process. + """ + + +class Impute(LowercaseStrEnum): + """Metric value to use in bayes search for runs that fail, crash, or are killed.""" + + BEST = auto() + LATEST = auto() + WORST = auto() + + +class ImputeWhileRunning(LowercaseStrEnum): + """Appends a calculated metric even when epochs are in a running state.""" + + BEST = auto() + FALSE = auto() + LATEST = auto() + WORST = auto() -@dataclass(frozen=True) +class Distribution(LowercaseStrEnum): + """Sweep Distribution.""" + + BETA = auto() + """Beta distribution. + + Utilizes the Beta distribution, a family of continuous probability distributions defined on the + interval [0, 1], for parameter sampling. + """ + + CATEGORICAL = auto() + """Categorical distribution. + + Employs a categorical distribution for discrete variable sampling, where each category has an + equal probability of being selected. + """ + + CATEGORICAL_W_PROBABILITIES = auto() + """Categorical distribution with probabilities. + + Similar to categorical distribution but allows assigning different probabilities to each + category. + """ + + CONSTANT = auto() + """Constant distribution. + + Uses a constant value for the parameter, ensuring it remains the same across all runs. + """ + + INT_UNIFORM = auto() + """Integer uniform distribution. + + Samples integer values uniformly across a specified range. + """ + + INV_LOG_UNIFORM = auto() + """Inverse log-uniform distribution. + + Samples values according to an inverse log-uniform distribution, useful for parameters that span + several orders of magnitude. + """ + + INV_LOG_UNIFORM_VALUES = auto() + """Inverse log-uniform values distribution. + + Similar to the inverse log-uniform distribution but allows specifying exact values to be + sampled. + """ + + +@dataclass +class Controller: + """Controller.""" + + type: ControllerType # noqa: A003 + + +@dataclass class HyperbandStopping: """Hyperband Stopping Config. Speed up hyperparameter search by killing off runs that appear to have lower performance than successful training runs. + + Example: + >>> HyperbandStopping(type=HyperbandStoppingType.HYPERBAND) + HyperbandStopping(type=hyperband) """ type: HyperbandStoppingType # noqa: A003 @@ -41,8 +185,8 @@ class HyperbandStopping: eta: float | None = None """ETA. - At every eta^n steps, hyperband continues running the top 1/eta runs and stops all other - runs. + At every $\text{eta}^n$ steps, hyperband continues running the top $1/\text{eta}$ runs and stops + all other runs. """ maxiter: int | None = None @@ -65,54 +209,33 @@ class HyperbandStopping: strict: bool | None = None """Use a more aggressive condition for termination, stops more runs.""" + @final + def __str__(self) -> str: + """String representation of this object.""" + items_representation = [] + for key, value in self.__dict__.items(): + if value is not None: + items_representation.append(f"{key}={value}") + joined_items = ", ".join(items_representation) -class Kind(Enum): - """Kind.""" - - sweep = "sweep" - - -class Method(Enum): - """Method.""" - - bayes = "bayes" - custom = "custom" - grid = "grid" - random = "random" - - -class Goal(Enum): - """Goal.""" + class_name = self.__class__.__name__ - maximize = "maximize" - minimize = "minimize" - - -class Impute(Enum): - """Metric value to use in bayes search for runs that fail, crash, or are killed.""" + return f"{class_name}({joined_items})" - best = "best" - latest = "latest" - worst = "worst" + @final + def __repr__(self) -> str: + """Representation of this object.""" + return self.__str__() -class ImputeWhileRunning(Enum): - """Appends a calculated metric even when epochs are in a running state.""" - - best = "best" - false = "false" - latest = "latest" - worst = "worst" - - -@dataclass(frozen=True) +@dataclass class Metric: """Metric to optimize.""" name: str """Name of metric.""" - goal: Goal | None = None + goal: Goal | None = Goal.MINIMIZE impute: Impute | None = None """Metric value to use in bayes search for runs that fail, crash, or are killed""" @@ -123,54 +246,66 @@ class Metric: target: float | None = None """The sweep will finish once any run achieves this value.""" + @final + def __str__(self) -> str: + """String representation of this object.""" + items_representation = [] + for key, value in self.__dict__.items(): + if value is not None: + items_representation.append(f"{key}={value}") + joined_items = ", ".join(items_representation) -class Distribution(Enum): - """Sweep Distribution.""" + class_name = self.__class__.__name__ + + return f"{class_name}({joined_items})" - beta = "beta" - categorical = "categorical" - categoricalwprobabilities = "categorical_w_probabilities" - constant = "constant" - intuniform = "int_uniform" - invloguniform = "inv_log_uniform" - invloguniformvalues = "inv_log_uniform_values" - lognormal = "log_normal" - loguniform = "log_uniform" - loguniformvalues = "log_uniform_values" - normal = "normal" - qbeta = "q_beta" - qlognormal = "q_log_normal" - qloguniform = "q_log_uniform" - qloguniformvalues = "q_log_uniform_values" - qnormal = "q_normal" - quniform = "q_uniform" - uniform = "uniform" + @final + def __repr__(self) -> str: + """Representation of this object.""" + return self.__str__() ParamType = TypeVar("ParamType", float, int, str) -@dataclass(frozen=True) +@dataclass class Parameter(Generic[ParamType]): - """Sweep Parameter.""" + """Sweep Parameter. + + https://docs.wandb.ai/guides/sweeps/define-sweep-configuration#parameters + """ + + value: ParamType | None = None + """Single value. - value: ParamType | list[ParamType] + Specifies the single valid value for this hyperparameter. Compatible with grid. + """ max: ParamType | None = None # noqa: A003 + """Maximum value.""" min: ParamType | None = None # noqa: A003 - - a: float | None = None - - b: float | None = None + """Minimum value.""" distribution: Distribution | None = None + """Distribution + + If not specified, will default to categorical if values is set, to int_uniform if max and min + are set to integers, to uniform if max and min are set to floats, or to constant if value is + set. + """ q: float | None = None - """Quantization parameter for quantized distributions""" + """Quantization parameter. + + Quantization step size for quantized hyperparameters. + """ values: list[ParamType] | None = None - """Discrete values""" + """Discrete values. + + Specifies all valid values for this hyperparameter. Compatible with grid. + """ probabilities: list[float] | None = None """Probability of each value""" @@ -181,25 +316,92 @@ class Parameter(Generic[ParamType]): sigma: float | None = None """Std Dev for normal or lognormal distributions""" - parameters: dict[str, "Parameter[ParamType]"] | None = None + @final + def __str__(self) -> str: + """String representation of this object.""" + items_representation = [] + for key, value in self.__dict__.items(): + if value is not None: + items_representation.append(f"{key}={value}") + joined_items = ", ".join(items_representation) + + class_name = self.__class__.__name__ + + return f"{class_name}({joined_items})" + + @final + def __repr__(self) -> str: + """Representation of this object.""" + return self.__str__() + + +@dataclass +class NestedParameter(ABC): # noqa: B024 (abstract so that we can check against it's type) + """Nested Parameter. + + Example: + >>> from dataclasses import field + >>> @dataclass + ... class MyNestedParameter(NestedParameter): + ... a: int = field(default_factory=lambda: Parameter(1)) + ... b: int = field(default_factory=lambda: Parameter(2)) + >>> MyNestedParameter().to_dict() + {'parameters': {'a': {'value': 1}, 'b': {'value': 2}}} + """ + + def to_dict(self) -> dict[str, Any]: + """Return dict representation of this object.""" + + def dict_without_none_values(obj: Any) -> dict: # noqa: ANN401 + """Return dict without None values. + Args: + obj: The object to convert to a dict. -Parameters = dict[str, Parameter[Any]] + Returns: + The dict representation of the object. + """ + dict_none_removed = {} + dict_with_none = dict(obj) + for key, value in dict_with_none.items(): + if value is not None: + dict_none_removed[key] = value + return dict_none_removed + return {"parameters": asdict(self, dict_factory=dict_without_none_values)} -@dataclass(frozen=True) + def __dict__(self) -> dict[str, Any]: # type: ignore[override] + """Return dict representation of this object.""" + return self.to_dict() + + +@dataclass +class Parameters: + """Parameters""" + + +@dataclass class WandbSweepConfig: - """Weights & Biases Sweep Configuration.""" + """Weights & Biases Sweep Configuration. + + Example: + >>> config = WandbSweepConfig( + ... parameters={"lr": Parameter(value=1e-3)}, + ... method=Method.BAYES, + ... metric=Metric(name="loss"), + ... ) + >>> print(config.to_dict()["parameters"]) + {'lr': {'value': 0.001}} + """ parameters: Parameters | Any method: Method + """Method (search strategy).""" metric: Metric """Metric to optimize""" - apiVersion: str | None = None - command: list[Any] | None = None """Command used to launch the training script""" @@ -235,3 +437,58 @@ class WandbSweepConfig: Sweep will run no more than this number of runs, across any number of agents. """ + + def to_dict(self) -> dict[str, Any]: + """Return dict representation of this object. + + Recursively removes all None values. Handles special cases of dataclass + instances and values that are `NestedParameter` instances. + + Returns: + dict[str, Any]: The dict representation of the object. + """ + + def recursive_format(obj: Any) -> Any: # noqa: ANN401 + """Recursively format the dict of hyperparameters.""" + # Handle dataclasses + if is_dataclass(obj): + cleaned_obj = {} + for parameter_name in asdict(obj): + value = getattr(obj, parameter_name) + + # Remove None values. + if value is None: + continue + + # Nested parameters have their own `to_dict` method, which we can call. + if isinstance(value, NestedParameter): + cleaned_obj[parameter_name] = value.to_dict() + # Otherwise recurse. + else: + cleaned_obj[parameter_name] = recursive_format(value) + return cleaned_obj + + # Handle dicts + if isinstance(obj, dict): + cleaned_obj = {} + for key, value in obj.items(): + # Remove None values. + if value is None: + continue + + # Otherwise recurse. + cleaned_obj[key] = recursive_format(value) + return cleaned_obj + + # Handle enums + if isinstance(obj, Enum): + return obj.value + + # Handle other types (e.g. float, int, str) + return obj + + return recursive_format(self) + + def __dict__(self) -> dict[str, Any]: # type: ignore[override] + """Return dict representation of this object.""" + return self.to_dict()