Skip to content

Commit

Permalink
Minor changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
shawntan committed Aug 19, 2024
1 parent 40b2891 commit a5c7b7a
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,7 @@ def __init__(
self.norm = PPMissingLayer()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
return inputs_embeds
return self.embed_tokens(input_ids)

def forward(
self,
Expand All @@ -317,7 +316,7 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

hidden_states = hidden_states * self.config.embedding_multiplier
hidden_states *= self.config.embedding_multiplier

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
Expand Down Expand Up @@ -426,10 +425,10 @@ def forward(
return model_output

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
logits = logits / self.config.logits_scaling
logits /= self.config.logits_scaling
return logits

def sample(
Expand Down

0 comments on commit a5c7b7a

Please sign in to comment.