diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index c354dd22364..118913699f9 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -128,10 +128,9 @@ def main(cfg: "DictConfig"): # noqa: F821 # Main loop collected_frames = 0 num_network_updates = 0 - start_time = time.time() pbar = tqdm.tqdm(total=total_frames) accumulator = [] - sampling_start = time.time() + start_time = sampling_start = time.time() for i, data in enumerate(collector): log_info = {} @@ -164,8 +163,8 @@ def main(cfg: "DictConfig"): # noqa: F821 for j in range(sgd_updates): # Create a single batch of trajectories - stacked_data = torch.stack(accumulator, dim=0) - stacked_data = stacked_data.to(device) + stacked_data = torch.stack(accumulator, dim=0).contiguous() + stacked_data = stacked_data.to(device, non_blocking=True) # Compute advantage stacked_data = adv_module(stacked_data) diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 47cf31cc6df..2cd1043f46f 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -47,8 +47,8 @@ def main(cfg: "DictConfig"): # noqa: F821 max_grad_norm = cfg.optim.max_grad_norm num_test_episodes = cfg.logger.num_test_episodes total_network_updates = ( - total_frames // (frames_per_batch * batch_size) - ) * cfg.loss.sgd_updates + total_frames // (frames_per_batch * batch_size) + ) * cfg.loss.sgd_updates # Create models (check utils_atari.py) actor, critic = make_ppo_models(cfg.env.env_name) @@ -116,10 +116,9 @@ def main(cfg: "DictConfig"): # noqa: F821 # Main loop collected_frames = 0 num_network_updates = 0 - start_time = time.time() pbar = tqdm.tqdm(total=total_frames) accumulator = [] - sampling_start = time.time() + start_time = sampling_start = time.time() for i, data in enumerate(collector): log_info = {} @@ -136,7 +135,7 @@ def main(cfg: "DictConfig"): # noqa: F821 { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() - / len(episode_length), + / len(episode_length), } ) @@ -183,7 +182,7 @@ def main(cfg: "DictConfig"): # noqa: F821 "loss_critic", "loss_entropy", "loss_objective" ).detach() loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] ) # Backward pass @@ -212,7 +211,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.MODE): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( - i * frames_in_batch * frame_skip + i * frames_in_batch * frame_skip ) // test_interval: actor.eval() eval_start = time.time()