Skip to content

Commit

Permalink
Reduce the memory usage of logistics from O(context_length) to O(1)
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Yuan committed Aug 13, 2024
1 parent 728a29d commit 1e0ff89
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 18 deletions.
13 changes: 9 additions & 4 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,8 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
# to get free perf gain.
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_causal_mask)
return (
_load_llama_model(

llm_edge_manager = (_load_llama_model(
modelname=modelname,
checkpoint=checkpoint_path,
checkpoint_dir=checkpoint_dir,
Expand All @@ -413,8 +413,13 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
)
.set_output_dir(output_dir_path)
.to_dtype(dtype_override)
.source_transform(transforms)
)
.source_transform(transforms))

inputs = llm_edge_manager.example_inputs

output = llm_edge_manager.model.forward(*inputs)

return llm_edge_manager


def get_quantizer_and_quant_params(args):
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,5 +526,6 @@ def forward(

h = self.norm(h)

h = h[:, -1, :]
logits = self.output(h)
return logits
14 changes: 5 additions & 9 deletions extension/llm/runner/text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,19 @@ class TextDecoderRunner {
* @return The next token.
*/
inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) {
ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D");
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
// ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D");
// auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(1);

switch (logits_tensor.scalar_type()) {
case ScalarType::Float: {
float* logits = logits_tensor.mutable_data_ptr<float>();
float* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
return sampler_->sample(logits);
}
case ScalarType::Half: {
exec_aten::Half* logits =
logits_tensor.mutable_data_ptr<exec_aten::Half>();
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
return sampler_->sample(logits);
}
default:
ET_CHECK_MSG(
Expand Down
10 changes: 5 additions & 5 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ Result<uint64_t> TextPrefiller::prefill(
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
ET_CHECK_MSG(
outputs_res.get().size(1) == num_prompt_tokens,
"Expected number of output tokens %d does not match returned value %zu.",
num_prompt_tokens,
outputs_res.get().size(1));
// ET_CHECK_MSG(
// outputs_res.get().size(1) == num_prompt_tokens,
// "Expected number of output tokens %d does not match returned value %zu.",
// num_prompt_tokens,
// outputs_res.get().size(1));
// insert new token into prompt_tokens
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
uint64_t prev = prompt_tokens[0];
Expand Down

0 comments on commit 1e0ff89

Please sign in to comment.