Skip to content

Commit

Permalink
Simplify logging (#65)
Browse files Browse the repository at this point in the history
Just use one main tqdm loop as nested loops aren't well supported with VSCode.
  • Loading branch information
alan-cooney authored Nov 13, 2023
1 parent 6a69e67 commit 7861f32
Showing 3 changed files with 53 additions and 72 deletions.
13 changes: 4 additions & 9 deletions sparse_autoencoder/train/generate_activations.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,6 @@
from jaxtyping import Int
import torch
from torch import Tensor
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer

from sparse_autoencoder.activation_store.base_store import (
@@ -57,6 +56,9 @@ def generate_activations(
than strict limit.
device: Device to run the model on.
"""
# Set model to evaluation (inference) mode
model.eval()

if isinstance(device, torch.device):
model.to(device, print_details=False)

@@ -70,17 +72,10 @@ def generate_activations(
total: int = num_items - num_items % activations_per_batch

# Loop through the dataloader until the store reaches the desired size
with torch.no_grad(), tqdm(
desc="Generate Activations",
total=total - total % activations_per_batch,
colour="green",
leave=False,
dynamic_ncols=True,
) as progress_bar:
with torch.no_grad():
for batch in source_data:
if len(store) + activations_per_batch > total:
break

input_ids: Int[Tensor, "batch pos"] = batch["input_ids"].to(device)
model.forward(input_ids, stop_at_layer=layer + 1) # type: ignore (TLens is typed incorrectly)
progress_bar.update(activations_per_batch)
18 changes: 8 additions & 10 deletions sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
@@ -114,23 +114,23 @@ def pipeline( # noqa: PLR0913
total_steps: int = 0
activations_since_resampling: int = 0
neuron_activity: Int[Tensor, " learned_features"] = torch.zeros(
autoencoder.n_learned_features, dtype=torch.int32, device=device
autoencoder.n_learned_features,
dtype=torch.int32,
device=device,
)
total_activations: int = 0
generate_train_iterations: int = 0

# Run loop until source data is exhausted:
with logging_redirect_tqdm(), tqdm(
desc="Total activations trained on",
dynamic_ncols=True,
colour="blue",
total=max_activations,
postfix={"Generate/train iterations": 0},
postfix={"Current mode": "initializing"},
) as progress_bar:
while total_activations < max_activations:
activation_store.empty() # In case it was filled by a different run

# Add activations to the store
activation_store.empty() # In case it was filled by a different run
progress_bar.set_postfix({"Current mode": "generating"})
generate_activations(
src_model,
src_model_activation_layer,
@@ -150,6 +150,7 @@ def pipeline( # noqa: PLR0913
activation_store.shuffle()

# Train the autoencoder
progress_bar.set_postfix({"Current mode": "training"})
train_steps, learned_activations_fired_count = train_autoencoder(
activation_store=activation_store,
autoencoder=autoencoder,
@@ -169,6 +170,7 @@ def pipeline( # noqa: PLR0913

# Resample neurons if required
if activations_since_resampling >= resample_frequency:
progress_bar.set_postfix({"Current mode": "resampling"})
activations_since_resampling = 0
resample_dead_neurons(
neuron_activity=neuron_activity,
@@ -180,7 +182,3 @@ def pipeline( # noqa: PLR0913
optimizer.reset_state_all_parameters()

activation_store.empty()

progress_bar.update(1)
generate_train_iterations += 1
progress_bar.set_postfix({"Generate/train iterations": generate_train_iterations})
94 changes: 41 additions & 53 deletions sparse_autoencoder/train/train_autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Training Pipeline."""
from jaxtyping import Float, Int
import torch
from torch import Tensor, device, set_grad_enabled
from torch import Tensor, device
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import wandb

from sparse_autoencoder.activation_store.base_store import ActivationStore
@@ -46,63 +45,52 @@ def train_autoencoder(
batch_size=sweep_parameters.batch_size,
)

n_dataset_items: int = len(activation_store)
batch_size: int = sweep_parameters.batch_size

learned_activations_fired_count: Int[Tensor, " learned_feature"] = torch.zeros(
autoencoder.n_learned_features, dtype=torch.int32, device=device
)

step = 0
with set_grad_enabled(True), tqdm( # noqa: FBT003
desc="Train Autoencoder",
total=n_dataset_items,
colour="green",
leave=False,
dynamic_ncols=True,
) as progress_bar:
for step, batch in enumerate(activations_dataloader):
# Zero the gradients
optimizer.zero_grad()

# Move the batch to the device (in place)
batch = batch.to(device) # noqa: PLW2901

# Forward pass
learned_activations, reconstructed_activations = autoencoder(batch)

# Get metrics
reconstruction_loss_mse: Float[Tensor, " item"] = reconstruction_loss(
batch,
reconstructed_activations,
)
l1_loss_learned_activations: Float[Tensor, " item"] = l1_loss(learned_activations)
total_loss: Float[Tensor, " item"] = sae_training_loss(
reconstruction_loss_mse,
l1_loss_learned_activations,
sweep_parameters.l1_coefficient,
)

# Store count of how many neurons have fired
step: int = 0 # Initialize step
for step, store_batch in enumerate(activations_dataloader):
# Zero the gradients
optimizer.zero_grad()

# Move the batch to the device (in place)
batch = store_batch.detach().to(device)

# Forward pass
learned_activations, reconstructed_activations = autoencoder(batch)

# Get metrics
reconstruction_loss_mse: Float[Tensor, " item"] = reconstruction_loss(
batch,
reconstructed_activations,
)
l1_loss_learned_activations: Float[Tensor, " item"] = l1_loss(learned_activations)
total_loss: Float[Tensor, " item"] = sae_training_loss(
reconstruction_loss_mse,
l1_loss_learned_activations,
sweep_parameters.l1_coefficient,
)

# Store count of how many neurons have fired
with torch.no_grad():
fired = learned_activations > 0
learned_activations_fired_count.add_(fired.sum(dim=0))

# Backwards pass
total_loss.mean().backward()

optimizer.step()

# Log
if step % log_interval == 0 and wandb.run is not None:
wandb.log(
{
"reconstruction_loss": reconstruction_loss_mse.mean().item(),
"l1_loss": l1_loss_learned_activations.mean().item(),
"loss": total_loss.mean().item(),
},
)
# Backwards pass
total_loss.mean().backward()
optimizer.step()

# Log
if step % log_interval == 0 and wandb.run is not None:
wandb.log(
{
"reconstruction_loss": reconstruction_loss_mse.mean().item(),
"l1_loss": l1_loss_learned_activations.mean().item(),
"loss": total_loss.mean().item(),
},
)

progress_bar.update(batch_size)
current_step = previous_steps + step + 1

current_step = previous_steps + step + 1
return current_step, learned_activations_fired_count
return current_step, learned_activations_fired_count

0 comments on commit 7861f32

Please sign in to comment.