Skip to content

Commit

Permalink
feat: scale 1st order of Adam in miner optimiser on checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
distributedstatemachine committed Dec 13, 2024
1 parent e2ac251 commit 926e831
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 35 deletions.
3 changes: 2 additions & 1 deletion hparams.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@
"validator_weights_temperature": 10,
"validator_window_eval_size": 3,
"validator_sample_rate": 0.01,
"validator_non_submission_decay": 0.9
"validator_non_submission_decay": 0.9,
"validator_learning_rate_scale_factor": 0.1
}
6 changes: 5 additions & 1 deletion neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ def __init__(self):
self.global_step = asyncio.run(
self.checkpoint_manager.load_from_highest_stake(
metagraph=self.metagraph,
buckets=self.buckets
buckets=self.buckets,
optimizer=self.optimizer,
scheduler=self.scheduler,
is_validator=False,
hparams=self.hparams
)
)

Expand Down
27 changes: 23 additions & 4 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(self):

self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.hparams.learning_rate,
lr=self.hparams.learning_rate*self.validator_learning_rate_scale_factor,
betas=(self.hparams.optimizer_beta1, self.hparams.optimizer_beta2),
weight_decay=self.hparams.optimizer_weight_decay,
foreach=True
Expand All @@ -205,7 +205,11 @@ def __init__(self):
self.global_step = asyncio.run(
self.checkpoint_manager.load_from_highest_stake(
metagraph=self.metagraph,
buckets=self.buckets
buckets=self.buckets,
optimizer=self.optimizer,
scheduler=self.scheduler,
is_validator=True, # Indicate validator
hparams=self.hparams
)
)

Expand Down Expand Up @@ -545,8 +549,23 @@ async def run(self):
else:
tplr.logger.warning("No valid parameter tensors found - setting score to 0.0")
score = 0.0

tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Computed score: [bold dark_sea_green]{score:.4f}[/bold dark_sea_green]")
self.optimizer.zero_grad()
# # Compute the score for this slice.
# st = tplr.T()
# score = 0.0
# for i, (name_i, param_i) in enumerate( self.model.named_parameters() ):
# if param_i.grad is None:
# continue # Skip parameters without gradients
# idxs_i = indices[name_i].to(self.model.device)
# grad_i = param_i.grad.view(-1).clone()[idxs_i].to(self.model.device)
# slice_i = eval_slice_data[name_i].view(-1).to(self.model.device)
# theta_i = param_i.data.view(-1)[idxs_i]
# delta_i = theta_i - slice_i
# sim_i = torch.nn.functional.cosine_similarity(delta_i, grad_i, dim=0).item()
# weight_i = param_i.data.view(-1)[idxs_i].norm().item() + 1e-8
# score += weight_i * sim_i
tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Computed score: [bold dark_sea_green]{score:.4f}[/bold dark_sea_green]")
# tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Computed score: [bold dark_sea_green]{score:.4f}[/bold dark_sea_green]")


# Assign and log scores.
Expand Down
81 changes: 52 additions & 29 deletions src/templar/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,45 +56,62 @@ async def save_checkpoint(
async def load_checkpoint(
filename: str,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
scheduler: torch.optim.lr_scheduler._LRScheduler = None,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
device: str = "cpu",
):
is_validator: bool = False,
hparams=None,
) -> int:
"""
Loads the checkpoint from the specified filename asynchronously.
Uses asyncio.to_thread to avoid blocking the main event loop.
Adjusts optimizer and scheduler for miners.
"""
try:
logger.info(f"Loading checkpoint from {filename}")
checkpoint = await asyncio.to_thread(
torch.load, filename, map_location=device, weights_only=True
torch.load, filename, map_location=device
)
logger.info("Loading model state dict")

# Load the model state
model.load_state_dict(checkpoint["model_state_dict"])
if optimizer and "optimizer_state_dict" in checkpoint:
logger.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if scheduler and "scheduler_state_dict" in checkpoint:
logger.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
global_step = checkpoint.get("global_step", 0)
logger.info(f"Loaded checkpoint at global step {global_step}")
additional_state = {
k: checkpoint[k]
for k in checkpoint
if k
not in [
"global_step",
"model_state_dict",
"optimizer_state_dict",
"scheduler_state_dict",
]
}
logger.info("Successfully loaded checkpoint")
return global_step, additional_state
logger.info(f"Loaded model state at global step {global_step}")

# Load optimizer state
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

# Adjust optimizer state if miner
if not is_validator:
# Retrieve validator's learning rate from optimizer state
validator_lr = optimizer.param_groups[0]['lr']
miner_lr = hparams.learning_rate # Miner's learning rate

# Compute scaling factor
scaling_factor = validator_lr / miner_lr

# Scale optimizer's internal states
for state in optimizer.state.values():
if "exp_avg" in state:
state["exp_avg"].mul_(scaling_factor)
if "exp_avg_sq" in state:
# Optionally adjust exp_avg_sq if needed
pass

# Update optimizer's learning rate to miner's learning rate
for param_group in optimizer.param_groups:
param_group['lr'] = miner_lr

logger.info("Adjusted optimizer states for miner.")

else:
logger.info("Loaded optimizer states for validator.")


return global_step

except Exception as e:
logger.error(f"Failed to load checkpoint from {filename}: {e}")
return 0, {}
return 0


async def download_checkpoint_from_neuron(
Expand Down Expand Up @@ -836,6 +853,10 @@ async def load_from_highest_stake(
self,
metagraph,
buckets,
optimizer,
scheduler,
is_validator: bool = False,
hparams=None,
) -> int:
"""
Attempts to load checkpoint from the highest stake neuron.
Expand Down Expand Up @@ -864,8 +885,10 @@ async def load_from_highest_stake(
filename=checkpoint_file,
model=self.model,
device=self.device,
optimizer=self.optimizer if self.optimizer else None,
scheduler=self.scheduler if self.scheduler else None,
optimizer=optimizer,
scheduler=scheduler,
is_validator=is_validator,
hparams=hparams
)
logger.info(f"Resumed from global step {global_step}")
return global_step if global_step is not None else 0
Expand Down

0 comments on commit 926e831

Please sign in to comment.