Skip to content

Commit

Permalink
[BugFix] Fix safe probabilistic backward by removing in-place modif
Browse files Browse the repository at this point in the history
ghstack-source-id: 574eb1f9b662c1eb5be25e97020e11b3fadf625e
Pull Request resolved: pytorch#2755
  • Loading branch information
vmoens committed Feb 4, 2025
1 parent ee4006a commit 2f8c118
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 83 deletions.
27 changes: 11 additions & 16 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@
unravel_key,
)
from tensordict.base import NO_DEFAULT
from tensordict.utils import _getitem_batch_size, is_non_tensor, NestedKey
from tensordict.utils import (
_getitem_batch_size,
expand_as_right,
is_non_tensor,
NestedKey,
)
from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for

try:
Expand Down Expand Up @@ -1848,9 +1853,8 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
gathered = mask_expand & val
oob = ~gathered.any(-1)
new_val = torch.multinomial(mask_expand[oob].float(), 1)
val = val.clone()
val[oob] = 0
val[oob] = torch.scatter(val[oob], -1, new_val, 1)
new_val = torch.scatter(torch.zeros_like(val[oob]), -1, new_val, 1)
val = val.masked_scatter(expand_as_right(oob, val), new_val)
return val

def is_in(self, val: torch.Tensor) -> bool:
Expand Down Expand Up @@ -2300,18 +2304,9 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
if self.device != val.device:
low = low.to(val.device)
high = high.to(val.device)
try:
val = torch.maximum(torch.minimum(val, high), low)
except ValueError:
low = low.expand_as(val)
high = high.expand_as(val)
val[val < low] = low[val < low]
val[val > high] = high[val > high]
except RuntimeError:
low = low.expand_as(val)
high = high.expand_as(val)
val[val < low] = low[val < low]
val[val > high] = high[val > high]
low = low.expand_as(val)
high = high.expand_as(val)
val = torch.clamp(val, low, high)
return val

def is_in(self, val: torch.Tensor) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9035,7 +9035,7 @@ def _reset(


class BatchSizeTransform(Transform):
"""A transform to modify the batch-size of an environmt.
"""A transform to modify the batch-size of an environment.
This transform has two distinct usages: it can be used to set the
batch-size for non-batch-locked (e.g. stateless) environments to
Expand Down
17 changes: 5 additions & 12 deletions torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,17 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out):
for _spec, _key in zip(values, keys):
if _spec is None:
continue
item = tensordict_out.get(_key, None)
item = tensordict_out.get(_key)
if item is None:
# this will happen when an exploration (e.g. OU) writes a key only
# during exploration, but is missing otherwise.
# it's fine since what we want here it to make sure that a key
# is within bounds if it is present
continue
if not _spec.is_in(item):
try:
tensordict_out.set_(
_key,
_spec.project(tensordict_out.get(_key)),
)
except RuntimeError:
tensordict_out.set(
_key,
_spec.project(tensordict_out.get(_key)),
)
tensordict_out.set(
_key,
_spec.project(item),
)
except RuntimeError as err:
if re.search(
"attempting to use a Tensor in some data-dependent control flow", str(err)
Expand Down
198 changes: 144 additions & 54 deletions torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from __future__ import annotations

import warnings
from typing import Dict, List, Optional, Type, Union
from typing import Dict, List, Optional, Union

import torch

from tensordict import TensorDictBase, unravel_key_list

Expand All @@ -23,95 +25,181 @@


class SafeProbabilisticModule(ProbabilisticTensorDictModule):
""":class:`tensordict.nn.ProbabilisticTensorDictModule` subclass that accepts a :class:`~torchrl.envs.TensorSpec` as argument to control the output domain.
""":class:`tensordict.nn.ProbabilisticTensorDictModule` subclass that accepts a :class:`~torchrl.envs.TensorSpec` as an argument to control the output domain.
`SafeProbabilisticModule` is a non-parametric module embedding a
probability distribution constructor. It reads the distribution parameters from an input
TensorDict using the specified `in_keys` and outputs a sample (loosely speaking) of the
distribution.
The output "sample" is produced given some rule, specified by the input ``default_interaction_type``
argument and the ``interaction_type()`` global function.
`SafeProbabilisticModule` is a non-parametric module representing a
probability distribution. It reads the distribution parameters from an input
TensorDict using the specified `in_keys`. The output is sampled given some rule,
specified by the input ``default_interaction_type`` argument and the
``interaction_type()`` global function.
`SafeProbabilisticModule` can be used to construct the distribution
(through the :meth:`~.get_dist` method) and/or sampling from this distribution
(through a regular :meth:`~.__call__` to the module).
:obj:`SafeProbabilisticModule` can be used to construct the distribution
(through the :obj:`get_dist()` method) and/or sampling from this distribution
(through a regular :obj:`__call__()` to the module).
A `SafeProbabilisticModule` instance has two main features:
A :obj:`SafeProbabilisticModule` instance has two main features:
- It reads and writes TensorDict objects
- It reads and writes from and to TensorDict objects;
- It uses a real mapping R^n -> R^m to create a distribution in R^d from
which values can be sampled or computed.
which values can be sampled or computed.
When the :obj:`__call__` / :obj:`forward` method is called, a distribution is
created, and a value computed (using the 'mean', 'mode', 'median' attribute or
the 'rsample', 'sample' method). The sampling step is skipped if the supplied
TensorDict has all of the desired key-value pairs already.
When the :meth:`~.__call__` and :meth:`~.forward` method are called, a distribution is
created, and a value computed (depending on the ``interaction_type`` value, 'dist.mean',
'dist.mode', 'dist.median' attributes could be used, as well as
the 'dist.rsample', 'dist.sample' method). The sampling step is skipped if the supplied
TensorDict has all the desired key-value pairs already.
By default, SafeProbabilisticModule distribution class is a Delta
distribution, making SafeProbabilisticModule a simple wrapper around
By default, `SafeProbabilisticModule` distribution class is a :class:`~torchrl.modules.distributions.Delta`
distribution, making `SafeProbabilisticModule` a simple wrapper around
a deterministic mapping function.
This class differs from :class:`tensordict.nn.ProbabilisticTensorDictModule` in that it accepts a :attr:`spec`
keyword argument which can be used to control whether samples belong to the distribution or not. The :attr:`safe`
keyword argument controls whether the samples values should be checked against the spec.
Args:
in_keys (NestedKey or list of NestedKey or dict): key(s) that will be read from the
input TensorDict and used to build the distribution. Importantly, if it's an
list of NestedKey or a NestedKey, the leaf (last element) of those keys must match the keywords used by
the distribution class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for
the Normal distribution and similar. If in_keys is a dictionary, the keys
are the keys of the distribution and the values are the keys in the
in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]): key(s) that will be read from the input TensorDict
and used to build the distribution.
Importantly, if it's a list of NestedKey or a NestedKey, the leaf (last element) of those keys must match the keywords used by
the distribution class of interest, e.g. ``"loc"`` and ``"scale"`` for
the :class:`~torch.distributions.Normal` distribution and similar.
If in_keys is a dictionary, the keys are the keys of the distribution and the values are the keys in the
tensordict that will get match to the corresponding distribution keys.
out_keys (NestedKey or list of NestedKey): keys where the sampled values will be
written. Importantly, if these keys are found in the input TensorDict, the
sampling step will be skipped.
out_keys (NestedKey | List[NestedKey] | None): key(s) where the sampled values will be written.
Importantly, if these keys are found in the input TensorDict, the sampling step will be skipped.
spec (TensorSpec): specs of the first output tensor. Used when calling
td_module.random() to generate random values in the target space.
Keyword Args:
safe (bool, optional): if ``True``, the value of the sample is checked against the
input spec. Out-of-domain sampling can occur because of exploration policies
or numerical under/overflow issues. As for the :obj:`spec` argument, this
check will only occur for the distribution sample, but not the other tensors
returned by the input module. If the sample is out of bounds, it is
projected back onto the desired space using the `TensorSpec.project` method.
Default is ``False``.
default_interaction_type (tensordict.nn.InteractionType, optional): default method to be used to retrieve
the output value. Should be one of: ``InteractionType.MODE``, ``InteractionType.MEDIAN``, ``InteractionType.MEAN`` or ``InteractionType.RANDOM``
default_interaction_type (InteractionType, optional): keyword-only argument.
Default method to be used to retrieve
the output value. Should be one of InteractionType: MODE, MEDIAN, MEAN or RANDOM
(in which case the value is sampled randomly from the distribution). Default
is ``InteractionType.MODE``.
Note: When a sample is drawn, the :obj:`ProbabilisticTDModule` instance will
fist look for the interaction mode dictated by the `interaction_type()`
global function. If this returns `None` (its default value), then the
`default_interaction_type` of the :class:`~.ProbabilisticTDModule`
instance will be used. Note that DataCollector instances will use
:func:`tensordict.nn.set_interaction_type` to
:class:`tensordict.nn.InteractionType.RANDOM` by default.
distribution_class (Type, optional): a torch.distributions.Distribution class to
be used for sampling. Default is Delta.
distribution_kwargs (dict, optional): kwargs to be passed to the distribution.
return_log_prob (bool, optional): if ``True``, the log-probability of the
is MODE.
.. note:: When a sample is drawn, the
:class:`ProbabilisticTensorDictModule` instance will
first look for the interaction mode dictated by the
:func:`~tensordict.nn.probabilistic.interaction_type`
global function. If this returns `None` (its default value), then the
`default_interaction_type` of the `ProbabilisticTDModule`
instance will be used. Note that
:class:`~torchrl.collectors.collectors.DataCollectorBase`
instances will use `set_interaction_type` to
:class:`tensordict.nn.InteractionType.RANDOM` by default.
.. note::
In some cases, the mode, median or mean value may not be
readily available through the corresponding attribute.
To paliate this, :class:`~ProbabilisticTensorDictModule` will first attempt
to get the value through a call to ``get_mode()``, ``get_median()`` or ``get_mean()``
if the method exists.
distribution_class (Type or Callable[[Any], Distribution], optional): keyword-only argument.
A :class:`torch.distributions.Distribution` class to
be used for sampling.
Default is :class:`~tensordict.nn.distributions.Delta`.
.. note::
If the distribution class is of type
:class:`~tensordict.nn.distributions.CompositeDistribution`, the ``out_keys``
can be inferred directly form the ``"distribution_map"`` or ``"name_map"``
keywork arguments provided through this class' ``distribution_kwargs``
keyword argument, making the ``out_keys`` optional in such cases.
distribution_kwargs (dict, optional): keyword-only argument.
Keyword-argument pairs to be passed to the distribution.
.. note:: if your kwargs contain tensors that you would like to transfer to device with the module, or
tensors that should see their dtype modified when calling `module.to(dtype)`, you can wrap the kwargs
in a :class:`~tensordict.nn.TensorDictParams` to do this automatically.
return_log_prob (bool, optional): keyword-only argument.
If ``True``, the log-probability of the
distribution sample will be written in the tensordict with the key
`'sample_log_prob'`. Default is ``False``.
log_prob_key (NestedKey, optional): key where to write the log_prob if return_log_prob = True.
Defaults to `"action_log_prob"`.
cache_dist (bool, optional): EXPERIMENTAL: if ``True``, the parameters of the
`log_prob_key`. Default is ``False``.
log_prob_keys (List[NestedKey], optional): keys where to write the log_prob if ``return_log_prob=True``.
Defaults to `'<sample_key_name>_log_prob'`, where `<sample_key_name>` is each of the :attr:`out_keys`.
.. note:: This is only available when :func:`~tensordict.nn.probabilistic.composite_lp_aggregate` is set to ``False``.
log_prob_key (NestedKey, optional): key where to write the log_prob if ``return_log_prob=True``.
Defaults to `'sample_log_prob'` when :func:`~tensordict.nn.probabilistic.composite_lp_aggregate` is set to `True`
or `'<sample_key_name>_log_prob'` otherwise.
.. note:: When there is more than one sample, this is only available when :func:`~tensordict.nn.probabilistic.composite_lp_aggregate` is set to ``True``.
cache_dist (bool, optional): keyword-only argument.
EXPERIMENTAL: if ``True``, the parameters of the
distribution (i.e. the output of the module) will be written to the
tensordict along with the sample. Those parameters can be used to re-compute
the original distribution later on (e.g. to compute the divergence between
the distribution used to sample the action and the updated distribution in
PPO). Default is ``False``.
n_empirical_estimate (int, optional): number of samples to compute the empirical
mean when it is not available. Default is 1000
n_empirical_estimate (int, optional): keyword-only argument.
Number of samples to compute the empirical
mean when it is not available. Defaults to 1000.
.. warning:: Running checks takes time! Using `safe=True` will guarantee that the samples are within the spec bounds
given some heuristic coded in :meth:`~torchrl.data.TensorSpec.project`, but that requires checking whether the
values are within the spec space, which will induce some overhead.
.. seealso:: :class`The composite distribution in tensordict <~tensordict.nn.CompositeDistribution>` can be used
to create multi-head policies.
Example:
>>> from torchrl.modules import SafeProbabilisticModule
>>> from torchrl.data import Bounded
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import InteractionType
>>> mod = SafeProbabilisticModule(
... in_keys=["loc", "scale"],
... out_keys=["action"],
... distribution_class=torch.distributions.Normal,
... safe=True,
... spec=Bounded(low=-1, high=1, shape=()),
... default_interaction_type=InteractionType.RANDOM
... )
>>> _ = torch.manual_seed(0)
>>> data = TensorDict(
... loc=torch.zeros(10, requires_grad=True),
... scale=torch.full((10,), 10.0),
... batch_size=(10,))
>>> data = mod(data)
>>> print(data["action"]) # All actions are within bound
tensor([ 1., -1., -1., 1., -1., -1., 1., 1., -1., -1.],
grad_fn=<ClampBackward0>)
>>> data["action"].mean().backward()
>>> print(data["loc"].grad) # clamp anihilates gradients
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
"""

def __init__(
self,
in_keys: Union[NestedKey, List[NestedKey], Dict[str, NestedKey]],
out_keys: Optional[Union[NestedKey, List[NestedKey]]] = None,
in_keys: NestedKey | List[NestedKey] | Dict[str, NestedKey],
out_keys: NestedKey | List[NestedKey] | None = None,
spec: Optional[TensorSpec] = None,
*,
safe: bool = False,
default_interaction_type: str = InteractionType.DETERMINISTIC,
distribution_class: Type = Delta,
distribution_kwargs: Optional[dict] = None,
default_interaction_type: InteractionType = InteractionType.DETERMINISTIC,
distribution_class: type = Delta,
distribution_kwargs: dict | None = None,
return_log_prob: bool = False,
log_prob_keys: List[NestedKey] | None = None,
log_prob_key: NestedKey | None = None,
cache_dist: bool = False,
n_empirical_estimate: int = 1000,
num_samples: int | torch.Size | None = None,
):
super().__init__(
in_keys=in_keys,
Expand All @@ -120,9 +208,11 @@ def __init__(
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=return_log_prob,
log_prob_key=log_prob_key,
cache_dist=cache_dist,
n_empirical_estimate=n_empirical_estimate,
log_prob_keys=log_prob_keys,
log_prob_key=log_prob_key,
num_samples=num_samples,
)
if spec is not None:
spec = spec.clone()
Expand Down

0 comments on commit 2f8c118

Please sign in to comment.