diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index d23c2bec3..82f4eadf7 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1301,6 +1301,12 @@ class WrapModule(TensorDictModuleBase): Keyword Args: inplace (bool, optional): If ``True``, the input TensorDict will be modified in-place. Otherwise, a new TensorDict will be returned (if the function does not modify it in-place and returns it). Defaults to ``False``. + in_keys (list of NestedKey, optional): if provided, indicates what entries are read by the module. + This will not be checked and is provided just for the purpose of informing :class:`~tensordict.nn.TensorDictSequential` + about the input keys of the wrapped module. Defaults to `[]`. + out_keys (list of NestedKey, optional): if provided, indicates what entries are written by the module. + This will not be checked and is provided just for the purpose of informing :class:`~tensordict.nn.TensorDictSequential` + about the output keys of the wrapped module. Defaults to `[]`. Examples: >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule @@ -1320,11 +1326,20 @@ class WrapModule(TensorDictModuleBase): out_keys = [] def __init__( - self, func: Callable[[TensorDictBase], TensorDictBase], *, inplace: bool = False + self, + func: Callable[[TensorDictBase], TensorDictBase], + *, + inplace: bool = False, + in_keys: List[NestedKey] | None = None, + out_keys: List[NestedKey] | None = None, ) -> None: super().__init__() self.func = func self.inplace = inplace + if in_keys is not None: + self.in_keys = in_keys + if out_keys is not None: + self.out_keys = out_keys def forward(self, data: TensorDictBase) -> TensorDictBase: result = self.func(data)