Skip to content

Commit

Permalink
[BugFix] Fix non-full TensorStorage indexing (#1730)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 4, 2023
1 parent 7166f3c commit f6188e3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
8 changes: 8 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,14 @@ def test_index(self, rbtype, storage, size, prefetch):
b = b.all()
assert b

def test_index_nonfull(self, rbtype, storage, size, prefetch):
# checks that indexing the buffer before it's full gives the accurate view of the data
rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch)
data = self._get_data(rbtype, size=size - 1)
rb.extend(data)
assert len(rb[: size - 1]) == size - 1
assert len(rb[size - 2 :]) == 1


def test_multi_loops():
"""Tests that one can iterate multiple times over a buffer without rep."""
Expand Down
6 changes: 5 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,11 +448,15 @@ def set( # noqa: F811
self._storage[cursor] = data

def get(self, index: Union[int, Sequence[int], slice]) -> Any:
if self._len < self.max_size:
storage = self._storage[: self._len]
else:
storage = self._storage
if not self.initialized:
raise RuntimeError(
"Cannot get an item from an unitialized LazyMemmapStorage"
)
out = self._storage[index]
out = storage[index]
if is_tensor_collection(out):
out = _reset_batch_size(out)
return out.unlock_()
Expand Down

1 comment on commit f6188e3

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: f6188e3 Previous: 7166f3c Ratio
benchmarks/test_objectives_benchmarks.py::test_values[vec_generalized_advantage_estimate-True-True] 90.63985757560513 iter/sec (stddev: 0.06667627757390637) 298.3572046535438 iter/sec (stddev: 0.012587699716367012) 3.29

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.