From 4713734f36522845eda41a902372d3b0542869d3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 13 Jan 2025 09:43:17 +0000 Subject: [PATCH] init --- tensordict/return_types.py | 52 ++++++++++++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/tensordict/return_types.py b/tensordict/return_types.py index 0c6668305..79d8feeb2 100644 --- a/tensordict/return_types.py +++ b/tensordict/return_types.py @@ -7,33 +7,69 @@ from tensordict.tensordict import TensorDict -@tensorclass +@tensorclass(shadow=True) class min: """A `min` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.min` operations.""" - vals: TensorDict + values: TensorDict indices: TensorDict + def __getitem__(self, item): + try: + return (self.values, self.indices)[item] + except IndexError: + raise IndexError( + f"Indexing a {type(self)} element follows the torch.return_types.{type(self).__name__}'s " + f"__getitem__ method API." + ) -@tensorclass + +@tensorclass(shadow=True) class max: """A `max` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.max` operations.""" - vals: TensorDict + values: TensorDict indices: TensorDict + def __getitem__(self, item): + try: + return (self.values, self.indices)[item] + except IndexError: + raise IndexError( + f"Indexing a {type(self)} element follows the torch.return_types.{type(self).__name__}'s " + f"__getitem__ method API." + ) + -@tensorclass +@tensorclass(shadow=True) class cummin: """A `cummin` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.cummin` operations.""" - vals: TensorDict + values: TensorDict indices: TensorDict + def __getitem__(self, item): + try: + return (self.values, self.indices)[item] + except IndexError: + raise IndexError( + f"Indexing a {type(self)} element follows the torch.return_types.{type(self).__name__}'s " + f"__getitem__ method API." + ) -@tensorclass + +@tensorclass(shadow=True) class cummax: """A `cummax` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.cummax` operations.""" - vals: TensorDict + values: TensorDict indices: TensorDict + + def __getitem__(self, item): + try: + return (self.values, self.indices)[item] + except IndexError: + raise IndexError( + f"Indexing a {type(self)} element follows the torch.return_types.{type(self).__name__}'s " + f"__getitem__ method API." + )