Skip to content

Commit

Permalink
Add join run cli command (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Jan 6, 2024
1 parent b8cc0cc commit 6b9c2a5
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 513 deletions.
1 change: 0 additions & 1 deletion docs/content/SUMMARY.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

* [Home](index.md)
* [Demo](demo.ipynb)
* [Flexible demo](flexible_demo.ipynb)
* [Source dataset pre-processing](pre-process-datasets.ipynb)
* [Reference](reference/)
* [Contributing](contributing.md)
Expand Down
157 changes: 88 additions & 69 deletions docs/content/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quick Start Training Demo\n",
"# Training Demo\n",
"\n",
"This is a quick start demo to get training a SAE right away. All you need to do is choose a few\n",
"hyperparameters (like the model to train on), and then set it off. By default it trains SAEs on all\n",
"MLP layers from GPT2 small."
"hyperparameters (like the model to train on), and then set it off.\n",
"\n",
"In this demo we'll train a sparse autoencoder on all MLP layer outputs in GPT-2 small (effectively\n",
"training an SAE on each layer in parallel)."
]
},
{
Expand Down Expand Up @@ -68,6 +70,7 @@
"\n",
"from sparse_autoencoder import (\n",
" ActivationResamplerHyperparameters,\n",
" AutoencoderHyperparameters,\n",
" Hyperparameters,\n",
" LossHyperparameters,\n",
" Method,\n",
Expand All @@ -76,12 +79,10 @@
" PipelineHyperparameters,\n",
" SourceDataHyperparameters,\n",
" SourceModelHyperparameters,\n",
" sweep,\n",
" SweepConfig,\n",
" sweep,\n",
")\n",
"\n",
"\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"demo.ipynb\""
]
},
Expand All @@ -97,76 +98,73 @@
"metadata": {},
"source": [
"Customize any hyperparameters you want below (by default we're sweeping over l1 coefficient and\n",
"learning rate):"
"learning rate).\n",
"\n",
"Note we are using the RANDOM sweep approach (try random combinations of hyperparameters), which\n",
"works surprisingly well but will need to be stopped at some point (as otherwise it will continue\n",
"forever). If you want to run pre-defined runs consider using `Parameter(values=[0.01, 0.05...])` for\n",
"example rather than `Parameter(max=0.03, min=0.008)` for each parameter you are sweeping over. You\n",
"can then set the strategy to `Method.GRID`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SweepConfig(parameters=Hyperparameters(\n",
" source_data=SourceDataHyperparameters(dataset_path=Parameter(value=alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2), context_size=Parameter(value=128), dataset_column_name=Parameter(value=input_ids), dataset_dir=None, dataset_files=None, pre_download=Parameter(value=False), pre_tokenized=Parameter(value=True), tokenizer_name=None)\n",
" source_model=SourceModelHyperparameters(name=Parameter(value=gpt2-small), cache_names=Parameter(value=['blocks.0.hook_mlp_out', 'blocks.1.hook_mlp_out', 'blocks.2.hook_mlp_out', 'blocks.3.hook_mlp_out', 'blocks.4.hook_mlp_out', 'blocks.5.hook_mlp_out', 'blocks.6.hook_mlp_out', 'blocks.7.hook_mlp_out', 'blocks.8.hook_mlp_out', 'blocks.9.hook_mlp_out', 'blocks.10.hook_mlp_out', 'blocks.11.hook_mlp_out']), hook_dimension=Parameter(value=768), dtype=Parameter(value=float32))\n",
" activation_resampler=ActivationResamplerHyperparameters(resample_interval=Parameter(value=200000000), max_n_resamples=Parameter(value=4), n_activations_activity_collate=Parameter(value=100000000), resample_dataset_size=Parameter(value=200000), threshold_is_dead_portion_fires=Parameter(value=1e-06))\n",
" autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(value=2))\n",
" loss=LossHyperparameters(l1_coefficient=Parameter(max=0.01, min=0.004))\n",
" optimizer=OptimizerHyperparameters(lr=Parameter(max=0.001, min=1e-05), adam_beta_1=Parameter(value=0.9), adam_beta_2=Parameter(value=0.99), adam_weight_decay=Parameter(value=0.0), amsgrad=Parameter(value=False), fused=Parameter(value=False))\n",
" pipeline=PipelineHyperparameters(log_frequency=Parameter(value=100), source_data_batch_size=Parameter(value=16), train_batch_size=Parameter(value=1024), max_store_size=Parameter(value=300000), max_activations=Parameter(value=1000000000), checkpoint_frequency=Parameter(value=100000000), validation_frequency=Parameter(value=100000000), validation_n_activations=Parameter(value=8192))\n",
" random_seed=Parameter(value=49)\n",
"), method=<Method.RANDOM: 'random'>, metric=Metric(name=train/loss/total_loss, goal=minimize), command=None, controller=None, description=None, earlyterminate=None, entity=None, imageuri=None, job=None, kind=None, name=None, program=None, project=None)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"n_layers_gpt2_small = 12\n",
"def train_gpt_small_mlp_layers(\n",
" expansion_factor: int = 4,\n",
" n_layers: int = 12,\n",
") -> None:\n",
" \"\"\"Run a new sweep experiment on GPT 2 Small's MLP layers.\n",
"\n",
"sweep_config = SweepConfig(\n",
" parameters=Hyperparameters(\n",
" activation_resampler=ActivationResamplerHyperparameters(\n",
" resample_interval=Parameter(200_000_000),\n",
" n_activations_activity_collate=Parameter(100_000_000),\n",
" threshold_is_dead_portion_fires=Parameter(1e-6),\n",
" max_n_resamples=Parameter(4),\n",
" resample_dataset_size=Parameter(200_000),\n",
" ),\n",
" loss=LossHyperparameters(\n",
" l1_coefficient=Parameter(max=1e-2, min=4e-3),\n",
" ),\n",
" optimizer=OptimizerHyperparameters(\n",
" lr=Parameter(max=1e-3, min=1e-5),\n",
" ),\n",
" source_model=SourceModelHyperparameters(\n",
" name=Parameter(\"gpt2-small\"),\n",
" # Train in parallel on all MLP layers\n",
" cache_names=Parameter(\n",
" [f\"blocks.{layer}.hook_mlp_out\" for layer in range(n_layers_gpt2_small)]\n",
" Args:\n",
" expansion_factor: Expansion factor for the autoencoder.\n",
" n_layers: Number of layers to train on. Max is 12.\n",
"\n",
" \"\"\"\n",
" sweep_config = SweepConfig(\n",
" parameters=Hyperparameters(\n",
" loss=LossHyperparameters(\n",
" l1_coefficient=Parameter(max=0.03, min=0.008),\n",
" ),\n",
" optimizer=OptimizerHyperparameters(\n",
" lr=Parameter(max=0.001, min=0.00001),\n",
" ),\n",
" source_model=SourceModelHyperparameters(\n",
" name=Parameter(\"gpt2\"),\n",
" cache_names=Parameter(\n",
" [f\"blocks.{layer}.hook_mlp_out\" for layer in range(n_layers)]\n",
" ),\n",
" hook_dimension=Parameter(768),\n",
" ),\n",
" source_data=SourceDataHyperparameters(\n",
" dataset_path=Parameter(\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"),\n",
" context_size=Parameter(256),\n",
" pre_tokenized=Parameter(value=True),\n",
" pre_download=Parameter(value=False), # Default to streaming the dataset\n",
" ),\n",
" autoencoder=AutoencoderHyperparameters(\n",
" expansion_factor=Parameter(value=expansion_factor)\n",
" ),\n",
" pipeline=PipelineHyperparameters(\n",
" max_activations=Parameter(1_000_000_000),\n",
" checkpoint_frequency=Parameter(100_000_000),\n",
" validation_frequency=Parameter(100_000_000),\n",
" max_store_size=Parameter(1_000_000),\n",
" ),\n",
" activation_resampler=ActivationResamplerHyperparameters(\n",
" resample_interval=Parameter(200_000_000),\n",
" n_activations_activity_collate=Parameter(100_000_000),\n",
" threshold_is_dead_portion_fires=Parameter(1e-6),\n",
" max_n_resamples=Parameter(4),\n",
" ),\n",
" hook_dimension=Parameter(768),\n",
" ),\n",
" source_data=SourceDataHyperparameters(\n",
" dataset_path=Parameter(\"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"),\n",
" context_size=Parameter(128),\n",
" pre_tokenized=Parameter(value=True),\n",
" ),\n",
" pipeline=PipelineHyperparameters(\n",
" max_activations=Parameter(1_000_000_000),\n",
" checkpoint_frequency=Parameter(100_000_000),\n",
" validation_frequency=Parameter(100_000_000),\n",
" train_batch_size=Parameter(1024),\n",
" max_store_size=Parameter(300_000),\n",
" ),\n",
" ),\n",
" method=Method.RANDOM,\n",
")\n",
"sweep_config"
" method=Method.RANDOM,\n",
" )\n",
"\n",
" sweep(sweep_config=sweep_config)"
]
},
{
Expand All @@ -176,13 +174,34 @@
"### Run the sweep"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This will start a sweep with just one agent (the current machine). If you have multiple GPUs, it\n",
"will use them automatically. Similarly it will work on Apple silicon devices by automatically using MPS."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"sweep(sweep_config=sweep_config)"
"train_gpt_small_mlp_layers()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Want to speed things up? You can trivially add extra machines to the sweep, each of which will peel\n",
"of some runs from the sweep agent (stored on Wandb). To do this, on another machine simply run:\n",
"\n",
"```bash\n",
"pip install sparse_autoencoder\n",
"join-sae-sweep --id=SWEEP_ID_SHOWN_ON_WANDB\n",
"```"
]
}
],
Expand Down
Loading

0 comments on commit 6b9c2a5

Please sign in to comment.