From 5630fc825cc145a5a476e4d32456573b6af8c40b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 4 Feb 2025 16:34:40 +0000 Subject: [PATCH] [BugFix] Better deterministic sample for composite ghstack-source-id: 9dd872e39ed9b697412ac8618b870b4d94670293 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1205 --- tensordict/nn/distributions/composite.py | 50 ++++++++++++++++++------ tensordict/nn/probabilistic.py | 41 +++++++++++-------- 2 files changed, 62 insertions(+), 29 deletions(-) diff --git a/tensordict/nn/distributions/composite.py b/tensordict/nn/distributions/composite.py index c64f21002..53ec3441d 100644 --- a/tensordict/nn/distributions/composite.py +++ b/tensordict/nn/distributions/composite.py @@ -304,22 +304,41 @@ def maybe_deterministic_sample(dist): if hasattr(dist, "deterministic_sample"): return dist.deterministic_sample else: - try: - support = dist.support - fallback = ( - "mean" - if isinstance(support, torch.distributions.constraints._Real) - else "mode" + from tensordict.nn.probabilistic import ( + DETERMINISTIC_REGISTER, + ) + + # Fallbacks + tdist = type(dist) + if issubclass(tdist, d.Independent): + tdist = type(dist.base_dist) + interaction_type = DETERMINISTIC_REGISTER.get(tdist) + if interaction_type == "mode": + return dist.mode + if interaction_type == "mean": + return dist.mean + if interaction_type == "random": + return dist.rsample() if dist.has_rsample else dist.sample() + if interaction_type is None: + try: + support = dist.support + fallback = ( + "mean" + if isinstance(support, d.constraints._Real) + else "mode" + ) + except NotImplementedError: + # Some custom dists don't have a support + # We arbitrarily fall onto 'mean' in these cases + fallback = "mean" + else: + raise RuntimeError( + f"InteractionType {interaction_type} is unaccounted for." ) - except NotImplementedError: - # Some custom dists don't have a support - # We arbitrarily fall onto 'mean' in these cases - fallback = "mean" try: if fallback == "mean": return dist.mean elif fallback == "mode": - # Categorical dists don't have an average return dist.mode else: raise AttributeError @@ -329,7 +348,7 @@ def maybe_deterministic_sample(dist): ) finally: warnings.warn( - f"deterministic_sample wasn't found when queried in {type(dist)}. " + f"deterministic_sample wasn't found when queried on {type(dist)}. " f"{type(self).__name__} is falling back on {fallback} instead. " f"For better code quality and efficiency, make sure to either " f"provide a distribution with a deterministic_sample attribute or " @@ -423,7 +442,12 @@ def log_prob_composite( slp = 0.0 d = {} for name, dist in self.dists.items(): - d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name)) + try: + d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name)) + except AttributeError: + raise RuntimeError( + f"Expected a tensordict sample, but got a {type(sample).__name__} instead." + ) if include_sum: if lp.ndim > sample.ndim: lp = lp.flatten(sample.ndim, -1).sum(-1) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 9d61d330b..9ae1133f3 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -143,30 +143,35 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: class ProbabilisticTensorDictModule(TensorDictModuleBase): """A probabilistic TD Module. - `ProbabilisticTensorDictModule` is a non-parametric module representing a probability distribution. - It reads the distribution parameters from an input TensorDict using the specified `in_keys`. - The output is sampled given some rule, specified by the input :obj:`default_interaction_type` argument and the - :func:`~tensordict.nn.interaction_type` global function. + `ProbabilisticTensorDictModule` is a non-parametric module embedding a + probability distribution constructor. It reads the distribution parameters from an input + TensorDict using the specified `in_keys` and outputs a sample (loosely speaking) of the + distribution. - :obj:`ProbabilisticTensorDictModule` can be used to construct the distribution + The output "sample" is produced given some rule, specified by the input ``default_interaction_type`` + argument and the ``interaction_type()`` global function. + + `ProbabilisticTensorDictModule` can be used to construct the distribution (through the :meth:`~.get_dist` method) and/or sampling from this distribution - (through a regular :meth:`~.forward` to the module). + (through a regular :meth:`~.__call__` to the module). - A ``ProbabilisticTensorDictModule`` instance has two main features: + A `ProbabilisticTensorDictModule` instance has two main features: - - It reads and writes TensorDict objects + - It reads and writes from and to TensorDict objects; - It uses a real mapping R^n -> R^m to create a distribution in R^d from - which values can be sampled or computed. + which values can be sampled or computed. - When the :meth:`~.forward` method is called, a distribution is - created, and a value computed (using the 'mean', 'mode', 'median' attribute or - the 'rsample', 'sample' method). The sampling step is skipped if the supplied - TensorDict has all of the desired key-value pairs already. + When the :meth:`~.__call__` and :meth:`~.forward` method are called, a distribution is + created, and a value computed (depending on the ``interaction_type`` value, 'dist.mean', + 'dist.mode', 'dist.median' attributes could be used, as well as + the 'dist.rsample', 'dist.sample' method). The sampling step is skipped if the supplied + TensorDict has all the desired key-value pairs already. - By default, the ``ProbabilisticTensorDictModule`` distribution class is a ``Delta`` - distribution, making ``ProbabilisticTensorDictModule`` a simple wrapper around + By default, `ProbabilisticTensorDictModule` distribution class is a :class:`~torchrl.modules.distributions.Delta` + distribution, making `ProbabilisticTensorDictModule` a simple wrapper around a deterministic mapping function. + Args: in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]): key(s) that will be read from the input TensorDict and used to build the distribution. @@ -668,7 +673,10 @@ def _dist_sample( return dist.deterministic_sample else: # Fallbacks - interaction_type = DETERMINISTIC_REGISTER.get(type(dist)) + tdist = type(dist) + if issubclass(tdist, D.Independent): + tdist = type(dist.base_dist) + interaction_type = DETERMINISTIC_REGISTER.get(tdist) if interaction_type is None: try: support = dist.support @@ -971,6 +979,7 @@ def __init__( ) else: modules_list = list(modules) + modules_list = self._convert_modules(modules_list) # if the modules not including the final probabilistic module return the sampled # key we won't be sampling it again, in that case