Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the memory use of the pairwise linear model constant #45

Merged
merged 15 commits into from
Jun 26, 2023

Conversation

danieldk
Copy link
Contributor

The Thinc part of the pairwire bilinear model is fairly simple before this change: we would collect the splits from all documents and then pad them. However, this caused the model to run out of memory on large docs, since it has to compute many n*n matrices (all padded to the longest sequence length). It would also perform unnecessary computations on many padding time steps.

This change make the memory use independent of the doc length (given a fixed split length) by doing the following:

  • Get all splits and flatten to a list of split representations. (with_splits)
  • Batch the splits by their padded sizes. This ensures that memory use is constant when splits have a maximum size. This also permits some buffering, so that we get more equisized batches. (with_minibatch_by_padded_size)
  • The splits in the batches are padded and passed to the Torch model. Since the outputs of the Torch model are matrices, we unpad taking this into account. (with_pad_seq_unpad_matrix)

In contrast to most with_* layers, with_splits is not symmetric. It takes at its input representations for each document (List[Floats2d]), however it outputs pairwise score matrices per split. The reason is that since the dimensions of the score matrices differ per split, we cannot concatenate them at a document level.

The Thinc part of the pairwire bilinear model is fairly simple before
this change: we would collect the splits from all documents and then pad
them. However, this caused the model to run out of memory on large docs,
since it has to compute many n*n matrices (all padded to the longest
sequence length). It would also perform unnecessary computations on
many padding time steps.

This change make the memory use independent of the doc length (given a
fixed split length) by doing the following:

- Get all splits and flatten to a list of split representations.
  (`with_splits`)
- Batch the splits by their padded sizes. This ensures that memory
  use is constant when splits have a maximum size. This also permits
  some buffering, so that we get more equisized batches.
  (`with_minibatch_by_padded_size`)
- The splits in the batches are padded and passed to the Torch model.
  Since the outputs of the Torch model are matrices, we unpad taking
  this into account. (`with_pad_seq_unpad_matrix`)

In contrast to most `with_*` layers, `with_splits` is not symmetric. It
takes at its input representations for each document (`List[Floats2d]`),
however it outputs pairwise score matrices per split. The reason is that
since the dimensions of the score matrices differ per split, we cannot
concatenate them at a document level.
@danieldk danieldk added the enhancement New feature or request label Mar 30, 2023
@danieldk
Copy link
Contributor Author

Assigning to @shadeMe, since this looks somewhat similar to the data massaging that we have in curated transformers.

@danieldk danieldk closed this Apr 11, 2023
@danieldk danieldk reopened this Apr 11, 2023
Copy link
Contributor

@shadeMe shadeMe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just a couple of minor typos; can be merged once they're fixed.

Daniël de Kok and others added 2 commits June 26, 2023 11:15
Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
@shadeMe shadeMe merged commit a139a5e into explosion:v4 Jun 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants