Skip to content

Commit

Permalink
[BugFix] inline TDParams kwargs in prob modules
Browse files Browse the repository at this point in the history
ghstack-source-id: 9fda35811b4656bd9939c9fb31cb253d7751b55c
Pull Request resolved: #1093
  • Loading branch information
vmoens committed Nov 20, 2024
1 parent c11024e commit 0b7ce93
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
3 changes: 3 additions & 0 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def __init__(
self._locked_tensordicts = []
self._get_post_hook = []

def __iter__(self):
yield from self._param_td.__iter__()

def register_get_post_hook(self, hook):
"""Register a hook to be called after any get operation on leaf tensors."""
if not callable(hook):
Expand Down
31 changes: 25 additions & 6 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
import re
import warnings

try:
from enum import StrEnum
except ImportError:
from .utils import StrEnum
from textwrap import indent
from typing import Any, Dict, List, Optional

Expand All @@ -30,6 +26,16 @@

from torch.utils._contextlib import _DecoratorContextManager

try:
from torch.compiler import is_compiling
except ImportError:
from torch._dynamo import is_compiling

try:
from enum import StrEnum
except ImportError:
from .utils import StrEnum

__all__ = ["ProbabilisticTensorDictModule", "ProbabilisticTensorDictSequential"]


Expand Down Expand Up @@ -350,11 +356,13 @@ def get_dist(self, tensordict: TensorDictBase) -> D.Distribution:
if isinstance(dist_key, tuple):
dist_key = dist_key[-1]
dist_kwargs[dist_key] = tensordict.get(td_key)
dist = self.distribution_class(**dist_kwargs, **self.distribution_kwargs)
dist = self.distribution_class(
**dist_kwargs, **_dynamo_friendly_to_dict(self.distribution_kwargs)
)
except TypeError as err:
if "an unexpected keyword argument" in str(err):
raise TypeError(
"distribution keywords and tensordict keys indicated by ProbabilisticTensorDictModule.dist_keys must match."
"distribution keywords and tensordict keys indicated by ProbabilisticTensorDictModule.dist_keys must match. "
f"Got this error message: \n{indent(str(err), 4 * ' ')}\nwith dist_keys={self.dist_keys}"
)
elif re.search(r"missing.*required positional arguments", str(err)):
Expand Down Expand Up @@ -623,3 +631,14 @@ def forward(
) -> TensorDictBase:
tensordict_out = self.get_dist_params(tensordict, tensordict_out, **kwargs)
return self.module[-1](tensordict_out, _requires_sample=self._requires_sample)


def _dynamo_friendly_to_dict(data):
if not is_compiling():
return data
if isinstance(data, TensorDictBase):
items = list(data.items())
if not items:
return {}
return dict(items)
return data
23 changes: 23 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
)
from tensordict.nn import (
CudaGraphModule,
InteractionType,
ProbabilisticTensorDictModule as Prob,
TensorDictModule,
TensorDictModule as Mod,
TensorDictSequential as Seq,
Expand Down Expand Up @@ -662,6 +664,27 @@ def test_dispatch_tensor(self, mode):
mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode)
torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y))

def test_prob_module_with_kwargs(self, mode):
kwargs = TensorDictParams(TensorDict(scale=1.0), no_convert=True)
dist_cls = torch.distributions.Normal
mod = Mod(torch.nn.Linear(3, 3), in_keys=["inp"], out_keys=["loc"])
prob_mod = Seq(
mod,
Prob(
in_keys=["loc"],
out_keys=["sample"],
return_log_prob=True,
distribution_class=dist_cls,
distribution_kwargs=kwargs,
default_interaction_type=InteractionType.RANDOM,
),
)
# check that the scale is in the buffers
assert len(list(prob_mod.buffers())) == 1
prob_mod(TensorDict(inp=torch.randn(3)))
prob_mod_c = torch.compile(prob_mod, fullgraph=True, mode=mode)
prob_mod_c(TensorDict(inp=torch.randn(3)))


@pytest.mark.skipif(
TORCH_VERSION <= version.parse("2.4.0"), reason="requires torch>2.4"
Expand Down

0 comments on commit 0b7ce93

Please sign in to comment.