Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llava][4/N] Allow llama_transformer take embedding as optional argument #4257

Closed
wants to merge 7 commits into from
14 changes: 10 additions & 4 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,13 +501,19 @@ def __init__(self, params: ModelArgs):

def forward(
self,
tokens: torch.Tensor,
tokens: Optional[torch.LongTensor] = None, # tokens
input_pos: Optional[
torch.Tensor
torch.LongTensor
] = None, # Scalar tensor indicating size of window of the caches
h: Optional[torch.FloatTensor] = None, # embeddings
) -> torch.Tensor:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
if (tokens is None) ^ (h is not None):
raise ValueError(
"You cannot specify both tokens and h at the same time, and must specify either one"
)
if tokens is not None and h is None:
h = self.tok_embeddings(tokens)
seqlen = h.shape[1]

if self.use_kv_cache:
assert (
Expand Down
Loading