From 3af843abc369fe1fb7f6d4a733e19a86b450b9fa Mon Sep 17 00:00:00 2001 From: ChenDRAG <40993476+ChenDRAG@users.noreply.github.com> Date: Mon, 11 Apr 2022 22:42:49 +0800 Subject: [PATCH 1/3] Fix action scaling bug in sac Solve issue in #588 --- tianshou/policy/modelfree/sac.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index ac5b3e07a..81444d508 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -121,15 +121,9 @@ def forward( # type: ignore # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. - if self.action_scaling and self.action_space is not None: - low, high = self.action_space.low, self.action_space.high # type: ignore - action_scale = to_torch_as((high - low) / 2.0, act) - else: - action_scale = 1.0 # type: ignore squashed_action = torch.tanh(act) log_prob = log_prob - torch.log( - action_scale * (1 - squashed_action.pow(2)) + self.__eps - ).sum(-1, keepdim=True) + (1 - squashed_action.pow(2)) + self.__eps).sum(-1, keepdim=True) return Batch( logits=logits, act=squashed_action, From bfaf9f0edeb45025210efc91fa138ed998cabf2c Mon Sep 17 00:00:00 2001 From: ChenDRAG <40993476+ChenDRAG@users.noreply.github.com> Date: Mon, 11 Apr 2022 22:51:45 +0800 Subject: [PATCH 2/3] fix --- tianshou/policy/modelfree/sac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 81444d508..14404767b 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -5,7 +5,7 @@ import torch from torch.distributions import Independent, Normal -from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.data import Batch, ReplayBuffer from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy From aee24fb56f76a7d1d9128fc8382a0bdd367ad3f2 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Mon, 11 Apr 2022 11:03:37 -0400 Subject: [PATCH 3/3] fix ci --- tianshou/policy/modelfree/sac.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 14404767b..abe707d01 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -122,8 +122,8 @@ def forward( # type: ignore # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. squashed_action = torch.tanh(act) - log_prob = log_prob - torch.log( - (1 - squashed_action.pow(2)) + self.__eps).sum(-1, keepdim=True) + log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + + self.__eps).sum(-1, keepdim=True) return Batch( logits=logits, act=squashed_action,