Skip to content

Commit

Permalink
[BE] No warning if user sets the log_prob_key explicitly and only one…
Browse files Browse the repository at this point in the history
… variable is sampled from the ProbTDMod

ghstack-source-id: 90621391506cdd67563a1e791ade093e7db5f5df
Pull Request resolved: #1209
  • Loading branch information
vmoens committed Feb 5, 2025
1 parent 76e1810 commit e3c1578
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 5 deletions.
4 changes: 3 additions & 1 deletion tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def __init__(
dist_params = params.get(name)
kwargs = extra_kwargs.get(name, {})
if dist_params is None:
raise KeyError
raise KeyError(
f"no param {name} found in params with keys {params.keys(True, True)}"
)
dist = dist_class(**dist_params, **kwargs)
dists[write_name] = dist
self.dists = dists
Expand Down
15 changes: 11 additions & 4 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase):
"""

# To be removed in v0.9
_trigger_warning_lpk: bool = False

def __init__(
self,
in_keys: NestedKey | List[NestedKey] | Dict[str, NestedKey],
Expand Down Expand Up @@ -396,9 +399,11 @@ def __init__(
"composite_lp_aggregate is set to True but log_prob_keys were passed. "
"When composite_lp_aggregate() returns ``True``, log_prob_key must be used instead."
)
self._trigger_warning_lpk = len(self._out_keys) > 1
if log_prob_key is None:
if composite_lp_aggregate(nowarn=True):
log_prob_key = "sample_log_prob"
self._trigger_warning_lpk = True
elif len(out_keys) == 1:
log_prob_key = _add_suffix(out_keys[0], "_log_prob")
elif len(out_keys) > 1 and not composite_lp_aggregate(nowarn=True):
Expand Down Expand Up @@ -451,13 +456,15 @@ def log_prob_key(self):
f"unless there is one and only one element in log_prob_keys (got log_prob_keys={self.log_prob_keys}). "
f"When composite_lp_aggregate() returns ``False``, try to use {type(self).__name__}.log_prob_keys instead."
)
if _composite_lp_aggregate.get_mode() is None:
if _composite_lp_aggregate.get_mode() is None and self._trigger_warning_lpk:
warnings.warn(
f"You are querying the log-probability key of a {type(self).__name__} where the "
f"composite_lp_aggregate has not been set. "
f"composite_lp_aggregate has not been set and the log-prob key has not been chosen. "
f"Currently, it is assumed that composite_lp_aggregate() will return True: the log-probs will be aggregated "
f"in a {self._log_prob_key} entry. From v0.9, this behaviour will be changed and individual log-probs will "
f"be written in `('path', 'to', 'leaf', '<sample_name>_log_prob')`. To prepare for this change, "
f"in a {self._log_prob_key} entry. "
f"From v0.9, this behaviour will be changed and individual log-probs will "
f"be written in `('path', 'to', 'leaf', '<sample_name>_log_prob')`. "
f"To prepare for this change, "
f"call `set_composite_lp_aggregate(mode: bool).set()` at the beginning of your script (or set the "
f"COMPOSITE_LP_AGGREGATE env variable). Use mode=True "
f"to keep the current behaviour, and mode=False to use per-leaf log-probs.",
Expand Down
66 changes: 66 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2270,6 +2270,72 @@ def test_index_prob_seq(self):
assert isinstance(seq[:2], ProbabilisticTensorDictSequential)
assert isinstance(seq[-2:], ProbabilisticTensorDictSequential)

def test_no_warning_single_key(self):
# Check that there is no warning if the number of out keys is 1 and sample log prob is set
torch.manual_seed(0)
with set_composite_lp_aggregate(None):
mod = ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
distribution_class=torch.distributions.Normal,
out_keys=[("an", "action")],
log_prob_key="sample_log_prob",
return_log_prob=True,
)
td = TensorDict(loc=torch.randn(()), scale=torch.rand(()))
mod(td.copy())
mod.log_prob(mod(td.copy()))
mod.log_prob_key

# Don't set the key and trigger the warning
mod = ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
distribution_class=torch.distributions.Normal,
out_keys=[("an", "action")],
return_log_prob=True,
)
with pytest.warns(
DeprecationWarning, match="You are querying the log-probability key"
):
mod(td.copy())
mod.log_prob(mod(td.copy()))
mod.log_prob_key

# add another variable, and trigger the warning
mod = ProbabilisticTensorDictModule(
in_keys=["params"],
distribution_class=CompositeDistribution,
distribution_kwargs={
"distribution_map": {
"dirich": torch.distributions.Dirichlet,
"categ": torch.distributions.Categorical,
}
},
out_keys=[("dirich", "categ")],
return_log_prob=True,
)
with pytest.warns(
DeprecationWarning, match="You are querying the log-probability key"
), pytest.warns(
DeprecationWarning,
match="Composite log-prob aggregation wasn't defined explicitly",
):
td = TensorDict(
params=TensorDict(
dirich=TensorDict(
concentration=torch.rand(
(
10,
11,
)
)
),
categ=TensorDict(logits=torch.rand((5,))),
)
)
mod(td.copy())
mod.log_prob(mod(td.copy()))
mod.log_prob_key


class TestEnsembleModule:
def test_init(self):
Expand Down

0 comments on commit e3c1578

Please sign in to comment.