From 5e1521170290de693d6330b31a40a5bd619c40b1 Mon Sep 17 00:00:00 2001 From: Martin Yuan Date: Wed, 21 Aug 2024 16:34:29 -0700 Subject: [PATCH] Reduce the memory usage of logits from O(context_length) to O(1) (#4688) Summary: The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output . Test command: ``` python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory ``` Before: 284 MB activation, with 262 MB on logits After: 162 MB activation, with 0.128 MB on logits Verified with llamma_runner, before and after it generates the same text with temperature=0. Now the dominant memory usage would be KV cache. TODO: - Improve KV cache memory usage using pf16 or quantization. - This PR only fixes logits. Further activation memory optimization with one token output. Reviewed By: larryliu0820 Differential Revision: D61246566 Pulled By: iseeyuan --- examples/models/llama2/llama_transformer.py | 3 ++ examples/models/llava/model.py | 7 ++-- .../llava/runner/llava_image_prefiller.h | 6 +++- examples/models/llava/runner/llava_runner.cpp | 4 +-- examples/models/llava/test/test_llava.py | 34 ++++++++++++------- extension/llm/runner/image_prefiller.h | 5 +-- extension/llm/runner/text_decoder_runner.h | 29 ++++++++++------ extension/llm/runner/text_prefiller.cpp | 5 --- 8 files changed, 58 insertions(+), 35 deletions(-) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 4ae12b0f647..71a52d9f93d 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -512,6 +512,9 @@ def forward( input_pos, ) + # Only the last logit is used for the new generated token + h = h[:, -1, :] + h = self.norm(h) logits = self.output(h) diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 9f6d8d32e8e..4f975e2ed4b 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -216,16 +216,19 @@ def prefill_embedding( result = torch.cat((embeds_before_img, image_embeds, embeds_after_img), dim=1) return result + # prefill using the in house text_model of llama transformer def prefill( self, prompt_before_image: torch.Tensor, images: torch.Tensor, prompt_after_image: torch.Tensor, - ) -> torch.Tensor: + ) -> (int, torch.Tensor): """Avoiding the torch.where() call to find placeholder and insert image embedding. Taking 3 inputs instead.""" embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image) - return self.text_model.forward(None, torch.tensor([0]), embeds) + # returns the prefilled token length too, because the text model generates one logits in each forward call. + return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds) + # reference prefill using the text model in HF def prefill_ref( self, prompt_before_image: torch.Tensor, diff --git a/examples/models/llava/runner/llava_image_prefiller.h b/examples/models/llava/runner/llava_image_prefiller.h index e8453299085..4d0a07b9a66 100644 --- a/examples/models/llava/runner/llava_image_prefiller.h +++ b/examples/models/llava/runner/llava_image_prefiller.h @@ -24,7 +24,7 @@ class LlavaImagePrefiller : public ImagePrefiller { * @param start_pos The starting position in KV cache of the input in the LLM * @return logits of the image prefill. */ - inline Result prefill(Image& image, int64_t start_pos = 0) + inline Result prefill(Image& image, int64_t& start_pos) override { ManagedTensor managed_images( image.data.data(), {3, image.height, image.width}, ScalarType::Byte); @@ -43,6 +43,10 @@ class LlavaImagePrefiller : public ImagePrefiller { outputs_res[0].isTensor(), "Non Tensor Output returned from executing image prefill"); + // Update the start_pos, which is only available inside this function. + // outputs_res can have only one logits. + start_pos += image_encoder_outputs[0].toTensor().size(1); + return outputs_res[0].toTensor(); } diff --git a/examples/models/llava/runner/llava_runner.cpp b/examples/models/llava/runner/llava_runner.cpp index c5ce03b88d7..48b8b78e61e 100644 --- a/examples/models/llava/runner/llava_runner.cpp +++ b/examples/models/llava/runner/llava_runner.cpp @@ -104,8 +104,8 @@ Error LlavaRunner::generate( // prefill images for (auto& image : images) { - auto logits = ET_UNWRAP(image_prefiller_->prefill(image, pos)); - pos += logits.size(1); + // pos is updated inside image prefill. + ET_UNWRAP(image_prefiller_->prefill(image, pos)); } // prefill user prompt. No BOS because preset prompt already has it. diff --git a/examples/models/llava/test/test_llava.py b/examples/models/llava/test/test_llava.py index ef503a88fc3..f464a580a87 100644 --- a/examples/models/llava/test/test_llava.py +++ b/examples/models/llava/test/test_llava.py @@ -35,12 +35,14 @@ def setUp(self): ) def test_prefill_logits(self): - prefill_logits = self.llava.prefill( + # For efficiency, the implemented prefill function only outputs the last logits. + _, prefill_logits = self.llava.prefill( self.prompt_before_image, self.resized, self.prompt_after_image ) + # The reference implementation in HF genetates the full logits. Get the last one. prefill_logits_ref = self.llava.prefill_ref( self.prompt_before_image, self.resized, self.prompt_after_image - )[0] + )[0][:, -1, :] self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2)) def test_generated_output(self): @@ -62,11 +64,11 @@ def test_generated_output(self): )[0].strip() # being tested, using llama_transformer - prefill_logits = self.llava.prefill( + context_len, prefill_logits = self.llava.prefill( self.prompt_before_image, self.resized, self.prompt_after_image ) - context_len = prefill_logits.shape[1] - new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()] + # Always generate one token at a time. + new_tokens = [torch.argmax(prefill_logits).item()] for i in range(4): logits = self.llava.step( torch.tensor([new_tokens[i]]), torch.tensor([context_len + i]) @@ -93,24 +95,27 @@ def test_llava_export(self): pte_embeds_before_img = llava_module.run_method( "token_embedding", (prompt_before_image,) )[0] - pte_prefill_before_img = llava_module.run_method( + llava_module.run_method( "text_model", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img), - )[0] + ) - start_pos += pte_prefill_before_img.shape[1] + # Update the start_pos. start_pos is used in kv cache. The source of truth + # of the delta length is from the embeddings, not from the logits. + start_pos += pte_embeds_before_img.shape[1] # pte prefill image pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0] - pte_prefill_img = llava_module.run_method( + llava_module.run_method( "text_model", ( torch.tensor([start_pos], dtype=torch.int64), pte_embeds_img, ), - )[0] + ) - start_pos += pte_prefill_img.shape[1] + # Update the logits for each prefill (kv cache) step. + start_pos += pte_embeds_img.shape[1] # pte prefill prompt after img pte_embeds_after_img = llava_module.run_method( @@ -121,8 +126,11 @@ def test_llava_export(self): (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img), )[0] + # Update the logits for each prefill (kv cache) step. + start_pos += pte_embeds_after_img.shape[1] + # being tested, using llama_transformer - new_tokens = [torch.argmax(pte_prefill_after_img[..., -1, :]).item()] + new_tokens = [torch.argmax(pte_prefill_after_img).item()] # TODO: uncomment this line # self.assertEquals(new_tokens[0], 1932) # When for i in range(4): @@ -134,7 +142,7 @@ def test_llava_export(self): "text_model", (torch.tensor([start_pos + i], dtype=torch.int64), token_embeds), )[0] - new_tokens.append(torch.argmax(logits[..., -1, :]).item()) + new_tokens.append(torch.argmax(logits).item()) outputs = llava_model.tokenizer.batch_decode( torch.tensor([new_tokens]), skip_special_tokens=True diff --git a/extension/llm/runner/image_prefiller.h b/extension/llm/runner/image_prefiller.h index 64b623be36f..1805893d816 100644 --- a/extension/llm/runner/image_prefiller.h +++ b/extension/llm/runner/image_prefiller.h @@ -22,12 +22,13 @@ class ImagePrefiller { /** * Prefill an LLM Module with the given image input. * @param image The image input to the multimodal LLM. - * @param start_pos The starting position in KV cache of the input in the LLM + * @param start_pos The starting position in KV cache of the input in the LLM. + * It's passed as reference and will be updated inside this function. * @return The next token of the LLM Module after prefill. */ virtual Result prefill( Image& image, - int64_t start_pos = 0) = 0; + int64_t& start_pos) = 0; virtual Error load() = 0; virtual bool is_method_loaded() = 0; diff --git a/extension/llm/runner/text_decoder_runner.h b/extension/llm/runner/text_decoder_runner.h index 49ddea66299..dfb5498062b 100644 --- a/extension/llm/runner/text_decoder_runner.h +++ b/extension/llm/runner/text_decoder_runner.h @@ -65,23 +65,32 @@ class TextDecoderRunner { * @return The next token. */ inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) { - ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D"); - auto num_tokens = logits_tensor.size(1); - auto vocab_size = logits_tensor.size(2); - switch (logits_tensor.scalar_type()) { + // If the logit_tensor rank is 3, the shape is [batch, seq_length, + // vocab_size], get the last logits, sample and return. Else the model + // outputs the last logit, directly sample and return. case ScalarType::Float: { float* logits = logits_tensor.mutable_data_ptr(); - float* logits_last = logits; - logits_last += (num_tokens - 1) * vocab_size; - return sampler_->sample(logits_last); + if (logits_tensor.dim() == 3) { + auto num_tokens = logits_tensor.size(1); + auto vocab_size = logits_tensor.size(2); + float* logits_last = logits; + logits_last += (num_tokens - 1) * vocab_size; + return sampler_->sample(logits_last); + } + return sampler_->sample(logits); } case ScalarType::Half: { exec_aten::Half* logits = logits_tensor.mutable_data_ptr(); - exec_aten::Half* logits_last = logits; - logits_last += (num_tokens - 1) * vocab_size; - return sampler_->sample(logits_last); + if (logits_tensor.dim() == 3) { + auto num_tokens = logits_tensor.size(1); + auto vocab_size = logits_tensor.size(2); + exec_aten::Half* logits_last = logits; + logits_last += (num_tokens - 1) * vocab_size; + return sampler_->sample(logits_last); + } + return sampler_->sample(logits); } default: ET_CHECK_MSG( diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index beafb21434d..45782bb9fa8 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -50,11 +50,6 @@ Result TextPrefiller::prefill( ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); ET_LOG( Info, "Prefill token result numel(): %zu", outputs_res.get().numel()); - ET_CHECK_MSG( - outputs_res.get().size(1) == num_prompt_tokens, - "Expected number of output tokens %d does not match returned value %zu.", - num_prompt_tokens, - outputs_res.get().size(1)); // insert new token into prompt_tokens // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) uint64_t prev = prompt_tokens[0];