Skip to content

Commit

Permalink
[llava][5/N] Add Llava model definition
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 1eb1b84cd332d130d62e4a0a8ea363c3ca727fd2
Pull Request resolved: #4259
  • Loading branch information
larryliu0820 committed Jul 15, 2024
1 parent 40ab5ce commit 173b0a7
Show file tree
Hide file tree
Showing 3 changed files with 387 additions and 24 deletions.
5 changes: 4 additions & 1 deletion examples/models/llava/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ pip install protobuf
# Reinstall bitsandbytes to make it compatible.
pip install bitsandbytes -I

# numpy needs to be pin to 1.24. 1.26.4 will error out
pip install numpy==1.24

# The deps of llava can have different versions than deps of ExecuTorch.
# For example, torch version required from llava is older than ExecuTorch.
# To make both work, recover ExecuTorch's original dependencies by rerunning
# the install_requirements.sh.
bash -x ./install_requirements.sh
bash -x ./install_requirements.sh --pybind xnnpack
63 changes: 63 additions & 0 deletions examples/models/llava/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import math

import os
import re

from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

import requests
import torch

import torchvision
from executorch.examples.models.llama2.llama_transformer import (
FeedForward,
KVCache,
ModelArgs,
RMSNorm,
SDPA,
)
from model import LlavaModel
from PIL import Image

from torch import nn
from torch.nn import functional as F
from torchvision.transforms import v2
from torchvision.transforms._functional_tensor import resize

from transformers import LlamaForCausalLM


def main():

llava_model = LlavaModel()
llava = llava_model.get_eager_model()

llava = llava.to(torch.float32) # overflow error with fp16
inputs = llava_model.get_example_inputs()

prefill_logits = llava.prefill(*inputs)
# prefill_logits_ref = llava.prefill_ref(prompt_before_image, imagr, prompt_after_image)[0]
# prefill_logits = llava.prefill(prompt_before_image, imagr, prompt_after_image)
context_len = prefill_logits.shape[1]
print(prefill_logits)
# first token
new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()]
# print(tokenizer.decode(new_tokens))
for i in range(llava_model.args.max_new_tokens):
print(i, llava_model.tokenizer.decode(new_tokens[i]))
logits = llava.forward(
torch.tensor([new_tokens[i]]), torch.tensor([context_len + i])
)
new_tokens.append(torch.argmax(logits[-1, :]))


if __name__ == "__main__":
main()
Loading

0 comments on commit 173b0a7

Please sign in to comment.