Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the memory usage of logits from O(context_length) to O(1) #4688

Merged
merged 1 commit into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def build_args_parser() -> argparse.ArgumentParser:
"--generate_full_logits",
action="store_true",
required=False,
default=True,
default=False,
help="Generate logits for all inputs.",
)
return parser
Expand Down Expand Up @@ -598,7 +598,7 @@ def _load_llama_model(
params_path: str,
use_kv_cache: bool = False,
use_sdpa_with_kv_cache: bool = False,
generate_full_logits: bool = True,
generate_full_logits: bool = False,
weight_type: WeightType = WeightType.LLAMA,
enable_dynamic_shape: bool = False,
verbose: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ModelArgs:
# Generate logits for all inputs. When it's True, it would take big memory usage
# at runtime. Enable it only necessary (e.g., use perplexity tools that requires
# logits for all input tokens.)
generate_full_logits: bool = True
generate_full_logits: bool = False
enable_dynamic_shape: bool = False # export model with dynamic shape support
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
rope_theta: Optional[float] = (
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, **kwargs):

self.use_kv_cache = kwargs.get("use_kv_cache", False)
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
self.generate_full_logits = kwargs.get("generate_full_logits", True)
self.generate_full_logits = kwargs.get("generate_full_logits", False)
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)

self.max_seq_len = kwargs.get("max_seq_len", 128)
Expand Down
7 changes: 5 additions & 2 deletions examples/models/llava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,19 @@ def prefill_embedding(
result = torch.cat((embeds_before_img, image_embeds, embeds_after_img), dim=1)
return result

# prefill using the in house text_model of llama transformer
def prefill(
self,
prompt_before_image: torch.Tensor,
images: torch.Tensor,
prompt_after_image: torch.Tensor,
) -> torch.Tensor:
) -> (int, torch.Tensor):
"""Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image)
return self.text_model.forward(None, torch.tensor([0]), embeds)
# returns the prefilled token length too, because the text model generates one logits in each forward call.
return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds)

# reference prefill using the text model in HF
def prefill_ref(
self,
prompt_before_image: torch.Tensor,
Expand Down
6 changes: 5 additions & 1 deletion examples/models/llava/runner/llava_image_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class LlavaImagePrefiller : public ImagePrefiller {
* @param start_pos The starting position in KV cache of the input in the LLM
* @return logits of the image prefill.
*/
inline Result<exec_aten::Tensor> prefill(Image& image, int64_t start_pos = 0)
inline Result<exec_aten::Tensor> prefill(Image& image, int64_t& start_pos)
override {
ManagedTensor managed_images(
image.data.data(), {3, image.height, image.width}, ScalarType::Byte);
Expand All @@ -43,6 +43,10 @@ class LlavaImagePrefiller : public ImagePrefiller {
outputs_res[0].isTensor(),
"Non Tensor Output returned from executing image prefill");

// Update the start_pos, which is only available inside this function.
// outputs_res can have only one logits.
start_pos += image_encoder_outputs[0].toTensor().size(1);

return outputs_res[0].toTensor();
}

Expand Down
4 changes: 2 additions & 2 deletions examples/models/llava/runner/llava_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ Error LlavaRunner::generate(

// prefill images
for (auto& image : images) {
auto logits = ET_UNWRAP(image_prefiller_->prefill(image, pos));
pos += logits.size(1);
// pos is updated inside image prefill.
ET_UNWRAP(image_prefiller_->prefill(image, pos));
}

// prefill user prompt. No BOS because preset prompt already has it.
Expand Down
34 changes: 21 additions & 13 deletions examples/models/llava/test/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ def setUp(self):
)

def test_prefill_logits(self):
prefill_logits = self.llava.prefill(
# For efficiency, the implemented prefill function only outputs the last logits.
_, prefill_logits = self.llava.prefill(
self.prompt_before_image, self.resized, self.prompt_after_image
)
# The reference implementation in HF genetates the full logits. Get the last one.
prefill_logits_ref = self.llava.prefill_ref(
self.prompt_before_image, self.resized, self.prompt_after_image
)[0]
)[0][:, -1, :]
self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2))

def test_generated_output(self):
Expand All @@ -62,11 +64,11 @@ def test_generated_output(self):
)[0].strip()

# being tested, using llama_transformer
prefill_logits = self.llava.prefill(
context_len, prefill_logits = self.llava.prefill(
self.prompt_before_image, self.resized, self.prompt_after_image
)
context_len = prefill_logits.shape[1]
new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()]
# Always generate one token at a time.
new_tokens = [torch.argmax(prefill_logits).item()]
for i in range(4):
logits = self.llava.step(
torch.tensor([new_tokens[i]]), torch.tensor([context_len + i])
Expand All @@ -93,24 +95,27 @@ def test_llava_export(self):
pte_embeds_before_img = llava_module.run_method(
"token_embedding", (prompt_before_image,)
)[0]
pte_prefill_before_img = llava_module.run_method(
llava_module.run_method(
"text_model",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
)[0]
)

start_pos += pte_prefill_before_img.shape[1]
# Update the start_pos. start_pos is used in kv cache. The source of truth
# of the delta length is from the embeddings, not from the logits.
start_pos += pte_embeds_before_img.shape[1]

# pte prefill image
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
pte_prefill_img = llava_module.run_method(
llava_module.run_method(
"text_model",
(
torch.tensor([start_pos], dtype=torch.int64),
pte_embeds_img,
),
)[0]
)

start_pos += pte_prefill_img.shape[1]
# Update the logits for each prefill (kv cache) step.
start_pos += pte_embeds_img.shape[1]

# pte prefill prompt after img
pte_embeds_after_img = llava_module.run_method(
Expand All @@ -121,8 +126,11 @@ def test_llava_export(self):
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
)[0]

# Update the logits for each prefill (kv cache) step.
start_pos += pte_embeds_after_img.shape[1]

# being tested, using llama_transformer
new_tokens = [torch.argmax(pte_prefill_after_img[..., -1, :]).item()]
new_tokens = [torch.argmax(pte_prefill_after_img).item()]
# TODO: uncomment this line
# self.assertEquals(new_tokens[0], 1932) # When
for i in range(4):
Expand All @@ -134,7 +142,7 @@ def test_llava_export(self):
"text_model",
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
)[0]
new_tokens.append(torch.argmax(logits[..., -1, :]).item())
new_tokens.append(torch.argmax(logits).item())

outputs = llava_model.tokenizer.batch_decode(
torch.tensor([new_tokens]), skip_special_tokens=True
Expand Down
5 changes: 3 additions & 2 deletions extension/llm/runner/image_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ class ImagePrefiller {
/**
* Prefill an LLM Module with the given image input.
* @param image The image input to the multimodal LLM.
* @param start_pos The starting position in KV cache of the input in the LLM
* @param start_pos The starting position in KV cache of the input in the LLM.
* It's passed as reference and will be updated inside this function.
* @return The next token of the LLM Module after prefill.
*/
virtual ::executorch::runtime::Result<exec_aten::Tensor> prefill(
Image& image,
int64_t start_pos = 0) = 0;
int64_t& start_pos) = 0;

virtual ::executorch::runtime::Error load() = 0;
virtual bool is_method_loaded() = 0;
Expand Down
29 changes: 19 additions & 10 deletions extension/llm/runner/text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,32 @@ class TextDecoderRunner {
* @return The next token.
*/
inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) {
ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D");
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);

switch (logits_tensor.scalar_type()) {
// If the logit_tensor rank is 3, the shape is [batch, seq_length,
// vocab_size], get the last logits, sample and return. Else the model
// outputs the last logit, directly sample and return.
case exec_aten::ScalarType::Float: {
float* logits = logits_tensor.mutable_data_ptr<float>();
float* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
float* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
}
return sampler_->sample(logits);
}
case exec_aten::ScalarType::Half: {
exec_aten::Half* logits =
logits_tensor.mutable_data_ptr<exec_aten::Half>();
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
}
return sampler_->sample(logits);
}
default:
ET_CHECK_MSG(
Expand Down
5 changes: 0 additions & 5 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
ET_CHECK_MSG(
outputs_res.get().size(1) == num_prompt_tokens,
"Expected number of output tokens %d does not match returned value %zu.",
num_prompt_tokens,
outputs_res.get().size(1));
// insert new token into prompt_tokens
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
uint64_t prev = prompt_tokens[0];
Expand Down
Loading