Skip to content

Commit

Permalink
[llava] Quantize embedding
Browse files Browse the repository at this point in the history
Differential Revision: D61939945

Pull Request resolved: #4955
  • Loading branch information
larryliu0820 authored Aug 29, 2024
1 parent 4a8b8ee commit 1774638
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .ci/scripts/test_llava.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ run_and_verify() {
RESULT=$(cat result.txt)
# set the expected prefix to be the same as prompt because there's a bug in sdpa_with_kv_cache that causes <unk> tokens.
if [[ "$(uname)" == "Darwin" ]]; then
EXPECTED_PREFIX="ASSISTANT: image captures a basketball game in progress on a basketball court. There are several players on the court, with one player in the foreground holding a basketball, and"
EXPECTED_PREFIX="ASSISTANT: image captures a basketball game in progress, with several players on the court. One of the players is dribbling the ball, while the others are in various"
else
# set the expected prefix to be the same as prompt because there's a bug in sdpa_with_kv_cache that causes <unk> tokens.
EXPECTED_PREFIX="ASSISTANT:"
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama2/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
vocab_size=child.weight.shape[0],
embedding_dim=child.weight.shape[1],
group_size=group_size,
dtype=child.weight.dtype,
packed=packed,
),
)
Expand Down
15 changes: 12 additions & 3 deletions examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_quantizer_and_quant_params,
)
from executorch.examples.models.llama2.source_transformation.quantize import (
EmbeddingQuantHandler,
get_quant_weight_transform,
)
from executorch.examples.models.llama2.source_transformation.sdpa import (
Expand Down Expand Up @@ -157,12 +158,20 @@ def forward(self, images):


def export_token_embedding(llava, prompt):
embed = llava.embed_tokens
token_dim_1 = Dim("token_dim_1", min=2, max=3518)
def quant_embedding(model):
return EmbeddingQuantHandler(
model,
bitwidth=8,
group_size=32,
packed=False,
).quantized_model()

quantized_token_embed = quant_embedding(llava.model_.language_model.model)
token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len)
dynamic_shapes = [{1: token_dim_1}]
with torch.no_grad():
token_embedding_ep = torch.export.export(
embed, (prompt,), dynamic_shapes=dynamic_shapes
quantized_token_embed.embed_tokens, (prompt,), dynamic_shapes=dynamic_shapes
)
return token_embedding_ep

Expand Down
14 changes: 3 additions & 11 deletions examples/models/llava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from executorch.examples.models.model_base import EagerModelBase
from PIL import Image

from torch import nn
from torch.export import Dim
from torchvision.transforms.v2 import functional as F

Expand Down Expand Up @@ -60,11 +59,6 @@ def __init__(
use_hf_rope=True,
max_seq_len=max_seq_len,
)
self.embed_tokens = nn.Embedding(
self.model_.config.text_config.vocab_size,
self.model_.config.text_config.hidden_size,
self.model_.config.pad_token_id,
)
self.text_model = Transformer(self.text_model_args)
# use custom op for SDPA.
if use_sdpa_with_kv_cache_op:
Expand All @@ -75,11 +69,6 @@ def __init__(
strict=False,
assign=True,
)
self.embed_tokens.load_state_dict(
state_dict=self.model_.language_model.model.embed_tokens.state_dict(),
strict=True,
assign=True,
)

def _translate_state_dict_for_text_model(self) -> Dict[str, Any]:
state_dict = self.model_.language_model.state_dict()
Expand Down Expand Up @@ -133,6 +122,9 @@ def _feature_select(self, image_outputs):
def get_model(self):
return self.model_.get_model()

def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
return self.model_.language_model.model.embed_tokens(tokens)

def encode_images(self, images: torch.Tensor) -> torch.Tensor:
images = images.to(dtype=self.model_.dtype)
if type(images) is list:
Expand Down
11 changes: 5 additions & 6 deletions examples/models/llava/test/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
# import order matters. We need to import portable_lib first since it contains the static op registry
# which will be used in the import of custom ops. Otherwise, the registration of custom ops will be skipped.
# I don't know how to mute UFMT so I'm just using if True: to avoid the error
if True:
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa: F401

from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.kernels import quantized # noqa # usort: skip

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down
10 changes: 4 additions & 6 deletions examples/models/llava/test/test_pte.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
from PIL import Image

# Custom ops has to be loaded after portable_lib.
# I don't know how to stop UFMT so I'm just using if True: to avoid lint error
if True:
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa

from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.kernels import quantized # noqa # usort: skip

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.DEBUG, format=FORMAT)
Expand Down Expand Up @@ -54,7 +52,7 @@ def main():
)[0]
print(pte_prefill_before_img)

start_pos += pte_prefill_before_img.shape[1]
start_pos += prompt_before_image.shape[1]

# pte prefill image
logging.warning("Image encoder started")
Expand All @@ -71,7 +69,7 @@ def main():
logging.warning("Image token prefill finished")
print(pte_prefill_img)

start_pos += pte_prefill_img.shape[1]
start_pos += pte_embeds_img.shape[1]

# pte prefill prompt after img
logging.warning("Text token prefill started")
Expand Down

0 comments on commit 1774638

Please sign in to comment.