Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 13, 2025
1 parent c750598 commit 4713734
Showing 1 changed file with 44 additions and 8 deletions.
52 changes: 44 additions & 8 deletions tensordict/return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,69 @@
from tensordict.tensordict import TensorDict


@tensorclass
@tensorclass(shadow=True)
class min:
"""A `min` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.min` operations."""

vals: TensorDict
values: TensorDict
indices: TensorDict

def __getitem__(self, item):
try:
return (self.values, self.indices)[item]
except IndexError:
raise IndexError(
f"Indexing a {type(self)} element follows the torch.return_types.{type(self).__name__}'s "
f"__getitem__ method API."
)

@tensorclass

@tensorclass(shadow=True)
class max:
"""A `max` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.max` operations."""

vals: TensorDict
values: TensorDict
indices: TensorDict

def __getitem__(self, item):
try:
return (self.values, self.indices)[item]
except IndexError:
raise IndexError(
f"Indexing a {type(self)} element follows the torch.return_types.{type(self).__name__}'s "
f"__getitem__ method API."
)


@tensorclass
@tensorclass(shadow=True)
class cummin:
"""A `cummin` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.cummin` operations."""

vals: TensorDict
values: TensorDict
indices: TensorDict

def __getitem__(self, item):
try:
return (self.values, self.indices)[item]
except IndexError:
raise IndexError(
f"Indexing a {type(self)} element follows the torch.return_types.{type(self).__name__}'s "
f"__getitem__ method API."
)

@tensorclass

@tensorclass(shadow=True)
class cummax:
"""A `cummax` tensorclass to be used as a result for :meth:`~tensordict.TensorDict.cummax` operations."""

vals: TensorDict
values: TensorDict
indices: TensorDict

def __getitem__(self, item):
try:
return (self.values, self.indices)[item]
except IndexError:
raise IndexError(
f"Indexing a {type(self)} element follows the torch.return_types.{type(self).__name__}'s "
f"__getitem__ method API."
)

0 comments on commit 4713734

Please sign in to comment.