From e294c68ca8ac1794b19398b07a1cc42cca586ea1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 17 Sep 2024 14:38:33 -0700 Subject: [PATCH] [Feature] Deterministic sample for Masked one-hot ghstack-source-id: 27787eab47324c5af152f706d81687e71b5b9803 Pull Request resolved: https://github.com/pytorch/rl/pull/2440 --- torchrl/modules/distributions/discrete.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index c48d8168887..d2ffba30686 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -389,6 +389,17 @@ def sample( ) -> torch.Tensor: ... + @property + def deterministic_sample(self): + return self.mode + + @property + def mode(self) -> torch.Tensor: + if hasattr(self, "logits"): + return (self.logits == self.logits.max(-1, True)[0]).to(torch.long) + else: + return (self.probs == self.probs.max(-1, True)[0]).to(torch.long) + def log_prob(self, value: torch.Tensor) -> torch.Tensor: return super().log_prob(value.argmax(dim=-1))