Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 5, 2025
1 parent 0519491 commit 836eb56
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2270,6 +2270,72 @@ def test_index_prob_seq(self):
assert isinstance(seq[:2], ProbabilisticTensorDictSequential)
assert isinstance(seq[-2:], ProbabilisticTensorDictSequential)

def test_no_warning_single_key(self):
# Check that there is no warning if the number of out keys is 1 and sample log prob is set
torch.manual_seed(0)
with set_composite_lp_aggregate(None):
mod = ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
distribution_class=torch.distributions.Normal,
out_keys=[("an", "action")],
log_prob_key="sample_log_prob",
return_log_prob=True,
)
td = TensorDict(loc=torch.randn(()), scale=torch.rand(()))
mod(td.copy())
mod.log_prob(mod(td.copy()))
mod.log_prob_key

# Don't set the key and trigger the warning
mod = ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
distribution_class=torch.distributions.Normal,
out_keys=[("an", "action")],
return_log_prob=True,
)
with pytest.warns(
DeprecationWarning, match="You are querying the log-probability key"
):
mod(td.copy())
mod.log_prob(mod(td.copy()))
mod.log_prob_key

# add another variable, and trigger the warning
mod = ProbabilisticTensorDictModule(
in_keys=["params"],
distribution_class=CompositeDistribution,
distribution_kwargs={
"distribution_map": {
"dirich": torch.distributions.Dirichlet,
"categ": torch.distributions.Categorical,
}
},
out_keys=[("dirich", "categ")],
return_log_prob=True,
)
with pytest.warns(
DeprecationWarning, match="You are querying the log-probability key"
), pytest.warns(
DeprecationWarning,
match="Composite log-prob aggregation wasn't defined explicitly",
):
td = TensorDict(
params=TensorDict(
dirich=TensorDict(
concentration=torch.rand(
(
10,
11,
)
)
),
categ=TensorDict(logits=torch.rand((5,))),
)
)
mod(td.copy())
mod.log_prob(mod(td.copy()))
mod.log_prob_key


class TestEnsembleModule:
def test_init(self):
Expand Down

0 comments on commit 836eb56

Please sign in to comment.