Skip to content

Commit

Permalink
Merge branch 'main' into sync-llama-cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Jul 27, 2024
2 parents 96e8e9a + 7972f83 commit 246e207
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 12 deletions.
6 changes: 4 additions & 2 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,17 @@ function(build_library target_name cpu_flags)
target_compile_options(${target_name} PRIVATE -DRNLLAMA_ANDROID_ENABLE_LOGGING)
endif ()

if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")
# NOTE: If you want to debug the native code, you can uncomment if and endif
# Note that it will be extremely slow
# if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")
target_compile_options(${target_name} PRIVATE -O3 -DNDEBUG)
target_compile_options(${target_name} PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden)
target_compile_options(${target_name} PRIVATE -ffunction-sections -fdata-sections)

target_link_options(${target_name} PRIVATE -Wl,--gc-sections)
target_link_options(${target_name} PRIVATE -Wl,--exclude-libs,ALL)
target_link_options(${target_name} PRIVATE -flto)
endif ()
# endif ()
endfunction()

# Default target (no specific CPU features)
Expand Down
16 changes: 11 additions & 5 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public WritableMap completion(ReadableMap params) {
}
}

return doCompletion(
WritableMap result = doCompletion(
this.context,
// String prompt,
params.getString("prompt"),
Expand Down Expand Up @@ -191,6 +191,10 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("emit_partial_completion") ? params.getBoolean("emit_partial_completion") : false
)
);
if (result.hasKey("error")) {
throw new IllegalStateException(result.getString("error"));
}
return result;
}

public void stopCompletion() {
Expand All @@ -215,12 +219,14 @@ public String detokenize(ReadableArray tokens) {
return detokenize(this.context, toks);
}

public WritableMap embedding(String text) {
public WritableMap getEmbedding(String text) {
if (isEmbeddingEnabled(this.context) == false) {
throw new IllegalStateException("Embedding is not enabled");
}
WritableMap result = Arguments.createMap();
result.putArray("embedding", embedding(this.context, text));
WritableMap result = embedding(this.context, text);
if (result.hasKey("error")) {
throw new IllegalStateException(result.getString("error"));
}
return result;
}

Expand Down Expand Up @@ -351,7 +357,7 @@ protected static native WritableMap doCompletion(
protected static native WritableArray tokenize(long contextPtr, String text);
protected static native String detokenize(long contextPtr, int[] tokens);
protected static native boolean isEmbeddingEnabled(long contextPtr);
protected static native WritableArray embedding(long contextPtr, String text);
protected static native WritableMap embedding(long contextPtr, String text);
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
protected static native void freeContext(long contextPtr);
}
2 changes: 1 addition & 1 deletion android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ protected WritableMap doInBackground(Void... voids) {
if (context == null) {
throw new Exception("Context not found");
}
return context.embedding(text);
return context.getEmbedding(text);
} catch (Exception e) {
exception = e;
}
Expand Down
13 changes: 10 additions & 3 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,17 +577,24 @@ Java_com_rnllama_LlamaContext_embedding(
llama->params.prompt = text_chars;

llama->params.n_predict = 0;

auto result = createWriteableMap(env);
if (!llama->initSampling()) {
putString(env, result, "error", "Failed to initialize sampling");
return reinterpret_cast<jobject>(result);
}

llama->beginCompletion();
llama->loadPrompt();
llama->doCompletion();

std::vector<float> embedding = llama->getEmbedding();

jobject result = createWritableArray(env);

auto embeddings = createWritableArray(env);
for (const auto &val : embedding) {
pushDouble(env, result, (double) val);
pushDouble(env, embeddings, (double) val);
}
putArray(env, result, "embedding", embeddings);

env->ReleaseStringUTFChars(text, text_chars);
return result;
Expand Down
6 changes: 5 additions & 1 deletion ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,12 @@ - (NSArray *)embedding:(NSString *)text {
llama->params.prompt = [text UTF8String];

llama->params.n_predict = 0;
llama->loadPrompt();

if (!llama->initSampling()) {
@throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to initialize sampling" userInfo:nil];
}
llama->beginCompletion();
llama->loadPrompt();
llama->doCompletion();

std::vector<float> result = llama->getEmbedding();
Expand Down

0 comments on commit 246e207

Please sign in to comment.