Skip to content

Commit

Permalink
Merge pull request #31 from tplr-ai/feat/temp_softmax
Browse files Browse the repository at this point in the history
Feat/temp softmax
  • Loading branch information
distributedstatemachine authored Jan 17, 2025
2 parents 662db2f + 8e26190 commit 0c53fce
Show file tree
Hide file tree
Showing 12 changed files with 673 additions and 544 deletions.
19 changes: 19 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@ uv.lock
*.swo
.DS_Store

# Node.js Specific
node_modules/
npm-debug.log*
yarn-debug.log*
yarn-error.log*
.pnpm-debug.log*
.npm
.yarn
.yarn-integrity
package.json
package-lock.json
yarn.lock
pnpm-lock.yaml
.env.local
.env.development.local
.env.test.local
.env.production.local


# Project specific
wandb
*ipynb
Expand Down
38 changes: 38 additions & 0 deletions ecosystem.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
require('dotenv').config({ path: '.env' });
const RANDOM_SUFFIX = require('child_process').execSync("cat /dev/urandom | tr -dc 'a-z0-9' | fold -w 4 | head -n 1").toString().trim();
const PROJECT_NAME = `test_${RANDOM_SUFFIX}`;

module.exports = {
apps: [
{
name: "TM1",
script: "neurons/miner.py",
interpreter: "python3",
env: {
...process.env,
PROJECT_NAME: PROJECT_NAME
},
args: `--wallet.name Bistro --wallet.hotkey M1 --device cuda:3 --subtensor.network ws://127.0.0.1:9945 --debug --netuid 1 --use_wandb --project "${PROJECT_NAME}"`
},
{
name: "TM2",
script: "neurons/miner.py",
interpreter: "python3",
env: {
...process.env,
PROJECT_NAME: PROJECT_NAME
},
args: `--wallet.name Bistro --wallet.hotkey M2 --device cuda:1 --subtensor.network ws://127.0.0.1:9945 --debug --netuid 1 --use_wandb --project "${PROJECT_NAME}"`
},
{
name: "TV1",
script: "neurons/validator.py",
interpreter: "python3",
env: {
...process.env,
PROJECT_NAME: PROJECT_NAME
},
args: `--wallet.name Bistro --wallet.hotkey V1 --device cuda:2 --subtensor.network ws://127.0.0.1:9945 --debug --netuid 1 --use_wandb --project "${PROJECT_NAME}"`
}
]
}
5 changes: 4 additions & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ lint:
ruff format .

# Run both check and format in a single command
fix: lint
fix: lint

test-run:
./scripts/start.sh && pm2 log TV1
35 changes: 21 additions & 14 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,25 +173,35 @@ async def run(self):
if result:
checkpoint_data, window = result
try:
# Load state dicts from dictionary and move to device
# Load state dicts from checkpoint data
self.model.load_state_dict({k: v.to(self.config.device) for k,v in checkpoint_data['model_state_dict'].items()})
self.model.to(self.config.device)

# Move optimizer state to device
# Load optimizer state
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(self.config.device)

self.optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])

# Load scheduler state
self.scheduler.load_state_dict(checkpoint_data['scheduler_state_dict'])

# Load momentum and global_step
self.momentum = checkpoint_data['momentum']
self.global_step = checkpoint_data['global_step']

# Update optimizer and scheduler steps to match
self.optimizer._step_count = self.global_step
self.scheduler.last_epoch = self.global_step

# Adjust scheduler to catch up with current window
checkpoint_window = checkpoint_data.get('checkpoint_window', None)
if checkpoint_window is not None:
window_difference = self.current_window - checkpoint_window
if window_difference > 0:
for _ in range(window_difference):
self.scheduler.step()
tplr.logger.info(f"Stepped scheduler {window_difference} times to catch up with current window {self.current_window}")
else:
tplr.logger.warning("Checkpoint does not contain 'checkpoint_window'; cannot adjust scheduler")

tplr.logger.info(f"Loaded checkpoint from window {window}, global_step={self.global_step}")
except KeyError as e:
tplr.logger.error(f"Invalid checkpoint format: missing key {e}")
Expand All @@ -217,14 +227,15 @@ async def run(self):
target=self.block_listener,
args=(self.loop,),
daemon=True,
).start()
)
self.listener.start() #
self.comms.start_commitment_fetcher()
self.comms.start_background_tasks()

while True:
step_window = self.current_window
tplr.logger.info(f"\n{'-' * 40} Window: {step_window} {'-' * 40}")
self.comms.update_peers_with_buckets()
# self.comms.update_peers_with_buckets()
# Update local references
self.peers = self.comms.peers

Expand Down Expand Up @@ -367,10 +378,7 @@ async def run(self):
if max_global_step > self.global_step:
tplr.logger.info(f"Updating global_step from {self.global_step} to {max_global_step}")
self.global_step = max_global_step
# Update optimizer and scheduler steps
self.optimizer._step_count = self.global_step
self.scheduler.last_epoch = self.global_step


# Decompress state and apply to grad.
for n, p in self.model.named_parameters():
idxs_key = n + 'idxs'
Expand All @@ -391,7 +399,6 @@ async def run(self):
vals,
xshapes[n],
totalks[n],
# median=True
)
)
# Set recomputed gathered gradient.
Expand Down
107 changes: 49 additions & 58 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import asyncio
import argparse
import threading
import os
from contextlib import contextmanager
from time import perf_counter

Expand Down Expand Up @@ -212,24 +211,34 @@ async def run(self):
if result:
checkpoint_data, window = result
try:
# Load state dicts from dictionary and move to device
# Load state dicts from checkpoint data
self.model.load_state_dict({k: v.to(self.config.device) for k,v in checkpoint_data['model_state_dict'].items()})
self.model.to(self.config.device)

# Move optimizer state to device
# Load optimizer state
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(self.config.device)

self.optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])

# Load scheduler state
self.scheduler.load_state_dict(checkpoint_data['scheduler_state_dict'])

# Load momentum and global_step
self.momentum = checkpoint_data['momentum']
self.global_step = checkpoint_data['global_step']

# Update optimizer and scheduler steps to match
self.optimizer._step_count = self.global_step
self.scheduler.last_epoch = self.global_step
# Adjust scheduler to catch up with current window
checkpoint_window = checkpoint_data.get('checkpoint_window', None)
if checkpoint_window is not None:
window_difference = self.current_window - checkpoint_window
if window_difference > 0:
for _ in range(window_difference):
self.scheduler.step()
tplr.logger.info(f"Stepped scheduler {window_difference} times to catch up with current window {self.current_window}")
else:
tplr.logger.warning("Checkpoint does not contain 'checkpoint_window'; cannot adjust scheduler")

tplr.logger.info(f"Loaded checkpoint from window {window}, global_step={self.global_step}")
except KeyError as e:
Expand All @@ -254,17 +263,14 @@ async def run(self):
# self.comms.track_active_peers()

while True:
step_window = self.current_window

tplr.logger.info(f'Step window: {step_window}, Scheduler epoch: {self.scheduler.last_epoch}, Global step: {self.global_step}')
# 1. Wait for validator offset - single wait loop
while self.sync_window >= (self.current_window - self.hparams.validator_offset):
tplr.logger.info(f'Waiting for validator window offset, synced: {self.sync_window}, current:{self.current_window}, offset:{self.hparams.validator_offset}')
await asyncio.sleep(12)
tplr.logger.info(f'Step window: {step_window}, Scheduler epoch: {self.scheduler.last_epoch}, Global step: {self.global_step}')
tplr.logger.info(f'Sync Window: {self.sync_window}, Scheduler epoch: {self.scheduler.last_epoch}, Global step: {self.global_step}')
# 2. Process one window at a time
self.sync_window += 1
step_window = self.sync_window + 1
tplr.logger.info(f'Processing window: {self.sync_window} current: {self.current_window}')

self.comms.update_peers_with_buckets()
Expand All @@ -281,7 +287,7 @@ async def run(self):
state_dict=None,
my_uid=self.uid,
uids=self.peers,
window=step_window,
window=self.sync_window,
key='gradient',
timeout=5,
device=self.config.device,
Expand All @@ -300,7 +306,7 @@ async def run(self):
# Get individual miner's gradient
eval_result = await self.comms.get(
uid=str(eval_uid),
window=step_window,
window=self.sync_window,
key='gradient',
timeout=30,
local=False,
Expand Down Expand Up @@ -370,7 +376,6 @@ async def run(self):
vals,
self.xshapes[n],
self.totalks[n],
# median=False
)
).to(self.config.device) # Ensure final gradient is on correct device

Expand Down Expand Up @@ -420,18 +425,25 @@ async def run(self):
self.scores[eval_uid] = score
self.moving_avg_scores[eval_uid] = self.ma_alpha * self.moving_avg_scores[eval_uid] + (1 - self.ma_alpha) * score

# Calculate weights - only positive moving averages get weights
# Calculate weights using temperature-based softmax
weights = torch.zeros_like(self.moving_avg_scores)
evaluated_mask = torch.zeros_like(self.moving_avg_scores, dtype=torch.bool)
evaluated_mask[list(self.evaluated_uids)] = True

# Only consider positive moving averages for weight calculation
# Create mask for positive scores
positive_mask = (self.moving_avg_scores > 0) & evaluated_mask
evaluated_scores = self.moving_avg_scores * positive_mask

total_score = evaluated_scores.sum()
if total_score > 0:
weights[positive_mask] = evaluated_scores[positive_mask] / total_score

if positive_mask.any():
# Only apply softmax to positive scores
positive_scores = self.moving_avg_scores[positive_mask]
temperature = 0.1 # Lower temperature = sharper distribution
positive_weights = torch.softmax(positive_scores / temperature, dim=0)

# Assign weights back to the original tensor
weights[positive_mask] = positive_weights
else:
# If no positive scores, all weights remain 0
tplr.logger.info("No positive scores found, all weights set to 0")

# Log only evaluated UIDs
tplr.logger.info('Updated scores for evaluated UIDs:')
Expand Down Expand Up @@ -486,47 +498,28 @@ async def run(self):
if self.global_step % self.hparams.checkpoint_frequency == 0:
tplr.logger.info(f"Creating checkpoint at global_step {self.global_step}")

# Create CPU copy of the checkpoint data to avoid GPU memory competition
# Create CPU copy of the checkpoint data
checkpoint_data = {
'model_state_dict': {k: v.cpu().clone() for k, v in self.model.state_dict().items()},
'optimizer_state_dict': {k: v.cpu().clone() if torch.is_tensor(v) else v
for k, v in self.optimizer.state_dict().items()},
'scheduler_state_dict': self.scheduler.state_dict(),
'momentum': {k: v.cpu().clone() for k, v in self.momentum.items()},
'global_step': self.global_step
'global_step': self.global_step,
'checkpoint_window': self.current_window
}

async def _save():
start_time = time.time()
try:
# Use a separate thread for CPU-intensive serialization
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: torch.save(checkpoint_data, '/tmp/temp_checkpoint.pt'))

await self.comms.put(
state_dict=checkpoint_data,
uid=str(self.uid),
window=self.current_window,
key='checkpoint',
global_step=self.global_step,
local=False
)
elapsed_time = time.time() - start_time
tplr.logger.info(f"Successfully saved checkpoint at global_step {self.global_step} (took {elapsed_time:.2f}s)")

self.wandb.log({
"validator/save_time": elapsed_time,
"validator/global_step": self.global_step,
}, step=self.global_step)

except Exception as e:
tplr.logger.error(f"Failed to save checkpoint: {e}")
finally:
# Cleanup temp file
if os.path.exists('/tmp/temp_checkpoint.pt'):
os.remove('/tmp/temp_checkpoint.pt')

asyncio.create_task(_save())
# Launch checkpoint saving as a background task
asyncio.create_task(
self.comms.put(
state_dict=checkpoint_data,
uid=str(self.uid),
window=self.current_window,
key='checkpoint',
global_step=self.global_step,
local=False
)
)

# Now apply the gathered gradients
if gather_result is not None:
Expand All @@ -535,8 +528,7 @@ async def _save():
if max_global_step > self.global_step:
tplr.logger.info(f"Updating global_step from {self.global_step} to {max_global_step}")
self.global_step = max_global_step
self.optimizer._step_count = self.global_step
self.scheduler.last_epoch = self.global_step


with timer("update_model_with_gathered", self.wandb, self.global_step):
self.optimizer.zero_grad()
Expand Down Expand Up @@ -581,12 +573,11 @@ async def _save():

# Increment global_step
self.global_step += 1
self.optimizer._step_count = self.global_step # Ensure optimizer's step count matches

# Log steps to wandb
self.wandb.log({
"validator/global_step": self.global_step,
"validator/optimizer_step_count": self.optimizer._step_count,
# "validator/optimizer_step_count": self.optimizer._step_count,
"validator/scheduler_last_epoch": self.scheduler.last_epoch,
}, step=self.global_step)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"pip",
"wandb",
"python-dotenv",
"zstandard"
]

[tool.pytest.ini_options]
Expand Down
2 changes: 1 addition & 1 deletion scripts/benchmarks/buckets.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def main():
results = {"AWS": [], "CF": []}
for _ in trange(10):
for platform in ("AWS", "CF"):
logger.info(f"{'-'*10} Starting benchmarks for {platform} {'-'*10}")
logger.info(f"{'-' * 10} Starting benchmarks for {platform} {'-' * 10}")
bucket, client = await get_bucket_and_client(platform)
durations = await benchmark(client, bucket)
results[platform].append(durations)
Expand Down
Loading

0 comments on commit 0c53fce

Please sign in to comment.