Skip to content

Commit

Permalink
[BugFix] Use same signature for append_transform in all cases (#2091)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 18, 2024
1 parent 9d3530f commit fc8ccd9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
8 changes: 6 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An
def mark_update(self, index: Union[int, torch.Tensor]) -> None:
self._sampler.mark_update(index)

def append_transform(self, transform: "Transform") -> None: # noqa-F821
def append_transform(self, transform: "Transform") -> ReplayBuffer: # noqa-F821
"""Appends transform at the end.
Transforms are applied in order when `sample` is called.
Expand All @@ -626,8 +626,11 @@ def append_transform(self, transform: "Transform") -> None: # noqa-F821
transform = _CallableTransform(transform)
transform.eval()
self._transform.append(transform)
return self

def insert_transform(self, index: int, transform: "Transform") -> None: # noqa-F821
def insert_transform(
self, index: int, transform: "Transform" # noqa-F821
) -> ReplayBuffer:
"""Inserts transform.
Transforms are executed in order when `sample` is called.
Expand All @@ -638,6 +641,7 @@ def insert_transform(self, index: int, transform: "Transform") -> None: # noqa-
"""
transform.eval()
self._transform.insert(index, transform)
return self

def __iter__(self):
if self._sampler.ran_out:
Expand Down
6 changes: 4 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def empty_cache(self):

def append_transform(
self, transform: Transform | Callable[[TensorDictBase], TensorDictBase]
) -> None:
) -> TransformedEnv:
"""Appends a transform to the env.
:class:`~torchrl.envs.transforms.Transform` or callable are accepted.
Expand All @@ -899,8 +899,9 @@ def append_transform(
self.transform.append(prev_transform)

self.transform.append(transform)
return self

def insert_transform(self, index: int, transform: Transform) -> None:
def insert_transform(self, index: int, transform: Transform) -> TransformedEnv:
"""Inserts a transform to the env at the desired index.
:class:`~torchrl.envs.transforms.Transform` or callable are accepted.
Expand All @@ -920,6 +921,7 @@ def insert_transform(self, index: int, transform: Transform) -> None:
self.transform = compose # parent set automatically

self.transform.insert(index, transform)
return self

def __getattr__(self, attr: str) -> Any:
try:
Expand Down

0 comments on commit fc8ccd9

Please sign in to comment.