Skip to content

Commit

Permalink
Reduce the memory usage of logits from O(context_length) to O(1)
Browse files Browse the repository at this point in the history
Differential Revision: D61246566

Pull Request resolved: pytorch#4688
  • Loading branch information
iseeyuan authored Aug 23, 2024
1 parent 6c26a87 commit 11e8ed3
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 39 deletions.
4 changes: 2 additions & 2 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def build_args_parser() -> argparse.ArgumentParser:
"--generate_full_logits",
action="store_true",
required=False,
default=True,
default=False,
help="Generate logits for all inputs.",
)
return parser
Expand Down Expand Up @@ -598,7 +598,7 @@ def _load_llama_model(
params_path: str,
use_kv_cache: bool = False,
use_sdpa_with_kv_cache: bool = False,
generate_full_logits: bool = True,
generate_full_logits: bool = False,
weight_type: WeightType = WeightType.LLAMA,
enable_dynamic_shape: bool = False,
verbose: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ModelArgs:
# Generate logits for all inputs. When it's True, it would take big memory usage
# at runtime. Enable it only necessary (e.g., use perplexity tools that requires
# logits for all input tokens.)
generate_full_logits: bool = True
generate_full_logits: bool = False
enable_dynamic_shape: bool = False # export model with dynamic shape support
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
rope_theta: Optional[float] = (
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, **kwargs):

self.use_kv_cache = kwargs.get("use_kv_cache", False)
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
self.generate_full_logits = kwargs.get("generate_full_logits", True)
self.generate_full_logits = kwargs.get("generate_full_logits", False)
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)

self.max_seq_len = kwargs.get("max_seq_len", 128)
Expand Down
7 changes: 5 additions & 2 deletions examples/models/llava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,19 @@ def prefill_embedding(
result = torch.cat((embeds_before_img, image_embeds, embeds_after_img), dim=1)
return result

# prefill using the in house text_model of llama transformer
def prefill(
self,
prompt_before_image: torch.Tensor,
images: torch.Tensor,
prompt_after_image: torch.Tensor,
) -> torch.Tensor:
) -> (int, torch.Tensor):
"""Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image)
return self.text_model.forward(None, torch.tensor([0]), embeds)
# returns the prefilled token length too, because the text model generates one logits in each forward call.
return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds)

# reference prefill using the text model in HF
def prefill_ref(
self,
prompt_before_image: torch.Tensor,
Expand Down
6 changes: 5 additions & 1 deletion examples/models/llava/runner/llava_image_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class LlavaImagePrefiller : public ImagePrefiller {
* @param start_pos The starting position in KV cache of the input in the LLM
* @return logits of the image prefill.
*/
inline Result<exec_aten::Tensor> prefill(Image& image, int64_t start_pos = 0)
inline Result<exec_aten::Tensor> prefill(Image& image, int64_t& start_pos)
override {
ManagedTensor managed_images(
image.data.data(), {3, image.height, image.width}, ScalarType::Byte);
Expand All @@ -43,6 +43,10 @@ class LlavaImagePrefiller : public ImagePrefiller {
outputs_res[0].isTensor(),
"Non Tensor Output returned from executing image prefill");

// Update the start_pos, which is only available inside this function.
// outputs_res can have only one logits.
start_pos += image_encoder_outputs[0].toTensor().size(1);

return outputs_res[0].toTensor();
}

Expand Down
4 changes: 2 additions & 2 deletions examples/models/llava/runner/llava_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ Error LlavaRunner::generate(

// prefill images
for (auto& image : images) {
auto logits = ET_UNWRAP(image_prefiller_->prefill(image, pos));
pos += logits.size(1);
// pos is updated inside image prefill.
ET_UNWRAP(image_prefiller_->prefill(image, pos));
}

// prefill user prompt. No BOS because preset prompt already has it.
Expand Down
34 changes: 21 additions & 13 deletions examples/models/llava/test/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ def setUp(self):
)

def test_prefill_logits(self):
prefill_logits = self.llava.prefill(
# For efficiency, the implemented prefill function only outputs the last logits.
_, prefill_logits = self.llava.prefill(
self.prompt_before_image, self.resized, self.prompt_after_image
)
# The reference implementation in HF genetates the full logits. Get the last one.
prefill_logits_ref = self.llava.prefill_ref(
self.prompt_before_image, self.resized, self.prompt_after_image
)[0]
)[0][:, -1, :]
self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2))

def test_generated_output(self):
Expand All @@ -62,11 +64,11 @@ def test_generated_output(self):
)[0].strip()

# being tested, using llama_transformer
prefill_logits = self.llava.prefill(
context_len, prefill_logits = self.llava.prefill(
self.prompt_before_image, self.resized, self.prompt_after_image
)
context_len = prefill_logits.shape[1]
new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()]
# Always generate one token at a time.
new_tokens = [torch.argmax(prefill_logits).item()]
for i in range(4):
logits = self.llava.step(
torch.tensor([new_tokens[i]]), torch.tensor([context_len + i])
Expand All @@ -93,24 +95,27 @@ def test_llava_export(self):
pte_embeds_before_img = llava_module.run_method(
"token_embedding", (prompt_before_image,)
)[0]
pte_prefill_before_img = llava_module.run_method(
llava_module.run_method(
"text_model",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
)[0]
)

start_pos += pte_prefill_before_img.shape[1]
# Update the start_pos. start_pos is used in kv cache. The source of truth
# of the delta length is from the embeddings, not from the logits.
start_pos += pte_embeds_before_img.shape[1]

# pte prefill image
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
pte_prefill_img = llava_module.run_method(
llava_module.run_method(
"text_model",
(
torch.tensor([start_pos], dtype=torch.int64),
pte_embeds_img,
),
)[0]
)

start_pos += pte_prefill_img.shape[1]
# Update the logits for each prefill (kv cache) step.
start_pos += pte_embeds_img.shape[1]

# pte prefill prompt after img
pte_embeds_after_img = llava_module.run_method(
Expand All @@ -121,8 +126,11 @@ def test_llava_export(self):
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
)[0]

# Update the logits for each prefill (kv cache) step.
start_pos += pte_embeds_after_img.shape[1]

# being tested, using llama_transformer
new_tokens = [torch.argmax(pte_prefill_after_img[..., -1, :]).item()]
new_tokens = [torch.argmax(pte_prefill_after_img).item()]
# TODO: uncomment this line
# self.assertEquals(new_tokens[0], 1932) # When
for i in range(4):
Expand All @@ -134,7 +142,7 @@ def test_llava_export(self):
"text_model",
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
)[0]
new_tokens.append(torch.argmax(logits[..., -1, :]).item())
new_tokens.append(torch.argmax(logits).item())

outputs = llava_model.tokenizer.batch_decode(
torch.tensor([new_tokens]), skip_special_tokens=True
Expand Down
5 changes: 3 additions & 2 deletions extension/llm/runner/image_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ class ImagePrefiller {
/**
* Prefill an LLM Module with the given image input.
* @param image The image input to the multimodal LLM.
* @param start_pos The starting position in KV cache of the input in the LLM
* @param start_pos The starting position in KV cache of the input in the LLM.
* It's passed as reference and will be updated inside this function.
* @return The next token of the LLM Module after prefill.
*/
virtual ::executorch::runtime::Result<exec_aten::Tensor> prefill(
Image& image,
int64_t start_pos = 0) = 0;
int64_t& start_pos) = 0;

virtual ::executorch::runtime::Error load() = 0;
virtual bool is_method_loaded() = 0;
Expand Down
29 changes: 19 additions & 10 deletions extension/llm/runner/text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,32 @@ class TextDecoderRunner {
* @return The next token.
*/
inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) {
ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D");
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);

switch (logits_tensor.scalar_type()) {
// If the logit_tensor rank is 3, the shape is [batch, seq_length,
// vocab_size], get the last logits, sample and return. Else the model
// outputs the last logit, directly sample and return.
case exec_aten::ScalarType::Float: {
float* logits = logits_tensor.mutable_data_ptr<float>();
float* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
float* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
}
return sampler_->sample(logits);
}
case exec_aten::ScalarType::Half: {
exec_aten::Half* logits =
logits_tensor.mutable_data_ptr<exec_aten::Half>();
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
}
return sampler_->sample(logits);
}
default:
ET_CHECK_MSG(
Expand Down
5 changes: 0 additions & 5 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
ET_CHECK_MSG(
outputs_res.get().size(1) == num_prompt_tokens,
"Expected number of output tokens %d does not match returned value %zu.",
num_prompt_tokens,
outputs_res.get().size(1));
// insert new token into prompt_tokens
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
uint64_t prev = prompt_tokens[0];
Expand Down

0 comments on commit 11e8ed3

Please sign in to comment.