Skip to content

Commit

Permalink
[Test only] multimodal android binding
Browse files Browse the repository at this point in the history
  • Loading branch information
kirklandsign committed Jul 23, 2024
1 parent caeeb96 commit d1ea23e
Show file tree
Hide file tree
Showing 7 changed files with 345 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/android.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
strategy:
matrix:
tokenizer: [bpe, tiktoken]
tokenizer: [bpe]
with:
runner: linux.2xlarge
docker-image: executorch-ubuntu-22.04-clang12-android
Expand Down
16 changes: 16 additions & 0 deletions build/build_android_llm_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,27 @@ build_android_native_library() {

cmake --build "${CMAKE_OUT}"/examples/models/llama2 -j "${CMAKE_JOBS}" --config Release

cmake examples/models/llava \
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
-DANDROID_ABI="$ANDROID_ABI" \
-DANDROID_PLATFORM=android-23 \
-DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \
-DEXECUTORCH_USE_TIKTOKEN="${EXECUTORCH_USE_TIKTOKEN}" \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_XNNPACK=ON \
-DCMAKE_BUILD_TYPE=Release \
-B"${CMAKE_OUT}"/examples/models/llava

cmake --build "${CMAKE_OUT}"/examples/models/llava -j "${CMAKE_JOBS}" --config Release

cmake extension/android \
-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \
-DANDROID_ABI="${ANDROID_ABI}" \
-DANDROID_PLATFORM=android-23 \
-DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \
-DEXECUTORCH_BUILD_LLAMA_JNI=ON \
-DEXECUTORCH_BUILD_MULTIMODAL_JNI=ON \
-DEXECUTORCH_USE_TIKTOKEN="${EXECUTORCH_USE_TIKTOKEN}" \
-DCMAKE_BUILD_TYPE=Release \
-B"${CMAKE_OUT}"/extension/android
Expand All @@ -89,6 +104,7 @@ build_aar() {
# Zip all necessary files into the AAR file
zip -r executorch.aar libs jni/*/libexecutorch.so AndroidManifest.xml
zip -r executorch-llama.aar libs jni/*/libexecutorch_llama_jni.so AndroidManifest.xml
zip -r executorch-multimodal.aar libs jni/*/libexecutorch_multimodal_jni.so AndroidManifest.xml
popd
}

Expand Down
2 changes: 1 addition & 1 deletion examples/models/llava/runner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ if(EXECUTORCH_USE_TIKTOKEN)
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizer/tiktoken.cpp
)
list(APPEND _multimodal_runner__srcs
${CMAKE_CURRENT_SOURCE_DIR}/../tokenizer/llama_tiktoken.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../llama2/tokenizer/llama_tiktoken.cpp
)
set(_preprocessor_flag -DET_USE_TIKTOKEN)
endif()
Expand Down
63 changes: 63 additions & 0 deletions extension/android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,66 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
target_link_libraries(executorch_llama_jni re2::re2)
endif()
endif()

if(EXECUTORCH_BUILD_MULTIMODAL_JNI)
set(MULTIMODAL_RUNNER_PATH
${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llava/runner/libmultimodal_runner.a
)
add_library(multimodal_runner STATIC IMPORTED)
set_property(
TARGET multimodal_runner PROPERTY IMPORTED_LOCATION ${MULTIMODAL_RUNNER_PATH}
)

target_link_options_shared_lib(quantized_ops_lib)

if(TARGET pthreadpool)
set(MULTIMODAL_JNI_SRCS jni/jni_layer_multimodal.cpp
../../backends/xnnpack/threadpool/cpuinfo_utils.cpp
)
else()
set(MULTIMODAL_JNI_SRCS jni/jni_layer_multimodal.cpp)
endif()
add_library(executorch_multimodal_jni SHARED ${MULTIMODAL_JNI_SRCS})
if(TARGET pthreadpool)
target_compile_definitions(executorch_multimodal_jni PRIVATE ET_USE_THREADPOOL=1)
target_include_directories(
executorch_multimodal_jni
PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/../../backends/xnnpack/third-party/cpuinfo/include
)
target_include_directories(
executorch_multimodal_jni
PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/../../backends/xnnpack/third-party/pthreadpool/include
)
endif()
target_include_directories(
executorch_multimodal_jni PRIVATE ${_common_include_directories}
)
target_link_libraries(
executorch_multimodal_jni
${link_libraries}
multimodal_runner
custom_ops
cpublas
eigen_blas
quantized_kernels
quantized_ops_lib
)
target_compile_options(executorch_multimodal_jni PUBLIC ${_common_compile_options})
if(EXECUTORCH_USE_TIKTOKEN)
set(ABSL_ENABLE_INSTALL ON)
set(_pic_flag ${CMAKE_POSITION_INDEPENDENT_CODE})
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
add_subdirectory(
${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/third-party/abseil-cpp
${CMAKE_CURRENT_BINARY_DIR}/abseil-cpp
)
add_subdirectory(
${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/third-party/re2
${CMAKE_CURRENT_BINARY_DIR}/re2
)
set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag})
target_link_libraries(executorch_multimodal_jni re2::re2)
endif()
endif()
179 changes: 179 additions & 0 deletions extension/android/jni/jni_layer_multimodal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cassert>
#include <chrono>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>

#include <executorch/examples/models/llava/runner/multimodal_runner.h>
#include <executorch/runtime/platform/log.h>
#include <executorch/runtime/platform/platform.h>
#include <executorch/runtime/platform/runtime.h>

#if defined(ET_USE_THREADPOOL)
#include <executorch/backends/xnnpack/threadpool/cpuinfo_utils.h>
#include <executorch/backends/xnnpack/threadpool/threadpool.h>
#endif

#include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h>

#ifdef __ANDROID__
#include <android/log.h>

// For Android, write to logcat
void et_pal_emit_log_message(
et_timestamp_t timestamp,
et_pal_log_level_t level,
const char* filename,
const char* function,
size_t line,
const char* message,
size_t length) {
int android_log_level = ANDROID_LOG_UNKNOWN;
if (level == 'D') {
android_log_level = ANDROID_LOG_DEBUG;
} else if (level == 'I') {
android_log_level = ANDROID_LOG_INFO;
} else if (level == 'E') {
android_log_level = ANDROID_LOG_ERROR;
} else if (level == 'F') {
android_log_level = ANDROID_LOG_FATAL;
}

__android_log_print(android_log_level, "MULTIMODAL", "%s", message);
}
#endif

using namespace torch::executor;

namespace executorch_jni {

class ExecuTorchMultiModalCallbackJni
: public facebook::jni::JavaClass<ExecuTorchMultiModalCallbackJni> {
public:
constexpr static const char* kJavaDescriptor =
"Lorg/pytorch/executorch/MultiModalCallback;";

void onResult(std::string result) const {
static auto cls = ExecuTorchMultiModalCallbackJni::javaClassStatic();
static const auto method =
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");
facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result);
method(self(), s);
}

void onStats(const MultiModalRunner::Stats& result) const {
static auto cls = ExecuTorchMultiModalCallbackJni::javaClassStatic();
static const auto method = cls->getMethod<void(jfloat)>("onStats");
double eval_time =
(double)(result.inference_end_ms - result.prompt_eval_end_ms);

float tps = result.num_generated_tokens / eval_time *
result.SCALING_FACTOR_UNITS_PER_SECOND;

method(self(), tps);
}
};

class ExecuTorchMultiModalJni
: public facebook::jni::HybridClass<ExecuTorchMultiModalJni> {
private:
friend HybridBase;
std::unique_ptr<MultiModalRunner> runner_;

public:
constexpr static auto kJavaDescriptor =
"Lorg/pytorch/executorch/MultiModalModule;";

static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> model_path,
facebook::jni::alias_ref<jstring> tokenizer_path,
jfloat temperature) {
return makeCxxInstance(model_path, tokenizer_path, temperature);
}

ExecuTorchMultiModalJni(
facebook::jni::alias_ref<jstring> model_path,
facebook::jni::alias_ref<jstring> tokenizer_path,
jfloat temperature) {
#if defined(ET_USE_THREADPOOL)
// Reserve 1 thread for the main thread.
uint32_t num_performant_cores =
torch::executorch::cpuinfo::get_num_performant_cores() - 1;
if (num_performant_cores > 0) {
ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores);
torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool(
num_performant_cores);
}
#endif

runner_ = std::make_unique<MultiModalRunner>(
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str(),
temperature);
}

jint generate(
facebook::jni::alias_ref<jintArray> image,
jint width,
jint height,
jint channels,
facebook::jni::alias_ref<jstring> prompt,
jint startPos,
facebook::jni::alias_ref<ExecuTorchMultiModalCallbackJni> callback) {
auto image_size = image->size();
std::vector<jint> image_data_jint(image_size);
std::vector<uint8_t> image_data(image_size);
image->getRegion(0, image_size, image_data_jint.data());
for (int i = 0; i < image_size; i++) {
image_data[i] = image_data_jint[i];
}
Image image_runner{image_data, width, height, channels};
runner_->generate(
image_runner,
prompt->toStdString(),
startPos,
1024,
[callback](std::string result) { callback->onResult(result); },
[callback](const MultiModalRunner::Stats& result) {
callback->onStats(result);
});
return 0;
}

void stop() {
runner_->stop();
}

jint load() {
return static_cast<jint>(runner_->load());
}

static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", ExecuTorchMultiModalJni::initHybrid),
makeNativeMethod("generate", ExecuTorchMultiModalJni::generate),
makeNativeMethod("stop", ExecuTorchMultiModalJni::stop),
makeNativeMethod("load", ExecuTorchMultiModalJni::load),
});
}
};

} // namespace executorch_jni

JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(
vm, [] { executorch_jni::ExecuTorchMultiModalJni::registerNatives(); });
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch;

import com.facebook.jni.annotations.DoNotStrip;

public interface MultiModalCallback {
/**
* Called when a new result is available from JNI. Users will keep getting onResult() invocations
* until generate() finishes.
*
* @param result Last generated token
*/
@DoNotStrip
public void onResult(String result);

/**
* Called when the statistics for the generate() is available.
*
* @param tps Tokens/second for generated tokens.
*/
@DoNotStrip
public void onStats(float tps);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch;

import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;

public class MultiModalModule {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("executorch_multimodal_jni");
}

private final HybridData mHybridData;

@DoNotStrip
private static native HybridData initHybrid(
String modulePath, String tokenizerPath, float temperature);

/** Constructs a MultiModal Module for a model with given path, tokenizer, and temperature. */
public MultiModalModule(String modulePath, String tokenizerPath, float temperature) {
mHybridData = initHybrid(modulePath, tokenizerPath, temperature);
}

public void resetNative() {
mHybridData.resetNative();
}

/**
* Start generating tokens from the module.
*
* @param prompt Input prompt
* @param MultiModalCallback callback object to receive results.
*/
@DoNotStrip
public native int generate(int[] image, int width, int height, int channels, String prompt, int startPos, MultiModalCallback MultiModalCallback);

/** Stop current generate() before it finishes. */
@DoNotStrip
public native void stop();

/** Force loading the module. Otherwise the model is loaded during first generate(). */
@DoNotStrip
public native int load();
}

0 comments on commit d1ea23e

Please sign in to comment.