diff --git a/tensordict/nn/distributions/composite.py b/tensordict/nn/distributions/composite.py index 781dc4261..a68136014 100644 --- a/tensordict/nn/distributions/composite.py +++ b/tensordict/nn/distributions/composite.py @@ -13,44 +13,38 @@ class CompositeDistribution(d.Distribution): - """A composition of distributions. + """A composite distribution that groups multiple distributions together using the TensorDict interface. - Groups distributions together with the TensorDict interface. Methods - (``log_prob_composite``, ``entropy_composite``, ``cdf``, ``icdf``, ``rsample``, ``sample`` etc.) - will return a tensordict, possibly modified in-place if the input was a tensordict. + This class allows for operations such as `log_prob_composite`, `entropy_composite`, `cdf`, `icdf`, `rsample`, and `sample` + to be performed on a collection of distributions, returning a TensorDict. The input TensorDict may be modified in-place. Args: - params (TensorDictBase): a nested key-tensor map where the root entries - point to the sample names, and the leaves are the distribution parameters. - Entry names must match those of ``distribution_map``. - - distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]): - indicated the distribution types to be used. The names of the distributions - will match the names of the samples in the tensordict. + params (TensorDictBase): A nested key-tensor map where the root entries correspond to sample names, and the leaves + are the distribution parameters. Entry names must match those specified in `distribution_map`. + distribution_map (Dict[NestedKey, Type[torch.distribution.Distribution]]): Specifies the distribution types to be used. + The names of the distributions should match the sample names in the `TensorDict`. Keyword Arguments: - name_map (Dict[NestedKey, NestedKey]]): a dictionary representing where each - sample should be written. If not provided, the key names from ``distribution_map`` - will be used. - extra_kwargs (Dict[NestedKey, Dict]): a possibly incomplete dictionary of - extra keyword arguments for the distributions to be built. - aggregate_probabilities (bool): if ``True``, the :meth:`~.log_prob` and :meth:`~.entropy` methods will - sum the probabilities and entropies of the individual distributions and return a single tensor. - If ``False``, the single log-probabilities will be registered in the input tensordict (for :meth:`~.log_prob`) - or retuned as leaves of the output tensordict (for :meth:`~.entropy`). - This parameter can be overridden at runtime by passing the ``aggregate_probabilities`` argument to - ``log_prob`` and ``entropy``. - Defaults to ``False``. - log_prob_key (NestedKey, optional): key where to write the log_prob. - Defaults to `'sample_log_prob'`. - entropy_key (NestedKey, optional): key where to write the entropy. - Defaults to `'entropy'`. - - .. note:: - In this distribution class, the batch-size of the input tensordict containing the params - (``params``) is indicative of the batch_shape of the distribution. For instance, - the ``"sample_log_prob"`` entry resulting from a call to ``log_prob`` - will be of the shape of the params (+ any supplementary batch dimension). + name_map (Dict[NestedKey, NestedKey], optional): A mapping of where each sample should be written. If not provided, + the key names from `distribution_map` will be used. + extra_kwargs (Dict[NestedKey, Dict], optional): A dictionary of additional keyword arguments for constructing the distributions. + aggregate_probabilities (bool, optional): If `True`, the `log_prob` and `entropy` methods will sum the probabilities and entropies + of the individual distributions and return a single tensor. If `False`, individual log-probabilities will be stored in the input + TensorDict (for `log_prob`) or returned as leaves of the output TensorDict (for `entropy`). This can be overridden at runtime + by passing the `aggregate_probabilities` argument to `log_prob` and `entropy`. Defaults to `False`. + log_prob_key (NestedKey, optional): The key where the log probability will be stored. Defaults to `'sample_log_prob'`. + entropy_key (NestedKey, optional): The key where the entropy will be stored. Defaults to `'entropy'`. + inplace (bool, optional): Whether to modify the input TensorDict in-place. Defaults to `True`. + + .. warning:: The default value of ``inplace`` will switch to ``False`` in v0.9 in the constructor. + + include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict. Defaults to `True`. + + .. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor. + + .. note:: The batch size of the input TensorDict containing the parameters (`params`) determines the batch shape of + the distribution. For example, the `"sample_log_prob"` entry resulting from a call to `log_prob` will have the + shape of the parameters plus any additional batch dimensions. Examples: >>> params = TensorDict({ @@ -88,6 +82,8 @@ def __init__( aggregate_probabilities: bool | None = None, log_prob_key: NestedKey = "sample_log_prob", entropy_key: NestedKey = "entropy", + inplace: bool | None = None, + include_sum: bool | None = None, ): self._batch_shape = params.shape if extra_kwargs is None: @@ -122,6 +118,8 @@ def __init__( self.entropy_key = entropy_key self.aggregate_probabilities = aggregate_probabilities + self.include_sum = include_sum + self.inplace = inplace @property def aggregate_probabilities(self): @@ -223,16 +221,32 @@ def rsample(self, shape=None) -> TensorDictBase: ) def log_prob( - self, sample: TensorDictBase, *, aggregate_probabilities: bool | None = None + self, + sample: TensorDictBase, + *, + aggregate_probabilities: bool | None = None, + include_sum: bool | None = None, + inplace: bool | None = None, ) -> torch.Tensor | TensorDictBase: # noqa: D417 - """Computes and returns the summed log-prob. + """Compute the summed log-probability of a given sample. Args: - sample (TensorDictBase): the sample to compute the log probability. + sample (TensorDictBase): The input sample to compute the log probability for. Keyword Args: aggregate_probabilities (bool, optional): if provided, overrides the default ``aggregate_probabilities`` from the class. + include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict. + Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default). + Has no effect if ``aggregate_probabilities`` is set to ``True``. + + .. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor. + + inplace (bool, optional): Whether to update the input sample in-place or return a new TensorDict. + Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default). + Has no effect if ``aggregate_probabilities`` is set to ``True``. + + .. warning:: The default value of ``inplace`` will switch to ``False`` in v0.9 in the constructor. If ``self.aggregate_probabilities`` is ``True``, this method will return a single tensor with the summed log-probabilities. If ``self.aggregate_probabilities`` is ``False``, this method will @@ -243,7 +257,9 @@ def log_prob( if aggregate_probabilities is None: aggregate_probabilities = self.aggregate_probabilities if not aggregate_probabilities: - return self.log_prob_composite(sample, include_sum=True) + return self.log_prob_composite( + sample, include_sum=include_sum, inplace=inplace + ) slp = 0.0 for name, dist in self.dists.items(): lp = dist.log_prob(sample.get(name)) @@ -253,47 +269,105 @@ def log_prob( return slp def log_prob_composite( - self, sample: TensorDictBase, include_sum=True + self, + sample: TensorDictBase, + *, + include_sum: bool | None = None, + inplace: bool | None = None, ) -> TensorDictBase: - """Writes a ``_log_prob`` entry for each sample in the input tensordict, along with a ``"sample_log_prob"`` entry with the summed log-prob. + """Computes the log-probability of each component in the input sample and return a TensorDict with individual log-probabilities. + + Args: + sample (TensorDictBase): The input sample to compute the log probabilities for. - This method is called by the :meth:`~.log_prob` method when ``self.aggregate_probabilities`` is ``False``. + Keyword Args: + include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict. + Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default). + + .. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor. + + inplace (bool, optional): Whether to update the input sample in-place or return a new TensorDict. + Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default). + + .. warning:: The default value of ``inplace`` will switch to ``False`` in v0.9 in the constructor. + + Returns: + TensorDictBase: A TensorDict containing the individual log-probabilities for each component in the input sample, + along with a "sample_log_prob" entry containing the summed log-probability if `include_sum` is True. """ - slp = 0.0 + if include_sum is None: + include_sum = self.include_sum + + if include_sum is None: + include_sum = True + warnings.warn( + "`include_sum` wasn't set when building the `CompositeDistribution` or when calling log_prob_composite. " + "The current default is ``True`` but from v0.9 it will be changed to ``False``. Please adapt your call to `log_prob_composite` accordingly.", + category=DeprecationWarning, + ) + if inplace is None: + inplace = self.inplace + if inplace is None: + inplace = True + warnings.warn( + "`inplace` wasn't set when building the `CompositeDistribution` or when calling log_prob_composite. " + "The current default is ``True`` but from v0.9 it will be changed to ``False``. Please adapt your call to `log_prob_composite` accordingly.", + category=DeprecationWarning, + ) + if include_sum: + slp = 0.0 d = {} for name, dist in self.dists.items(): d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name)) - if lp.ndim > sample.ndim: - lp = lp.flatten(sample.ndim, -1).sum(-1) - slp = slp + lp + if include_sum: + if lp.ndim > sample.ndim: + lp = lp.flatten(sample.ndim, -1).sum(-1) + slp = slp + lp if include_sum: d[self.log_prob_key] = slp - sample.update(d) + if inplace: + sample.update(d) + else: + return sample.empty(recurse=True).update(d).filter_empty_() return sample def entropy( - self, samples_mc: int = 1, *, aggregate_probabilities: bool | None = None + self, + samples_mc: int = 1, + *, + aggregate_probabilities: bool | None = None, + include_sum: bool | None = None, ) -> torch.Tensor | TensorDictBase: # noqa: D417 - """Computes and returns the summed entropies. + """Computes and returns the entropy of the composite distribution. + + This method calculates the entropy for each component distribution and optionally sums them. Args: - samples_mc (int): the number samples to draw if the entropy does not have a closed form formula. - Defaults to ``1``. + samples_mc (int): The number of samples to draw if the entropy does not have a closed-form solution. + Defaults to `1`. Keyword Args: - aggregate_probabilities (bool, optional): if provided, overrides the default ``aggregate_probabilities`` - from the class. - - If ``self.aggregate_probabilities`` is ``True``, this method will return a single tensor with - the summed entropies. If ``self.aggregate_probabilities`` is ``False``, this method will call - the `:meth:`~.entropy_composite` method and return a tensordict with the entropies of each sample - in the input tensordict along with an ``entropy`` entry with the summed entropy. In both cases, - the output shape will match the shape of the distribution ``batch_shape``. + aggregate_probabilities (bool, optional): If provided, overrides the default `aggregate_probabilities` + setting from the class. Determines whether to return a single summed entropy tensor or a TensorDict + with individual entropies. Defaults to ``False`` if not set in the class. + include_sum (bool, optional): Whether to include the summed entropy in the output TensorDict. + Defaults to `self.inplace`, which is set through the class constructor. Has no effect if + `aggregate_probabilities` is set to `True`. + + .. warning:: The default value of `include_sum` will switch to `False` in v0.9 in the constructor. + + Returns: + torch.Tensor or TensorDictBase: If `aggregate_probabilities` is `True`, returns a single tensor with + the summed entropies. If `aggregate_probabilities` is `False`, returns a TensorDict with the entropies + of each component distribution. + + .. note:: If a distribution does not implement a closed-form solution for entropy, Monte Carlo sampling is used + to estimate it. """ if aggregate_probabilities is None: aggregate_probabilities = self.aggregate_probabilities if not aggregate_probabilities: - return self.entropy_composite(samples_mc, include_sum=True) + return self.entropy_composite(samples_mc, include_sum=include_sum) se = 0.0 for _, dist in self.dists.items(): try: @@ -306,11 +380,44 @@ def entropy( se = se + e return se - def entropy_composite(self, samples_mc=1, include_sum=True) -> TensorDictBase: - """Writes a ``_entropy`` entry for each sample in the input tensordict, along with a ``"entropy"`` entry with the summed entropies. + def entropy_composite( + self, + samples_mc=1, + *, + include_sum: bool | None = None, + ) -> TensorDictBase: + """Computes the entropy for each component distribution and returns a TensorDict with individual entropies. + + This method is used by the `entropy` method when `self.aggregate_probabilities` is `False`. + + Args: + samples_mc (int): The number of samples to draw if the entropy does not have a closed-form solution. + Defaults to `1`. + + Keyword Args: + include_sum (bool, optional): Whether to include the summed entropy in the output TensorDict. + Defaults to `self.include_sum`, which is set through the class constructor. + + .. warning:: The default value of `include_sum` will switch to `False` in v0.9 in the constructor. - This method is called by the :meth:`~.entropy` method when ``self.aggregate_probabilities`` is ``False``. + Returns: + TensorDictBase: A TensorDict containing the individual entropies for each component distribution, + along with an "entropy" entry containing the summed entropies if `include_sum` is `True`. + + .. note:: If a distribution does not implement a closed-form solution for entropy, Monte Carlo sampling is used + to estimate it. """ + if include_sum is None: + include_sum = self.include_sum + + if include_sum is None: + include_sum = True + warnings.warn( + "`include_sum` wasn't set when building the `CompositeDistribution` or when calling log_prob_composite. " + "The current default is ``True`` but from v0.9 it will be changed to ``False``. Please adapt your call to `log_prob_composite` accordingly.", + category=DeprecationWarning, + ) + se = 0.0 d = {} for name, dist in self.dists.items(): @@ -320,9 +427,10 @@ def entropy_composite(self, samples_mc=1, include_sum=True) -> TensorDictBase: x = dist.rsample((samples_mc,)) e = -dist.log_prob(x).mean(0) d[_add_suffix(name, "_entropy")] = e - if e.ndim > len(self.batch_shape): - e = e.flatten(len(self.batch_shape), -1).sum(-1) - se = se + e + if include_sum: + if e.ndim > len(self.batch_shape): + e = e.flatten(len(self.batch_shape), -1).sum(-1) + se = se + e if include_sum: d[self.entropy_key] = se return TensorDict( @@ -331,6 +439,16 @@ def entropy_composite(self, samples_mc=1, include_sum=True) -> TensorDictBase: ) def cdf(self, sample: TensorDictBase) -> TensorDictBase: + """Computes the cumulative distribution function (CDF) for each component distribution in the composite distribution. + + This method calculates the CDF for each component distribution and updates the input TensorDict with the results. + + Args: + sample (TensorDictBase): A TensorDict containing samples for which to compute the CDF. + + Returns: + TensorDictBase: The input TensorDict updated with `_cdf` entries for each component distribution. + """ cdfs = { _add_suffix(name, "_cdf"): dist.cdf(sample.get(name)) for name, dist in self.dists.items() @@ -339,14 +457,20 @@ def cdf(self, sample: TensorDictBase) -> TensorDictBase: return sample def icdf(self, sample: TensorDictBase) -> TensorDictBase: - """Computes the inverse CDF. + """Computes the inverse cumulative distribution function (inverse CDF) for each component distribution. - Requires the input tensordict to have one of `+'_cdf'` entry - or a `` entry. + This method requires the input TensorDict to have either a `_cdf` entry or a `` entry + for each component distribution. It calculates the inverse CDF and updates the TensorDict with the results. Args: - sample (TensorDictBase): a tensordict containing `_log_prob` where - `` is the name of the sample provided during construction. + sample (TensorDictBase): A TensorDict containing either `_cdf` or `` entries + for each component distribution. + + Returns: + TensorDictBase: The input TensorDict updated with `_icdf` entries for each component distribution. + + Raises: + KeyError: If neither `` nor `_cdf` can be found in the input TensorDict for a component distribution. """ for name, dist in self.dists.items(): # TODO: v0.7: remove the None diff --git a/test/test_nn.py b/test/test_nn.py index ba9faf93d..4bab031a9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -12,6 +12,7 @@ import pytest import torch + from tensordict import NonTensorData, NonTensorStack, tensorclass, TensorDict from tensordict._C import unravel_key_list from tensordict.nn import ( @@ -71,6 +72,12 @@ "ignore:You are using `torch.load` with `weights_only=False`" ), pytest.mark.filterwarnings("ignore:enable_nested_tensor is True"), + pytest.mark.filterwarnings( + "ignore:`include_sum` wasn't set when building the `CompositeDistribution`" + ), + pytest.mark.filterwarnings( + "ignore:`inplace` wasn't set when building the `CompositeDistribution`" + ), ] @@ -2233,20 +2240,6 @@ def test_log_prob(self): }, [3], ) - # Capture the warning for upcoming changes in aggregate_probabilities - dist = CompositeDistribution( - params, - distribution_map={ - "cont": distributions.Normal, - ("nested", "disc"): distributions.RelaxedOneHotCategorical, - }, - extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}}, - ) - - sample = dist.rsample((4,)) - with pytest.warns(FutureWarning, match="aggregate_probabilities"): - lp = dist.log_prob(sample) - dist = CompositeDistribution( params, distribution_map={ @@ -2262,7 +2255,9 @@ def test_log_prob(self): assert isinstance(lp, torch.Tensor) assert lp.requires_grad - def test_log_prob_composite(self): + @pytest.mark.parametrize("inplace", [None, True, False]) + @pytest.mark.parametrize("include_sum", [None, True, False]) + def test_log_prob_composite(self, inplace, include_sum): params = TensorDict( { "cont": { @@ -2273,17 +2268,6 @@ def test_log_prob_composite(self): }, [3], ) - # Capture the warning for upcoming changes in aggregate_probabilities - dist = CompositeDistribution( - params, - distribution_map={ - "cont": distributions.Normal, - ("nested", "disc"): distributions.RelaxedOneHotCategorical, - }, - extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}}, - ) - with pytest.warns(FutureWarning, match="aggregate_probabilities"): - dist.log_prob(dist.sample()) dist = CompositeDistribution( params, distribution_map={ @@ -2292,12 +2276,25 @@ def test_log_prob_composite(self): }, extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}}, aggregate_probabilities=False, + inplace=inplace, + include_sum=include_sum, ) + if include_sum is None: + include_sum = True + if inplace is None: + inplace = True sample = dist.rsample((4,)) - sample = dist.log_prob_composite(sample, include_sum=True) - assert sample.get("cont_log_prob").requires_grad - assert sample.get(("nested", "disc_log_prob")).requires_grad - assert "sample_log_prob" in sample.keys() + sample_lp = dist.log_prob_composite(sample) + assert sample_lp.get("cont_log_prob").requires_grad + assert sample_lp.get(("nested", "disc_log_prob")).requires_grad + if inplace: + assert sample_lp is sample + else: + assert sample_lp is not sample + if include_sum: + assert "sample_log_prob" in sample_lp.keys() + else: + assert "sample_log_prob" not in sample_lp.keys() def test_entropy(self): params = TensorDict( @@ -2310,16 +2307,6 @@ def test_entropy(self): }, [3], ) - # Capture the warning for upcoming changes in aggregate_probabilities - dist = CompositeDistribution( - params, - distribution_map={ - "cont": distributions.Normal, - ("nested", "disc"): distributions.Categorical, - }, - ) - with pytest.warns(FutureWarning, match="aggregate_probabilities"): - dist.log_prob(dist.sample()) dist = CompositeDistribution( params, distribution_map={ @@ -2333,7 +2320,8 @@ def test_entropy(self): assert isinstance(ent, torch.Tensor) assert ent.requires_grad - def test_entropy_composite(self): + @pytest.mark.parametrize("include_sum", [None, True, False]) + def test_entropy_composite(self, include_sum): params = TensorDict( { "cont": { @@ -2344,16 +2332,6 @@ def test_entropy_composite(self): }, [3], ) - # Capture the warning for upcoming changes in aggregate_probabilities - dist = CompositeDistribution( - params, - distribution_map={ - "cont": distributions.Normal, - ("nested", "disc"): distributions.Categorical, - }, - ) - with pytest.warns(FutureWarning, match="aggregate_probabilities"): - dist.log_prob(dist.sample()) dist = CompositeDistribution( params, distribution_map={ @@ -2361,12 +2339,18 @@ def test_entropy_composite(self): ("nested", "disc"): distributions.Categorical, }, aggregate_probabilities=False, + include_sum=include_sum, ) + if include_sum is None: + include_sum = True sample = dist.entropy() assert sample.shape == params.shape == dist._batch_shape assert sample.get("cont_entropy").requires_grad assert sample.get(("nested", "disc_entropy")).requires_grad - assert "entropy" in sample.keys() + if include_sum: + assert "entropy" in sample.keys() + else: + assert "entropy" not in sample.keys() def test_cdf(self): params = TensorDict(