Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Nov 22, 2023
1 parent c68fd40 commit 638c0d6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
7 changes: 3 additions & 4 deletions examples/impala/impala_multi_node_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,9 @@ def main(cfg: "DictConfig"): # noqa: F821
# Main loop
collected_frames = 0
num_network_updates = 0
start_time = time.time()
pbar = tqdm.tqdm(total=total_frames)
accumulator = []
sampling_start = time.time()
start_time = sampling_start = time.time()
for i, data in enumerate(collector):

log_info = {}
Expand Down Expand Up @@ -164,8 +163,8 @@ def main(cfg: "DictConfig"): # noqa: F821
for j in range(sgd_updates):

# Create a single batch of trajectories
stacked_data = torch.stack(accumulator, dim=0)
stacked_data = stacked_data.to(device)
stacked_data = torch.stack(accumulator, dim=0).contiguous()
stacked_data = stacked_data.to(device, non_blocking=True)

# Compute advantage
stacked_data = adv_module(stacked_data)
Expand Down
13 changes: 6 additions & 7 deletions examples/impala/impala_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def main(cfg: "DictConfig"): # noqa: F821
max_grad_norm = cfg.optim.max_grad_norm
num_test_episodes = cfg.logger.num_test_episodes
total_network_updates = (
total_frames // (frames_per_batch * batch_size)
) * cfg.loss.sgd_updates
total_frames // (frames_per_batch * batch_size)
) * cfg.loss.sgd_updates

# Create models (check utils_atari.py)
actor, critic = make_ppo_models(cfg.env.env_name)
Expand Down Expand Up @@ -116,10 +116,9 @@ def main(cfg: "DictConfig"): # noqa: F821
# Main loop
collected_frames = 0
num_network_updates = 0
start_time = time.time()
pbar = tqdm.tqdm(total=total_frames)
accumulator = []
sampling_start = time.time()
start_time = sampling_start = time.time()
for i, data in enumerate(collector):

log_info = {}
Expand All @@ -136,7 +135,7 @@ def main(cfg: "DictConfig"): # noqa: F821
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
/ len(episode_length),
}
)

Expand Down Expand Up @@ -183,7 +182,7 @@ def main(cfg: "DictConfig"): # noqa: F821
"loss_critic", "loss_entropy", "loss_objective"
).detach()
loss_sum = (
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
)

# Backward pass
Expand Down Expand Up @@ -212,7 +211,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
i * frames_in_batch * frame_skip
) // test_interval:
actor.eval()
eval_start = time.time()
Expand Down

0 comments on commit 638c0d6

Please sign in to comment.