Skip to content

Commit

Permalink
Reduce the memory usage of logits from O(context_length) to O(1) (#4688)
Browse files Browse the repository at this point in the history
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Differential Revision: D61246566
  • Loading branch information
Martin Yuan authored and facebook-github-bot committed Aug 14, 2024
1 parent ba3448c commit 4e1f741
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 15 deletions.
3 changes: 3 additions & 0 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,9 @@ def forward(
input_pos,
)

# Only the last logit is used for the new generated token
h = h[:, -1, :]

h = self.norm(h)

logits = self.output(h)
Expand Down
12 changes: 2 additions & 10 deletions extension/llm/runner/text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,15 @@ class TextDecoderRunner {
* @return The next token.
*/
inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) {
ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D");
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);

switch (logits_tensor.scalar_type()) {
case ScalarType::Float: {
float* logits = logits_tensor.mutable_data_ptr<float>();
float* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
return sampler_->sample(logits);
}
case ScalarType::Half: {
exec_aten::Half* logits =
logits_tensor.mutable_data_ptr<exec_aten::Half>();
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
return sampler_->sample(logits);
}
default:
ET_CHECK_MSG(
Expand Down
5 changes: 0 additions & 5 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ Result<uint64_t> TextPrefiller::prefill(
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
ET_CHECK_MSG(
outputs_res.get().size(1) == num_prompt_tokens,
"Expected number of output tokens %d does not match returned value %zu.",
num_prompt_tokens,
outputs_res.get().size(1));
// insert new token into prompt_tokens
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
uint64_t prev = prompt_tokens[0];
Expand Down

0 comments on commit 4e1f741

Please sign in to comment.