Skip to content

Commit

Permalink
[BugFix] Better deterministic sample for composite
Browse files Browse the repository at this point in the history
ghstack-source-id: 9dd872e39ed9b697412ac8618b870b4d94670293
Pull Request resolved: #1205
  • Loading branch information
vmoens committed Feb 4, 2025
1 parent e900b24 commit 5630fc8
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 29 deletions.
50 changes: 37 additions & 13 deletions tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,22 +304,41 @@ def maybe_deterministic_sample(dist):
if hasattr(dist, "deterministic_sample"):
return dist.deterministic_sample
else:
try:
support = dist.support
fallback = (
"mean"
if isinstance(support, torch.distributions.constraints._Real)
else "mode"
from tensordict.nn.probabilistic import (
DETERMINISTIC_REGISTER,
)

# Fallbacks
tdist = type(dist)
if issubclass(tdist, d.Independent):
tdist = type(dist.base_dist)
interaction_type = DETERMINISTIC_REGISTER.get(tdist)
if interaction_type == "mode":
return dist.mode
if interaction_type == "mean":
return dist.mean
if interaction_type == "random":
return dist.rsample() if dist.has_rsample else dist.sample()
if interaction_type is None:
try:
support = dist.support
fallback = (
"mean"
if isinstance(support, d.constraints._Real)
else "mode"
)
except NotImplementedError:
# Some custom dists don't have a support
# We arbitrarily fall onto 'mean' in these cases
fallback = "mean"
else:
raise RuntimeError(
f"InteractionType {interaction_type} is unaccounted for."
)
except NotImplementedError:
# Some custom dists don't have a support
# We arbitrarily fall onto 'mean' in these cases
fallback = "mean"
try:
if fallback == "mean":
return dist.mean
elif fallback == "mode":
# Categorical dists don't have an average
return dist.mode
else:
raise AttributeError
Expand All @@ -329,7 +348,7 @@ def maybe_deterministic_sample(dist):
)
finally:
warnings.warn(
f"deterministic_sample wasn't found when queried in {type(dist)}. "
f"deterministic_sample wasn't found when queried on {type(dist)}. "
f"{type(self).__name__} is falling back on {fallback} instead. "
f"For better code quality and efficiency, make sure to either "
f"provide a distribution with a deterministic_sample attribute or "
Expand Down Expand Up @@ -423,7 +442,12 @@ def log_prob_composite(
slp = 0.0
d = {}
for name, dist in self.dists.items():
d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name))
try:
d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name))
except AttributeError:
raise RuntimeError(
f"Expected a tensordict sample, but got a {type(sample).__name__} instead."
)
if include_sum:
if lp.ndim > sample.ndim:
lp = lp.flatten(sample.ndim, -1).sum(-1)
Expand Down
41 changes: 25 additions & 16 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,30 +143,35 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
class ProbabilisticTensorDictModule(TensorDictModuleBase):
"""A probabilistic TD Module.
`ProbabilisticTensorDictModule` is a non-parametric module representing a probability distribution.
It reads the distribution parameters from an input TensorDict using the specified `in_keys`.
The output is sampled given some rule, specified by the input :obj:`default_interaction_type` argument and the
:func:`~tensordict.nn.interaction_type` global function.
`ProbabilisticTensorDictModule` is a non-parametric module embedding a
probability distribution constructor. It reads the distribution parameters from an input
TensorDict using the specified `in_keys` and outputs a sample (loosely speaking) of the
distribution.
:obj:`ProbabilisticTensorDictModule` can be used to construct the distribution
The output "sample" is produced given some rule, specified by the input ``default_interaction_type``
argument and the ``interaction_type()`` global function.
`ProbabilisticTensorDictModule` can be used to construct the distribution
(through the :meth:`~.get_dist` method) and/or sampling from this distribution
(through a regular :meth:`~.forward` to the module).
(through a regular :meth:`~.__call__` to the module).
A ``ProbabilisticTensorDictModule`` instance has two main features:
A `ProbabilisticTensorDictModule` instance has two main features:
- It reads and writes TensorDict objects
- It reads and writes from and to TensorDict objects;
- It uses a real mapping R^n -> R^m to create a distribution in R^d from
which values can be sampled or computed.
which values can be sampled or computed.
When the :meth:`~.forward` method is called, a distribution is
created, and a value computed (using the 'mean', 'mode', 'median' attribute or
the 'rsample', 'sample' method). The sampling step is skipped if the supplied
TensorDict has all of the desired key-value pairs already.
When the :meth:`~.__call__` and :meth:`~.forward` method are called, a distribution is
created, and a value computed (depending on the ``interaction_type`` value, 'dist.mean',
'dist.mode', 'dist.median' attributes could be used, as well as
the 'dist.rsample', 'dist.sample' method). The sampling step is skipped if the supplied
TensorDict has all the desired key-value pairs already.
By default, the ``ProbabilisticTensorDictModule`` distribution class is a ``Delta``
distribution, making ``ProbabilisticTensorDictModule`` a simple wrapper around
By default, `ProbabilisticTensorDictModule` distribution class is a :class:`~torchrl.modules.distributions.Delta`
distribution, making `ProbabilisticTensorDictModule` a simple wrapper around
a deterministic mapping function.
Args:
in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]): key(s) that will be read from the input TensorDict
and used to build the distribution.
Expand Down Expand Up @@ -668,7 +673,10 @@ def _dist_sample(
return dist.deterministic_sample
else:
# Fallbacks
interaction_type = DETERMINISTIC_REGISTER.get(type(dist))
tdist = type(dist)
if issubclass(tdist, D.Independent):
tdist = type(dist.base_dist)
interaction_type = DETERMINISTIC_REGISTER.get(tdist)
if interaction_type is None:
try:
support = dist.support
Expand Down Expand Up @@ -971,6 +979,7 @@ def __init__(
)
else:
modules_list = list(modules)
modules_list = self._convert_modules(modules_list)

# if the modules not including the final probabilistic module return the sampled
# key we won't be sampling it again, in that case
Expand Down

0 comments on commit 5630fc8

Please sign in to comment.