Skip to content

Commit

Permalink
Update torchrl/objectives/dreamer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 22, 2024
1 parent 2dfa7ae commit 4e74969
Showing 1 changed file with 0 additions and 34 deletions.
34 changes: 0 additions & 34 deletions torchrl/objectives/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,40 +153,6 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor:
reward_loss = reward_loss.mean().unsqueeze(-1)
# import ipdb; ipdb.set_trace()

# Alternative:
# def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]:
# tensordict = tensordict.copy()
# tensordict.rename_key_(
# ("next", self.tensor_keys.reward),
# ("next", self.tensor_keys.true_reward),
# )
# tensordict = self.world_model(tensordict)
# # compute model loss
# kl_loss = self.kl_loss(
# tensordict.get(("next", self.tensor_keys.prior_mean)),
# tensordict.get(("next", self.tensor_keys.prior_std)),
# tensordict.get(("next", self.tensor_keys.posterior_mean)),
# tensordict.get(("next", self.tensor_keys.posterior_std)),
# )
#
# dist: IndependentNormal = self.decoder.get_dist(tensordict)
# reco_loss = -dist.log_prob(
# tensordict.get(("next", self.tensor_keys.pixels))
# ).mean()
# # x = tensordict.get(("next", self.tensor_keys.pixels))
# # loc = dist.base_dist.loc
# # scale = dist.base_dist.scale
# # reco_loss = -self.normal_log_probability(x, loc, scale).mean()
#
# dist: IndependentNormal = self.reward_model.get_dist(tensordict)
# reward_loss = -dist.log_prob(
# tensordict.get(("next", self.tensor_keys.true_reward))
# ).mean()
# # x = tensordict.get(("next", self.tensor_keys.true_reward))
# # loc = dist.base_dist.loc
# # scale = dist.base_dist.scale
# # reward_loss = -self.normal_log_probability(x, loc, scale).mean()

return (
TensorDict(
{
Expand Down

0 comments on commit 4e74969

Please sign in to comment.