Skip to content

Commit

Permalink
Update on "[Test only] multimodal android binding"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
larryliu0820 committed Jul 25, 2024
2 parents f9ee05f + 2124930 commit c894dd7
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 34 deletions.
2 changes: 1 addition & 1 deletion examples/models/llava/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,6 @@ int32_t main(int32_t argc, char** argv) {
.width = static_cast<int32_t>(image_tensor.size(2)),
.height = static_cast<int32_t>(image_tensor.size(1))};
// generate
runner.generate(image, prompt, seq_len);
runner.generate({image}, prompt, seq_len);
return 0;
}
27 changes: 14 additions & 13 deletions examples/models/llava/runner/multimodal_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ Result<torch::executor::Tensor> MultiModalRunner::step(
}

Error MultiModalRunner::generate(
Image& image,
std::vector<Image> images,
const std::string& prompt,
int32_t seq_len,
std::function<void(const std::string&)> token_callback,
Expand Down Expand Up @@ -313,18 +313,19 @@ Error MultiModalRunner::generate(
ET_LOG(Info, "pos: %d", pos);

// prefill image
auto image_prefill_res = prefill_image(image, pos);
ET_LOG(
Info,
"prefill image res sizes(0): %zu, sizes(1): %zu, sizes(2): %zu",
image_prefill_res.get().size(0),
image_prefill_res.get().size(1),
image_prefill_res.get().size(2));

// update pos to include prefilled image tokens
pos += image_prefill_res.get().size(1);
ET_LOG(Info, "pos: %d", pos);

for (auto image : images) {
auto image_prefill_res = prefill_image(image, pos);
ET_LOG(
Info,
"prefill image res sizes(0): %zu, sizes(1): %zu, sizes(2): %zu",
image_prefill_res.get().size(0),
image_prefill_res.get().size(1),
image_prefill_res.get().size(2));

// update pos to include prefilled image tokens
pos += image_prefill_res.get().size(1);
ET_LOG(Info, "pos: %d", pos);
}
// prefill prompt. Do not append bos because preset prompt has it.
auto prompt_prefill_res = prefill_prompt(prompt, pos, false, token_callback);
ET_LOG(
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llava/runner/multimodal_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class MultiModalRunner {
std::function<void(const std::string&)> token_callback = {});

Error generate(
Image& image,
std::vector<Image> images,
const std::string& prompt,
int32_t seq_len = 1024,
std::function<void(const std::string&)> token_callback = {},
Expand Down
46 changes: 34 additions & 12 deletions examples/models/llava/test_pte.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,65 @@
from executorch.extension.pybindings.portable_lib import _load_for_executorch, _get_operator_names
import sys

import torch

from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # noqa
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # noqa
from executorch.examples.models.llava.model import LlavaModel
from executorch.extension.pybindings.portable_lib import _load_for_executorch

import torch
import sys

def main():
args = sys.argv[1:]
llava_module = _load_for_executorch(args[0])

llava_model = LlavaModel()

prompt_before_image, resized, prompt_after_image = llava_model.get_inputs_for_prefill()
prompt_before_image, resized, prompt_after_image = (
llava_model.get_inputs_for_prefill()
)

start_pos = 0
# pte prefill prompt before img
pte_embeds_before_img = llava_module.run_method("token_embedding", (prompt_before_image,))[0]
pte_prefill_before_img = llava_module.run_method("text_model", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img))[0]
pte_embeds_before_img = llava_module.run_method(
"token_embedding", (prompt_before_image,)
)[0]
pte_prefill_before_img = llava_module.run_method(
"text_model",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
)[0]
print(pte_prefill_before_img)

start_pos += pte_prefill_before_img.shape[1]

# pte prefill image
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
pte_prefill_img = llava_module.run_method("text_model", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_img,))[0]
pte_prefill_img = llava_module.run_method(
"text_model",
(
torch.tensor([start_pos], dtype=torch.int64),
pte_embeds_img,
),
)[0]
print(pte_prefill_img)

start_pos += pte_prefill_img.shape[1]

# pte prefill prompt after img
pte_embeds_after_img = llava_module.run_method("token_embedding", (prompt_after_image,))[0]
pte_prefill_after_img = llava_module.run_method("text_model", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img))[0]
pte_embeds_after_img = llava_module.run_method(
"token_embedding", (prompt_after_image,)
)[0]
pte_prefill_after_img = llava_module.run_method(
"text_model",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
)[0]
print(pte_prefill_after_img)

# being tested, using llama_transformer
new_tokens = [torch.argmax(pte_prefill_after_img[..., -1, :]).item()]
for i in range(4):
print(i, llava_model.tokenizer.decode(new_tokens[i]))
token_embeds = llava_module.run_method("token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),))[0]
token_embeds = llava_module.run_method(
"token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
)[0]
logits = llava_module.run_method(
"text_model",
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
Expand All @@ -50,5 +71,6 @@ def main():
)[0].strip()
print(outputs)


if __name__ == "__main__":
main()
18 changes: 11 additions & 7 deletions extension/android/jni/jni_layer_multimodal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,19 @@ class ExecuTorchMultiModalJni
jint startPos,
facebook::jni::alias_ref<ExecuTorchMultiModalCallbackJni> callback) {
auto image_size = image->size();
std::vector<jint> image_data_jint(image_size);
std::vector<uint8_t> image_data(image_size);
image->getRegion(0, image_size, image_data_jint.data());
for (int i = 0; i < image_size; i++) {
image_data[i] = image_data_jint[i];
std::vector<Image> images;
if (image_size != 0) {
std::vector<jint> image_data_jint(image_size);
std::vector<uint8_t> image_data(image_size);
image->getRegion(0, image_size, image_data_jint.data());
for (int i = 0; i < image_size; i++) {
image_data[i] = image_data_jint[i];
}
Image image_runner{image_data, width, height, channels};
images.push_back(image_runner);
}
Image image_runner{image_data, width, height, channels};
runner_->generate(
image_runner,
images,
prompt->toStdString(),
1024,
[callback](std::string result) { callback->onResult(result); },
Expand Down

0 comments on commit c894dd7

Please sign in to comment.