diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 99544426fd3..3c688133b01 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -524,6 +524,9 @@ def forward( input_pos, ) + # Only the last logit is used for the new generated token + h = h[:, -1, :] + h = self.norm(h) logits = self.output(h) diff --git a/extension/llm/runner/text_decoder_runner.h b/extension/llm/runner/text_decoder_runner.h index 31b8c1b983f..6019e7ce481 100644 --- a/extension/llm/runner/text_decoder_runner.h +++ b/extension/llm/runner/text_decoder_runner.h @@ -63,23 +63,15 @@ class TextDecoderRunner { * @return The next token. */ inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) { - ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D"); - auto num_tokens = logits_tensor.size(1); - auto vocab_size = logits_tensor.size(2); - switch (logits_tensor.scalar_type()) { case ScalarType::Float: { float* logits = logits_tensor.mutable_data_ptr(); - float* logits_last = logits; - logits_last += (num_tokens - 1) * vocab_size; - return sampler_->sample(logits_last); + return sampler_->sample(logits); } case ScalarType::Half: { exec_aten::Half* logits = logits_tensor.mutable_data_ptr(); - exec_aten::Half* logits_last = logits; - logits_last += (num_tokens - 1) * vocab_size; - return sampler_->sample(logits_last); + return sampler_->sample(logits); } default: ET_CHECK_MSG( diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index a5aa668e73a..fa084cbe016 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -50,11 +50,6 @@ Result TextPrefiller::prefill( ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); ET_LOG( Info, "Prefill token result numel(): %zu", outputs_res.get().numel()); - ET_CHECK_MSG( - outputs_res.get().size(1) == num_prompt_tokens, - "Expected number of output tokens %d does not match returned value %zu.", - num_prompt_tokens, - outputs_res.get().size(1)); // insert new token into prompt_tokens // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) uint64_t prev = prompt_tokens[0];