You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The original glow code uses gradient checkpointing, a very efficient way of reducing peak memory consumption. The following single line adds gradient checkpointing in away that memory consumption from 11 GB to 2 GB. It allowed me to increased batch size from 64 to 256 with no issue. I think 512 is possible, maybe even 1024 if we use float16 for some of the layers.
def forward(self, x, ldj, reverse=False):
x_change, x_id = x.chunk(2, dim=1)
#st = self.nn(x_id) # change this line to the one below.
st = torch.utils.checkpoint.checkpoint(self.nn, x_id)
s, t = st[:, 0::2, ...], st[:, 1::2, ...]
s = self.scale * torch.tanh(s)
The text was updated successfully, but these errors were encountered:
Thanks for the suggestion. I added a reference to this issue in the README. If you'd like to add support for checkpointing as a command line argument, feel free to open a pull request and I'll happily review.
Thanks for the suggestion. I added a reference to this issue in the README. If you'd like to add support for checkpointing as a command line argument, feel free to open a pull request and I'll happily review.
The original glow code uses gradient checkpointing, a very efficient way of reducing peak memory consumption. The following single line adds gradient checkpointing in away that memory consumption from 11 GB to 2 GB. It allowed me to increased batch size from 64 to 256 with no issue. I think 512 is possible, maybe even 1024 if we use float16 for some of the layers.
glow/models/glow/coupling.py
Line 28 in 59ed99f
The text was updated successfully, but these errors were encountered: