From cb7ef7a172640bb24b99690227bc335bbe994bfa Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 12 Jul 2024 15:43:58 -0700 Subject: [PATCH] [llava][4/N] Allow llama_transformer take embedding as optional argument Summary: Llava run token embedding separately and combine that with image embeddings as input to text transformer. This PR adds support for that. Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- examples/models/llama2/llama_transformer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 4fa832eccf..0135aeefbb 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -468,13 +468,19 @@ def __init__(self, params: ModelArgs): def forward( self, - tokens: torch.Tensor, + tokens: Optional[torch.LongTensor,] = None, # tokens input_pos: Optional[ - torch.Tensor + torch.LongTensor ] = None, # Scalar tensor indicating size of window of the caches + h: Optional[torch.FloatTensor,] = None, # embeddings ) -> torch.Tensor: - _bsz, seqlen = tokens.shape - h = self.tok_embeddings(tokens) + if (tokens is None) ^ (h is not None): + raise ValueError( + "You cannot specify both tokens and h at the same time, and must specify either one" + ) + if tokens is not None and h is None: + h = self.tok_embeddings(tokens) + _bsz, seqlen = h.shape if self.use_kv_cache: assert (