Skip to content

Commit

Permalink
Reduce the memory usage of logits from O(context_length) to O(1) (#4688)
Browse files Browse the repository at this point in the history
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Reviewed By: larryliu0820

Differential Revision: D61246566

Pulled By: iseeyuan
  • Loading branch information
Martin Yuan authored and facebook-github-bot committed Aug 21, 2024
1 parent 78b0867 commit 5e15211
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 35 deletions.
3 changes: 3 additions & 0 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,9 @@ def forward(
input_pos,
)

# Only the last logit is used for the new generated token
h = h[:, -1, :]

h = self.norm(h)

logits = self.output(h)
Expand Down
7 changes: 5 additions & 2 deletions examples/models/llava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,19 @@ def prefill_embedding(
result = torch.cat((embeds_before_img, image_embeds, embeds_after_img), dim=1)
return result

# prefill using the in house text_model of llama transformer
def prefill(
self,
prompt_before_image: torch.Tensor,
images: torch.Tensor,
prompt_after_image: torch.Tensor,
) -> torch.Tensor:
) -> (int, torch.Tensor):
"""Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image)
return self.text_model.forward(None, torch.tensor([0]), embeds)
# returns the prefilled token length too, because the text model generates one logits in each forward call.
return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds)

# reference prefill using the text model in HF
def prefill_ref(
self,
prompt_before_image: torch.Tensor,
Expand Down
6 changes: 5 additions & 1 deletion examples/models/llava/runner/llava_image_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class LlavaImagePrefiller : public ImagePrefiller {
* @param start_pos The starting position in KV cache of the input in the LLM
* @return logits of the image prefill.
*/
inline Result<exec_aten::Tensor> prefill(Image& image, int64_t start_pos = 0)
inline Result<exec_aten::Tensor> prefill(Image& image, int64_t& start_pos)
override {
ManagedTensor managed_images(
image.data.data(), {3, image.height, image.width}, ScalarType::Byte);
Expand All @@ -43,6 +43,10 @@ class LlavaImagePrefiller : public ImagePrefiller {
outputs_res[0].isTensor(),
"Non Tensor Output returned from executing image prefill");

// Update the start_pos, which is only available inside this function.
// outputs_res can have only one logits.
start_pos += image_encoder_outputs[0].toTensor().size(1);

return outputs_res[0].toTensor();
}

Expand Down
4 changes: 2 additions & 2 deletions examples/models/llava/runner/llava_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ Error LlavaRunner::generate(

// prefill images
for (auto& image : images) {
auto logits = ET_UNWRAP(image_prefiller_->prefill(image, pos));
pos += logits.size(1);
// pos is updated inside image prefill.
ET_UNWRAP(image_prefiller_->prefill(image, pos));
}

// prefill user prompt. No BOS because preset prompt already has it.
Expand Down
34 changes: 21 additions & 13 deletions examples/models/llava/test/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ def setUp(self):
)

def test_prefill_logits(self):
prefill_logits = self.llava.prefill(
# For efficiency, the implemented prefill function only outputs the last logits.
_, prefill_logits = self.llava.prefill(
self.prompt_before_image, self.resized, self.prompt_after_image
)
# The reference implementation in HF genetates the full logits. Get the last one.
prefill_logits_ref = self.llava.prefill_ref(
self.prompt_before_image, self.resized, self.prompt_after_image
)[0]
)[0][:, -1, :]
self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2))

def test_generated_output(self):
Expand All @@ -62,11 +64,11 @@ def test_generated_output(self):
)[0].strip()

# being tested, using llama_transformer
prefill_logits = self.llava.prefill(
context_len, prefill_logits = self.llava.prefill(
self.prompt_before_image, self.resized, self.prompt_after_image
)
context_len = prefill_logits.shape[1]
new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()]
# Always generate one token at a time.
new_tokens = [torch.argmax(prefill_logits).item()]
for i in range(4):
logits = self.llava.step(
torch.tensor([new_tokens[i]]), torch.tensor([context_len + i])
Expand All @@ -93,24 +95,27 @@ def test_llava_export(self):
pte_embeds_before_img = llava_module.run_method(
"token_embedding", (prompt_before_image,)
)[0]
pte_prefill_before_img = llava_module.run_method(
llava_module.run_method(
"text_model",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
)[0]
)

start_pos += pte_prefill_before_img.shape[1]
# Update the start_pos. start_pos is used in kv cache. The source of truth
# of the delta length is from the embeddings, not from the logits.
start_pos += pte_embeds_before_img.shape[1]

# pte prefill image
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
pte_prefill_img = llava_module.run_method(
llava_module.run_method(
"text_model",
(
torch.tensor([start_pos], dtype=torch.int64),
pte_embeds_img,
),
)[0]
)

start_pos += pte_prefill_img.shape[1]
# Update the logits for each prefill (kv cache) step.
start_pos += pte_embeds_img.shape[1]

# pte prefill prompt after img
pte_embeds_after_img = llava_module.run_method(
Expand All @@ -121,8 +126,11 @@ def test_llava_export(self):
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
)[0]

# Update the logits for each prefill (kv cache) step.
start_pos += pte_embeds_after_img.shape[1]

# being tested, using llama_transformer
new_tokens = [torch.argmax(pte_prefill_after_img[..., -1, :]).item()]
new_tokens = [torch.argmax(pte_prefill_after_img).item()]
# TODO: uncomment this line
# self.assertEquals(new_tokens[0], 1932) # When
for i in range(4):
Expand All @@ -134,7 +142,7 @@ def test_llava_export(self):
"text_model",
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
)[0]
new_tokens.append(torch.argmax(logits[..., -1, :]).item())
new_tokens.append(torch.argmax(logits).item())

outputs = llava_model.tokenizer.batch_decode(
torch.tensor([new_tokens]), skip_special_tokens=True
Expand Down
5 changes: 3 additions & 2 deletions extension/llm/runner/image_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ class ImagePrefiller {
/**
* Prefill an LLM Module with the given image input.
* @param image The image input to the multimodal LLM.
* @param start_pos The starting position in KV cache of the input in the LLM
* @param start_pos The starting position in KV cache of the input in the LLM.
* It's passed as reference and will be updated inside this function.
* @return The next token of the LLM Module after prefill.
*/
virtual Result<exec_aten::Tensor> prefill(
Image& image,
int64_t start_pos = 0) = 0;
int64_t& start_pos) = 0;

virtual Error load() = 0;
virtual bool is_method_loaded() = 0;
Expand Down
29 changes: 19 additions & 10 deletions extension/llm/runner/text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,32 @@ class TextDecoderRunner {
* @return The next token.
*/
inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) {
ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D");
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);

switch (logits_tensor.scalar_type()) {
// If the logit_tensor rank is 3, the shape is [batch, seq_length,
// vocab_size], get the last logits, sample and return. Else the model
// outputs the last logit, directly sample and return.
case ScalarType::Float: {
float* logits = logits_tensor.mutable_data_ptr<float>();
float* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
float* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
}
return sampler_->sample(logits);
}
case ScalarType::Half: {
exec_aten::Half* logits =
logits_tensor.mutable_data_ptr<exec_aten::Half>();
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
}
return sampler_->sample(logits);
}
default:
ET_CHECK_MSG(
Expand Down
5 changes: 0 additions & 5 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ Result<uint64_t> TextPrefiller::prefill(
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
ET_CHECK_MSG(
outputs_res.get().size(1) == num_prompt_tokens,
"Expected number of output tokens %d does not match returned value %zu.",
num_prompt_tokens,
outputs_res.get().size(1));
// insert new token into prompt_tokens
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
uint64_t prev = prompt_tokens[0];
Expand Down

0 comments on commit 5e15211

Please sign in to comment.