-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mixed double precision for PPO algorithm #155
base: develop
Are you sure you want to change the base?
Conversation
9c98abe
to
06fbf2e
Compare
@@ -388,55 +398,62 @@ def compute_gae(rewards: torch.Tensor, | |||
# mini-batches loop | |||
for sampled_states, sampled_actions, sampled_log_prob, sampled_values, sampled_returns, sampled_advantages in sampled_batches: | |||
|
|||
sampled_states = self._state_preprocessor(sampled_states, train=not epoch) | |||
with torch.autocast(device_type=self._device_type, enabled=self._mixed_precision): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to apply autocast
to
skrl/skrl/agents/torch/ppo/ppo.py
Lines 369 to 373 in c15f3ce
with torch.no_grad(): | |
self.value.train(False) | |
last_values, _, _ = self.value.act({"states": self._state_preprocessor(self._current_next_states.float())}, role="value") | |
self.value.train(True) | |
last_values = self._value_preprocessor(last_values, inverse=True) |
@@ -219,8 +227,9 @@ def act(self, states: torch.Tensor, timestep: int, timesteps: int) -> torch.Tens | |||
return self.policy.random_act({"states": self._state_preprocessor(states)}, role="policy") | |||
|
|||
# sample stochastic actions | |||
actions, log_prob, outputs = self.policy.act({"states": self._state_preprocessor(states)}, role="policy") | |||
self._current_log_prob = log_prob | |||
with torch.autocast(device_type=self._device_type, enabled=(self._mixed_precision)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why self._mixed_precision
in all with torch.autocast(device_type=self._device_type, enabled=(self._mixed_precision)):
statements is between ()
?
Mixed precision
Motivation:
Inspired by RLGames, we implemented automatic mixed double precision to boost performance of PPO.
Sources:
https://pytorch.org/docs/stable/amp.html
https://pytorch.org/docs/stable/notes/amp_examples.html
Speed eval:
Big neural network (units: [2048, 1024, 1024, 512])
10000 steps
Running on top of Oige env simulation (constant for each run)
Skrl uses single forward pass implementation
* in this run mixed precision was used also for inference during data collection phase
Quality eval: