Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Collectors of batched environemnts return more frames than requested #846

Closed
matteobettini opened this issue Jan 19, 2023 · 2 comments
Assignees
Labels
bug Something isn't working Good first issue A good way to start hacking torchrl!

Comments

@matteobettini
Copy link
Contributor

matteobettini commented Jan 19, 2023

Describe the bug

The collectors currently force the actual collected frames_per_batch to be divisible by the number of batched environments (which can be collector workers or parallel workers)(if looking at #828 this could also be vectorized dimensions in the batch size).

This leads to the user feeding a desired frames_per_batch at collector creation, and actually getting more frames than requested. As you can see in the following example:

gym_env = lambda: GymEnv("Pendulum-v1", device="cpu")
gym_parallel_env = lambda: ParallelEnv(10, gym_env)

pendulum_policy = TensorDictModule(
    nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
)

coll = SyncDataCollector(
    gym_parallel_env,
    pendulum_policy,
    total_frames=20000,
    max_frames_per_traj=5,
    frames_per_batch=145,
    split_trajs=False,
)

for data in coll:
    print("Ending", data) # batch_size=torch.Size([10, 15]) aka 150 frames
    break

Which is caused for example by code like this:

self.frames_per_batch = -(-frames_per_batch // self.n_env)

This behavior might be dangerous for some users which might think that at each iteration they are training on x frames and instead they are training on x+y frames.

Solutions

  1. Throw an error if the frames_per_batch is not divisible by the number of batched envs
  2. Throw a warning if the frames_per_batch is not divisible by the number of batched envs
  3. Find a way to return only the requested amount of frames through discarding some of the collected data
@matteobettini matteobettini added the bug Something isn't working label Jan 19, 2023
@vmoens
Copy link
Contributor

vmoens commented Jan 19, 2023

Makes sense, I'll add that to my collector refactoring

@vmoens vmoens added the Good first issue A good way to start hacking torchrl! label Jan 19, 2023
@matteobettini
Copy link
Contributor Author

#828 introduces warnings for this, but we would like to turn those into errors to resolve this issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Good first issue A good way to start hacking torchrl!
Projects
None yet
Development

No branches or pull requests

2 participants