Skip to content

Commit

Permalink
[llava][4/N] Allow llama_transformer take embedding as optional argument
Browse files Browse the repository at this point in the history
Summary: Llava run token embedding separately and combine that with
image embeddings as input to text transformer. This PR adds support for
that.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 9cf46d027d5502ec03f65c597c2769b9cb2e6506
Pull Request resolved: #4257
  • Loading branch information
larryliu0820 committed Jul 12, 2024
1 parent 71cd798 commit 2357d0b
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,19 @@ def __init__(self, params: ModelArgs):

def forward(
self,
tokens: torch.Tensor,
tokens: Optional[torch.LongTensor] = None, # tokens
input_pos: Optional[
torch.Tensor
torch.LongTensor
] = None, # Scalar tensor indicating size of window of the caches
h: Optional[torch.FloatTensor] = None, # embeddings
) -> torch.Tensor:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
if (tokens is None) ^ (h is not None):
raise ValueError(
"You cannot specify both tokens and h at the same time, and must specify either one"
)
if tokens is not None and h is None:
h = self.tok_embeddings(tokens)
_bsz, seqlen = h.shape

if self.use_kv_cache:
assert (
Expand Down

0 comments on commit 2357d0b

Please sign in to comment.