From 836eb56a24f6c2b23888a873aedeaa244e76d3d3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 5 Feb 2025 11:44:12 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_nn.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/test/test_nn.py b/test/test_nn.py index a2c3597b7..dd91ab60b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2270,6 +2270,72 @@ def test_index_prob_seq(self): assert isinstance(seq[:2], ProbabilisticTensorDictSequential) assert isinstance(seq[-2:], ProbabilisticTensorDictSequential) + def test_no_warning_single_key(self): + # Check that there is no warning if the number of out keys is 1 and sample log prob is set + torch.manual_seed(0) + with set_composite_lp_aggregate(None): + mod = ProbabilisticTensorDictModule( + in_keys=["loc", "scale"], + distribution_class=torch.distributions.Normal, + out_keys=[("an", "action")], + log_prob_key="sample_log_prob", + return_log_prob=True, + ) + td = TensorDict(loc=torch.randn(()), scale=torch.rand(())) + mod(td.copy()) + mod.log_prob(mod(td.copy())) + mod.log_prob_key + + # Don't set the key and trigger the warning + mod = ProbabilisticTensorDictModule( + in_keys=["loc", "scale"], + distribution_class=torch.distributions.Normal, + out_keys=[("an", "action")], + return_log_prob=True, + ) + with pytest.warns( + DeprecationWarning, match="You are querying the log-probability key" + ): + mod(td.copy()) + mod.log_prob(mod(td.copy())) + mod.log_prob_key + + # add another variable, and trigger the warning + mod = ProbabilisticTensorDictModule( + in_keys=["params"], + distribution_class=CompositeDistribution, + distribution_kwargs={ + "distribution_map": { + "dirich": torch.distributions.Dirichlet, + "categ": torch.distributions.Categorical, + } + }, + out_keys=[("dirich", "categ")], + return_log_prob=True, + ) + with pytest.warns( + DeprecationWarning, match="You are querying the log-probability key" + ), pytest.warns( + DeprecationWarning, + match="Composite log-prob aggregation wasn't defined explicitly", + ): + td = TensorDict( + params=TensorDict( + dirich=TensorDict( + concentration=torch.rand( + ( + 10, + 11, + ) + ) + ), + categ=TensorDict(logits=torch.rand((5,))), + ) + ) + mod(td.copy()) + mod.log_prob(mod(td.copy())) + mod.log_prob_key + class TestEnsembleModule: def test_init(self):