From 11e8ed33bed08b4cb893ab8dda201f683017dee7 Mon Sep 17 00:00:00 2001 From: Mengtao Yuan Date: Thu, 22 Aug 2024 19:03:04 -0700 Subject: [PATCH] Reduce the memory usage of logits from O(context_length) to O(1) Differential Revision: D61246566 Pull Request resolved: https://github.com/pytorch/executorch/pull/4688 --- examples/models/llama2/export_llama_lib.py | 4 +-- examples/models/llama2/llama_transformer.py | 2 +- examples/models/llama2/model.py | 2 +- examples/models/llava/model.py | 7 ++-- .../llava/runner/llava_image_prefiller.h | 6 +++- examples/models/llava/runner/llava_runner.cpp | 4 +-- examples/models/llava/test/test_llava.py | 34 ++++++++++++------- extension/llm/runner/image_prefiller.h | 5 +-- extension/llm/runner/text_decoder_runner.h | 29 ++++++++++------ extension/llm/runner/text_prefiller.cpp | 5 --- 10 files changed, 59 insertions(+), 39 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 221f2f75bc..172a1d72fd 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -300,7 +300,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--generate_full_logits", action="store_true", required=False, - default=True, + default=False, help="Generate logits for all inputs.", ) return parser @@ -598,7 +598,7 @@ def _load_llama_model( params_path: str, use_kv_cache: bool = False, use_sdpa_with_kv_cache: bool = False, - generate_full_logits: bool = True, + generate_full_logits: bool = False, weight_type: WeightType = WeightType.LLAMA, enable_dynamic_shape: bool = False, verbose: bool = False, diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 81b47a3a5d..0c93115ee3 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -99,7 +99,7 @@ class ModelArgs: # Generate logits for all inputs. When it's True, it would take big memory usage # at runtime. Enable it only necessary (e.g., use perplexity tools that requires # logits for all input tokens.) - generate_full_logits: bool = True + generate_full_logits: bool = False enable_dynamic_shape: bool = False # export model with dynamic shape support use_hf_rope: bool = False # Use HuggingFace's RoPE implementation rope_theta: Optional[float] = ( diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index b375399f33..f58a2a2def 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -61,7 +61,7 @@ def __init__(self, **kwargs): self.use_kv_cache = kwargs.get("use_kv_cache", False) self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) - self.generate_full_logits = kwargs.get("generate_full_logits", True) + self.generate_full_logits = kwargs.get("generate_full_logits", False) self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) self.max_seq_len = kwargs.get("max_seq_len", 128) diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 9f6d8d32e8..4f975e2ed4 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -216,16 +216,19 @@ def prefill_embedding( result = torch.cat((embeds_before_img, image_embeds, embeds_after_img), dim=1) return result + # prefill using the in house text_model of llama transformer def prefill( self, prompt_before_image: torch.Tensor, images: torch.Tensor, prompt_after_image: torch.Tensor, - ) -> torch.Tensor: + ) -> (int, torch.Tensor): """Avoiding the torch.where() call to find placeholder and insert image embedding. Taking 3 inputs instead.""" embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image) - return self.text_model.forward(None, torch.tensor([0]), embeds) + # returns the prefilled token length too, because the text model generates one logits in each forward call. + return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds) + # reference prefill using the text model in HF def prefill_ref( self, prompt_before_image: torch.Tensor, diff --git a/examples/models/llava/runner/llava_image_prefiller.h b/examples/models/llava/runner/llava_image_prefiller.h index e845329908..4d0a07b9a6 100644 --- a/examples/models/llava/runner/llava_image_prefiller.h +++ b/examples/models/llava/runner/llava_image_prefiller.h @@ -24,7 +24,7 @@ class LlavaImagePrefiller : public ImagePrefiller { * @param start_pos The starting position in KV cache of the input in the LLM * @return logits of the image prefill. */ - inline Result prefill(Image& image, int64_t start_pos = 0) + inline Result prefill(Image& image, int64_t& start_pos) override { ManagedTensor managed_images( image.data.data(), {3, image.height, image.width}, ScalarType::Byte); @@ -43,6 +43,10 @@ class LlavaImagePrefiller : public ImagePrefiller { outputs_res[0].isTensor(), "Non Tensor Output returned from executing image prefill"); + // Update the start_pos, which is only available inside this function. + // outputs_res can have only one logits. + start_pos += image_encoder_outputs[0].toTensor().size(1); + return outputs_res[0].toTensor(); } diff --git a/examples/models/llava/runner/llava_runner.cpp b/examples/models/llava/runner/llava_runner.cpp index a58fdfd5e5..b186af892f 100644 --- a/examples/models/llava/runner/llava_runner.cpp +++ b/examples/models/llava/runner/llava_runner.cpp @@ -106,8 +106,8 @@ Error LlavaRunner::generate( // prefill images for (auto& image : images) { - auto logits = ET_UNWRAP(image_prefiller_->prefill(image, pos)); - pos += logits.size(1); + // pos is updated inside image prefill. + ET_UNWRAP(image_prefiller_->prefill(image, pos)); } // prefill user prompt. No BOS because preset prompt already has it. diff --git a/examples/models/llava/test/test_llava.py b/examples/models/llava/test/test_llava.py index ef503a88fc..f464a580a8 100644 --- a/examples/models/llava/test/test_llava.py +++ b/examples/models/llava/test/test_llava.py @@ -35,12 +35,14 @@ def setUp(self): ) def test_prefill_logits(self): - prefill_logits = self.llava.prefill( + # For efficiency, the implemented prefill function only outputs the last logits. + _, prefill_logits = self.llava.prefill( self.prompt_before_image, self.resized, self.prompt_after_image ) + # The reference implementation in HF genetates the full logits. Get the last one. prefill_logits_ref = self.llava.prefill_ref( self.prompt_before_image, self.resized, self.prompt_after_image - )[0] + )[0][:, -1, :] self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2)) def test_generated_output(self): @@ -62,11 +64,11 @@ def test_generated_output(self): )[0].strip() # being tested, using llama_transformer - prefill_logits = self.llava.prefill( + context_len, prefill_logits = self.llava.prefill( self.prompt_before_image, self.resized, self.prompt_after_image ) - context_len = prefill_logits.shape[1] - new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()] + # Always generate one token at a time. + new_tokens = [torch.argmax(prefill_logits).item()] for i in range(4): logits = self.llava.step( torch.tensor([new_tokens[i]]), torch.tensor([context_len + i]) @@ -93,24 +95,27 @@ def test_llava_export(self): pte_embeds_before_img = llava_module.run_method( "token_embedding", (prompt_before_image,) )[0] - pte_prefill_before_img = llava_module.run_method( + llava_module.run_method( "text_model", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img), - )[0] + ) - start_pos += pte_prefill_before_img.shape[1] + # Update the start_pos. start_pos is used in kv cache. The source of truth + # of the delta length is from the embeddings, not from the logits. + start_pos += pte_embeds_before_img.shape[1] # pte prefill image pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0] - pte_prefill_img = llava_module.run_method( + llava_module.run_method( "text_model", ( torch.tensor([start_pos], dtype=torch.int64), pte_embeds_img, ), - )[0] + ) - start_pos += pte_prefill_img.shape[1] + # Update the logits for each prefill (kv cache) step. + start_pos += pte_embeds_img.shape[1] # pte prefill prompt after img pte_embeds_after_img = llava_module.run_method( @@ -121,8 +126,11 @@ def test_llava_export(self): (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img), )[0] + # Update the logits for each prefill (kv cache) step. + start_pos += pte_embeds_after_img.shape[1] + # being tested, using llama_transformer - new_tokens = [torch.argmax(pte_prefill_after_img[..., -1, :]).item()] + new_tokens = [torch.argmax(pte_prefill_after_img).item()] # TODO: uncomment this line # self.assertEquals(new_tokens[0], 1932) # When for i in range(4): @@ -134,7 +142,7 @@ def test_llava_export(self): "text_model", (torch.tensor([start_pos + i], dtype=torch.int64), token_embeds), )[0] - new_tokens.append(torch.argmax(logits[..., -1, :]).item()) + new_tokens.append(torch.argmax(logits).item()) outputs = llava_model.tokenizer.batch_decode( torch.tensor([new_tokens]), skip_special_tokens=True diff --git a/extension/llm/runner/image_prefiller.h b/extension/llm/runner/image_prefiller.h index 879b0a6e21..93bb9a030b 100644 --- a/extension/llm/runner/image_prefiller.h +++ b/extension/llm/runner/image_prefiller.h @@ -26,12 +26,13 @@ class ImagePrefiller { /** * Prefill an LLM Module with the given image input. * @param image The image input to the multimodal LLM. - * @param start_pos The starting position in KV cache of the input in the LLM + * @param start_pos The starting position in KV cache of the input in the LLM. + * It's passed as reference and will be updated inside this function. * @return The next token of the LLM Module after prefill. */ virtual ::executorch::runtime::Result prefill( Image& image, - int64_t start_pos = 0) = 0; + int64_t& start_pos) = 0; virtual ::executorch::runtime::Error load() = 0; virtual bool is_method_loaded() = 0; diff --git a/extension/llm/runner/text_decoder_runner.h b/extension/llm/runner/text_decoder_runner.h index 6a8e3396fe..70ee1d0136 100644 --- a/extension/llm/runner/text_decoder_runner.h +++ b/extension/llm/runner/text_decoder_runner.h @@ -67,23 +67,32 @@ class TextDecoderRunner { * @return The next token. */ inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) { - ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D"); - auto num_tokens = logits_tensor.size(1); - auto vocab_size = logits_tensor.size(2); - switch (logits_tensor.scalar_type()) { + // If the logit_tensor rank is 3, the shape is [batch, seq_length, + // vocab_size], get the last logits, sample and return. Else the model + // outputs the last logit, directly sample and return. case exec_aten::ScalarType::Float: { float* logits = logits_tensor.mutable_data_ptr(); - float* logits_last = logits; - logits_last += (num_tokens - 1) * vocab_size; - return sampler_->sample(logits_last); + if (logits_tensor.dim() == 3) { + auto num_tokens = logits_tensor.size(1); + auto vocab_size = logits_tensor.size(2); + float* logits_last = logits; + logits_last += (num_tokens - 1) * vocab_size; + return sampler_->sample(logits_last); + } + return sampler_->sample(logits); } case exec_aten::ScalarType::Half: { exec_aten::Half* logits = logits_tensor.mutable_data_ptr(); - exec_aten::Half* logits_last = logits; - logits_last += (num_tokens - 1) * vocab_size; - return sampler_->sample(logits_last); + if (logits_tensor.dim() == 3) { + auto num_tokens = logits_tensor.size(1); + auto vocab_size = logits_tensor.size(2); + exec_aten::Half* logits_last = logits; + logits_last += (num_tokens - 1) * vocab_size; + return sampler_->sample(logits_last); + } + return sampler_->sample(logits); } default: ET_CHECK_MSG( diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 19fc2d5936..4b9afb8326 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -55,11 +55,6 @@ ::executorch::runtime::Result TextPrefiller::prefill( ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); ET_LOG( Info, "Prefill token result numel(): %zu", outputs_res.get().numel()); - ET_CHECK_MSG( - outputs_res.get().size(1) == num_prompt_tokens, - "Expected number of output tokens %d does not match returned value %zu.", - num_prompt_tokens, - outputs_res.get().size(1)); // insert new token into prompt_tokens // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) uint64_t prev = prompt_tokens[0];