diff --git a/tensordict/base.py b/tensordict/base.py index 354237cc6..376830c31 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6444,7 +6444,6 @@ def _get_tuple_maybe_non_tensor(self, key, default): return result.data return result - @overload def get_at(self, key, index): ... diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index 615c72658..9db6d2c71 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -601,11 +601,14 @@ class TensorClass: def get(self, key, default): ... def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType: ... @overload - def get_at(self, key, index):... + def get_at(self, key, index): ... @overload def get_at(self, key, index, default): ... def get_at( - self, key: NestedKey, *args, **kwargs, + self, + key: NestedKey, + *args, + **kwargs, ) -> CompatibleType: ... def get_item_shape(self, key: NestedKey): ... def update(