Skip to content

Commit

Permalink
Allow llama_transformer take embedding as optional argument (pytorch#…
Browse files Browse the repository at this point in the history
…4257)

Summary:
Pull Request resolved: pytorch#4257

Llava run token embedding separately and combine that with
image embeddings as input to text transformer. This PR adds support for
that.

Reviewed By: tarun292

Differential Revision: D59759977

fbshipit-source-id: 45fe8af0bdfc090cb135d3d3f0b8ff3a18fa1b3c
  • Loading branch information
larryliu0820 authored and facebook-github-bot committed Jul 16, 2024
1 parent 58b1b18 commit 8a1589d
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,13 +501,19 @@ def __init__(self, params: ModelArgs):

def forward(
self,
tokens: torch.Tensor,
tokens: Optional[torch.LongTensor] = None, # tokens
input_pos: Optional[
torch.Tensor
torch.LongTensor
] = None, # Scalar tensor indicating size of window of the caches
h: Optional[torch.FloatTensor] = None, # embeddings
) -> torch.Tensor:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
if (tokens is None) ^ (h is not None):
raise ValueError(
"You cannot specify both tokens and h at the same time, and must specify either one"
)
if tokens is not None and h is None:
h = self.tok_embeddings(tokens)
seqlen = h.shape[1]

if self.use_kv_cache:
assert (
Expand Down

0 comments on commit 8a1589d

Please sign in to comment.