-
Notifications
You must be signed in to change notification settings - Fork 334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Memory leak? #845
Comments
Yeah there are multiple alternatives to padding in the torch world, but none of them really having all the features we need. There is for example This refactoring of |
thanks @matteobettini
|
Yes makes sense |
If Also for me makes sense having one device for execution and one for results |
Closing this as:
|
Describe the bug
I’m addressing a “memory-leak” issue, but I’m not sure it’s a real memory leak
With pong the mem on the GPU i use for data collection keeps increasing.
The strange thing is that it correlates with the performance of the training: the better the training, the higher the mem consumption. The obvious explanation (not the full story) is that better perf <=> longer trajectories. Hence for some reason, longer trajs cause the memory to increase.
A few things could explain that:
[B x max_T]
tensordict, wheremax_T
is the maximum length of the trajectories collected. Now imagine you have 8 workers and a batch size of 128 elts per worker. 7 workers collect trajectories all < 10 steps for a batch of length 128 (ie 7 x 128 // 10 = 100 small trajectories), and one of them collects one long trajectory of length 128. Thesplit_trajs
will deliver a batchB=101
and amax_T=128
but 90% of the values will be zeros.Possible solutions
The main thing that worried me and made me use this split traj was using different trajectories sequentially may break some algos.
From my experiments with the advantage functions (TD0, TDLambda, GAE) only TDLambda suffers from this and it's likely that it is because we're not using the
done
flag appropriately.split_trajs
but turn it off by default. Fix the value functions to make them work in this context.device
and thepassing_device
or else is the device where the data is dumped at each iteration.Some more context
I tried using
gc.collect()
but with the pong example I was running it didn't change anything.@albertbou92 I know you had a similar issue, interested in having your perspective on this.
@ShahRutav I believe that in your case
split_trajs
does not have an impact so I doubt that it is the cause of the problem. I'll keep diggingThe text was updated successfully, but these errors were encountered: