diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 5f1d8fb666..81ba0fd014 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -501,13 +501,19 @@ def __init__(self, params: ModelArgs): def forward( self, - tokens: torch.Tensor, + tokens: Optional[torch.LongTensor] = None, # tokens input_pos: Optional[ - torch.Tensor + torch.LongTensor ] = None, # Scalar tensor indicating size of window of the caches + h: Optional[torch.FloatTensor] = None, # embeddings ) -> torch.Tensor: - _bsz, seqlen = tokens.shape - h = self.tok_embeddings(tokens) + if (tokens is None) ^ (h is not None): + raise ValueError( + "You cannot specify both tokens and h at the same time, and must specify either one" + ) + if tokens is not None and h is None: + h = self.tok_embeddings(tokens) + seqlen = h.shape[1] if self.use_kv_cache: assert (