Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 19, 2024
1 parent 4f374d9 commit 2dfa7ae
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 29 deletions.
27 changes: 3 additions & 24 deletions sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
# mixed precision training
from torch.cuda.amp import GradScaler
from torch.nn.utils import clip_grad_norm_
from torch.profiler import profile, ProfilerActivity
from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import RSSMRollout
Expand Down Expand Up @@ -189,11 +188,7 @@ def compile_rssms(module):
with torch.autocast(
device_type=device.type,
dtype=torch.bfloat16,
) if use_autocast else contextlib.nullcontext(), (
profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA])
if (i == 1 and k == 1)
else contextlib.nullcontext()
) as prof:
) if use_autocast else contextlib.nullcontext():
model_loss_td, sampled_tensordict = world_model_loss(
sampled_tensordict
)
Expand All @@ -202,8 +197,6 @@ def compile_rssms(module):
+ model_loss_td["loss_model_reco"]
+ model_loss_td["loss_model_reward"]
)
if i == 1 and k == 1:
prof.export_chrome_trace("trace_world_model.json")

world_model_opt.zero_grad()
if use_autocast:
Expand All @@ -223,16 +216,9 @@ def compile_rssms(module):
t_loss_actor_init = time.time()
with torch.autocast(
device_type=device.type, dtype=torch.bfloat16
) if use_autocast else contextlib.nullcontext(), (
profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA])
if (i == 1 and k == 1)
else contextlib.nullcontext()
) as prof:
) if use_autocast else contextlib.nullcontext():
actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict)

if i == 1 and k == 1:
prof.export_chrome_trace("trace_actor.json")

actor_opt.zero_grad()
if use_autocast:
scaler2.scale(actor_loss_td["loss_actor"]).backward()
Expand All @@ -251,16 +237,9 @@ def compile_rssms(module):
t_loss_critic_init = time.time()
with torch.autocast(
device_type=device.type, dtype=torch.bfloat16
) if use_autocast else contextlib.nullcontext(), (
profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA])
if (i == 1 and k == 1)
else contextlib.nullcontext()
) as prof:
) if use_autocast else contextlib.nullcontext():
value_loss_td, sampled_tensordict = value_loss(sampled_tensordict)

if i == 1 and k == 1:
prof.export_chrome_trace("trace_critic.json")

value_opt.zero_grad()
if use_autocast:
scaler3.scale(value_loss_td["loss_value"]).backward()
Expand Down
10 changes: 5 additions & 5 deletions torchrl/modules/models/model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TensorDictSequential,
)
from torch import nn

from tensordict import LazyStackedTensorDict
# from torchrl.modules.tensordict_module.rnn import GRUCell
from torch.nn import GRUCell
from torchrl._utils import timeit
Expand Down Expand Up @@ -259,12 +259,12 @@ def forward(self, tensordict):

tensordict_out.append(_tensordict)
if t < time_steps - 1:
_tensordict = step_mdp(
_tensordict.select(*self.out_keys, strict=False), keep_other=False
)
_tensordict = _tensordict.select(*self.in_keys, strict=False)
_tensordict = update_values[t + 1].update(_tensordict)

return torch.stack(tensordict_out, tensordict.ndim - 1)
out = torch.stack(tensordict_out, tensordict.ndim - 1)
assert not any(isinstance(val, LazyStackedTensorDict) for val in out.values(True)), out
return out


class RSSMPrior(nn.Module):
Expand Down

0 comments on commit 2dfa7ae

Please sign in to comment.